feat: Phase 2 complete — 13 Phases of COBOL type classification and test benchmark
P0.6: gcov infrastructure P1: extract_structure output expansion (11 new feature fields) P2: Confusion group rule engine (8 pairs + contradiction + backtrack) P3: 4-factor confidence calculation + quality gate update P4: 33+2 COBOL program type test samples (22 files, 7 categories) P5: parametrized/ test data generation engine P6: japanese_data.py lookup tables P7-10: Type-specific test suites (~159 parametrized tests) P11: Full classification pipeline (classify_program) + orchestrator integration P12: Documentation (module-interfaces, test-plan v3.0, coverage-matrix) Architecture decisions: - classification_pipeline/ merged to hina/pipeline/ - parametrized/ as independent module - japanese_data.py as root-level file - hina/__all__ only exports classify_program() Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
+25
-1
@@ -1 +1,25 @@
|
||||
# HINA 程序分类与质量门禁包
|
||||
"""HINA 程序分类与质量门禁包
|
||||
|
||||
公开 API:
|
||||
classify_program() — 完整类型判定管道(唯一外部入口)
|
||||
|
||||
内部模块(不直接导出,但保留模块级导入以维持向后兼容):
|
||||
gate_check() — 质量门禁判定
|
||||
get_strategy() — 策略模板获取
|
||||
supplement() — 策略补充
|
||||
RetryHandler — 分层重试处理器
|
||||
collect_gcov() — gcov 覆盖率采集
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .pipeline.pipeline import classify_program
|
||||
from .gate import check as gate_check
|
||||
from .strategy import get_strategy, supplement, supplement_only
|
||||
from .retry import RetryHandler
|
||||
from .gcov_collector import collect_gcov
|
||||
|
||||
__all__ = [
|
||||
# ═══ 唯一外部入口 ═══
|
||||
"classify_program", # (source: str, llm?: object) -> dict
|
||||
]
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
确信度 4 因子计算。
|
||||
|
||||
公式: confidence = base × context_factor × consistency_factor × structure_factor
|
||||
|
||||
判定:
|
||||
>= 0.90 auto — 自动通过
|
||||
0.70-0.89 review — 需要人工审核
|
||||
0.50-0.69 manual — 需要人工介入
|
||||
< 0.50 impossible — 无法判定
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def compute_confidence_v2(
|
||||
keyword_result: dict[str, Any],
|
||||
structure_features: dict[str, Any],
|
||||
contradictions: list[dict[str, Any]] | None = None,
|
||||
resolution: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""4 因子确信度计算。
|
||||
|
||||
Args:
|
||||
keyword_result: L1 关键字判定结果,
|
||||
例如 {"category": "DB操作", "base_confidence": 0.95, "match_count": 3}
|
||||
structure_features: 结构特征分析结果,
|
||||
例如 {"structure_match_score": 5, "total_paragraphs": 10}
|
||||
contradictions: 矛盾列表,每条包含 {"type": str, "resolved": bool, ...}
|
||||
resolution: 矛盾解决方案,
|
||||
例如 {"resolved_count": 0, "total_count": 0}
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"confidence": float, # 综合确信度 (0.0 ~ 1.0)
|
||||
"base": float, # 基础确信度
|
||||
"context_factor": float, # 上下文因子
|
||||
"consistency_factor": float,# 一致性因子
|
||||
"structure_factor": float, # 结构一致性因子
|
||||
"judgment": str, # 判定结果 (auto/review/manual/impossible)
|
||||
"needs_review": bool, # 是否需要人工审核
|
||||
}
|
||||
"""
|
||||
# ── 1. 基础确信度 ──
|
||||
base = keyword_result.get("base_confidence", 0.7)
|
||||
|
||||
# ── 2. 上下文因子(关键字匹配数)──
|
||||
match_count = keyword_result.get("match_count", 0)
|
||||
if match_count >= 3:
|
||||
context_factor = 1.0
|
||||
elif match_count == 2:
|
||||
context_factor = 0.95
|
||||
elif match_count == 1:
|
||||
context_factor = 0.90
|
||||
else:
|
||||
context_factor = 0.50
|
||||
|
||||
# ── 3. 一致性因子(矛盾检测)──
|
||||
contradictions = contradictions or []
|
||||
unresolved_count = sum(1 for c in contradictions if not c.get("resolved", False))
|
||||
total_contradictions = len(contradictions)
|
||||
|
||||
if total_contradictions == 0:
|
||||
consistency_factor = 1.0
|
||||
elif unresolved_count == 0:
|
||||
# 有矛盾但全部已解决
|
||||
consistency_factor = 0.90
|
||||
elif total_contradictions >= 3:
|
||||
consistency_factor = 0.50
|
||||
else:
|
||||
# 有未解决的矛盾,但少于 3 个
|
||||
consistency_factor = 0.80
|
||||
|
||||
# ── 4. 结构一致性因子 ──
|
||||
structure_score = structure_features.get("structure_match_score", 0)
|
||||
if structure_score == 5:
|
||||
structure_factor = 1.0
|
||||
elif structure_score >= 3:
|
||||
structure_factor = 0.7
|
||||
elif structure_score >= 1:
|
||||
structure_factor = 0.5
|
||||
else:
|
||||
structure_factor = 0.3
|
||||
|
||||
# ── 计算综合确信度 ──
|
||||
confidence = round(base * context_factor * consistency_factor * structure_factor, 4)
|
||||
|
||||
# ── 判定 ──
|
||||
if confidence >= 0.90:
|
||||
judgment = "auto"
|
||||
needs_review = False
|
||||
elif confidence >= 0.70:
|
||||
judgment = "review"
|
||||
needs_review = True
|
||||
elif confidence >= 0.50:
|
||||
judgment = "manual"
|
||||
needs_review = True
|
||||
else:
|
||||
judgment = "impossible"
|
||||
needs_review = True
|
||||
|
||||
return {
|
||||
"confidence": confidence,
|
||||
"base": base,
|
||||
"context_factor": context_factor,
|
||||
"consistency_factor": consistency_factor,
|
||||
"structure_factor": structure_factor,
|
||||
"judgment": judgment,
|
||||
"needs_review": needs_review,
|
||||
}
|
||||
@@ -5,6 +5,10 @@ Phase 1 可用: 决策点覆盖、段落覆盖
|
||||
Phase 2 启用: HINA 必须项、字段覆盖
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def check(
|
||||
complete_tests: list,
|
||||
@@ -60,3 +64,43 @@ def _compute_score(coverage: dict, hina_result: dict) -> float:
|
||||
boundary_quality = 1.0
|
||||
|
||||
return round(coverage_quality * 0.6 + boundary_quality * 0.4, 2)
|
||||
|
||||
|
||||
def compute_quality_score(
|
||||
static_coverage: dict[str, Any],
|
||||
gcov_coverage: dict[str, Any] | None = None,
|
||||
confidence: float = 0.5,
|
||||
) -> float:
|
||||
"""双模式质量评分。
|
||||
|
||||
模式 1 — gcov 未启用 (gcov_coverage is None):
|
||||
score = branch_rate × 0.5 + paragraph_rate × 0.5 + confidence × 0.4
|
||||
其中 confidence 作为加分项(最高 +0.4)
|
||||
|
||||
模式 2 — gcov 启用:
|
||||
score = static_cov × 0.3 + gcov_cov × 0.4 + confidence × 0.3
|
||||
其中 static_cov = branch_rate × 0.5 + paragraph_rate × 0.5
|
||||
|
||||
Args:
|
||||
static_coverage: 静态覆盖率数据
|
||||
{"branch_rate": float, "paragraph_rate": float, ...}
|
||||
gcov_coverage: gcov 动态覆盖率数据,None 表示未启用
|
||||
{"gcov_cov": float, ...} 或 None
|
||||
confidence: 确信度 (0.0 ~ 1.0)
|
||||
|
||||
Returns:
|
||||
float: 质量评分 (0.0 ~ 1.0)
|
||||
"""
|
||||
branch_rate = static_coverage.get("branch_rate", 0.0)
|
||||
paragraph_rate = static_coverage.get("paragraph_rate", 0.0)
|
||||
static_cov = branch_rate * 0.5 + paragraph_rate * 0.5
|
||||
|
||||
if gcov_coverage is not None:
|
||||
# 模式 2: gcov 启用
|
||||
gcov_cov = gcov_coverage.get("gcov_cov", 0.0)
|
||||
score = static_cov * 0.3 + gcov_cov * 0.4 + confidence * 0.3
|
||||
else:
|
||||
# 模式 1: gcov 未启用 — confidence 作为加分
|
||||
score = branch_rate * 0.5 + paragraph_rate * 0.5 + confidence * 0.4
|
||||
|
||||
return round(min(score, 1.0), 4)
|
||||
|
||||
@@ -7,7 +7,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def collect_gcov(cobol_src: Path, work_dir: Path) -> dict:
|
||||
try:
|
||||
gcda_files = list(work_dir.glob("*.gcda"))
|
||||
cd = str(work_dir)
|
||||
gcda_files = list(Path(cd).glob("*.gcda"))
|
||||
if not gcda_files:
|
||||
logger.warning("[gcov] 未找到 .gcda 文件,可能未启用插桩编译")
|
||||
return {"available": False, "reason": "no_gcda_files"}
|
||||
@@ -15,16 +16,16 @@ def collect_gcov(cobol_src: Path, work_dir: Path) -> dict:
|
||||
result = subprocess.run(
|
||||
["gcov", cobol_src.name],
|
||||
capture_output=True, text=True, timeout=30,
|
||||
cwd=work_dir,
|
||||
cwd=cd,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"[gcov] gcov 执行失败: {result.stderr[:200]}")
|
||||
return {"available": False, "reason": "gcov_failed"}
|
||||
|
||||
gcov_file = work_dir / f"{cobol_src.stem}.cbl.gcov"
|
||||
gcov_file = Path(cd) / f"{cobol_src.stem}.cbl.gcov"
|
||||
if not gcov_file.exists():
|
||||
gcov_file = work_dir / f"{cobol_src.stem}.gcov"
|
||||
gcov_file = Path(cd) / f"{cobol_src.stem}.gcov"
|
||||
|
||||
if not gcov_file.exists():
|
||||
logger.warning("[gcov] .gcov 文件未生成")
|
||||
@@ -32,7 +33,7 @@ def collect_gcov(cobol_src: Path, work_dir: Path) -> dict:
|
||||
|
||||
total_lines = 0
|
||||
executed_lines = 0
|
||||
with open(gcov_file) as f:
|
||||
with open(str(gcov_file), encoding="utf-8", errors="replace") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("-"):
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
"""HINA 完整类型判定管道。"""
|
||||
@@ -0,0 +1,419 @@
|
||||
"""
|
||||
完整程序类型判定管道 — classify_program()
|
||||
|
||||
流程:
|
||||
1. 并行: detect_keyword() + extract_structure()
|
||||
2. keyword confidence >= 90% -> 直接输出
|
||||
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
||||
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
||||
5. 输出最终 JSON
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
from hina.classifier import detect_keyword
|
||||
from hina.confidence import compute_confidence_v2
|
||||
from hina.rule_engine.confusion_groups import resolve_confusion_pair
|
||||
from hina.rule_engine.contradiction import (
|
||||
CONTRADICTION_PAIRS,
|
||||
detect_contradictions,
|
||||
resolve_contradiction,
|
||||
)
|
||||
from cobol_testgen import extract_structure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 所有可尝试的混淆对名称
|
||||
_PAIR_NAMES = [
|
||||
"matching_vs_keybreak",
|
||||
"dedup_vs_nodedup",
|
||||
"validation_vs_keybreak",
|
||||
"csv_merge_vs_split",
|
||||
"simple_vs_two_stage",
|
||||
"pure_vs_mixed",
|
||||
"division_50_25_100",
|
||||
"mn_output_mode",
|
||||
]
|
||||
|
||||
|
||||
# ── 内部工具 ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_best_keyword_match(matches: list) -> dict | None:
|
||||
"""从 L1 关键字匹配结果中找出最佳匹配。
|
||||
|
||||
Args:
|
||||
matches: detect_keyword() 返回的 list[tuple[str, float, str]]
|
||||
|
||||
Returns:
|
||||
dict | None: {"category", "confidence", "keyword", "all_matches"}
|
||||
"""
|
||||
if not matches:
|
||||
return None
|
||||
best = max(matches, key=lambda m: m[1]) # (category, confidence, keyword)
|
||||
return {
|
||||
"category": best[0],
|
||||
"confidence": best[1],
|
||||
"keyword": best[2],
|
||||
"all_matches": matches,
|
||||
}
|
||||
|
||||
|
||||
def _compute_structure_match_score(structure: dict) -> int:
|
||||
"""计算结构匹配度评分 (0-5),供 compute_confidence_v2 使用。"""
|
||||
return min(
|
||||
5,
|
||||
bool(structure.get("total_paragraphs", 0)) # 有段落
|
||||
+ bool(structure.get("file_count", 0)) # 有文件
|
||||
+ bool(len(structure.get("decision_points", []))) # 有决策点
|
||||
+ bool(structure.get("if_types", {}).get("total", 0)) # 有 IF
|
||||
+ bool(structure.get("branch_tree_obj") is not None), # 有分支树
|
||||
)
|
||||
|
||||
|
||||
def _build_structure_summary(structure: dict) -> dict:
|
||||
"""从完整结构中提取调试摘要。"""
|
||||
return {
|
||||
"paragraph_count": structure.get("total_paragraphs", 0),
|
||||
"file_count": structure.get("file_count", 0),
|
||||
"decision_count": len(structure.get("decision_points", [])),
|
||||
"has_call": structure.get("has_call", False),
|
||||
"has_divide": structure.get("has_divide", False),
|
||||
}
|
||||
|
||||
|
||||
def _build_keyword_result_for_v2(keyword_info: dict | None) -> dict:
|
||||
"""构建 compute_confidence_v2 所需的 keyword_result。"""
|
||||
if keyword_info:
|
||||
return {
|
||||
"base_confidence": keyword_info["confidence"],
|
||||
"match_count": len(keyword_info["all_matches"]),
|
||||
}
|
||||
return {"base_confidence": 0.0, "match_count": 0}
|
||||
|
||||
|
||||
def _build_structure_features(structure: dict) -> dict:
|
||||
"""构建 compute_confidence_v2 所需的 structure_features。"""
|
||||
return {
|
||||
"structure_match_score": _compute_structure_match_score(structure),
|
||||
"total_paragraphs": structure.get("total_paragraphs", 0),
|
||||
}
|
||||
|
||||
|
||||
# ── 分路径逻辑 ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _path_keyword_direct(
|
||||
keyword_info: dict,
|
||||
structure: dict,
|
||||
) -> dict:
|
||||
"""路径 A: keyword confidence >= 90%, 直接输出。
|
||||
|
||||
仍会计算 v2 确信度用于最终 validation,但结果来源标记为 "keyword"。
|
||||
"""
|
||||
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
||||
structure_features = _build_structure_features(structure)
|
||||
|
||||
v2_conf = compute_confidence_v2(
|
||||
keyword_result=keyword_result_v2,
|
||||
structure_features=structure_features,
|
||||
contradictions=[],
|
||||
resolution={"resolved_count": 0, "total_count": 0},
|
||||
)
|
||||
|
||||
return {
|
||||
"category": keyword_info["category"],
|
||||
"confidence": v2_conf["confidence"],
|
||||
"needs_review": v2_conf["needs_review"],
|
||||
"method": "keyword",
|
||||
"source": "l1",
|
||||
"judgment": v2_conf["judgment"],
|
||||
"matches": keyword_info["all_matches"],
|
||||
"contradictions": [],
|
||||
"v2_confidence": v2_conf,
|
||||
"structure": _build_structure_summary(structure),
|
||||
}
|
||||
|
||||
|
||||
def _path_rule_engine(
|
||||
keyword_info: dict | None,
|
||||
structure: dict,
|
||||
) -> dict:
|
||||
"""路径 B: keyword 50-89%, 规则引擎 + 确信度计算 + 矛盾回溯。
|
||||
|
||||
流程:
|
||||
1. 用 structure 特征构建 features dict
|
||||
2. 遍历所有混淆组解析器, 收集 resolved_types
|
||||
3. 检测矛盾并解决
|
||||
4. 确定最终分类
|
||||
5. 计算 4 因子确信度
|
||||
"""
|
||||
# 1. 结构特征直接作为 features
|
||||
features = dict(structure)
|
||||
|
||||
# 2. 运行所有混淆组解析器
|
||||
resolved_types: dict[str, str] = {}
|
||||
for pair_name in _PAIR_NAMES:
|
||||
try:
|
||||
result = resolve_confusion_pair(features, pair_name)
|
||||
if result["resolved_type"] != "unknown" and result["confidence"] > 0:
|
||||
resolved_types[pair_name] = result["resolved_type"]
|
||||
except Exception as e:
|
||||
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
|
||||
|
||||
features["resolved_types"] = resolved_types
|
||||
|
||||
# 3. 矛盾检测与解决
|
||||
contradictions = detect_contradictions(features)
|
||||
resolution_map: dict[str, Any] = {
|
||||
"resolved_count": 0,
|
||||
"total_count": len(contradictions),
|
||||
}
|
||||
for c in contradictions:
|
||||
try:
|
||||
winner = resolve_contradiction(features, c)
|
||||
if winner:
|
||||
resolution_map[c.get("name", "unknown")] = winner
|
||||
resolution_map["resolved_count"] += 1
|
||||
except Exception as e:
|
||||
logger.debug("[pipeline] 矛盾解决异常: %s", e)
|
||||
|
||||
# 4. 确定最终分类与基础置信度
|
||||
final_category = "unknown"
|
||||
final_base_confidence = 0.0
|
||||
|
||||
# 优先采纳 keyword 判定
|
||||
if keyword_info:
|
||||
final_category = keyword_info["category"]
|
||||
final_base_confidence = keyword_info["confidence"]
|
||||
|
||||
# 如果规则引擎有更高置信度的结果, 则采纳
|
||||
best_resolved_type = None
|
||||
best_resolved_conf = 0.0
|
||||
for pair_name, rtype in resolved_types.items():
|
||||
try:
|
||||
rr = resolve_confusion_pair(features, pair_name)
|
||||
if rr["confidence"] > best_resolved_conf:
|
||||
best_resolved_conf = rr["confidence"]
|
||||
best_resolved_type = rtype
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if best_resolved_type and best_resolved_conf > final_base_confidence:
|
||||
final_category = best_resolved_type
|
||||
final_base_confidence = best_resolved_conf
|
||||
|
||||
# 5. 计算 4 因子确信度
|
||||
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
||||
keyword_result_v2["base_confidence"] = final_base_confidence
|
||||
|
||||
structure_features = _build_structure_features(structure)
|
||||
|
||||
v2_confidence = compute_confidence_v2(
|
||||
keyword_result=keyword_result_v2,
|
||||
structure_features=structure_features,
|
||||
contradictions=contradictions,
|
||||
resolution=resolution_map,
|
||||
)
|
||||
|
||||
# 6. 组装结果
|
||||
return {
|
||||
"category": final_category,
|
||||
"confidence": v2_confidence["confidence"],
|
||||
"needs_review": v2_confidence["needs_review"],
|
||||
"method": "rule_engine",
|
||||
"source": "pipeline",
|
||||
"judgment": v2_confidence["judgment"],
|
||||
"matches": keyword_info["all_matches"] if keyword_info else [],
|
||||
"contradictions": contradictions,
|
||||
"contradiction_resolution": resolution_map,
|
||||
"resolved_types": resolved_types,
|
||||
"v2_confidence": v2_confidence,
|
||||
"structure": _build_structure_summary(structure),
|
||||
}
|
||||
|
||||
|
||||
def _path_llm_assisted(
|
||||
keyword_info: dict | None,
|
||||
structure: dict,
|
||||
llm: Any,
|
||||
) -> dict:
|
||||
"""路径 C: keyword < 50%, LLM 辅助 + 规则引擎验证。
|
||||
|
||||
流程:
|
||||
1. 调用 classify_with_llm 获取 LLM 分类
|
||||
2. 规则引擎验证 LLM 结果
|
||||
3. 矛盾检测
|
||||
4. 确信度计算
|
||||
"""
|
||||
from hina.hina_agent import classify_with_llm
|
||||
|
||||
# 1. LLM 分类
|
||||
llm_result = classify_with_llm(structure, llm)
|
||||
llm_category = llm_result.get("category", "unknown")
|
||||
llm_confidence = llm_result.get("confidence", 0.5)
|
||||
|
||||
# 2. 规则引擎验证 LLM 分类
|
||||
features = dict(structure)
|
||||
validated_category = llm_category
|
||||
validated_confidence = llm_confidence
|
||||
|
||||
for pair_name in _PAIR_NAMES:
|
||||
try:
|
||||
pair_result = resolve_confusion_pair(features, pair_name)
|
||||
if (pair_result["resolved_type"] != "unknown"
|
||||
and pair_result["confidence"] > validated_confidence):
|
||||
validated_category = pair_result["resolved_type"]
|
||||
validated_confidence = pair_result["confidence"]
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 3. 矛盾检测
|
||||
resolved_types: dict[str, str] = {}
|
||||
for pair_name in _PAIR_NAMES:
|
||||
try:
|
||||
rr = resolve_confusion_pair(features, pair_name)
|
||||
if rr["resolved_type"] != "unknown":
|
||||
resolved_types[pair_name] = rr["resolved_type"]
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
features["resolved_types"] = resolved_types
|
||||
contradictions = detect_contradictions(features)
|
||||
|
||||
# 4. 确信度计算
|
||||
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
||||
keyword_result_v2["base_confidence"] = validated_confidence
|
||||
|
||||
structure_features = _build_structure_features(structure)
|
||||
|
||||
v2_confidence = compute_confidence_v2(
|
||||
keyword_result=keyword_result_v2,
|
||||
structure_features=structure_features,
|
||||
contradictions=contradictions,
|
||||
resolution={"resolved_count": 0, "total_count": len(contradictions)},
|
||||
)
|
||||
|
||||
return {
|
||||
"category": validated_category,
|
||||
"confidence": v2_confidence["confidence"],
|
||||
"needs_review": v2_confidence["needs_review"],
|
||||
"method": "llm",
|
||||
"source": "pipeline",
|
||||
"judgment": v2_confidence["judgment"],
|
||||
"matches": keyword_info["all_matches"] if keyword_info else [],
|
||||
"contradictions": contradictions,
|
||||
"llm_raw": llm_result,
|
||||
"v2_confidence": v2_confidence,
|
||||
"structure": _build_structure_summary(structure),
|
||||
}
|
||||
|
||||
|
||||
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def classify_program(cobol_source: str, llm: Any = None) -> dict:
|
||||
"""完整程序类型判定管道。
|
||||
|
||||
流程:
|
||||
1. 并行: detect_keyword() + extract_structure()
|
||||
2. keyword confidence >= 90% -> 直接输出
|
||||
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
||||
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
||||
5. 输出最终 JSON
|
||||
|
||||
Args:
|
||||
cobol_source: COBOL 程序源码文本。
|
||||
llm: 可选的 LLM 客户端实例。
|
||||
在 keyword confidence < 50% 路径中用于 LLM 辅助分类。
|
||||
若为 None 且 keyword < 50%, 则使用规则引擎兜底。
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"category": str, # 程序分类名称
|
||||
"confidence": float, # 综合确信度 (0.0 ~ 1.0)
|
||||
"needs_review": bool, # 是否需要人工审核
|
||||
"method": str, # "keyword" | "rule_engine" | "llm"
|
||||
"source": str, # 结果来源: "l1" / "pipeline"
|
||||
"judgment": str, # auto / review / manual / impossible
|
||||
"matches": list, # L1 关键字匹配详情
|
||||
"contradictions": list, # 矛盾列表
|
||||
"v2_confidence": dict, # 4 因子确信度详情
|
||||
"structure": dict, # 结构特征摘要(调试用)
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 cobol_source 为空或无效。
|
||||
"""
|
||||
if not cobol_source or not cobol_source.strip():
|
||||
return {
|
||||
"category": "unknown",
|
||||
"confidence": 0.0,
|
||||
"needs_review": True,
|
||||
"method": "none",
|
||||
"source": "error",
|
||||
"judgment": "impossible",
|
||||
"matches": [],
|
||||
"contradictions": [],
|
||||
"v2_confidence": {},
|
||||
"structure": {},
|
||||
}
|
||||
|
||||
# ── 第 1 步: 并行执行 keyword 检测和结构提取 ──
|
||||
keyword_matches: list = []
|
||||
structure: dict = {}
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future_keyword = executor.submit(detect_keyword, cobol_source)
|
||||
future_structure = executor.submit(extract_structure, cobol_source)
|
||||
|
||||
for future in as_completed([future_keyword, future_structure]):
|
||||
if future == future_keyword:
|
||||
try:
|
||||
keyword_matches = future.result()
|
||||
except Exception as e:
|
||||
logger.warning("[pipeline] detect_keyword 失败: %s", e)
|
||||
elif future == future_structure:
|
||||
try:
|
||||
structure = future.result()
|
||||
except Exception as e:
|
||||
logger.warning("[pipeline] extract_structure 失败: %s", e)
|
||||
|
||||
# ── 第 2 步: 分析关键字结果, 确定路径 ──
|
||||
keyword_info = _get_best_keyword_match(keyword_matches)
|
||||
max_keyword_confidence = keyword_info["confidence"] if keyword_info else 0.0
|
||||
|
||||
logger.info(
|
||||
"[pipeline] keyword matches=%d, max_confidence=%.2f, paragraphs=%d, files=%d",
|
||||
len(keyword_matches),
|
||||
max_keyword_confidence,
|
||||
structure.get("total_paragraphs", 0),
|
||||
structure.get("file_count", 0),
|
||||
)
|
||||
|
||||
# ── 第 3 步: 根据确信度分路径 ──
|
||||
|
||||
# 路径 A: keyword >= 90% -> 直接输出
|
||||
if max_keyword_confidence >= 0.90:
|
||||
logger.info("[pipeline] 路径 A: keyword 高确信度 (%.2f)", max_keyword_confidence)
|
||||
return _path_keyword_direct(keyword_info, structure)
|
||||
|
||||
# 路径 B: keyword 50-89% -> 规则引擎
|
||||
if max_keyword_confidence >= 0.50:
|
||||
logger.info("[pipeline] 路径 B: keyword 中确信度 (%.2f) -> 规则引擎", max_keyword_confidence)
|
||||
return _path_rule_engine(keyword_info, structure)
|
||||
|
||||
# 路径 C: keyword < 50% -> LLM 辅助
|
||||
if llm is not None:
|
||||
logger.info("[pipeline] 路径 C: keyword 低确信度 (%.2f) -> LLM 辅助", max_keyword_confidence)
|
||||
return _path_llm_assisted(keyword_info, structure, llm)
|
||||
|
||||
# LLM 不可用: 使用规则引擎兜底
|
||||
logger.info("[pipeline] 路径 C(fallback): keyword 低确信度 (%.2f) -> 规则引擎兜底", max_keyword_confidence)
|
||||
result = _path_rule_engine(keyword_info, structure)
|
||||
result["method"] = "rule_engine_fallback"
|
||||
return result
|
||||
@@ -0,0 +1,47 @@
|
||||
"""HINA 混淆组判定规则引擎
|
||||
|
||||
公开 API:
|
||||
resolve_confusion_pair() — 根据 pair_name 调度对应函数
|
||||
detect_contradictions() — 检测可能矛盾的类型对
|
||||
resolve_contradiction() — 解决矛盾,返回胜出的类型名
|
||||
BacktrackResolver — 多轮回溯判定
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .confusion_groups import (
|
||||
resolve_confusion_pair,
|
||||
resolve_matching_vs_keybreak,
|
||||
resolve_dedup_vs_nodedup,
|
||||
resolve_validation_vs_keybreak,
|
||||
resolve_csv_merge_vs_split,
|
||||
resolve_simple_vs_two_stage,
|
||||
resolve_pure_vs_mixed,
|
||||
resolve_division_50_25_100,
|
||||
resolve_mn_output_mode,
|
||||
)
|
||||
from .contradiction import (
|
||||
CONTRADICTION_PAIRS,
|
||||
detect_contradictions,
|
||||
resolve_contradiction,
|
||||
)
|
||||
from .backtrack import BacktrackResolver
|
||||
|
||||
__all__ = [
|
||||
# 混淆组判定
|
||||
"resolve_confusion_pair",
|
||||
"resolve_matching_vs_keybreak",
|
||||
"resolve_dedup_vs_nodedup",
|
||||
"resolve_validation_vs_keybreak",
|
||||
"resolve_csv_merge_vs_split",
|
||||
"resolve_simple_vs_two_stage",
|
||||
"resolve_pure_vs_mixed",
|
||||
"resolve_division_50_25_100",
|
||||
"resolve_mn_output_mode",
|
||||
# 矛盾检测与解决
|
||||
"CONTRADICTION_PAIRS",
|
||||
"detect_contradictions",
|
||||
"resolve_contradiction",
|
||||
# 回溯
|
||||
"BacktrackResolver",
|
||||
]
|
||||
@@ -0,0 +1,96 @@
|
||||
"""回溯机制 — 多轮判定,必要时重新提取特征以化解矛盾。
|
||||
|
||||
BacktrackResolver 封装了多轮判定的核心逻辑:
|
||||
1. 用当前 features 检测矛盾。
|
||||
2. 对有矛盾的对调用 resolve_contradiction。
|
||||
3. 如果仍然存在矛盾,重新提取特征再判定。
|
||||
4. 超过 max_rounds 轮或 30s 超时后降级。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Callable
|
||||
|
||||
from .contradiction import detect_contradictions, resolve_contradiction
|
||||
|
||||
|
||||
class BacktrackResolver:
|
||||
"""多轮回溯判定器。
|
||||
|
||||
Args:
|
||||
structure_extractor: 接受 COBOL 源码字符串,返回 features dict 的可调用对象。
|
||||
"""
|
||||
|
||||
def __init__(self, structure_extractor: Callable[[str], dict[str, Any]]) -> None:
|
||||
self.extract = structure_extractor
|
||||
self.max_rounds = 3
|
||||
|
||||
def _needs_backtrack(self, contradictions: list[dict]) -> bool:
|
||||
"""判断是否需要回溯重提取。
|
||||
|
||||
只要检测到矛盾(列表非空),就需要回溯。
|
||||
"""
|
||||
return len(contradictions) > 0
|
||||
|
||||
def resolve(self, cobol_source: str, initial_features: dict) -> dict[str, Any]:
|
||||
"""多轮判定,30s 超时降级。
|
||||
|
||||
Args:
|
||||
cobol_source: COBOL 程序源码。
|
||||
initial_features: 初始提取的特征字典。
|
||||
|
||||
Returns:
|
||||
最终的特征字典,可能包含 backtrack_rounds 和 backtrack_timeout 信息。
|
||||
"""
|
||||
start = time.time()
|
||||
features: dict[str, Any] = dict(initial_features)
|
||||
features["backtrack_rounds"] = 0
|
||||
|
||||
for round_num in range(1, self.max_rounds + 1):
|
||||
# 超时检查
|
||||
if time.time() - start > 30:
|
||||
features["backtrack_timeout"] = True
|
||||
break
|
||||
|
||||
# 检测矛盾
|
||||
contradictions = detect_contradictions(features)
|
||||
if not contradictions:
|
||||
# 无矛盾,判定完成
|
||||
features["backtrack_resolved"] = True
|
||||
break
|
||||
|
||||
# 解决矛盾
|
||||
for c in contradictions:
|
||||
resolution = resolve_contradiction(features, c)
|
||||
# 将解决结果写入 features
|
||||
resolved_types = features.setdefault("resolved_types", {})
|
||||
resolved_types[f"resolved_{c['name']}"] = resolution
|
||||
|
||||
features["backtrack_rounds"] = round_num
|
||||
|
||||
# 判断是否需要重新提取
|
||||
if self._needs_backtrack(contradictions):
|
||||
# 重新提取特征
|
||||
try:
|
||||
new_features = self.extract(cobol_source)
|
||||
# 合并新特征,保留旧特征中的回溯状态和已解决的矛盾
|
||||
preserved_keys = ("backtrack_rounds", "backtrack_timeout", "resolved_types")
|
||||
preserved = {k: features[k] for k in preserved_keys if k in features}
|
||||
features.update(new_features)
|
||||
features.update(preserved)
|
||||
except Exception:
|
||||
features["backtrack_extract_error"] = True
|
||||
break
|
||||
else:
|
||||
# max_rounds 耗尽,标记降级
|
||||
features["backtrack_degraded"] = True
|
||||
|
||||
# 确保时间字段存在
|
||||
elapsed = time.time() - start
|
||||
features.setdefault("backtrack_timeout", False)
|
||||
features.setdefault("backtrack_resolved", False)
|
||||
features.setdefault("backtrack_degraded", False)
|
||||
features["backtrack_elapsed"] = round(elapsed, 3)
|
||||
|
||||
return features
|
||||
@@ -0,0 +1,235 @@
|
||||
"""混淆组判定规则引擎 — 8 个混淆对的化解函数。
|
||||
|
||||
每个函数接收 features dict,返回:
|
||||
{
|
||||
"resolved_type": str,
|
||||
"confidence": float,
|
||||
"evidence": list[str],
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def resolve_matching_vs_keybreak(features: dict) -> dict:
|
||||
"""区分「マッチング」与「キーブレイク」。
|
||||
|
||||
规则:
|
||||
- IF 三路分支 (comparison ≥ 2) + SELECT 文件数 ≥ 2 → マッチング
|
||||
- IF 双路分支 (equality 为主) + WS-PREV-KEY 存在 + 累加器存在 → キーブレイク
|
||||
"""
|
||||
if_types = features.get("if_types", {})
|
||||
total_ifs = if_types.get("total", 0)
|
||||
comparison_ifs = if_types.get("comparison", 0)
|
||||
equality_ifs = if_types.get("equality", 0)
|
||||
|
||||
select_files = features.get("select_files", {})
|
||||
file_count = len(select_files) if isinstance(select_files, dict) else features.get("file_count", 0)
|
||||
|
||||
variable_patterns = features.get("variable_patterns", {})
|
||||
has_prev_key = variable_patterns.get("has_prev_key", False)
|
||||
has_accumulator = variable_patterns.get("has_accumulator", False)
|
||||
|
||||
evidence: list[str] = []
|
||||
|
||||
# 规则 1: 三路分支 + 多文件 → マッチング
|
||||
if comparison_ifs >= 2 and file_count >= 2:
|
||||
evidence.append(f"三路 IF 分支 (comparison={comparison_ifs}) + SELECT 文件数 >=2 ({file_count}) → マッチング")
|
||||
return {"resolved_type": "マッチング", "confidence": 0.90, "evidence": evidence}
|
||||
|
||||
# 规则 2: 双路 + WS-PREV-KEY + 累加器 → キーブレイク
|
||||
if total_ifs >= 1 and has_prev_key and has_accumulator:
|
||||
evidence.append(f"WS-PREV-KEY 存在 + 累加器存在 + IF 分支 → キーブレイク")
|
||||
return {"resolved_type": "キーブレイク", "confidence": 0.85, "evidence": evidence}
|
||||
|
||||
# 补充规则: SELECT 文件数 >= 2 且 comparison 至少 1 → 倾向マッチング
|
||||
if file_count >= 2 and comparison_ifs >= 1:
|
||||
evidence.append(f"SELECT 文件数 >=2 + comparison IF >=1 → マッチング")
|
||||
return {"resolved_type": "マッチング", "confidence": 0.75, "evidence": evidence}
|
||||
|
||||
# 回退: 无法明确判定
|
||||
evidence.append(f"特征不足: total_ifs={total_ifs}, comparison={comparison_ifs}, "
|
||||
f"file_count={file_count}, has_prev_key={has_prev_key}, "
|
||||
f"has_accumulator={has_accumulator}")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_dedup_vs_nodedup(features: dict) -> dict:
|
||||
"""区分「項目チェック(重複含む)」与「項目チェック(重複含まず)」。
|
||||
|
||||
规则:
|
||||
- WS-PREV-KEY 存在 → 含重复
|
||||
- 无 WS-PREV-KEY → 不含重复
|
||||
"""
|
||||
variable_patterns = features.get("variable_patterns", {})
|
||||
has_prev_key = variable_patterns.get("has_prev_key", False)
|
||||
evidence: list[str] = []
|
||||
|
||||
if has_prev_key:
|
||||
evidence.append("WS-PREV-KEY 存在 → 含重复")
|
||||
return {"resolved_type": "項目チェック(重複含む)", "confidence": 0.90, "evidence": evidence}
|
||||
else:
|
||||
evidence.append("未检测到 WS-PREV-KEY → 不含重复")
|
||||
return {"resolved_type": "項目チェック(重複含まず)", "confidence": 0.85, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_validation_vs_keybreak(features: dict) -> dict:
|
||||
"""区分「編集処理(校验)」与「キーブレイク」。
|
||||
|
||||
规则:
|
||||
- WS-ERR* 相关字段存在 → 校验 (validation)
|
||||
- WS-*CNT 累加计数器存在 → キーブレイク (key break)
|
||||
"""
|
||||
variable_patterns = features.get("variable_patterns", {})
|
||||
has_error_flag = variable_patterns.get("has_error_flag", False)
|
||||
has_counter = variable_patterns.get("has_counter", False)
|
||||
evidence: list[str] = []
|
||||
|
||||
if has_error_flag:
|
||||
evidence.append("WS-ERR* 错误字段存在 → 校验")
|
||||
return {"resolved_type": "編集処理(校验)", "confidence": 0.85, "evidence": evidence}
|
||||
|
||||
if has_counter:
|
||||
evidence.append("WS-*CNT 计数器存在 → キーブレイク")
|
||||
return {"resolved_type": "キーブレイク", "confidence": 0.80, "evidence": evidence}
|
||||
|
||||
evidence.append("既无错误字段也无计数器,无法判定")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_csv_merge_vs_split(features: dict) -> dict:
|
||||
"""区分 CSV 合并与拆分。
|
||||
|
||||
规则:
|
||||
- STRING 语句存在 → 无换行 (合并, merge)
|
||||
- INSPECT REPLACING 存在 → 有换行 (拆分, split)
|
||||
"""
|
||||
has_string = features.get("has_string", False)
|
||||
has_inspect = features.get("has_inspect", False)
|
||||
evidence: list[str] = []
|
||||
|
||||
if has_string:
|
||||
evidence.append("STRING 语句存在 → CSV 合并 (无换行)")
|
||||
return {"resolved_type": "CSV合并", "confidence": 0.85, "evidence": evidence}
|
||||
|
||||
if has_inspect:
|
||||
evidence.append("INSPECT REPLACING 存在 → CSV 拆分 (有换行)")
|
||||
return {"resolved_type": "CSV拆分", "confidence": 0.85, "evidence": evidence}
|
||||
|
||||
evidence.append("既无 STRING 也无 INSPECT REPLACING")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_simple_vs_two_stage(features: dict) -> dict:
|
||||
"""区分「単純マッチング」与「二段階マッチング」。
|
||||
|
||||
规则:
|
||||
- OPEN → CLOSE → 再 OPEN 模式 → 二级匹配
|
||||
- 其他顺序 → 简单匹配
|
||||
"""
|
||||
open_pattern = features.get("open_pattern", "")
|
||||
evidence: list[str] = []
|
||||
|
||||
if open_pattern == "open-close-open":
|
||||
evidence.append("OPEN→CLOSE→再OPEN 模式 → 二级匹配")
|
||||
return {"resolved_type": "二段階マッチング", "confidence": 0.90, "evidence": evidence}
|
||||
else:
|
||||
evidence.append(f"OPEN 模式为 '{open_pattern}' → 简单匹配")
|
||||
return {"resolved_type": "単純マッチング", "confidence": 0.80, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_pure_vs_mixed(features: dict) -> dict:
|
||||
"""区分「純粋マッチング」与「混合マッチング」。
|
||||
|
||||
规则:
|
||||
- variable_patterns 中 has_switch 且 has_counter → 混合(隐含额外键比较)
|
||||
- 有 PERFORM 且 多文件 → 可能混合
|
||||
- 否则 → 纯粹匹配(低确信度,因无法静态确定有无额外键比较)
|
||||
"""
|
||||
variable_patterns = features.get("variable_patterns", {})
|
||||
if_types = features.get("if_types", {})
|
||||
evidence: list[str] = []
|
||||
|
||||
has_switch = variable_patterns.get("has_switch", False)
|
||||
has_counter = variable_patterns.get("has_counter", False)
|
||||
if_count = if_types.get("total", 0)
|
||||
|
||||
if has_switch and has_counter and if_count >= 3:
|
||||
evidence.append("多个变量模式和 IF 分支 → 可能混合匹配")
|
||||
return {"resolved_type": "混合マッチング", "confidence": 0.70, "evidence": evidence}
|
||||
|
||||
evidence.append("无明确混合特征 → 纯粹匹配(需数据验证)")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_division_50_25_100(features: dict) -> dict:
|
||||
"""区分 DIVIDE 被除数常量 50/25/100。
|
||||
|
||||
从 features["divide_constants"] 列表中匹配已知常量。
|
||||
"""
|
||||
divide_constants = features.get("divide_constants", [])
|
||||
evidence: list[str] = []
|
||||
|
||||
if not isinstance(divide_constants, (list, tuple)):
|
||||
evidence.append("divide_constants 格式无效")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
for c in divide_constants:
|
||||
if c in (50, 25, 100):
|
||||
evidence.append(f"DIVIDE 被除数 = {c}")
|
||||
return {"resolved_type": f"DIVIDE_{c}", "confidence": 0.95, "evidence": evidence}
|
||||
|
||||
evidence.append(f"未匹配已知常量 (50/25/100),当前值: {divide_constants}")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
def resolve_mn_output_mode(features: dict) -> dict:
|
||||
"""判断 M:N 输出模式。
|
||||
|
||||
规则:
|
||||
- 根据文件或记录数判断 M:N 关系
|
||||
- 返回 unknown 注明需数据验证
|
||||
"""
|
||||
select_files = features.get("select_files", {})
|
||||
file_count = len(select_files) if isinstance(select_files, dict) else features.get("file_count", 0)
|
||||
evidence: list[str] = []
|
||||
|
||||
# 尝试判断 M:N(从现有特征推断)
|
||||
select_count = len(select_files)
|
||||
total_branches = features.get("total_branches", 0)
|
||||
if select_count >= 2 and total_branches >= 3:
|
||||
evidence.append(f"SELECT={select_count}, 分支={total_branches} → 可能 M:N")
|
||||
return {"resolved_type": "M:N", "confidence": 0.65, "evidence": evidence}
|
||||
|
||||
if file_count >= 3:
|
||||
evidence.append(f"文件数 {file_count} >= 3, 可能为 M:N 关系")
|
||||
return {"resolved_type": "M:N", "confidence": 0.60, "evidence": evidence}
|
||||
|
||||
evidence.append("需数据验证确定 M:N 输出模式")
|
||||
return {"resolved_type": "unknown", "confidence": 0.0, "evidence": evidence}
|
||||
|
||||
|
||||
# ── 调度表 ──────────────────────────────────────────────────────────────────
|
||||
|
||||
_RESOLVER_MAP = {
|
||||
"matching_vs_keybreak": resolve_matching_vs_keybreak,
|
||||
"dedup_vs_nodedup": resolve_dedup_vs_nodedup,
|
||||
"validation_vs_keybreak": resolve_validation_vs_keybreak,
|
||||
"csv_merge_vs_split": resolve_csv_merge_vs_split,
|
||||
"simple_vs_two_stage": resolve_simple_vs_two_stage,
|
||||
"pure_vs_mixed": resolve_pure_vs_mixed,
|
||||
"division_50_25_100": resolve_division_50_25_100,
|
||||
"mn_output_mode": resolve_mn_output_mode,
|
||||
}
|
||||
|
||||
|
||||
def resolve_confusion_pair(features: dict, pair_name: str) -> dict:
|
||||
"""Dispatch to the correct function by pair_name."""
|
||||
resolver = _RESOLVER_MAP.get(pair_name)
|
||||
if resolver is None:
|
||||
return {
|
||||
"resolved_type": "unknown",
|
||||
"confidence": 0.0,
|
||||
"evidence": [f"未知混淆对名称: {pair_name}"],
|
||||
}
|
||||
return resolver(features)
|
||||
@@ -0,0 +1,153 @@
|
||||
"""矛盾检测与解决 — 检测来自不同混淆组的类型冲突。
|
||||
|
||||
CONTRADICTION_PAIRS 定义了可能会矛盾的分类类型对。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# ── 矛盾对定义 ──────────────────────────────────────────────────────────────
|
||||
|
||||
CONTRADICTION_PAIRS: list[dict[str, str]] = [
|
||||
{
|
||||
"name": "matching_vs_keybreak",
|
||||
"type_a": "マッチング",
|
||||
"type_b": "キーブレイク",
|
||||
},
|
||||
{
|
||||
"name": "dedup_vs_nodedup",
|
||||
"type_a": "項目チェック(重複含む)",
|
||||
"type_b": "項目チェック(重複含まず)",
|
||||
},
|
||||
{
|
||||
"name": "validation_vs_keybreak",
|
||||
"type_a": "編集処理(校验)",
|
||||
"type_b": "キーブレイク",
|
||||
},
|
||||
{
|
||||
"name": "csv_merge_vs_split",
|
||||
"type_a": "CSV合并",
|
||||
"type_b": "CSV拆分",
|
||||
},
|
||||
{
|
||||
"name": "simple_vs_two_stage",
|
||||
"type_a": "単純マッチング",
|
||||
"type_b": "二段階マッチング",
|
||||
},
|
||||
{
|
||||
"name": "pure_vs_mixed",
|
||||
"type_a": "純粋マッチング",
|
||||
"type_b": "混合マッチング",
|
||||
},
|
||||
{
|
||||
"name": "division_50_25_100",
|
||||
"type_a": "DIVIDE_50",
|
||||
"type_b": "DIVIDE_100",
|
||||
},
|
||||
{
|
||||
"name": "mn_output_mode",
|
||||
"type_a": "M:N",
|
||||
"type_b": "1:1",
|
||||
},
|
||||
]
|
||||
|
||||
# ── 冲突优先级: 当同一种类型被多个混淆组判定时,优先级高者胜出 ──────────
|
||||
|
||||
TYPE_PRIORITY: dict[str, int] = {
|
||||
"マッチング": 10,
|
||||
"キーブレイク": 9,
|
||||
"項目チェック(重複含む)": 8,
|
||||
"項目チェック(重複含まず)": 8,
|
||||
"編集処理(校验)": 7,
|
||||
"CSV合并": 6,
|
||||
"CSV拆分": 6,
|
||||
"単純マッチング": 5,
|
||||
"二段階マッチング": 5,
|
||||
"純粋マッチング": 4,
|
||||
"混合マッチング": 4,
|
||||
"DIVIDE_50": 3,
|
||||
"DIVIDE_100": 3,
|
||||
"DIVIDE_25": 3,
|
||||
"M:N": 2,
|
||||
"1:1": 2,
|
||||
}
|
||||
|
||||
|
||||
def detect_contradictions(features: dict) -> list[dict]:
|
||||
"""检测可能矛盾的类型对,返回矛盾列表。
|
||||
|
||||
检查 features["resolved_types"] 中已判定的类型,
|
||||
如果同一混淆组内两个类型同时存在,或不同组的类型存在冲突,则记录。
|
||||
|
||||
Args:
|
||||
features: 包含所有已判定的 resolved_types 字典。
|
||||
|
||||
Returns:
|
||||
矛盾列表。每个元素格式: {"name": str, "type_a": str, "type_b": str}
|
||||
"""
|
||||
resolved_types: dict[str, str] = features.get("resolved_types", {})
|
||||
if not resolved_types:
|
||||
return []
|
||||
|
||||
contradictions: list[dict] = []
|
||||
|
||||
for pair in CONTRADICTION_PAIRS:
|
||||
name = pair["name"]
|
||||
type_a = pair["type_a"]
|
||||
type_b = pair["type_b"]
|
||||
|
||||
# 检查该混淆组的判定结果中是否同时包含两个类型
|
||||
for key, resolved_type in resolved_types.items():
|
||||
if resolved_type == type_a:
|
||||
for other_key, other_type in resolved_types.items():
|
||||
if other_key != key and other_type == type_b:
|
||||
contradictions.append({
|
||||
"name": name,
|
||||
"type_a": type_a,
|
||||
"type_b": type_b,
|
||||
"source_a": key,
|
||||
"source_b": other_key,
|
||||
})
|
||||
break
|
||||
break
|
||||
|
||||
return contradictions
|
||||
|
||||
|
||||
def resolve_contradiction(features: dict, contradiction: dict) -> str:
|
||||
"""解决矛盾,返回胜出的类型名。
|
||||
|
||||
策略:
|
||||
1. 根据 TYPE_PRIORITY 取优先级高的类型。
|
||||
2. 若优先级相同,根据 features 中的额外证据选择。
|
||||
|
||||
Args:
|
||||
features: 完整特征字典。
|
||||
contradiction: detect_contradictions 返回的单个矛盾。
|
||||
|
||||
Returns:
|
||||
胜出的类型名称。
|
||||
"""
|
||||
type_a = contradiction["type_a"]
|
||||
type_b = contradiction["type_b"]
|
||||
|
||||
priority_a = TYPE_PRIORITY.get(type_a, 0)
|
||||
priority_b = TYPE_PRIORITY.get(type_b, 0)
|
||||
|
||||
if priority_a > priority_b:
|
||||
return type_a
|
||||
elif priority_b > priority_a:
|
||||
return type_b
|
||||
|
||||
# 优先级相同,尝试使用 confusion_groups 重判定
|
||||
from .confusion_groups import resolve_confusion_pair
|
||||
|
||||
pair_name = contradiction.get("name", "")
|
||||
if pair_name:
|
||||
result = resolve_confusion_pair(features, pair_name)
|
||||
if result.get("confidence", 0) >= 0.80:
|
||||
return result["resolved_type"]
|
||||
|
||||
# 最终回退: 取 type_a
|
||||
return type_a
|
||||
Reference in New Issue
Block a user