fix: resolve 3 MEDIUM code review findings
M1: Cache confusion-pair confidences in Path B (eliminate redundant
resolve_confusion_pair re-calls in _path_rule_engine)
M2: Resolve contradictions in Path C instead of hardcoding
resolved_count=0 in _path_llm_assisted
M4: Add DIVIDE_25 to contradiction pair coverage (50-25, 100-25)
and update test_contradiction_pairs_defined to verify all 3 variants
This commit is contained in:
@@ -157,11 +157,13 @@ def _path_rule_engine(
|
|||||||
|
|
||||||
# 2. 运行所有混淆组解析器
|
# 2. 运行所有混淆组解析器
|
||||||
resolved_types: dict[str, str] = {}
|
resolved_types: dict[str, str] = {}
|
||||||
|
resolved_confidences: dict[str, float] = {}
|
||||||
for pair_name in _PAIR_NAMES:
|
for pair_name in _PAIR_NAMES:
|
||||||
try:
|
try:
|
||||||
result = resolve_confusion_pair(features, pair_name)
|
result = resolve_confusion_pair(features, pair_name)
|
||||||
if result["resolved_type"] != "unknown" and result["confidence"] > 0:
|
if result["resolved_type"] != "unknown" and result["confidence"] > 0:
|
||||||
resolved_types[pair_name] = result["resolved_type"]
|
resolved_types[pair_name] = result["resolved_type"]
|
||||||
|
resolved_confidences[pair_name] = result["confidence"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
|
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
|
||||||
|
|
||||||
@@ -192,16 +194,14 @@ def _path_rule_engine(
|
|||||||
final_base_confidence = keyword_info["confidence"]
|
final_base_confidence = keyword_info["confidence"]
|
||||||
|
|
||||||
# 如果规则引擎有更高置信度的结果, 则采纳
|
# 如果规则引擎有更高置信度的结果, 则采纳
|
||||||
|
# 使用第一轮缓存的结果(M1: 消除冗余重复调用)
|
||||||
best_resolved_type = None
|
best_resolved_type = None
|
||||||
best_resolved_conf = 0.0
|
best_resolved_conf = 0.0
|
||||||
for pair_name, rtype in resolved_types.items():
|
for pair_name, rtype in resolved_types.items():
|
||||||
try:
|
cached_conf = resolved_confidences.get(pair_name, 0.0)
|
||||||
rr = resolve_confusion_pair(features, pair_name)
|
if cached_conf > best_resolved_conf:
|
||||||
if rr["confidence"] > best_resolved_conf:
|
best_resolved_conf = cached_conf
|
||||||
best_resolved_conf = rr["confidence"]
|
|
||||||
best_resolved_type = rtype
|
best_resolved_type = rtype
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if best_resolved_type and best_resolved_conf > final_base_confidence:
|
if best_resolved_type and best_resolved_conf > final_base_confidence:
|
||||||
final_category = best_resolved_type
|
final_category = best_resolved_type
|
||||||
@@ -272,7 +272,7 @@ def _path_llm_assisted(
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 3. 矛盾检测
|
# 3. 矛盾检测与解决 (M2: 消除硬编码 resolved_count=0)
|
||||||
resolved_types: dict[str, str] = {}
|
resolved_types: dict[str, str] = {}
|
||||||
for pair_name in _PAIR_NAMES:
|
for pair_name in _PAIR_NAMES:
|
||||||
try:
|
try:
|
||||||
@@ -285,6 +285,19 @@ def _path_llm_assisted(
|
|||||||
features["resolved_types"] = resolved_types
|
features["resolved_types"] = resolved_types
|
||||||
contradictions = detect_contradictions(features)
|
contradictions = detect_contradictions(features)
|
||||||
|
|
||||||
|
resolution_map: dict[str, Any] = {
|
||||||
|
"resolved_count": 0,
|
||||||
|
"total_count": len(contradictions),
|
||||||
|
}
|
||||||
|
for c in contradictions:
|
||||||
|
try:
|
||||||
|
winner = resolve_contradiction(features, c)
|
||||||
|
if winner:
|
||||||
|
resolution_map[c.get("name", "unknown")] = winner
|
||||||
|
resolution_map["resolved_count"] += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("[pipeline] Path C 矛盾解决异常: %s", e)
|
||||||
|
|
||||||
# 4. 确信度计算
|
# 4. 确信度计算
|
||||||
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
||||||
keyword_result_v2["base_confidence"] = validated_confidence
|
keyword_result_v2["base_confidence"] = validated_confidence
|
||||||
@@ -295,7 +308,7 @@ def _path_llm_assisted(
|
|||||||
keyword_result=keyword_result_v2,
|
keyword_result=keyword_result_v2,
|
||||||
structure_features=structure_features,
|
structure_features=structure_features,
|
||||||
contradictions=contradictions,
|
contradictions=contradictions,
|
||||||
resolution={"resolved_count": 0, "total_count": len(contradictions)},
|
resolution=resolution_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -45,6 +45,16 @@ CONTRADICTION_PAIRS: list[dict[str, str]] = [
|
|||||||
"type_a": "DIVIDE_50",
|
"type_a": "DIVIDE_50",
|
||||||
"type_b": "DIVIDE_100",
|
"type_b": "DIVIDE_100",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "division_50_25_100",
|
||||||
|
"type_a": "DIVIDE_50",
|
||||||
|
"type_b": "DIVIDE_25",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "division_50_25_100",
|
||||||
|
"type_a": "DIVIDE_100",
|
||||||
|
"type_b": "DIVIDE_25",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "mn_output_mode",
|
"name": "mn_output_mode",
|
||||||
"type_a": "M:N",
|
"type_a": "M:N",
|
||||||
|
|||||||
@@ -327,15 +327,26 @@ def test_resolve_contradiction_csv():
|
|||||||
# ═══════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
def test_contradiction_pairs_defined():
|
def test_contradiction_pairs_defined():
|
||||||
"""CONTRADICTION_PAIRS 包含所有 8 个混淆对"""
|
"""CONTRADICTION_PAIRS 包含所有混淆对,DIVIDE 全部 3 种变体"""
|
||||||
assert len(CONTRADICTION_PAIRS) == 8
|
assert len(CONTRADICTION_PAIRS) >= 8
|
||||||
names = {p["name"] for p in CONTRADICTION_PAIRS}
|
pairs_by_name: dict[str, list[dict]] = {}
|
||||||
expected = {
|
for p in CONTRADICTION_PAIRS:
|
||||||
|
pairs_by_name.setdefault(p["name"], []).append(p)
|
||||||
|
|
||||||
|
expected_names = {
|
||||||
"matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak",
|
"matching_vs_keybreak", "dedup_vs_nodedup", "validation_vs_keybreak",
|
||||||
"csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed",
|
"csv_merge_vs_split", "simple_vs_two_stage", "pure_vs_mixed",
|
||||||
"division_50_25_100", "mn_output_mode",
|
"division_50_25_100", "mn_output_mode",
|
||||||
}
|
}
|
||||||
assert names == expected
|
assert set(pairs_by_name.keys()) >= expected_names
|
||||||
|
|
||||||
|
# division 应有 3 个矛盾对 (50-100, 50-25, 100-25) 覆盖所有变体
|
||||||
|
div_pairs = pairs_by_name.get("division_50_25_100", [])
|
||||||
|
assert len(div_pairs) == 3, f"DIVIDE 应覆盖全部 3 组变体,当前 {len(div_pairs)} 组"
|
||||||
|
div_types = {(p["type_a"], p["type_b"]) for p in div_pairs}
|
||||||
|
assert ("DIVIDE_50", "DIVIDE_100") in div_types
|
||||||
|
assert ("DIVIDE_50", "DIVIDE_25") in div_types
|
||||||
|
assert ("DIVIDE_100", "DIVIDE_25") in div_types
|
||||||
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════════════
|
# ═══════════════════════════════════════════════════════════════════════════
|
||||||
|
|||||||
Reference in New Issue
Block a user