309 lines
9.9 KiB
Python
309 lines
9.9 KiB
Python
"""AI智能体接口 — 基于DeepSeek的PROCEDURE DIVISION解析"""
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
from pathlib import Path
|
||
|
||
from .models import BrSeq, BrIf, BrEval, BrPerform, Assign, CallNode
|
||
|
||
|
||
DEEPSEEK_API_KEY_ENV = "DEEPSEEK_API_KEY"
|
||
DEEPSEEK_BASE_URL = "https://api.deepseek.com/v1"
|
||
DEEPSEEK_MODEL = "deepseek-chat"
|
||
PROMPT_FILE = Path(__file__).parent / "prompts" / "parse_proc_division.txt"
|
||
|
||
|
||
def parse_proc_division_ai(proc_text: str, fields: list = None, spec_doc: str = ""):
|
||
"""AI版PROCEDURE DIVISION解析:调用DeepSeek API,返回(branch_tree, assignments)."""
|
||
api_key = os.environ.get(DEEPSEEK_API_KEY_ENV)
|
||
if not api_key:
|
||
raise NotImplementedError(
|
||
f"AI agent requires {DEEPSEEK_API_KEY_ENV} environment variable"
|
||
)
|
||
|
||
prompt = _build_prompt(proc_text, fields)
|
||
response_text = _call_llm(prompt, api_key)
|
||
data = _extract_json(response_text)
|
||
if not data:
|
||
raise NotImplementedError("AI returned no parsable JSON")
|
||
|
||
branch_tree = _json_to_tree(data.get("tree", {}))
|
||
assignments = data.get("assignments", {})
|
||
return branch_tree, assignments
|
||
|
||
|
||
def _build_prompt(proc_text: str, fields: list = None) -> list[dict]:
|
||
system = PROMPT_FILE.read_text(encoding="utf-8")
|
||
|
||
fields_json = json.dumps(fields, ensure_ascii=False, indent=2) if fields else "[]"
|
||
|
||
user = f"""## PROCEDURE DIVISION 源码
|
||
|
||
```
|
||
{proc_text}
|
||
```
|
||
|
||
## DATA DIVISION 字段列表
|
||
|
||
```json
|
||
{fields_json}
|
||
```
|
||
"""
|
||
|
||
return [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": user},
|
||
]
|
||
|
||
|
||
def _call_llm(messages: list[dict], api_key: str) -> str:
|
||
try:
|
||
from openai import OpenAI
|
||
except ImportError:
|
||
raise NotImplementedError(
|
||
"openai package not installed. Run: pip install openai"
|
||
)
|
||
|
||
client = OpenAI(api_key=api_key, base_url=DEEPSEEK_BASE_URL)
|
||
response = client.chat.completions.create(
|
||
model=DEEPSEEK_MODEL,
|
||
messages=messages,
|
||
temperature=0.1,
|
||
max_tokens=8192,
|
||
)
|
||
return response.choices[0].message.content or ""
|
||
|
||
|
||
def _extract_json(text: str) -> dict | None:
|
||
stripped = text.strip()
|
||
# Try extracting from markdown code block first
|
||
m = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", stripped, re.DOTALL)
|
||
if m:
|
||
stripped = m.group(1).strip()
|
||
try:
|
||
return json.loads(stripped)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
|
||
def _json_to_tree(data: dict):
|
||
node_type = data.get("type", "seq")
|
||
|
||
if node_type == "seq":
|
||
node = BrSeq()
|
||
for child_data in data.get("children", []):
|
||
child = _json_to_tree(child_data)
|
||
if child is not None:
|
||
node.add(child)
|
||
return node
|
||
|
||
if node_type == "if":
|
||
node = BrIf(data.get("condition", ""))
|
||
node.true_seq = _json_to_tree(data.get("true_seq", {"type": "seq", "children": []}))
|
||
node.false_seq = _json_to_tree(data.get("false_seq", {"type": "seq", "children": []}))
|
||
return node
|
||
|
||
if node_type == "eval":
|
||
node = BrEval(data.get("subject", ""))
|
||
for w in data.get("when_list", []):
|
||
node.when_list.append((w.get("value", ""), _json_to_tree(w.get("seq", {"type": "seq", "children": []}))))
|
||
node.other_seq = _json_to_tree(data.get("other_seq", {"type": "seq", "children": []}))
|
||
node.has_other = data.get("has_other", False)
|
||
return node
|
||
|
||
if node_type == "perform":
|
||
perf_type = data.get("perf_type", "para")
|
||
kw = {"perf_type": perf_type}
|
||
for k in ("condition", "target", "thru", "times",
|
||
"varying_var", "varying_from", "varying_by"):
|
||
if k in data:
|
||
kw[k] = data[k]
|
||
node = BrPerform(**kw)
|
||
if "body_seq" in data:
|
||
node.body_seq = _json_to_tree(data["body_seq"])
|
||
return node
|
||
|
||
if node_type == "assign":
|
||
return Assign(
|
||
target=data.get("target", ""),
|
||
source_info=data.get("source_info", {}),
|
||
)
|
||
|
||
if node_type == "call":
|
||
return CallNode(
|
||
program_name=data.get("program_name", ""),
|
||
using_params=data.get("using_params", []),
|
||
)
|
||
|
||
return None
|
||
|
||
|
||
# ── LLM 路径生成 ──
|
||
|
||
|
||
def llm_generate_all_paths(tree_root, fields) -> list | None:
|
||
"""为整个控制流树生成 MC/DC 路径。返回 [(constraints, assignments), ...] 或 None。"""
|
||
api_key = os.environ.get(DEEPSEEK_API_KEY_ENV)
|
||
if not api_key:
|
||
return None
|
||
|
||
tree_json = _serialize_tree_for_llm(tree_root)
|
||
if tree_json is None:
|
||
return None
|
||
|
||
level88_map = _extract_88_mapping(fields)
|
||
messages = _build_path_prompt(tree_json, fields, level88_map)
|
||
|
||
try:
|
||
response = _call_llm(messages, api_key)
|
||
data = _extract_json(response)
|
||
if data and "paths" in data:
|
||
return _parse_llm_paths(data["paths"])
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
def _serialize_tree_for_llm(node):
|
||
if node is None:
|
||
return None
|
||
from .models import BrSeq, BrIf, BrEval, BrPerform, Assign, CallNode, ExitNode, GoTo
|
||
|
||
if isinstance(node, BrSeq):
|
||
children = []
|
||
for child in node.children:
|
||
s = _serialize_tree_for_llm(child)
|
||
if s is not None:
|
||
children.append(s)
|
||
return {"type": "seq", "children": children} if children else None
|
||
|
||
if isinstance(node, BrIf):
|
||
return {
|
||
"type": "if",
|
||
"condition": node.condition,
|
||
"true_seq": _serialize_tree_for_llm(node.true_seq) or {"type": "seq", "children": []},
|
||
"false_seq": _serialize_tree_for_llm(node.false_seq) or {"type": "seq", "children": []},
|
||
}
|
||
|
||
if isinstance(node, BrEval):
|
||
when_list = []
|
||
for val, seq in node.when_list:
|
||
s = _serialize_tree_for_llm(seq)
|
||
when_list.append({"value": val, "seq": s or {"type": "seq", "children": []}})
|
||
return {
|
||
"type": "eval",
|
||
"subject": node.subject,
|
||
"when_list": when_list,
|
||
"other_seq": _serialize_tree_for_llm(node.other_seq) or {"type": "seq", "children": []},
|
||
"has_other": node.has_other,
|
||
}
|
||
|
||
if isinstance(node, BrPerform):
|
||
result = {"type": "perform", "perf_type": node.perf_type}
|
||
for attr in ("condition", "target", "thru", "times",
|
||
"varying_var", "varying_from", "varying_by"):
|
||
val = getattr(node, attr, None)
|
||
if val is not None:
|
||
result[attr] = val
|
||
if node.body_seq:
|
||
bs = _serialize_tree_for_llm(node.body_seq)
|
||
if bs:
|
||
result["body_seq"] = bs
|
||
return result
|
||
|
||
# Assign / CallNode / ExitNode / GoTo — 不影响路径生成,可省略
|
||
return None
|
||
|
||
|
||
def _extract_88_mapping(fields):
|
||
mapping = {}
|
||
for f in fields:
|
||
if f.get('is_88'):
|
||
mapping[f['name']] = {
|
||
"parent": f['parent'],
|
||
"value": f['value'],
|
||
"pic_info": f.get('pic_info', {}),
|
||
}
|
||
return mapping
|
||
|
||
|
||
def _build_path_prompt(tree_json, fields, level88_map):
|
||
system = ("你是 COBOL 测试路径生成专家。"
|
||
"请为给定的控制流树生成满足 MC/DC 覆盖的测试路径集。"
|
||
"只输出 JSON,不要多余文字。")
|
||
|
||
reduced_fields = []
|
||
for f in fields:
|
||
entry = {"name": f["name"], "pic": f.get("pic", "")}
|
||
pi = f.get("pic_info", {})
|
||
if pi:
|
||
entry["pic_info"] = {
|
||
"type": pi.get("type"), "digits": pi.get("digits"),
|
||
"decimal": pi.get("decimal"), "length": pi.get("length"),
|
||
}
|
||
if f.get("is_88"):
|
||
entry["is_88"] = True
|
||
entry["value"] = f.get("value")
|
||
entry["parent"] = f.get("parent")
|
||
reduced_fields.append(entry)
|
||
|
||
user = (
|
||
"## 控制流树(JSON)\n\n"
|
||
f"```json\n{json.dumps(tree_json, ensure_ascii=False, indent=2)}\n```\n\n"
|
||
"## 字段定义\n\n"
|
||
f"```json\n{json.dumps(reduced_fields, ensure_ascii=False, indent=2)}\n```\n\n"
|
||
"## 要求\n"
|
||
"1. 每个 IF/EVALUATE/PERFORM UNTIL 的每个分支至少被覆盖一次\n"
|
||
"2. 复合条件(AND/OR/NOT)需要满足 MC/DC:每个叶条件的独立影响对\n"
|
||
"3. 路径数尽量少(最小集优先)\n"
|
||
"4. 88-level 条件名要展开为实际字段比较(如 CUST-VIP → WS-CUST-LEVEL='V')\n"
|
||
"5. 同一路径中的约束不能自相矛盾(同一字段不能同时等于 'A' 和等于 'B')\n"
|
||
"6. 数值边界值合理(>5000 → 5001, <100 → 99)\n"
|
||
"7. AND 优先级高于 OR\n\n"
|
||
"## 输出格式\n\n"
|
||
"```json\n"
|
||
"{\n"
|
||
' "paths": [\n'
|
||
" {\n"
|
||
' "constraints": [\n'
|
||
' {"field": "WS-AMOUNT", "op": ">", "value": "5000", "want_true": true}\n'
|
||
" ],\n"
|
||
' "assignments": {}\n'
|
||
" }\n"
|
||
" ]\n"
|
||
"}\n"
|
||
"```"
|
||
)
|
||
|
||
return [
|
||
{"role": "system", "content": system},
|
||
{"role": "user", "content": user},
|
||
]
|
||
|
||
|
||
def _parse_llm_paths(paths_data):
|
||
result = []
|
||
for p in paths_data:
|
||
constraints = []
|
||
for c in p.get("constraints", []):
|
||
constraints.append((c["field"], c["op"], str(c["value"]), c["want_true"]))
|
||
assignments = p.get("assignments", {})
|
||
result.append((constraints, assignments))
|
||
return result
|
||
|
||
|
||
def resolve_constraints_ai(paths, fields=None, assignments=None):
|
||
"""AI版约束推理(未来实现)"""
|
||
raise NotImplementedError("AI agent not yet implemented")
|
||
|
||
|
||
def enhance_metadata_ai(records, fields=None, spec_doc: str = ""):
|
||
"""AI版测试用例元数据生成(未来实现)"""
|
||
raise NotImplementedError("AI agent not yet implemented")
|
||
|
||
|
||
def analyze_spec_ai(spec_doc: str = ""):
|
||
"""AI版式样书解析(未来实现)"""
|
||
raise NotImplementedError("AI agent not yet implemented")
|