fix: resolve 3 MEDIUM code review findings

M1: Cache confusion-pair confidences in Path B (eliminate redundant
    resolve_confusion_pair re-calls in _path_rule_engine)
M2: Resolve contradictions in Path C instead of hardcoding
    resolved_count=0 in _path_llm_assisted
M4: Add DIVIDE_25 to contradiction pair coverage (50-25, 100-25)
    and update test_contradiction_pairs_defined to verify all 3 variants
This commit is contained in:
NB-076
2026-06-21 11:25:59 +08:00
parent bc1d56d1a4
commit a6c454692a
3 changed files with 48 additions and 14 deletions
+22 -9
View File
@@ -157,11 +157,13 @@ def _path_rule_engine(
# 2. 运行所有混淆组解析器 # 2. 运行所有混淆组解析器
resolved_types: dict[str, str] = {} resolved_types: dict[str, str] = {}
resolved_confidences: dict[str, float] = {}
for pair_name in _PAIR_NAMES: for pair_name in _PAIR_NAMES:
try: try:
result = resolve_confusion_pair(features, pair_name) result = resolve_confusion_pair(features, pair_name)
if result["resolved_type"] != "unknown" and result["confidence"] > 0: if result["resolved_type"] != "unknown" and result["confidence"] > 0:
resolved_types[pair_name] = result["resolved_type"] resolved_types[pair_name] = result["resolved_type"]
resolved_confidences[pair_name] = result["confidence"]
except Exception as e: except Exception as e:
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e) logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
@@ -192,16 +194,14 @@ def _path_rule_engine(
final_base_confidence = keyword_info["confidence"] final_base_confidence = keyword_info["confidence"]
# 如果规则引擎有更高置信度的结果, 则采纳 # 如果规则引擎有更高置信度的结果, 则采纳
# 使用第一轮缓存的结果(M1: 消除冗余重复调用)
best_resolved_type = None best_resolved_type = None
best_resolved_conf = 0.0 best_resolved_conf = 0.0
for pair_name, rtype in resolved_types.items(): for pair_name, rtype in resolved_types.items():
try: cached_conf = resolved_confidences.get(pair_name, 0.0)
rr = resolve_confusion_pair(features, pair_name) if cached_conf > best_resolved_conf:
if rr["confidence"] > best_resolved_conf: best_resolved_conf = cached_conf
best_resolved_conf = rr["confidence"] best_resolved_type = rtype
best_resolved_type = rtype
except Exception:
continue
if best_resolved_type and best_resolved_conf > final_base_confidence: if best_resolved_type and best_resolved_conf > final_base_confidence:
final_category = best_resolved_type final_category = best_resolved_type
@@ -272,7 +272,7 @@ def _path_llm_assisted(
except Exception: except Exception:
continue continue
# 3. 矛盾检测 # 3. 矛盾检测与解决 (M2: 消除硬编码 resolved_count=0)
resolved_types: dict[str, str] = {} resolved_types: dict[str, str] = {}
for pair_name in _PAIR_NAMES: for pair_name in _PAIR_NAMES:
try: try:
@@ -285,6 +285,19 @@ def _path_llm_assisted(
features["resolved_types"] = resolved_types features["resolved_types"] = resolved_types
contradictions = detect_contradictions(features) contradictions = detect_contradictions(features)
resolution_map: dict[str, Any] = {
"resolved_count": 0,
"total_count": len(contradictions),
}
for c in contradictions:
try:
winner = resolve_contradiction(features, c)
if winner:
resolution_map[c.get("name", "unknown")] = winner
resolution_map["resolved_count"] += 1
except Exception as e:
logger.debug("[pipeline] Path C 矛盾解决异常: %s", e)
# 4. 确信度计算 # 4. 确信度计算
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info) keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
keyword_result_v2["base_confidence"] = validated_confidence keyword_result_v2["base_confidence"] = validated_confidence
@@ -295,7 +308,7 @@ def _path_llm_assisted(
keyword_result=keyword_result_v2, keyword_result=keyword_result_v2,
structure_features=structure_features, structure_features=structure_features,
contradictions=contradictions, contradictions=contradictions,
resolution={"resolved_count": 0, "total_count": len(contradictions)}, resolution=resolution_map,
) )
return { return {
+10
View File
@@ -45,6 +45,16 @@ CONTRADICTION_PAIRS: list[dict[str, str]] = [
"type_a": "DIVIDE_50", "type_a": "DIVIDE_50",
"type_b": "DIVIDE_100", "type_b": "DIVIDE_100",
}, },
{
"name": "division_50_25_100",
"type_a": "DIVIDE_50",
"type_b": "DIVIDE_25",
},
{
"name": "division_50_25_100",
"type_a": "DIVIDE_100",
"type_b": "DIVIDE_25",
},
{ {
"name": "mn_output_mode", "name": "mn_output_mode",
"type_a": "M:N", "type_a": "M:N",
+16 -5
View File
@@ -327,15 +327,26 @@ def test_resolve_contradiction_csv():
# ═══════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════
def test_contradiction_pairs_defined(): def test_contradiction_pairs_defined():
"""CONTRADICTION_PAIRS 包含所有 8 个混淆对""" """CONTRADICTION_PAIRS 包含所有混淆对DIVIDE 全部 3 种变体"""
assert len(CONTRADICTION_PAIRS) == 8 assert len(CONTRADICTION_PAIRS) >= 8
names = {p["name"] for p in CONTRADICTION_PAIRS} pairs_by_name: dict[str, list[dict]] = {}
expected = { for p in CONTRADICTION_PAIRS:
pairs_by_name.setdefault(p["name"], []).append(p)
expected_names = {
"matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak", "matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak",
"csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed", "csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed",
"division_50_25_100", "mn_output_mode", "division_50_25_100", "mn_output_mode",
} }
assert names == expected assert set(pairs_by_name.keys()) >= expected_names
# division 应有 3 个矛盾对 (50-100, 50-25, 100-25) 覆盖所有变体
div_pairs = pairs_by_name.get("division_50_25_100", [])
assert len(div_pairs) == 3, f"DIVIDE 应覆盖全部 3 组变体,当前 {len(div_pairs)}"
div_types = {(p["type_a"], p["type_b"]) for p in div_pairs}
assert ("DIVIDE_50", "DIVIDE_100") in div_types
assert ("DIVIDE_50", "DIVIDE_25") in div_types
assert ("DIVIDE_100", "DIVIDE_25") in div_types
# ═══════════════════════════════════════════════════════════════════════════ # ═══════════════════════════════════════════════════════════════════════════