7d5c82e0e2
Add _resolve_matching_subtype post-processing step in classify_program() that distinguishes matching program subtypes based on key variable naming patterns and file/structural features: Rules (in priority order): 1. 二段階 → 二段階 (already handled by rule engine) 2. 3 files + WS-SAVE-KEY → M:N→MxN (MT20) 3. WS-PREV-KEY present → 混合 (already handled, MT32) 4. WS-MAST-KEY + WS-TRAN-KEY → 1:N (MT02) 5. >=3 KEY vars + >=2 files → M:N (MT33) 6. Otherwise → 1:1 (MT01, MT03, MT18, MT19) Results: MT01→1:1, MT02→1:N, MT03→1:1, MT16/17→二段階, MT18/19→1:1, MT20→M:N→MxN, MT33→M:N Also fix double-backslash regex bug in classifier.py and pipeline.py (r'[-\w]' should be r'[\w-]' for word character class). Regression: 745 passed (unchanged).
510 lines
18 KiB
Python
510 lines
18 KiB
Python
"""
|
|
完整程序类型判定管道 — classify_program()
|
|
|
|
流程:
|
|
1. 并行: detect_keyword() + extract_structure()
|
|
2. keyword confidence >= 90% -> 直接输出
|
|
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
|
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
|
5. 输出最终 JSON
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Any
|
|
|
|
from hina.classifier import detect_keyword
|
|
from hina.confidence import compute_confidence_v2
|
|
from hina.rule_engine.confusion_groups import resolve_confusion_pair
|
|
from hina.rule_engine.contradiction import (
|
|
CONTRADICTION_PAIRS,
|
|
detect_contradictions,
|
|
resolve_contradiction,
|
|
)
|
|
from cobol_testgen import extract_structure
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 所有可尝试的混淆对名称
|
|
_PAIR_NAMES = [
|
|
"matching_vs_keybreak",
|
|
"dedup_vs_nodedup",
|
|
"validation_vs_keybreak",
|
|
"csv_merge_vs_split",
|
|
"simple_vs_two_stage",
|
|
"pure_vs_mixed",
|
|
"division_50_25_100",
|
|
"mn_output_mode",
|
|
]
|
|
|
|
|
|
# ── 内部工具 ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _get_best_keyword_match(matches: list) -> dict | None:
|
|
"""从 L1 关键字匹配结果中找出最佳匹配。
|
|
|
|
Args:
|
|
matches: detect_keyword() 返回的 list[tuple[str, float, str]]
|
|
|
|
Returns:
|
|
dict | None: {"category", "confidence", "keyword", "all_matches"}
|
|
"""
|
|
if not matches:
|
|
return None
|
|
best = max(matches, key=lambda m: m[1]) # (category, confidence, keyword)
|
|
return {
|
|
"category": best[0],
|
|
"confidence": best[1],
|
|
"keyword": best[2],
|
|
"all_matches": matches,
|
|
}
|
|
|
|
|
|
def _compute_structure_match_score(structure: dict) -> int:
|
|
"""计算结构匹配度评分 (0-5),供 compute_confidence_v2 使用。"""
|
|
return min(
|
|
5,
|
|
bool(structure.get("total_paragraphs", 0)) # 有段落
|
|
+ bool(structure.get("file_count", 0)) # 有文件
|
|
+ bool(len(structure.get("decision_points", []))) # 有决策点
|
|
+ bool(structure.get("if_types", {}).get("total", 0)) # 有 IF
|
|
+ bool(structure.get("branch_tree_obj") is not None), # 有分支树
|
|
)
|
|
|
|
|
|
def _build_structure_summary(structure: dict) -> dict:
|
|
"""从完整结构中提取调试摘要。"""
|
|
return {
|
|
"paragraph_count": structure.get("total_paragraphs", 0),
|
|
"file_count": structure.get("file_count", 0),
|
|
"decision_count": len(structure.get("decision_points", [])),
|
|
"has_call": structure.get("has_call", False),
|
|
"has_divide": structure.get("has_divide", False),
|
|
}
|
|
|
|
|
|
def _build_keyword_result_for_v2(keyword_info: dict | None) -> dict:
|
|
"""构建 compute_confidence_v2 所需的 keyword_result。"""
|
|
if keyword_info:
|
|
return {
|
|
"base_confidence": keyword_info["confidence"],
|
|
"match_count": len(keyword_info["all_matches"]),
|
|
"category": keyword_info.get("category"),
|
|
}
|
|
return {"base_confidence": 0.0, "match_count": 0, "category": None}
|
|
|
|
|
|
def _build_structure_features(structure: dict) -> dict:
|
|
"""构建 compute_confidence_v2 所需的 structure_features。"""
|
|
return {
|
|
"structure_match_score": _compute_structure_match_score(structure),
|
|
"total_paragraphs": structure.get("total_paragraphs", 0),
|
|
}
|
|
|
|
|
|
# ── 分路径逻辑 ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _path_keyword_direct(
|
|
keyword_info: dict,
|
|
structure: dict,
|
|
) -> dict:
|
|
"""路径 A: keyword confidence >= 90%, 直接输出。
|
|
|
|
仍会计算 v2 确信度用于最终 validation,但结果来源标记为 "keyword"。
|
|
"""
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
v2_conf = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=[],
|
|
resolution={"resolved_count": 0, "total_count": 0},
|
|
)
|
|
|
|
return {
|
|
"category": keyword_info["category"],
|
|
"confidence": v2_conf["confidence"],
|
|
"needs_review": v2_conf["needs_review"],
|
|
"method": "keyword",
|
|
"source": "l1",
|
|
"judgment": v2_conf["judgment"],
|
|
"matches": keyword_info["all_matches"],
|
|
"contradictions": [],
|
|
"v2_confidence": v2_conf,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
def _path_rule_engine(
|
|
keyword_info: dict | None,
|
|
structure: dict,
|
|
) -> dict:
|
|
"""路径 B: keyword 50-89%, 规则引擎 + 确信度计算 + 矛盾回溯。
|
|
|
|
流程:
|
|
1. 用 structure 特征构建 features dict
|
|
2. 遍历所有混淆组解析器, 收集 resolved_types
|
|
3. 检测矛盾并解决
|
|
4. 确定最终分类
|
|
5. 计算 4 因子确信度
|
|
"""
|
|
# 1. 结构特征直接作为 features
|
|
features = dict(structure)
|
|
|
|
# 2. 运行所有混淆组解析器
|
|
resolved_types: dict[str, str] = {}
|
|
resolved_confidences: dict[str, float] = {}
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
result = resolve_confusion_pair(features, pair_name)
|
|
if result["resolved_type"] != "unknown" and result["confidence"] > 0:
|
|
resolved_types[pair_name] = result["resolved_type"]
|
|
resolved_confidences[pair_name] = result["confidence"]
|
|
except Exception as e:
|
|
logger.debug("[pipeline] 混淆对 %s 解析异常: %s", pair_name, e)
|
|
|
|
features["resolved_types"] = resolved_types
|
|
|
|
# 3. 矛盾检测与解决
|
|
contradictions = detect_contradictions(features)
|
|
resolution_map: dict[str, Any] = {
|
|
"resolved_count": 0,
|
|
"total_count": len(contradictions),
|
|
}
|
|
for c in contradictions:
|
|
try:
|
|
winner = resolve_contradiction(features, c)
|
|
if winner:
|
|
resolution_map[c.get("name", "unknown")] = winner
|
|
resolution_map["resolved_count"] += 1
|
|
except Exception as e:
|
|
logger.debug("[pipeline] 矛盾解决异常: %s", e)
|
|
|
|
# 4. 确定最终分类与基础置信度
|
|
final_category = "unknown"
|
|
final_base_confidence = 0.0
|
|
|
|
# 优先采纳 keyword 判定
|
|
if keyword_info:
|
|
final_category = keyword_info["category"]
|
|
final_base_confidence = keyword_info["confidence"]
|
|
|
|
# 如果规则引擎有更高置信度的结果, 则采纳
|
|
# 使用第一轮缓存的结果(M1: 消除冗余重复调用)
|
|
best_resolved_type = None
|
|
best_resolved_conf = 0.0
|
|
for pair_name, rtype in resolved_types.items():
|
|
cached_conf = resolved_confidences.get(pair_name, 0.0)
|
|
if cached_conf > best_resolved_conf:
|
|
best_resolved_conf = cached_conf
|
|
best_resolved_type = rtype
|
|
|
|
if best_resolved_type and best_resolved_conf > final_base_confidence:
|
|
final_category = best_resolved_type
|
|
final_base_confidence = best_resolved_conf
|
|
|
|
# 5. 计算 4 因子确信度
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
keyword_result_v2["base_confidence"] = final_base_confidence
|
|
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
# 共识检测: L1 关键字分类与规则引擎最终分类一致时给予奖励
|
|
kw_cat = keyword_info["category"] if keyword_info else None
|
|
consensus_cat = kw_cat if (kw_cat and kw_cat == final_category) else None
|
|
|
|
v2_confidence = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=contradictions,
|
|
resolution=resolution_map,
|
|
consensus_category=consensus_cat,
|
|
)
|
|
|
|
# 6. 组装结果
|
|
return {
|
|
"category": final_category,
|
|
"confidence": v2_confidence["confidence"],
|
|
"needs_review": v2_confidence["needs_review"],
|
|
"method": "rule_engine",
|
|
"source": "pipeline",
|
|
"judgment": v2_confidence["judgment"],
|
|
"matches": keyword_info["all_matches"] if keyword_info else [],
|
|
"contradictions": contradictions,
|
|
"contradiction_resolution": resolution_map,
|
|
"resolved_types": resolved_types,
|
|
"v2_confidence": v2_confidence,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
def _path_llm_assisted(
|
|
keyword_info: dict | None,
|
|
structure: dict,
|
|
llm: Any,
|
|
) -> dict:
|
|
"""路径 C: keyword < 50%, LLM 辅助 + 规则引擎验证。
|
|
|
|
流程:
|
|
1. 调用 classify_with_llm 获取 LLM 分类
|
|
2. 规则引擎验证 LLM 结果
|
|
3. 矛盾检测
|
|
4. 确信度计算
|
|
"""
|
|
from hina.hina_agent import classify_with_llm
|
|
|
|
# 1. LLM 分类
|
|
llm_result = classify_with_llm(structure, llm)
|
|
llm_category = llm_result.get("category", "unknown")
|
|
llm_confidence = llm_result.get("confidence", 0.5)
|
|
|
|
# 2. 规则引擎验证 LLM 分类
|
|
features = dict(structure)
|
|
validated_category = llm_category
|
|
validated_confidence = llm_confidence
|
|
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
pair_result = resolve_confusion_pair(features, pair_name)
|
|
if (pair_result["resolved_type"] != "unknown"
|
|
and pair_result["confidence"] > validated_confidence):
|
|
validated_category = pair_result["resolved_type"]
|
|
validated_confidence = pair_result["confidence"]
|
|
except Exception:
|
|
continue
|
|
|
|
# 3. 矛盾检测与解决 (M2: 消除硬编码 resolved_count=0)
|
|
resolved_types: dict[str, str] = {}
|
|
for pair_name in _PAIR_NAMES:
|
|
try:
|
|
rr = resolve_confusion_pair(features, pair_name)
|
|
if rr["resolved_type"] != "unknown":
|
|
resolved_types[pair_name] = rr["resolved_type"]
|
|
except Exception:
|
|
continue
|
|
|
|
features["resolved_types"] = resolved_types
|
|
contradictions = detect_contradictions(features)
|
|
|
|
resolution_map: dict[str, Any] = {
|
|
"resolved_count": 0,
|
|
"total_count": len(contradictions),
|
|
}
|
|
for c in contradictions:
|
|
try:
|
|
winner = resolve_contradiction(features, c)
|
|
if winner:
|
|
resolution_map[c.get("name", "unknown")] = winner
|
|
resolution_map["resolved_count"] += 1
|
|
except Exception as e:
|
|
logger.debug("[pipeline] Path C 矛盾解决异常: %s", e)
|
|
|
|
# 4. 确信度计算
|
|
keyword_result_v2 = _build_keyword_result_for_v2(keyword_info)
|
|
keyword_result_v2["base_confidence"] = validated_confidence
|
|
|
|
structure_features = _build_structure_features(structure)
|
|
|
|
v2_confidence = compute_confidence_v2(
|
|
keyword_result=keyword_result_v2,
|
|
structure_features=structure_features,
|
|
contradictions=contradictions,
|
|
resolution=resolution_map,
|
|
)
|
|
|
|
return {
|
|
"category": validated_category,
|
|
"confidence": v2_confidence["confidence"],
|
|
"needs_review": v2_confidence["needs_review"],
|
|
"method": "llm",
|
|
"source": "pipeline",
|
|
"judgment": v2_confidence["judgment"],
|
|
"matches": keyword_info["all_matches"] if keyword_info else [],
|
|
"contradictions": contradictions,
|
|
"llm_raw": llm_result,
|
|
"v2_confidence": v2_confidence,
|
|
"structure": _build_structure_summary(structure),
|
|
}
|
|
|
|
|
|
# ── 主入口 ────────────────────────────────────────────────────────────────────
|
|
|
|
# ── 匹配子类型解析 ──────────────────────────────────────────────────────────
|
|
|
|
_MATCHING_SUBTYPE_RULES = [
|
|
# (match_fn, subtype)
|
|
# 按优先级从高到低排列
|
|
]
|
|
|
|
|
|
def _resolve_matching_subtype(
|
|
result: dict,
|
|
cobol_source: str,
|
|
structure: dict,
|
|
) -> dict:
|
|
"""匹配程序的子类型区分后处理。
|
|
|
|
在 classify_program 判定为 マッチング 后,进一步区分子类型:
|
|
- 1:1 マッチング / 1:N / N:1 / M:N / M:N→M 等
|
|
|
|
Args:
|
|
result: classify_program 的返回结果。
|
|
cobol_source: 原始 COBOL 源码。
|
|
structure: extract_structure 的返回结构。
|
|
|
|
Returns:
|
|
更新后的 result,增加 "subtype" 字段。
|
|
"""
|
|
category = result.get("category", "")
|
|
if "マッチング" not in category and "キーブレイク" not in category:
|
|
return result # 非匹配程序不做子类型区分
|
|
|
|
src_upper = cobol_source.upper()
|
|
import re
|
|
|
|
# 0. 二段階マッチング — 已在规则引擎中处理
|
|
if "二段階" in category:
|
|
result["subtype"] = "二段階"
|
|
return result
|
|
|
|
# 1. M:N→MxN 直積 — 特征: WRITE + WS-SAVE-KEY + 3 文件
|
|
if structure.get("file_count", 0) >= 3 and 'WS-SAVE' in src_upper:
|
|
result["subtype"] = "M:N→MxN"
|
|
return result
|
|
|
|
# 2. 混合匹配 (WS-PREV-KEY 存在)
|
|
if 'WS-PREV-KEY' in src_upper:
|
|
result["subtype"] = "混合"
|
|
return result
|
|
|
|
# 3. 检查键变量命名模式
|
|
key_vars = set(re.findall(r'WS-[\w-]*KEY[A-Z0-9-]*', src_upper))
|
|
|
|
# 不对称键名 → 1:N 或 N:1 (WS-MAST-KEY + WS-TRAN-KEY)
|
|
has_master = any('MAST' in k for k in key_vars)
|
|
has_tran = any('TRAN' in k for k in key_vars)
|
|
if has_master and has_tran:
|
|
result["subtype"] = "1:N"
|
|
return result
|
|
|
|
# 4. 多个键名 → 多文件匹配 (M:N 模式)
|
|
if len(key_vars) >= 3 and structure.get("file_count", 0) >= 2:
|
|
result["subtype"] = "M:N"
|
|
return result
|
|
|
|
# 5. 对称键名 → 默认为 1:1
|
|
result["subtype"] = "1:1"
|
|
return result
|
|
|
|
|
|
def classify_program(cobol_source: str, llm: Any = None) -> dict:
|
|
"""完整程序类型判定管道。
|
|
|
|
流程:
|
|
1. 并行: detect_keyword() + extract_structure()
|
|
2. keyword confidence >= 90% -> 直接输出
|
|
3. keyword 50-89% -> 规则引擎 + 确信度计算 + 矛盾回溯
|
|
4. keyword < 50% -> LLM 辅助 + 规则引擎验证
|
|
5. 输出最终 JSON
|
|
|
|
Args:
|
|
cobol_source: COBOL 程序源码文本。
|
|
llm: 可选的 LLM 客户端实例。
|
|
在 keyword confidence < 50% 路径中用于 LLM 辅助分类。
|
|
若为 None 且 keyword < 50%, 则使用规则引擎兜底。
|
|
|
|
Returns:
|
|
dict: {
|
|
"category": str, # 程序分类名称
|
|
"confidence": float, # 综合确信度 (0.0 ~ 1.0)
|
|
"needs_review": bool, # 是否需要人工审核
|
|
"method": str, # "keyword" | "rule_engine" | "llm"
|
|
"source": str, # 结果来源: "l1" / "pipeline"
|
|
"judgment": str, # auto / review / manual / impossible
|
|
"matches": list, # L1 关键字匹配详情
|
|
"contradictions": list, # 矛盾列表
|
|
"v2_confidence": dict, # 4 因子确信度详情
|
|
"structure": dict, # 结构特征摘要(调试用)
|
|
}
|
|
|
|
Raises:
|
|
ValueError: 如果 cobol_source 为空或无效。
|
|
"""
|
|
if not cobol_source or not cobol_source.strip():
|
|
return {
|
|
"category": "unknown",
|
|
"confidence": 0.0,
|
|
"needs_review": True,
|
|
"method": "none",
|
|
"source": "error",
|
|
"judgment": "impossible",
|
|
"matches": [],
|
|
"contradictions": [],
|
|
"v2_confidence": {},
|
|
"structure": {},
|
|
}
|
|
|
|
# ── 第 1 步: 并行执行 keyword 检测和结构提取 ──
|
|
keyword_matches: list = []
|
|
structure: dict = {}
|
|
|
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
|
future_keyword = executor.submit(detect_keyword, cobol_source)
|
|
future_structure = executor.submit(extract_structure, cobol_source)
|
|
|
|
for future in as_completed([future_keyword, future_structure]):
|
|
if future == future_keyword:
|
|
try:
|
|
keyword_matches = future.result()
|
|
except Exception as e:
|
|
logger.warning("[pipeline] detect_keyword 失败: %s", e)
|
|
elif future == future_structure:
|
|
try:
|
|
structure = future.result()
|
|
except Exception as e:
|
|
logger.warning("[pipeline] extract_structure 失败: %s", e)
|
|
|
|
# ── 第 2 步: 分析关键字结果, 确定路径 ──
|
|
keyword_info = _get_best_keyword_match(keyword_matches)
|
|
max_keyword_confidence = keyword_info["confidence"] if keyword_info else 0.0
|
|
|
|
logger.info(
|
|
"[pipeline] keyword matches=%d, max_confidence=%.2f, paragraphs=%d, files=%d",
|
|
len(keyword_matches),
|
|
max_keyword_confidence,
|
|
structure.get("total_paragraphs", 0),
|
|
structure.get("file_count", 0),
|
|
)
|
|
|
|
# ── 第 3 步: 根据确信度分路径 ──
|
|
|
|
# 路径 A: keyword >= 90% -> 直接输出
|
|
if max_keyword_confidence >= 0.90:
|
|
logger.info("[pipeline] 路径 A: keyword 高确信度 (%.2f)", max_keyword_confidence)
|
|
result = _path_keyword_direct(keyword_info, structure)
|
|
|
|
# 路径 B: keyword 50-89% -> 规则引擎
|
|
elif max_keyword_confidence >= 0.50:
|
|
logger.info("[pipeline] 路径 B: keyword 中确信度 (%.2f) -> 规则引擎", max_keyword_confidence)
|
|
result = _path_rule_engine(keyword_info, structure)
|
|
|
|
# 路径 C: keyword < 50% -> LLM 辅助
|
|
elif llm is not None:
|
|
logger.info("[pipeline] 路径 C: keyword 低确信度 (%.2f) -> LLM 辅助", max_keyword_confidence)
|
|
result = _path_llm_assisted(keyword_info, structure, llm)
|
|
|
|
# LLM 不可用: 使用规则引擎兜底
|
|
else:
|
|
logger.info("[pipeline] 路径 C(fallback): keyword 低确信度 (%.2f) -> 规则引擎兜底", max_keyword_confidence)
|
|
result = _path_rule_engine(keyword_info, structure)
|
|
result["method"] = "rule_engine_fallback"
|
|
|
|
# ── 第 4 步: 匹配子类型区分(仅对匹配/键中断程序)──
|
|
result = _resolve_matching_subtype(result, cobol_source, structure)
|
|
return result
|