Files
cobol-java-v3/cobol_testgen/agents.py
T
2026-06-08 21:07:16 +08:00

309 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI智能体接口 — 基于DeepSeek的PROCEDURE DIVISION解析"""
import json
import os
import re
from pathlib import Path
from .models import BrSeq, BrIf, BrEval, BrPerform, Assign, CallNode
DEEPSEEK_API_KEY_ENV = "DEEPSEEK_API_KEY"
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
DEEPSEEK_MODEL = "deepseek-chat"
PROMPT_FILE = Path(__file__).parent / "prompts" / "parse_proc_division.txt"
def parse_proc_division_ai(proc_text: str, fields: list = None, spec_doc: str = ""):
"""AI版PROCEDURE DIVISION解析:调用DeepSeek API,返回(branch_tree, assignments)."""
api_key = os.environ.get(DEEPSEEK_API_KEY_ENV)
if not api_key:
raise NotImplementedError(
f"AI agent requires {DEEPSEEK_API_KEY_ENV} environment variable"
)
prompt = _build_prompt(proc_text, fields)
response_text = _call_llm(prompt, api_key)
data = _extract_json(response_text)
if not data:
raise NotImplementedError("AI returned no parsable JSON")
branch_tree = _json_to_tree(data.get("tree", {}))
assignments = data.get("assignments", {})
return branch_tree, assignments
def _build_prompt(proc_text: str, fields: list = None) -> list[dict]:
system = PROMPT_FILE.read_text(encoding="utf-8")
fields_json = json.dumps(fields, ensure_ascii=False, indent=2) if fields else "[]"
user = f"""## PROCEDURE DIVISION 源码
```
{proc_text}
```
## DATA DIVISION 字段列表
```json
{fields_json}
```
"""
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
def _call_llm(messages: list[dict], api_key: str) -> str:
try:
from openai import OpenAI
except ImportError:
raise NotImplementedError(
"openai package not installed. Run: pip install openai"
)
client = OpenAI(api_key=api_key, base_url=DEEPSEEK_BASE_URL)
response = client.chat.completions.create(
model=DEEPSEEK_MODEL,
messages=messages,
temperature=0.1,
max_tokens=8192,
)
return response.choices[0].message.content or ""
def _extract_json(text: str) -> dict | None:
stripped = text.strip()
# Try extracting from markdown code block first
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", stripped, re.DOTALL)
if m:
stripped = m.group(1).strip()
try:
return json.loads(stripped)
except json.JSONDecodeError:
return None
def _json_to_tree(data: dict):
node_type = data.get("type", "seq")
if node_type == "seq":
node = BrSeq()
for child_data in data.get("children", []):
child = _json_to_tree(child_data)
if child is not None:
node.add(child)
return node
if node_type == "if":
node = BrIf(data.get("condition", ""))
node.true_seq = _json_to_tree(data.get("true_seq", {"type": "seq", "children": []}))
node.false_seq = _json_to_tree(data.get("false_seq", {"type": "seq", "children": []}))
return node
if node_type == "eval":
node = BrEval(data.get("subject", ""))
for w in data.get("when_list", []):
node.when_list.append((w.get("value", ""), _json_to_tree(w.get("seq", {"type": "seq", "children": []}))))
node.other_seq = _json_to_tree(data.get("other_seq", {"type": "seq", "children": []}))
node.has_other = data.get("has_other", False)
return node
if node_type == "perform":
perf_type = data.get("perf_type", "para")
kw = {"perf_type": perf_type}
for k in ("condition", "target", "thru", "times",
"varying_var", "varying_from", "varying_by"):
if k in data:
kw[k] = data[k]
node = BrPerform(**kw)
if "body_seq" in data:
node.body_seq = _json_to_tree(data["body_seq"])
return node
if node_type == "assign":
return Assign(
target=data.get("target", ""),
source_info=data.get("source_info", {}),
)
if node_type == "call":
return CallNode(
program_name=data.get("program_name", ""),
using_params=data.get("using_params", []),
)
return None
# ── LLM 路径生成 ──
def llm_generate_all_paths(tree_root, fields) -> list | None:
"""为整个控制流树生成 MC/DC 路径。返回 [(constraints, assignments), ...] 或 None。"""
api_key = os.environ.get(DEEPSEEK_API_KEY_ENV)
if not api_key:
return None
tree_json = _serialize_tree_for_llm(tree_root)
if tree_json is None:
return None
level88_map = _extract_88_mapping(fields)
messages = _build_path_prompt(tree_json, fields, level88_map)
try:
response = _call_llm(messages, api_key)
data = _extract_json(response)
if data and "paths" in data:
return _parse_llm_paths(data["paths"])
except Exception:
pass
return None
def _serialize_tree_for_llm(node):
if node is None:
return None
from .models import BrSeq, BrIf, BrEval, BrPerform, Assign, CallNode, ExitNode, GoTo
if isinstance(node, BrSeq):
children = []
for child in node.children:
s = _serialize_tree_for_llm(child)
if s is not None:
children.append(s)
return {"type": "seq", "children": children} if children else None
if isinstance(node, BrIf):
return {
"type": "if",
"condition": node.condition,
"true_seq": _serialize_tree_for_llm(node.true_seq) or {"type": "seq", "children": []},
"false_seq": _serialize_tree_for_llm(node.false_seq) or {"type": "seq", "children": []},
}
if isinstance(node, BrEval):
when_list = []
for val, seq in node.when_list:
s = _serialize_tree_for_llm(seq)
when_list.append({"value": val, "seq": s or {"type": "seq", "children": []}})
return {
"type": "eval",
"subject": node.subject,
"when_list": when_list,
"other_seq": _serialize_tree_for_llm(node.other_seq) or {"type": "seq", "children": []},
"has_other": node.has_other,
}
if isinstance(node, BrPerform):
result = {"type": "perform", "perf_type": node.perf_type}
for attr in ("condition", "target", "thru", "times",
"varying_var", "varying_from", "varying_by"):
val = getattr(node, attr, None)
if val is not None:
result[attr] = val
if node.body_seq:
bs = _serialize_tree_for_llm(node.body_seq)
if bs:
result["body_seq"] = bs
return result
# Assign / CallNode / ExitNode / GoTo — 不影响路径生成,可省略
return None
def _extract_88_mapping(fields):
mapping = {}
for f in fields:
if f.get('is_88'):
mapping[f['name']] = {
"parent": f['parent'],
"value": f['value'],
"pic_info": f.get('pic_info', {}),
}
return mapping
def _build_path_prompt(tree_json, fields, level88_map):
system = ("你是 COBOL 测试路径生成专家。"
"请为给定的控制流树生成满足 MC/DC 覆盖的测试路径集。"
"只输出 JSON,不要多余文字。")
reduced_fields = []
for f in fields:
entry = {"name": f["name"], "pic": f.get("pic", "")}
pi = f.get("pic_info", {})
if pi:
entry["pic_info"] = {
"type": pi.get("type"), "digits": pi.get("digits"),
"decimal": pi.get("decimal"), "length": pi.get("length"),
}
if f.get("is_88"):
entry["is_88"] = True
entry["value"] = f.get("value")
entry["parent"] = f.get("parent")
reduced_fields.append(entry)
user = (
"## 控制流树(JSON\n\n"
f"```json\n{json.dumps(tree_json, ensure_ascii=False, indent=2)}\n```\n\n"
"## 字段定义\n\n"
f"```json\n{json.dumps(reduced_fields, ensure_ascii=False, indent=2)}\n```\n\n"
"## 要求\n"
"1. 每个 IF/EVALUATE/PERFORM UNTIL 的每个分支至少被覆盖一次\n"
"2. 复合条件(AND/OR/NOT)需要满足 MC/DC:每个叶条件的独立影响对\n"
"3. 路径数尽量少(最小集优先)\n"
"4. 88-level 条件名要展开为实际字段比较(如 CUST-VIP → WS-CUST-LEVEL='V'\n"
"5. 同一路径中的约束不能自相矛盾(同一字段不能同时等于 'A' 和等于 'B'\n"
"6. 数值边界值合理(>5000 → 5001, <100 → 99\n"
"7. AND 优先级高于 OR\n\n"
"## 输出格式\n\n"
"```json\n"
"{\n"
' "paths": [\n'
" {\n"
' "constraints": [\n'
' {"field": "WS-AMOUNT", "op": ">", "value": "5000", "want_true": true}\n'
" ],\n"
' "assignments": {}\n'
" }\n"
" ]\n"
"}\n"
"```"
)
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
def _parse_llm_paths(paths_data):
result = []
for p in paths_data:
constraints = []
for c in p.get("constraints", []):
constraints.append((c["field"], c["op"], str(c["value"]), c["want_true"]))
assignments = p.get("assignments", {})
result.append((constraints, assignments))
return result
def resolve_constraints_ai(paths, fields=None, assignments=None):
"""AI版约束推理(未来实现)"""
raise NotImplementedError("AI agent not yet implemented")
def enhance_metadata_ai(records, fields=None, spec_doc: str = ""):
"""AI版测试用例元数据生成(未来实现)"""
raise NotImplementedError("AI agent not yet implemented")
def analyze_spec_ai(spec_doc: str = ""):
"""AI版式样书解析(未来实现)"""
raise NotImplementedError("AI agent not yet implemented")