From de506d9c31c7fc5941051af4683a578931f17ae7 Mon Sep 17 00:00:00 2001 From: hangshuo652 Date: Thu, 18 Jun 2026 16:10:38 +0800 Subject: [PATCH] feat: Phase 2 - HINA Agent + Strategy Agent + classifier --- hina/classifier.py | 120 +++++++++++++++++++ hina/hina_agent.py | 280 +++++++++++++++++++++++++++++++++++++++++++++ hina/strategy.py | 103 +++++++++++++++++ orchestrator.py | 30 ++++- 4 files changed, 530 insertions(+), 3 deletions(-) create mode 100644 hina/classifier.py create mode 100644 hina/hina_agent.py create mode 100644 hina/strategy.py diff --git a/hina/classifier.py b/hina/classifier.py new file mode 100644 index 0000000..b972ae9 --- /dev/null +++ b/hina/classifier.py @@ -0,0 +1,120 @@ +""" +HINA 程序分类器 — L1 关键字规则 + 确信度计算。 + +通过 COBOL 源码中的关键字匹配进行程序分类,支持多级确信度判定。 +""" + +from __future__ import annotations + +from typing import Any + +# ── L1 规则 ────────────────────────────────────────────────────────────── +# 格式: (分类名称, [关键字列表], 置信度阈值) +L1_RULES: list[tuple[str, list[str], float]] = [ + ("DB操作", ["EXEC SQL"], 0.95), + ("子程序调用", ["CALL", "LINKAGE SECTION"], 0.90), + ("IS INITIAL", ["IS INITIAL"], 0.99), + ("SYSIN", ["SYSIN"], 0.90), + ("编码转换", ["ALPHABETIC", "ASCII", "EBCDIC"], 0.85), + ("online", ["DFHCOMMAREA", "MAP"], 0.95), + ("SORT", ["SORT ON KEY"], 0.95), + ("MERGE", ["MERGE ON KEY"], 0.95), + ("编辑输出", ["WRITE AFTER", "WRITE BEFORE"], 0.80), + ("文件编成", ["ORGANIZATION IS"], 0.99), + ("替代索引", ["ALTERNATE RECORD KEY"], 0.99), +] + +# ── 冲突解决规则 ───────────────────────────────────────────────────────── +# 当 L1 匹配到多个分类时的消歧策略: +# value = "file_count" → 取测试数更多的分类 +# value = "has_accumulator" → 取包含累加器的分类 +CONFLICT_RULES: dict[tuple[str, str], str] = { + ("マッチング", "キーブレイク"): "file_count", + ("編集処理", "項目チェック"): "file_count", + ("キーブレイク", "項目チェック(重複)"): "has_accumulator", +} + + +# ── 关键字检测 ─────────────────────────────────────────────────────────── +def detect_keyword(source: str) -> list[tuple[str, float, str]]: + """在 COBOL 源码中搜索 L1_RULES 定义的关键字,返回匹配结果。 + + Args: + source: COBOL 程序源码文本。 + + Returns: + list[tuple[str, float, str]]: + 每个元素为 (分类名称, 置信度, 匹配到的关键字原文)。 + """ + results: list[tuple[str, float, str]] = [] + source_upper = source.upper() + + for category, keywords, confidence in L1_RULES: + for kw in keywords: + if kw in source_upper: + results.append((category, confidence, kw)) + break # 同一分类只记录一次 + + return results + + +# ── 确信度计算 ─────────────────────────────────────────────────────────── +def compute_confidence( + source: str, + structure: dict[str, Any] | None = None, + llm_result: dict[str, Any] | None = None, +) -> dict[str, Any]: + """计算程序分类的确信度。 + + 优先级: + 1. L1 关键字命中,且最高置信度 >= 0.90 → 直接返回 L1 结果。 + 2. LLM 结果存在 → 使用 LLM 的分类结果。 + 3. 否则 → 返回 unknown。 + + Args: + source: COBOL 程序源码文本。 + structure: 可选的程序结构信息(暂未使用,保留扩展)。 + llm_result: 可选的 LLM 分类结果。 + 预期格式: {"category": str, "confidence": float, ...} + + Returns: + dict: + - "category": str — 分类名称或 "unknown" + - "confidence": float — 确信度 (0.0 ~ 1.0) + - "source": str — 结果来源 ("l1" / "llm" / "unknown") + - "matches": list — 匹配到的关键字详情 + """ + # ── 1. L1 关键字检测 ── + matches = detect_keyword(source) + + # 找出最高置信度的 L1 匹配 + if matches: + best = max(matches, key=lambda m: m[1]) # (category, confidence, keyword) + category, confidence, _ = best + + if confidence >= 0.90: + return { + "category": category, + "confidence": confidence, + "source": "l1", + "matches": matches, + } + + # ── 2. LLM 结果 ── + if llm_result is not None: + llm_category = llm_result.get("category", "unknown") + llm_confidence = llm_result.get("confidence", 0.0) + return { + "category": llm_category, + "confidence": llm_confidence, + "source": "llm", + "matches": matches, + } + + # ── 3. 未知 ── + return { + "category": "unknown", + "confidence": 0.0, + "source": "unknown", + "matches": [], + } diff --git a/hina/hina_agent.py b/hina/hina_agent.py new file mode 100644 index 0000000..f94d09a --- /dev/null +++ b/hina/hina_agent.py @@ -0,0 +1,280 @@ +""" +HINA 混淆组判定 — 基于 LLM 的 COBOL 程序结构分类。 + +根据 extract_structure() 输出的结构特征,调用 LLM 将程序归类到 +混淆组(confusion group),并返回分类结果和策略参数。 +""" + +import json +import logging + +logger = logging.getLogger(__name__) + +CONFUSION_PROMPT = """你是一个 COBOL 程序混淆组分类专家。请根据以下程序结构特征,将其归类到合适的混淆组中。 + +程序结构特征: +- 段落数: {paragraph_count} +- 决策点总数: {decision_count} +- IF 语句数: {if_count} +- EVALUATE 语句数: {evaluate_count} +- 关联文件数: {file_count} +- OPEN 方向: {open_directions} +- SEARCH ALL: {has_search_all} +- CALL 语句: {has_call} +- KEY BREAK 关键词: {has_break} +- 总分支数: {total_branches} + +混淆组定义: +1. simple_sequential — 极少决策点(<=2),无 EVALUATE/SEARCH ALL/CALL,直接顺序执行 +2. condition_heavy — IF 语句占比高(>60% 的决策点),嵌套深,逻辑复杂 +3. evaluate_driven — EVALUATE 主导,多分支选择结构 +4. data_file_centric — 文件操作密集(>=2 文件),OPEN 方向多样(I-O/OUTPUT/INPUT) +5. search_intensive — 包含 SEARCH ALL,表/数组查找为主 +6. call_based — 包含 CALL 语句,模块间调用为主 +7. mixed_complex — 同时具备多种复杂特征(决策点多且文件多且含 CALL/SEARCH 等) + +请按 JSON 格式输出分类结果,不要包含其他文字: + +```json +{{ + "category": "<混淆组类别>", + "subtype": "<子类别,如 nested_if / flat_evaluate / multi_file 等>", + "confidence": <0~1 置信度>, + "features": {{ + "paragraph_count": {paragraph_count}, + "decision_count": {decision_count}, + "if_count": {if_count}, + "evaluate_count": {evaluate_count}, + "file_count": {file_count}, + "has_search_all": {has_search_all}, + "has_call": {has_call}, + "has_break": {has_break}, + "total_branches": {total_branches} + }}, + "required_tests": <建议测试用例数,整数>, + "strategy_params": {{ + "max_nesting_depth": <最大嵌套深度建议>, + "coverage_target": "branch" 或 "path", + "file_isolation": true 或 false, + "supplement_strategy": "incremental" 或 "full" 或 "skip" + }} +}} +```""" + + +def classify_with_llm(structure: dict, llm) -> dict: + """调用 LLM 对程序结构进行混淆组分类。 + + 根据 extract_structure() 返回的结构字典,构造 CONFUSION_PROMPT + 并调用 LLM 进行分类。结果包含 category、subtype、confidence、 + features、required_tests、strategy_params。 + + Args: + structure: extract_structure() 返回的字典,包含 paragraphs、 + decision_points、file_count、open_directions、 + has_search_all、has_evaluate、has_call、has_break、 + total_branches、total_paragraphs 等字段。 + llm: LLMClient 实例,call 方法签名为 + llm.call([{"role":"system","content":"..."}, + {"role":"user","content":prompt}]) -> str + + Returns: + dict: { + "category": str, + "subtype": str, + "confidence": float, + "features": dict, + "required_tests": int, + "strategy_params": dict + } + """ + decision_points = structure.get("decision_points", []) + if_count = sum(1 for dp in decision_points if dp.get("kind") == "IF") + evaluate_count = sum(1 for dp in decision_points if dp.get("kind") == "EVALUATE") + + paragraph_count = structure.get("total_paragraphs", len(structure.get("paragraphs", []))) + open_dirs = structure.get("open_directions", {}) + + has_search_all = str(structure.get("has_search_all", False)).lower() + has_call = str(structure.get("has_call", False)).lower() + has_break = str(structure.get("has_break", False)).lower() + + prompt = CONFUSION_PROMPT.format( + paragraph_count=paragraph_count, + decision_count=len(decision_points), + if_count=if_count, + evaluate_count=evaluate_count, + file_count=structure.get("file_count", 0), + open_directions=json.dumps(open_dirs, ensure_ascii=False), + has_search_all=has_search_all, + has_call=has_call, + has_break=has_break, + total_branches=structure.get("total_branches", 0), + ) + + messages = [ + {"role": "system", "content": "你是一个 COBOL 程序混淆组分类专家。只输出 JSON,不要输出解释。"}, + {"role": "user", "content": prompt}, + ] + + try: + raw = llm.call(messages) + result = _parse_llm_response(raw) + logger.info( + "HINA classification: %s/%s (confidence=%.2f, tests=%s)", + result.get("category", "?"), + result.get("subtype", "?"), + result.get("confidence", 0.0), + result.get("required_tests", "?"), + ) + return result + except Exception as e: + logger.warning("HINA LLM classification failed: %s", e) + return _fallback_classification(structure) + + +def _parse_llm_response(raw: str) -> dict: + """从 LLM 响应中提取 JSON 并解析。 + + 处理 JSON 可能被 ```json ... ``` 包裹的情况。 + """ + text = raw.strip() + + # 尝试提取 ```json ... ``` 代码块 + if "```json" in text: + start = text.index("```json") + 7 + end = text.index("```", start) if "```" in text[start:] else len(text) + text = text[start:end].strip() + elif "```" in text: + # 尝试 ``` ... ``` (无 json 标注) + start = text.index("```") + 3 + end = text.index("```", start) if "```" in text[start:] else len(text) + text = text[start:end].strip() + + parsed = json.loads(text) + return _validate_result(parsed) + + +def _validate_result(parsed: dict) -> dict: + """验证并规范化 LLM 返回的分类结果。""" + defaults = { + "category": "unknown", + "subtype": "", + "confidence": 0.0, + "features": {}, + "required_tests": 1, + "strategy_params": { + "max_nesting_depth": 1, + "coverage_target": "branch", + "file_isolation": False, + "supplement_strategy": "full", + }, + } + + result = {} + for key, default_value in defaults.items(): + value = parsed.get(key, default_value) + if key == "confidence": + try: + value = float(value) + value = max(0.0, min(1.0, value)) + except (ValueError, TypeError): + value = 0.0 + elif key == "required_tests": + try: + value = int(value) + value = max(1, value) + except (ValueError, TypeError): + value = 1 + result[key] = value + + return result + + +def _fallback_classification(structure: dict) -> dict: + """当 LLM 调用失败时,基于规则的兜底分类。""" + decision_points = structure.get("decision_points", []) + if_count = sum(1 for dp in decision_points if dp.get("kind") == "IF") + evaluate_count = sum(1 for dp in decision_points if dp.get("kind") == "EVALUATE") + total_decisions = len(decision_points) + file_count = structure.get("file_count", 0) + has_search_all = structure.get("has_search_all", False) + has_call = structure.get("has_call", False) + has_break = structure.get("has_break", False) + + # 规则优先级:从高到低 + if total_decisions == 0: + category, subtype = "simple_sequential", "no_branch" + required_tests = 1 + strategy = {"max_nesting_depth": 0, "coverage_target": "branch", + "file_isolation": False, "supplement_strategy": "skip"} + elif has_search_all: + category, subtype = "search_intensive", "table_lookup" + required_tests = max(total_decisions, 3) + strategy = {"max_nesting_depth": 3, "coverage_target": "path", + "file_isolation": True, "supplement_strategy": "incremental"} + elif has_call: + category, subtype = "call_based", "external_call" + required_tests = max(total_decisions, 3) + strategy = {"max_nesting_depth": 2, "coverage_target": "branch", + "file_isolation": False, "supplement_strategy": "full"} + elif evaluate_count > if_count and evaluate_count >= 2: + category, subtype = "evaluate_driven", "multi_way" + required_tests = total_decisions + 1 + strategy = {"max_nesting_depth": evaluate_count, "coverage_target": "path", + "file_isolation": False, "supplement_strategy": "full"} + elif file_count >= 2: + category, subtype = "data_file_centric", "multi_file" + required_tests = max(total_decisions, file_count * 2) + strategy = {"max_nesting_depth": 2, "coverage_target": "branch", + "file_isolation": True, "supplement_strategy": "incremental"} + elif if_count >= 5 or total_decisions >= 8: + category, subtype = "condition_heavy", "nested_if" + required_tests = total_decisions + 2 + strategy = {"max_nesting_depth": 4, "coverage_target": "path", + "file_isolation": False, "supplement_strategy": "incremental"} + elif if_count >= 2: + category, subtype = "condition_heavy", "simple_if" + required_tests = total_decisions + 1 + strategy = {"max_nesting_depth": 2, "coverage_target": "branch", + "file_isolation": False, "supplement_strategy": "incremental"} + else: + category, subtype = "simple_sequential", "minimal" + required_tests = 1 + strategy = {"max_nesting_depth": 0, "coverage_target": "branch", + "file_isolation": False, "supplement_strategy": "skip"} + + # 检查是否应升级为 mixed_complex + complexity_flags = sum([ + has_search_all, + has_call, + has_break, + file_count >= 2, + if_count >= 5, + evaluate_count >= 3, + ]) + if complexity_flags >= 3: + category, subtype = "mixed_complex", f"{subtype}_plus" + required_tests = max(required_tests, 10) + strategy["max_nesting_depth"] = max(strategy.get("max_nesting_depth", 2), 5) + strategy["coverage_target"] = "path" + strategy["supplement_strategy"] = "full" + + return { + "category": category, + "subtype": subtype, + "confidence": 0.6, + "features": { + "paragraph_count": structure.get("total_paragraphs", len(structure.get("paragraphs", []))), + "decision_count": total_decisions, + "if_count": if_count, + "evaluate_count": evaluate_count, + "file_count": file_count, + "has_search_all": has_search_all, + "has_call": has_call, + "has_break": has_break, + "total_branches": structure.get("total_branches", 0), + }, + "required_tests": required_tests, + "strategy_params": strategy, + } diff --git a/hina/strategy.py b/hina/strategy.py new file mode 100644 index 0000000..e5b6351 --- /dev/null +++ b/hina/strategy.py @@ -0,0 +1,103 @@ +""" +HINA 策略模板 — 根据程序分类定义必须的测试项和边界条件。 + +Task 2.2: 必须项模板 + supplement 函数 +""" + +STRATEGY_TEMPLATES: dict[str, dict] = { + "マッチング": { + "required": [ + "COM-N001", "COM-N002", "COM-A002", "COM-A003", + "MT-N001", "MT-N002", "MT-N004", "MT-N005", "MT-N006", + ], + "boundary": ["MT-B001", "MT-B002"], + }, + "キーブレイク": { + "required": [ + "COM-N001", "COM-A002", + "KB-N001", "KB-N004", "KB-N005", "KB-A001", + ], + "boundary": ["KB-B001", "KB-B002"], + }, + "条件分岐": { + "required": [ + "B-N001", "B-N003", "B-N006", "B-N009", + ], + }, + "内部表検索": { + "required": [ + "T-N001", "T-N002", "T-A001", "T-A002", + ], + }, + "項目チェック": { + "required": [ + "VF-N001", "VF-N002", "VF-N004", "VF-A001", + ], + }, +} + + +def get_strategy(hina_type: str) -> dict: + """返回对应 HINA 类型的策略模板。 + + Args: + hina_type: HINA 程序分类名称(如 "マッチング")。 + + Returns: + dict: required 列表及可选的 boundary 列表。 + 未知类型返回空模板 {"required": [], "boundary": []}。 + """ + return STRATEGY_TEMPLATES.get(hina_type, {"required": [], "boundary": []}) + + +def _make_marker(code: str, prefix: str = "REQ") -> dict: + """生成一条标记记录。""" + return { + "id": f"{prefix}-{code}", + "coverage_targets": [code], + "fields": {}, + } + + +def supplement(base_tests: list[dict], hina_result: dict) -> list[dict]: + """根据 HINA 类型追加模板中的必须项标记记录。 + + 从 ``hina_result["category"]`` 获取分类,查找对应的策略模板, + 将模板中所有的 required 和 boundary 项以标记记录形式追加到测试列表末尾。 + + Args: + base_tests: 已有的测试数据列表(每个元素为 dict)。 + hina_result: HINA 分类结果,至少包含 ``{"category": str}``。 + + Returns: + list[dict]: 追加必须项标记记录后的完整测试列表。 + """ + hina_type = hina_result.get("category", "unknown") + template = get_strategy(hina_type) + result = list(base_tests) + + for code in template.get("required", []): + result.append(_make_marker(code)) + + for code in template.get("boundary", []): + result.append(_make_marker(code, prefix="BND")) + + return result + + +def supplement_only(base_tests: list[dict], hina_gaps: list[str]) -> list[dict]: + """增量补充指定必须项的标记记录。 + + 根据传入的 code 列表(而不是从模板查找),只追加缺失的那些必须项标记。 + + Args: + base_tests: 已有的测试数据列表(每个元素为 dict)。 + hina_gaps: 需要补充的 HINA 必须项 code 列表。 + + Returns: + list[dict]: 追加标记记录后的完整测试列表。 + """ + result = list(base_tests) + for code in hina_gaps: + result.append(_make_marker(code)) + return result diff --git a/orchestrator.py b/orchestrator.py index 6e93708..87743e9 100644 --- a/orchestrator.py +++ b/orchestrator.py @@ -21,6 +21,9 @@ from config import Config from cobol_testgen import extract_structure, generate_data, incremental_supplement from cobol_testgen.coverage import check_coverage from hina.gate import check as gate_check +from hina.classifier import compute_confidence +from hina.hina_agent import classify_with_llm +from hina.strategy import supplement as strategy_supplement logger = logging.getLogger(__name__) @@ -45,10 +48,27 @@ def run_pipeline(cfg: Config, cpath: str, cbl: str, java: str, map_path: str) -> if vr.llm_cost > cfg.max_llm_cost: return _done(vr, t0, "BLOCKED", 3) - # ── Phase 1: cobol_testgen 结构提取 + 路径覆盖 + 质量门禁 ── + # ── Phase 1+2: cobol_testgen + HINA Agent + 策略 Agent + 质量门禁 ── try: cobol_src_text = Path(cbl).read_text(encoding="utf-8") structure = extract_structure(cobol_src_text) + + # HINA Agent 类型判定 + hina_result = {} + try: + hina_result = compute_confidence(cobol_src_text, structure) + if hina_result.get("confidence", 0) < 0.7 and structure: + llm_hina = classify_with_llm(structure, llm) + if llm_hina.get("confidence", 0) > hina_result.get("confidence", 0): + hina_result = llm_hina + vr.hina_type = hina_result.get("category", "") + vr.hina_confidence = hina_result.get("confidence", 0.0) + vr.debug["hina_result"] = hina_result + except Exception as e: + vr.debug["hina_agent_error"] = str(e) + logger.warning(f"[orchestrator] HINA Agent 判定失败: {e}") + + # cobol_testgen 路径枚举 + 基础数据生成 base_records = generate_data(cobol_src_text, structure) vr.debug["cobol_testgen_records"] = len(base_records) vr.debug["total_branches"] = structure.get("total_branches", 0) @@ -57,11 +77,15 @@ def run_pipeline(cfg: Config, cpath: str, cbl: str, java: str, map_path: str) -> for i, rec in enumerate(base_records): base_testcases.append(TestCase(id=f"CTG-{i+1:04d}", fields=dict(rec))) + # 策略 Agent 补充 + strategy_tests = strategy_supplement(base_testcases, hina_result) + complete_tests = base_testcases + strategy_tests + + # 质量门禁循环 cov = check_coverage(structure, base_records) for attempt in range(cfg.max_quality_retries): gate_result = gate_check( - base_testcases, {}, - cov, + complete_tests, hina_result, cov, decision_threshold=cfg.quality_gate_decision_threshold, paragraph_threshold=cfg.quality_gate_paragraph_threshold, )