a6c454692a
M1: Cache confusion-pair confidences in Path B (eliminate redundant
resolve_confusion_pair re-calls in _path_rule_engine)
M2: Resolve contradictions in Path C instead of hardcoding
resolved_count=0 in _path_llm_assisted
M4: Add DIVIDE_25 to contradiction pair coverage (50-25, 100-25)
and update test_contradiction_pairs_defined to verify all 3 variants
433 lines
16 KiB
Python
433 lines
16 KiB
Python
"""
|
|
完整程序类型判定管道 — classify_program()
|
|
|
|
流程:
|
|
1. 并行: detect_keyword() + extract_structure()
|
|
2. keyword confidence >= 90% -> 直接输出
|
|
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
|
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
|
5. 输出最终 JSON
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Any
|
|
|
|
from hina.classifier import detect_keyword
|
|
from hina.confidence import compute_confidence_v2
|
|
from hina.rule_engine.confusion_groups import resolve_confusion_pair
|
|
from hina.rule_engine.contradiction import (
|
|
CONTRADICTION_PAIRS,
|
|
detect_contradictions,
|
|
resolve_contradiction,
|
|
)
|
|
from cobol_testgen import extract_structure
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 所有可尝试的混淆对名称
|
|
_PAIR_NAMES = [
|
|
"matching_vs_keybreak",
|
|
"dedup_vs_nodedup",
|
|
"validation_vs_keybreak",
|
|
"csv_merge_vs_split",
|
|
"simple_vs_two_stage",
|
|
"pure_vs_mixed",
|
|
"division_50_25_100",
|
|
"mn_output_mode",
|
|
]
|
|
|
|
|
|
# ── 内部工具 ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _get_best_keyword_match(matches: list) -> dict | None:
|
|
"""从 L1 关键字匹配结果中找出最佳匹配。
|
|
|
|
Args:
|
|
matches: detect_keyword() 返回的 list[tuple[str, float, str]]
|
|
|
|
Returns:
|
|
dict | None: {"category", "confidence", "keyword", "all_matches"}
|
|
"""
|
|
if not matches:
|
|
return None
|
|
best = max(matches, key=lambda m: m[1]) # (category, confidence, keyword)
|
|
return {
|
|
"category": best[0],
|
|
"confidence": best[1],
|
|
"keyword": best[2],
|
|
"all_matches": matches,
|
|
}
|
|
|
|
|
|
def _compute_structure_match_score(structure: dict) -> int:
|
|
"""计算结构匹配度评分 (0-5),供 compute_confidence_v2 使用。"""
|
|
return min(
|
|
5,
|
|
bool(structure.get("total_paragraphs", 0)) # 有段落
|
|
+ bool(structure.get("file_count", 0)) # 有文件
|
|
+ bool(len(structure.get("decision_points", []))) # 有决策点
|
|
+ bool(structure.get("if_types", {}).get("total", 0)) # 有 IF
|
|
+ bool(structure.get("branch_tree_obj") is not None), # 有分支树
|
|
)
|
|
|
|
|
|
def _build_structure_summary(structure: dict) -> dict:
|
|
"""从完整结构中提取调试摘要。"""
|
|
return {
|
|
"paragraph_count": structure.get("total_paragraphs", 0),
|
|
"file_count": structure.get("file_count", 0),
|
|
"decision_count": len(structure.get("decision_points", [])),
|
|
"has_call": structure.get("has_call", False),
|
|
"has_divide": structure.get("has_divide", False),
|
|
}
|
|
|
|
|
|
def _build_keyword_result_for_v2(keyword_info: dict | None) -> dict:
|
|
"""构建 compute_confidence_v2 所需的 keyword_result。"""
|
|
if keyword_info:
|
|
return {
|
|
"base_confidence": keyword_info["confidence"],
|
|
"match_count": len(keyword_info["all_matches"]),
|
|
}
|
|
return {"base_confidence": 0.0, "match_count": 0}
|
|
|
|
|
|
def _build_structure_features(structure: dict) -> dict:
|
|
"""构建 compute_confidence_v2 所需的 structure_features。"""
|
|
return {
|
|
"structure_match_score": _compute_structure_match_score(structure),
|
|
"total_paragraphs": structure.get("total_paragraphs", 0),
|
|
}
|
|
|
|
|
|
# ── 分路径逻辑 ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _path_keyword_direct(
|
|
keyword_info: dict,
|
|
structure: dict,
|
|
) -> dict:
|
|
"""路径 A: keyword confidence >= 90%, 直接输出。
|
|
|
|
仍会计算 v2 确信度用于最终 validation,但结果来源标记为 "keyword"。
|
|
"""
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
v2_conf = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=[],
|
|
resolution={"resolved_count": 0, "total_count": 0},
|
|
)
|
|
|
|
return {
|
|
"category": keyword_info["category"],
|
|
"confidence": v2_conf["confidence"],
|
|
"needs_review": v2_conf["needs_review"],
|
|
"method": "keyword",
|
|
"source": "l1",
|
|
"judgment": v2_conf["judgment"],
|
|
"matches": keyword_info["all_matches"],
|
|
"contradictions": [],
|
|
"v2_confidence": v2_conf,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
def _path_rule_engine(
|
|
keyword_info: dict | None,
|
|
structure: dict,
|
|
) -> dict:
|
|
"""路径 B: keyword 50-89%, 规则引擎 + 确信度计算 + 矛盾回溯。
|
|
|
|
流程:
|
|
1. 用 structure 特征构建 features dict
|
|
2. 遍历所有混淆组解析器, 收集 resolved_types
|
|
3. 检测矛盾并解决
|
|
4. 确定最终分类
|
|
5. 计算 4 因子确信度
|
|
"""
|
|
# 1. 结构特征直接作为 features
|
|
features = dict(structure)
|
|
|
|
# 2. 运行所有混淆组解析器
|
|
resolved_types: dict[str, str] = {}
|
|
resolved_confidences: dict[str, float] = {}
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
result = resolve_confusion_pair(features, pair_name)
|
|
if result["resolved_type"] != "unknown" and result["confidence"] > 0:
|
|
resolved_types[pair_name] = result["resolved_type"]
|
|
resolved_confidences[pair_name] = result["confidence"]
|
|
except Exception as e:
|
|
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
|
|
|
|
features["resolved_types"] = resolved_types
|
|
|
|
# 3. 矛盾检测与解决
|
|
contradictions = detect_contradictions(features)
|
|
resolution_map: dict[str, Any] = {
|
|
"resolved_count": 0,
|
|
"total_count": len(contradictions),
|
|
}
|
|
for c in contradictions:
|
|
try:
|
|
winner = resolve_contradiction(features, c)
|
|
if winner:
|
|
resolution_map[c.get("name", "unknown")] = winner
|
|
resolution_map["resolved_count"] += 1
|
|
except Exception as e:
|
|
logger.debug("[pipeline] 矛盾解决异常: %s", e)
|
|
|
|
# 4. 确定最终分类与基础置信度
|
|
final_category = "unknown"
|
|
final_base_confidence = 0.0
|
|
|
|
# 优先采纳 keyword 判定
|
|
if keyword_info:
|
|
final_category = keyword_info["category"]
|
|
final_base_confidence = keyword_info["confidence"]
|
|
|
|
# 如果规则引擎有更高置信度的结果, 则采纳
|
|
# 使用第一轮缓存的结果(M1: 消除冗余重复调用)
|
|
best_resolved_type = None
|
|
best_resolved_conf = 0.0
|
|
for pair_name, rtype in resolved_types.items():
|
|
cached_conf = resolved_confidences.get(pair_name, 0.0)
|
|
if cached_conf > best_resolved_conf:
|
|
best_resolved_conf = cached_conf
|
|
best_resolved_type = rtype
|
|
|
|
if best_resolved_type and best_resolved_conf > final_base_confidence:
|
|
final_category = best_resolved_type
|
|
final_base_confidence = best_resolved_conf
|
|
|
|
# 5. 计算 4 因子确信度
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
keyword_result_v2["base_confidence"] = final_base_confidence
|
|
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
v2_confidence = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=contradictions,
|
|
resolution=resolution_map,
|
|
)
|
|
|
|
# 6. 组装结果
|
|
return {
|
|
"category": final_category,
|
|
"confidence": v2_confidence["confidence"],
|
|
"needs_review": v2_confidence["needs_review"],
|
|
"method": "rule_engine",
|
|
"source": "pipeline",
|
|
"judgment": v2_confidence["judgment"],
|
|
"matches": keyword_info["all_matches"] if keyword_info else [],
|
|
"contradictions": contradictions,
|
|
"contradiction_resolution": resolution_map,
|
|
"resolved_types": resolved_types,
|
|
"v2_confidence": v2_confidence,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
def _path_llm_assisted(
|
|
keyword_info: dict | None,
|
|
structure: dict,
|
|
llm: Any,
|
|
) -> dict:
|
|
"""路径 C: keyword < 50%, LLM 辅助 + 规则引擎验证。
|
|
|
|
流程:
|
|
1. 调用 classify_with_llm 获取 LLM 分类
|
|
2. 规则引擎验证 LLM 结果
|
|
3. 矛盾检测
|
|
4. 确信度计算
|
|
"""
|
|
from hina.hina_agent import classify_with_llm
|
|
|
|
# 1. LLM 分类
|
|
llm_result = classify_with_llm(structure, llm)
|
|
llm_category = llm_result.get("category", "unknown")
|
|
llm_confidence = llm_result.get("confidence", 0.5)
|
|
|
|
# 2. 规则引擎验证 LLM 分类
|
|
features = dict(structure)
|
|
validated_category = llm_category
|
|
validated_confidence = llm_confidence
|
|
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
pair_result = resolve_confusion_pair(features, pair_name)
|
|
if (pair_result["resolved_type"] != "unknown"
|
|
and pair_result["confidence"] > validated_confidence):
|
|
validated_category = pair_result["resolved_type"]
|
|
validated_confidence = pair_result["confidence"]
|
|
except Exception:
|
|
continue
|
|
|
|
# 3. 矛盾检测与解决 (M2: 消除硬编码 resolved_count=0)
|
|
resolved_types: dict[str, str] = {}
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
rr = resolve_confusion_pair(features, pair_name)
|
|
if rr["resolved_type"] != "unknown":
|
|
resolved_types[pair_name] = rr["resolved_type"]
|
|
except Exception:
|
|
continue
|
|
|
|
features["resolved_types"] = resolved_types
|
|
contradictions = detect_contradictions(features)
|
|
|
|
resolution_map: dict[str, Any] = {
|
|
"resolved_count": 0,
|
|
"total_count": len(contradictions),
|
|
}
|
|
for c in contradictions:
|
|
try:
|
|
winner = resolve_contradiction(features, c)
|
|
if winner:
|
|
resolution_map[c.get("name", "unknown")] = winner
|
|
resolution_map["resolved_count"] += 1
|
|
except Exception as e:
|
|
logger.debug("[pipeline] Path C 矛盾解决异常: %s", e)
|
|
|
|
# 4. 确信度计算
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
keyword_result_v2["base_confidence"] = validated_confidence
|
|
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
v2_confidence = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=contradictions,
|
|
resolution=resolution_map,
|
|
)
|
|
|
|
return {
|
|
"category": validated_category,
|
|
"confidence": v2_confidence["confidence"],
|
|
"needs_review": v2_confidence["needs_review"],
|
|
"method": "llm",
|
|
"source": "pipeline",
|
|
"judgment": v2_confidence["judgment"],
|
|
"matches": keyword_info["all_matches"] if keyword_info else [],
|
|
"contradictions": contradictions,
|
|
"llm_raw": llm_result,
|
|
"v2_confidence": v2_confidence,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def classify_program(cobol_source: str, llm: Any = None) -> dict:
|
|
"""完整程序类型判定管道。
|
|
|
|
流程:
|
|
1. 并行: detect_keyword() + extract_structure()
|
|
2. keyword confidence >= 90% -> 直接输出
|
|
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
|
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
|
5. 输出最终 JSON
|
|
|
|
Args:
|
|
cobol_source: COBOL 程序源码文本。
|
|
llm: 可选的 LLM 客户端实例。
|
|
在 keyword confidence < 50% 路径中用于 LLM 辅助分类。
|
|
若为 None 且 keyword < 50%, 则使用规则引擎兜底。
|
|
|
|
Returns:
|
|
dict: {
|
|
"category": str, # 程序分类名称
|
|
"confidence": float, # 综合确信度 (0.0 ~ 1.0)
|
|
"needs_review": bool, # 是否需要人工审核
|
|
"method": str, # "keyword" | "rule_engine" | "llm"
|
|
"source": str, # 结果来源: "l1" / "pipeline"
|
|
"judgment": str, # auto / review / manual / impossible
|
|
"matches": list, # L1 关键字匹配详情
|
|
"contradictions": list, # 矛盾列表
|
|
"v2_confidence": dict, # 4 因子确信度详情
|
|
"structure": dict, # 结构特征摘要(调试用)
|
|
}
|
|
|
|
Raises:
|
|
ValueError: 如果 cobol_source 为空或无效。
|
|
"""
|
|
if not cobol_source or not cobol_source.strip():
|
|
return {
|
|
"category": "unknown",
|
|
"confidence": 0.0,
|
|
"needs_review": True,
|
|
"method": "none",
|
|
"source": "error",
|
|
"judgment": "impossible",
|
|
"matches": [],
|
|
"contradictions": [],
|
|
"v2_confidence": {},
|
|
"structure": {},
|
|
}
|
|
|
|
# ── 第 1 步: 并行执行 keyword 检测和结构提取 ──
|
|
keyword_matches: list = []
|
|
structure: dict = {}
|
|
|
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
future_keyword = executor.submit(detect_keyword, cobol_source)
|
|
future_structure = executor.submit(extract_structure, cobol_source)
|
|
|
|
for future in as_completed([future_keyword, future_structure]):
|
|
if future == future_keyword:
|
|
try:
|
|
keyword_matches = future.result()
|
|
except Exception as e:
|
|
logger.warning("[pipeline] detect_keyword 失败: %s", e)
|
|
elif future == future_structure:
|
|
try:
|
|
structure = future.result()
|
|
except Exception as e:
|
|
logger.warning("[pipeline] extract_structure 失败: %s", e)
|
|
|
|
# ── 第 2 步: 分析关键字结果, 确定路径 ──
|
|
keyword_info = _get_best_keyword_match(keyword_matches)
|
|
max_keyword_confidence = keyword_info["confidence"] if keyword_info else 0.0
|
|
|
|
logger.info(
|
|
"[pipeline] keyword matches=%d, max_confidence=%.2f, paragraphs=%d, files=%d",
|
|
len(keyword_matches),
|
|
max_keyword_confidence,
|
|
structure.get("total_paragraphs", 0),
|
|
structure.get("file_count", 0),
|
|
)
|
|
|
|
# ── 第 3 步: 根据确信度分路径 ──
|
|
|
|
# 路径 A: keyword >= 90% -> 直接输出
|
|
if max_keyword_confidence >= 0.90:
|
|
logger.info("[pipeline] 路径 A: keyword 高确信度 (%.2f)", max_keyword_confidence)
|
|
return _path_keyword_direct(keyword_info, structure)
|
|
|
|
# 路径 B: keyword 50-89% -> 规则引擎
|
|
if max_keyword_confidence >= 0.50:
|
|
logger.info("[pipeline] 路径 B: keyword 中确信度 (%.2f) -> 规则引擎", max_keyword_confidence)
|
|
return _path_rule_engine(keyword_info, structure)
|
|
|
|
# 路径 C: keyword < 50% -> LLM 辅助
|
|
if llm is not None:
|
|
logger.info("[pipeline] 路径 C: keyword 低确信度 (%.2f) -> LLM 辅助", max_keyword_confidence)
|
|
return _path_llm_assisted(keyword_info, structure, llm)
|
|
|
|
# LLM 不可用: 使用规则引擎兜底
|
|
logger.info("[pipeline] 路径 C(fallback): keyword 低确信度 (%.2f) -> 规则引擎兜底", max_keyword_confidence)
|
|
result = _path_rule_engine(keyword_info, structure)
|
|
result["method"] = "rule_engine_fallback"
|
|
return result
|