diff --git a/hina/pipeline/pipeline.py b/hina/pipeline/pipeline.py index 590054f..99a1056 100644 --- a/hina/pipeline/pipeline.py +++ b/hina/pipeline/pipeline.py @@ -157,11 +157,13 @@ def _path_rule_engine( # 2. 运行所有混淆组解析器 resolved_types: dict[str, str] = {} + resolved_confidences: dict[str, float] = {} for pair_name in _PAIR_NAMES: try: result = resolve_confusion_pair(features, pair_name) if result["resolved_type"] != "unknown" and result["confidence"] > 0: resolved_types[pair_name] = result["resolved_type"] + resolved_confidences[pair_name] = result["confidence"] except Exception as e: logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e) @@ -192,16 +194,14 @@ def _path_rule_engine( final_base_confidence = keyword_info["confidence"] # 如果规则引擎有更高置信度的结果, 则采纳 + # 使用第一轮缓存的结果(M1: 消除冗余重复调用) best_resolved_type = None best_resolved_conf = 0.0 for pair_name, rtype in resolved_types.items(): - try: - rr = resolve_confusion_pair(features, pair_name) - if rr["confidence"] > best_resolved_conf: - best_resolved_conf = rr["confidence"] - best_resolved_type = rtype - except Exception: - continue + cached_conf = resolved_confidences.get(pair_name, 0.0) + if cached_conf > best_resolved_conf: + best_resolved_conf = cached_conf + best_resolved_type = rtype if best_resolved_type and best_resolved_conf > final_base_confidence: final_category = best_resolved_type @@ -272,7 +272,7 @@ def _path_llm_assisted( except Exception: continue - # 3. 矛盾检测 + # 3. 矛盾检测与解决 (M2: 消除硬编码 resolved_count=0) resolved_types: dict[str, str] = {} for pair_name in _PAIR_NAMES: try: @@ -285,6 +285,19 @@ def _path_llm_assisted( features["resolved_types"] = resolved_types contradictions = detect_contradictions(features) + resolution_map: dict[str, Any] = { + "resolved_count": 0, + "total_count": len(contradictions), + } + for c in contradictions: + try: + winner = resolve_contradiction(features, c) + if winner: + resolution_map[c.get("name", "unknown")] = winner + resolution_map["resolved_count"] += 1 + except Exception as e: + logger.debug("[pipeline] Path C 矛盾解决异常: %s", e) + # 4. 确信度计算 keyword_result_v2 = _build_keyword_result_for_v2(keyword_info) keyword_result_v2["base_confidence"] = validated_confidence @@ -295,7 +308,7 @@ def _path_llm_assisted( keyword_result=keyword_result_v2, structure_features=structure_features, contradictions=contradictions, - resolution={"resolved_count": 0, "total_count": len(contradictions)}, + resolution=resolution_map, ) return { diff --git a/hina/rule_engine/contradiction.py b/hina/rule_engine/contradiction.py index f8deb1d..842ffa4 100644 --- a/hina/rule_engine/contradiction.py +++ b/hina/rule_engine/contradiction.py @@ -45,6 +45,16 @@ CONTRADICTION_PAIRS: list[dict[str, str]] = [ "type_a": "DIVIDE_50", "type_b": "DIVIDE_100", }, + { + "name": "division_50_25_100", + "type_a": "DIVIDE_50", + "type_b": "DIVIDE_25", + }, + { + "name": "division_50_25_100", + "type_a": "DIVIDE_100", + "type_b": "DIVIDE_25", + }, { "name": "mn_output_mode", "type_a": "M:N", diff --git a/tests/hina/test_rule_engine.py b/tests/hina/test_rule_engine.py index db9c09f..3173d67 100644 --- a/tests/hina/test_rule_engine.py +++ b/tests/hina/test_rule_engine.py @@ -327,15 +327,26 @@ def test_resolve_contradiction_csv(): # ═══════════════════════════════════════════════════════════════════════════ def test_contradiction_pairs_defined(): - """CONTRADICTION_PAIRS 包含所有 8 个混淆对""" - assert len(CONTRADICTION_PAIRS) == 8 - names = {p["name"] for p in CONTRADICTION_PAIRS} - expected = { + """CONTRADICTION_PAIRS 包含所有混淆对,DIVIDE 全部 3 种变体""" + assert len(CONTRADICTION_PAIRS) >= 8 + pairs_by_name: dict[str, list[dict]] = {} + for p in CONTRADICTION_PAIRS: + pairs_by_name.setdefault(p["name"], []).append(p) + + expected_names = { "matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak", "csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed", "division_50_25_100", "mn_output_mode", } - assert names == expected + assert set(pairs_by_name.keys()) >= expected_names + + # division 应有 3 个矛盾对 (50-100, 50-25, 100-25) 覆盖所有变体 + div_pairs = pairs_by_name.get("division_50_25_100", []) + assert len(div_pairs) == 3, f"DIVIDE 应覆盖全部 3 组变体,当前 {len(div_pairs)} 组" + div_types = {(p["type_a"], p["type_b"]) for p in div_pairs} + assert ("DIVIDE_50", "DIVIDE_100") in div_types + assert ("DIVIDE_50", "DIVIDE_25") in div_types + assert ("DIVIDE_100", "DIVIDE_25") in div_types # ═══════════════════════════════════════════════════════════════════════════