"""Tests for HINA rule engine: confusion groups, contradiction, backtrack.""" from __future__ import annotations import sys import os import json sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) from hina.rule_engine.confusion_groups import ( resolve_matching_vs_keybreak, resolve_dedup_vs_nodedup, resolve_validation_vs_keybreak, resolve_csv_merge_vs_split, resolve_simple_vs_two_stage, resolve_pure_vs_mixed, resolve_division_50_25_100, resolve_mn_output_mode, resolve_confusion_pair, ) from hina.rule_engine.contradiction import ( CONTRADICTION_PAIRS, detect_contradictions, resolve_contradiction, ) from hina.rule_engine.backtrack import BacktrackResolver # ═══════════════════════════════════════════════════════════════════════════ # 1. confusion_groups — matching_vs_keybreak # ═══════════════════════════════════════════════════════════════════════════ def test_matching_vs_keybreak_matching(): """3路 IF + SELECT>=2 → マッチング""" features = { "if_types": {"total": 5, "comparison": 3, "equality": 1, "compound": 1, "nested_depth": 2}, "select_files": {"file1": {"organization": "SEQUENTIAL"}, "file2": {"organization": "SEQUENTIAL"}}, "variable_patterns": {"has_prev_key": False, "has_accumulator": False, "has_error_field": False}, } result = resolve_matching_vs_keybreak(features) assert result["resolved_type"] == "マッチング" assert result["confidence"] >= 0.75 assert len(result["evidence"]) > 0 def test_matching_vs_keybreak_keybreak(): """2路 IF + WS-PREV-KEY + 累加器 → キーブレイク""" features = { "if_types": {"total": 2, "comparison": 0, "equality": 2, "compound": 0, "nested_depth": 1}, "select_files": {"file1": {"organization": "SEQUENTIAL"}}, "variable_patterns": {"has_prev_key": True, "has_accumulator": True, "has_error_field": False}, } result = resolve_matching_vs_keybreak(features) assert result["resolved_type"] == "キーブレイク" assert result["confidence"] >= 0.70 assert len(result["evidence"]) > 0 def test_matching_vs_keybreak_unknown(): """特征不足 → unknown""" features = { "if_types": {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0}, "select_files": {}, "variable_patterns": {"has_prev_key": False, "has_accumulator": False, "has_error_field": False}, } result = resolve_matching_vs_keybreak(features) assert result["resolved_type"] == "unknown" assert result["confidence"] == 0.0 # ═══════════════════════════════════════════════════════════════════════════ # 2. confusion_groups — dedup_vs_nodedup # ═══════════════════════════════════════════════════════════════════════════ def test_dedup_vs_nodedup_dedup(): """WS-PREV-KEY 存在 → 含重复""" features = {"variable_patterns": {"has_prev_key": True, "has_accumulator": False, "has_error_field": False}} result = resolve_dedup_vs_nodedup(features) assert result["resolved_type"] == "項目チェック(重複含む)" assert result["confidence"] >= 0.85 def test_dedup_vs_nodedup_nodedup(): """WS-PREV-KEY 不存在 → 不含重复(低确信度:无 WS-PREV-KEY 不代表一定是项目检查)""" features = {"variable_patterns": {"has_prev_key": False, "has_accumulator": False, "has_error_field": False}} result = resolve_dedup_vs_nodedup(features) assert result["resolved_type"] == "項目チェック(重複含まず)" assert result["confidence"] >= 0.30 # ═══════════════════════════════════════════════════════════════════════════ # 3. confusion_groups — validation_vs_keybreak # ═══════════════════════════════════════════════════════════════════════════ def test_validation_vs_keybreak_validation(): """WS-ERR* 错误字段存在 → 校验""" features = {"variable_patterns": {"has_error_flag": True, "has_counter": False, "has_prev_key": False}} result = resolve_validation_vs_keybreak(features) assert result["resolved_type"] == "編集処理(校验)" assert result["confidence"] >= 0.70 def test_validation_vs_keybreak_keybreak(): """WS-*CNT 计数器存在 → キーブレイク(低确信度:计数器是通用模式)""" features = {"variable_patterns": {"has_error_field": False, "has_counter": True, "has_prev_key": False}} result = resolve_validation_vs_keybreak(features) assert result["resolved_type"] == "キーブレイク" assert result["confidence"] >= 0.40 def test_validation_vs_keybreak_unknown(): """既无错误字段也无计数器 → unknown""" features = {"variable_patterns": {"has_error_field": False, "has_counter": False, "has_prev_key": False}} result = resolve_validation_vs_keybreak(features) assert result["resolved_type"] == "unknown" # ═══════════════════════════════════════════════════════════════════════════ # 4. confusion_groups — csv_merge_vs_split # ═══════════════════════════════════════════════════════════════════════════ def test_csv_merge_vs_split_merge(): """STRING + 逗号分隔 → CSV合并""" features = {"has_string": True, "has_csv_merge": True, "has_inspect": False} result = resolve_csv_merge_vs_split(features) assert result["resolved_type"] == "CSV合并" assert result["confidence"] >= 0.70 def test_csv_merge_vs_split_split(): """INSPECT REPLACING + 逗号 → CSV拆分""" features = {"has_string": False, "has_csv_split": True, "has_inspect": True} result = resolve_csv_merge_vs_split(features) assert result["resolved_type"] == "CSV拆分" assert result["confidence"] >= 0.70 def test_csv_merge_vs_split_both(): """CSV合并证据优先 → CSV合并""" features = {"has_string": True, "has_csv_merge": True, "has_inspect": True, "has_csv_split": True} result = resolve_csv_merge_vs_split(features) assert result["resolved_type"] == "CSV合并" def test_csv_merge_vs_split_unknown(): """两者都不存在 → unknown""" features = {"has_string": False, "has_inspect": False} result = resolve_csv_merge_vs_split(features) assert result["resolved_type"] == "unknown" # ═══════════════════════════════════════════════════════════════════════════ # 5. confusion_groups — simple_vs_two_stage # ═══════════════════════════════════════════════════════════════════════════ def test_simple_vs_two_stage_two_stage(): """OPEN→CLOSE→再OPEN → 二级匹配""" features = {"open_pattern": "open-close-open"} result = resolve_simple_vs_two_stage(features) assert result["resolved_type"] == "二段階マッチング" assert result["confidence"] >= 0.85 def test_simple_vs_two_stage_simple(): """顺序 OPEN 无匹配证据 → unknown(2.2+ 不再胡乱判定为単純マッチング)""" features = {"open_pattern": "sequential", "file_count": 0} result = resolve_simple_vs_two_stage(features) assert result["resolved_type"] == "unknown" assert result["confidence"] == 0.0 # ═══════════════════════════════════════════════════════════════════════════ # 6. confusion_groups — pure_vs_mixed # ═══════════════════════════════════════════════════════════════════════════ def test_pure_vs_mixed_mixed(): """has_switch + has_counter + IF≥3 → 混合匹配""" features = {"variable_patterns": {"has_switch": True, "has_counter": True}, "if_types": {"total": 3}} result = resolve_pure_vs_mixed(features) assert result["resolved_type"] == "混合マッチング" assert result["confidence"] >= 0.70 def test_pure_vs_mixed_pure(): """无混合特征 → unknown(无法静态确定)""" features = {"variable_patterns": {"has_switch": False, "has_counter": False}, "if_types": {"total": 1}} result = resolve_pure_vs_mixed(features) assert result["resolved_type"] == "unknown" # ═══════════════════════════════════════════════════════════════════════════ # 7. confusion_groups — division_50_25_100 # ═══════════════════════════════════════════════════════════════════════════ def test_division_50(): """DIVIDE 被除数 = 50""" features = {"divide_constants": [50]} result = resolve_division_50_25_100(features) assert result["resolved_type"] == "DIVIDE_50" assert result["confidence"] >= 0.90 def test_division_100(): """DIVIDE 被除数 = 100""" features = {"divide_constants": [100]} result = resolve_division_50_25_100(features) assert result["resolved_type"] == "DIVIDE_100" assert result["confidence"] >= 0.90 def test_division_unknown(): """无匹配常量 → unknown""" features = {"divide_constants": [10, 20]} result = resolve_division_50_25_100(features) assert result["resolved_type"] == "unknown" assert result["confidence"] == 0.0 def test_division_empty(): """空列表 → unknown""" features = {"divide_constants": []} result = resolve_division_50_25_100(features) assert result["resolved_type"] == "unknown" # ═══════════════════════════════════════════════════════════════════════════ # 8. confusion_groups — mn_output_mode # ═══════════════════════════════════════════════════════════════════════════ def test_mn_output_mode_known(): """SELECT≥2 + 分支≥3 → M:N""" features = {"select_files": {"a": {}, "b": {}, "c": {}}, "total_branches": 3} result = resolve_mn_output_mode(features) assert result["resolved_type"] == "M:N" assert result["confidence"] >= 0.60 def test_mn_output_mode_unknown(): """无提示且文件 < 3 → unknown (需数据验证)""" features = {"has_mn_output_hint": False, "select_files": {"a": {}, "b": {}}} result = resolve_mn_output_mode(features) assert result["resolved_type"] == "unknown" assert result["confidence"] == 0.0 def test_mn_output_mode_many_files(): """文件数 >=3 + IF 分支 + KEY 证据 → M:N""" features = { "has_mn_output_hint": False, "select_files": {"a": {}, "b": {}, "c": {}}, "if_types": {"total": 2, "comparison": 1, "equality": 1, "compound": 0, "nested_depth": 0}, "variable_patterns": {"has_prev_key": True, "has_accumulator": False}, } result = resolve_mn_output_mode(features) assert result["resolved_type"] == "M:N" assert result["confidence"] >= 0.55 # ═══════════════════════════════════════════════════════════════════════════ # 9. resolve_confusion_pair — dispatcher # ═══════════════════════════════════════════════════════════════════════════ def test_resolve_confusion_pair_dispatch(): """resolve_confusion_pair 正确调度到具体函数""" features = { "variable_patterns": {"has_prev_key": True, "has_accumulator": False, "has_error_field": False}, } result = resolve_confusion_pair(features, "dedup_vs_nodedup") assert result["resolved_type"] == "項目チェック(重複含む)" result = resolve_confusion_pair(features, "nonexistent_pair") assert result["resolved_type"] == "unknown" assert "未知混淆对名称" in result["evidence"][0] # ═══════════════════════════════════════════════════════════════════════════ # 10. contradiction — detect_contradictions # ═══════════════════════════════════════════════════════════════════════════ def test_detect_contradictions_empty(): """无 resolved_types → 空矛盾列表""" features = {"resolved_types": {}} assert detect_contradictions(features) == [] def test_detect_contradictions_no_contradiction(): """只有一个类型 → 无矛盾""" features = { "resolved_types": { "pair_1": "マッチング", } } assert detect_contradictions(features) == [] def test_detect_contradictions_found(): """マッチング 和 キーブレイク 同时存在 → 检测到矛盾""" features = { "resolved_types": { "pair_1": "マッチング", "pair_2": "キーブレイク", } } contradictions = detect_contradictions(features) assert len(contradictions) >= 1 match = [c for c in contradictions if c["type_a"] == "マッチング" and c["type_b"] == "キーブレイク"] assert len(match) >= 1 # ═══════════════════════════════════════════════════════════════════════════ # 11. contradiction — resolve_contradiction # ═══════════════════════════════════════════════════════════════════════════ def test_resolve_contradiction_priority(): """マッチング(prio=10) 胜出 over キーブレイク(prio=9)""" contradiction = {"name": "matching_vs_keybreak", "type_a": "マッチング", "type_b": "キーブレイク"} result = resolve_contradiction({}, contradiction) assert result == "マッチング" def test_resolve_contradiction_csv(): """CSV合并(prio=6) == CSV拆分(prio=6) → 使用重判定""" contradiction = {"name": "csv_merge_vs_split", "type_a": "CSV合并", "type_b": "CSV拆分"} features = {"has_string": True, "has_inspect": False} result = resolve_contradiction(features, contradiction) assert result == "CSV合并" # ═══════════════════════════════════════════════════════════════════════════ # 12. contradiction — CONTRACTION_PAIRS 常量 # ═══════════════════════════════════════════════════════════════════════════ def test_contradiction_pairs_defined(): """CONTRADICTION_PAIRS 包含所有混淆对,DIVIDE 全部 3 种变体""" assert len(CONTRADICTION_PAIRS) >= 8 pairs_by_name: dict[str, list[dict]] = {} for p in CONTRADICTION_PAIRS: pairs_by_name.setdefault(p["name"], []).append(p) expected_names = { "matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak", "csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed", "division_50_25_100", "mn_output_mode", } assert set(pairs_by_name.keys()) >= expected_names # division 应有 3 个矛盾对 (50-100, 50-25, 100-25) 覆盖所有变体 div_pairs = pairs_by_name.get("division_50_25_100", []) assert len(div_pairs) == 3, f"DIVIDE 应覆盖全部 3 组变体,当前 {len(div_pairs)} 组" div_types = {(p["type_a"], p["type_b"]) for p in div_pairs} assert ("DIVIDE_50", "DIVIDE_100") in div_types assert ("DIVIDE_50", "DIVIDE_25") in div_types assert ("DIVIDE_100", "DIVIDE_25") in div_types # ═══════════════════════════════════════════════════════════════════════════ # 13. backtrack — BacktrackResolver # ═══════════════════════════════════════════════════════════════════════════ def test_backtrack_no_contradiction(): """无矛盾 → 一轮解决,backtrack_resolved=True""" def extractor(src: str) -> dict: return {"resolved_types": {"pair_1": "マッチング"}, "if_types": {}} resolver = BacktrackResolver(extractor) result = resolver.resolve("some source", {"resolved_types": {"pair_1": "マッチング"}}) assert result["backtrack_resolved"] is True assert result["backtrack_rounds"] == 0 def test_backtrack_with_contradiction(): """有矛盾 → 解决,标记 round""" def extractor(src: str) -> dict: return {"resolved_types": {"pair_1": "マッチング"}, "if_types": {}} features = { "resolved_types": { "pair_1": "マッチング", "pair_2": "キーブレイク", } } resolver = BacktrackResolver(extractor) result = resolver.resolve("some source", features) # 核心断言: 矛盾被解决 (resolved_* keys 出现) resolved_keys = [k for k in result if k.startswith("resolved_")] assert len(resolved_keys) >= 1 assert result["backtrack_rounds"] >= 1 def test_backtrack_max_rounds_degraded(): """持续矛盾 → 耗尽 max_rounds 后 degraded""" round_count = 0 def extractor(src: str) -> dict: nonlocal round_count round_count += 1 # 每次都返回包含矛盾的特征 return { "resolved_types": { "pair_1": "マッチング", "pair_2": "キーブレイク", } } features = { "resolved_types": { "pair_1": "マッチング", "pair_2": "キーブレイク", } } resolver = BacktrackResolver(extractor) resolver.max_rounds = 2 result = resolver.resolve("some source", features) assert result["backtrack_degraded"] is True # 应已进行多轮尝试 assert result["backtrack_rounds"] >= 1 def test_backtrack_extract_error(): """提取器抛异常 → 标记 extract_error""" def extractor(src: str) -> dict: raise ValueError("extraction failed") features = { "resolved_types": { "pair_1": "マッチング", "pair_2": "キーブレイク", } } resolver = BacktrackResolver(extractor) result = resolver.resolve("some source", features) assert result.get("backtrack_extract_error") is True def test_backtrack_no_contradiction(): """无矛盾 → 不超时,直接返回""" def fast_extractor(src: str) -> dict: return {"resolved_types": {}} resolver = BacktrackResolver(fast_extractor) result = resolver.resolve("source", {"resolved_types": {}}) assert isinstance(result, dict) # ═══════════════════════════════════════════════════════════════════════════ # 14. Integration — full round-trip via resolve_confusion_pair # ═══════════════════════════════════════════════════════════════════════════ def test_integration_matching_roundtrip(): """完整流程: 通过 resolve_confusion_pair → resolve_matching_vs_keybreak""" features = { "if_types": {"total": 5, "comparison": 3, "equality": 1, "compound": 1, "nested_depth": 2}, "select_files": {"f1": {}, "f2": {}}, "variable_patterns": {"has_prev_key": False, "has_accumulator": False, "has_error_field": False}, } result = resolve_confusion_pair(features, "matching_vs_keybreak") assert result["resolved_type"] in ("マッチング", "キーブレイク", "unknown") assert "confidence" in result assert "evidence" in result def test_integration_contradiction_resolve_cycle(): """矛盾检测 → 解决完整闭环""" features = { "resolved_types": { "from_keyword": "マッチング", "from_llm": "キーブレイク", } } contradictions = detect_contradictions(features) assert len(contradictions) >= 1 winner = resolve_contradiction(features, contradictions[0]) assert winner in ("マッチング", "キーブレイク")