Files
cobol-java-v3/parametrized/matching.py
T
hangshuo652 bc1d56d1a4 feat: Phase 2 complete — 13 Phases of COBOL type classification and test benchmark
P0.6: gcov infrastructure
P1: extract_structure output expansion (11 new feature fields)
P2: Confusion group rule engine (8 pairs + contradiction + backtrack)
P3: 4-factor confidence calculation + quality gate update
P4: 33+2 COBOL program type test samples (22 files, 7 categories)
P5: parametrized/ test data generation engine
P6: japanese_data.py lookup tables
P7-10: Type-specific test suites (~159 parametrized tests)
P11: Full classification pipeline (classify_program) + orchestrator integration
P12: Documentation (module-interfaces, test-plan v3.0, coverage-matrix)

Architecture decisions:
- classification_pipeline/ merged to hina/pipeline/
- parametrized/ as independent module
- japanese_data.py as root-level file
- hina/__all__ only exports classify_program()

Co-Authored-By: Claude <noreply@anthropic.com>
2026-06-19 23:51:55 +08:00

195 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""匹配系测试数据生成模块。
提供两种生成器:
- generate_matching_data() — 生成主/从匹配测试数据
- generate_keybreak_data() — 生成 KEY 切中断测试数据
"""
from __future__ import annotations
import random
from typing import Any
def generate_matching_data(
matching_type: str = "1:1",
record_count_r01: int = 10,
record_count_r02: int = 10,
key_match_ratio: float = 1.0,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""生成匹配系测试数据。
参数
----------
matching_type : str
匹配模式,支持:
- "1:1" 主件每条在从件最多命中一条
- "1:N" 主件每条在从件可能命中多条
- "N:1" 从件每条在主件可能命中多条
record_count_r01 : int
主文件(R01)记录条数
record_count_r02 : int
从文件(R02)记录条数
key_match_ratio : float
键值匹配比例,0.0~1.0 之间。
1.0 表示全部匹配,0.0 表示全部不匹配。
返回
-------
tuple[list[dict], list[dict]]
(主文件记录列表, 从文件记录列表)
"""
if matching_type not in ("1:1", "1:N", "N:1"):
raise ValueError(f"不支持的 matching_type{matching_type!r},应为 '1:1' / '1:N' / 'N:1'")
if not 0.0 <= key_match_ratio <= 1.0:
raise ValueError(f"key_match_ratio 必须在 0.0~1.0 之间,收到 {key_match_ratio}")
if record_count_r01 < 0 or record_count_r02 < 0:
raise ValueError("记录数不能为负数")
main_records: list[dict[str, Any]] = []
sub_records: list[dict[str, Any]] = []
# 生成主文件记录
for i in range(record_count_r01):
main_records.append({
"KEY": f"MAIN-{i:04d}",
"DATA": f"main_data_{i}",
"SEQ": i + 1,
})
# 生成从文件记录
matched = 0
unmatched = 0
if matching_type == "1:1":
# 1:1 — 最多让 record_count_r01 条从件匹配
max_match = min(record_count_r01, record_count_r02)
match_count = int(max_match * key_match_ratio)
for i in range(record_count_r02):
if i < match_count and i < record_count_r01:
sub_records.append({
"KEY": f"MAIN-{i:04d}",
"DATA": f"sub_data_{i}",
"SEQ": i + 1,
})
matched += 1
else:
sub_records.append({
"KEY": f"UNMATCHED-SUB-{unmatched:04d}",
"DATA": f"sub_unmatched_{unmatched}",
"SEQ": record_count_r01 + unmatched + 1,
})
unmatched += 1
elif matching_type == "1:N":
# 1:N — 每条主件可能对应多条从件
match_count = int(record_count_r01 * key_match_ratio)
idx = 0
for i in range(record_count_r01):
if i < match_count:
n_per_main = max(1, record_count_r02 // max(1, match_count))
for _ in range(n_per_main):
if idx < record_count_r02:
sub_records.append({
"KEY": f"MAIN-{i:04d}",
"DATA": f"sub_data_{idx}",
"SEQ": idx + 1,
})
idx += 1
else:
if idx < record_count_r02:
sub_records.append({
"KEY": f"UNMATCHED-SUB-{unmatched:04d}",
"DATA": f"sub_unmatched_{unmatched}",
"SEQ": idx + 1,
})
idx += 1
unmatched += 1
# 补齐剩余
while idx < record_count_r02:
sub_records.append({
"KEY": f"UNMATCHED-SUB-{unmatched:04d}",
"DATA": f"sub_unmatched_{unmatched}",
"SEQ": idx + 1,
})
idx += 1
unmatched += 1
elif matching_type == "N:1":
# N:1 — 多条主件对应同一条从件
match_count = int(record_count_r02 * key_match_ratio)
for i in range(record_count_r02):
if i < match_count:
sub_records.append({
"KEY": f"MAIN-{i % max(1, record_count_r01):04d}",
"DATA": f"sub_data_{i}",
"SEQ": i + 1,
})
matched += 1
else:
sub_records.append({
"KEY": f"UNMATCHED-SUB-{unmatched:04d}",
"DATA": f"sub_unmatched_{unmatched}",
"SEQ": i + 1,
})
unmatched += 1
return main_records, sub_records
def generate_keybreak_data(
group_count: int = 3,
records_per_group: int = 2,
sum_type: str = "accumulate",
) -> list[dict[str, Any]]:
"""生成 KEY 切测试数据,组间 KEY 值变化触发中断。
每组内的记录 KEY 值相同;组间 KEY 值递增。
适用于测试 AT END / BREAK / 集计功能。
参数
----------
group_count : int
分组数量,默认 3。
records_per_group : int
每组记录数,默认 2。
sum_type : str
集计类型:
- "accumulate" 累加型(FIELD 值递增)
- "aggregate" 集计型(FIELD 值相同)
- "mark" 标记型(FIELD 为固定标记值)
返回
-------
list[dict]
包含 KEY、FIELD、GROUP、SEQ 等字段的记录列表。
"""
if group_count < 1:
raise ValueError(f"group_count 必须 >= 1,收到 {group_count}")
if records_per_group < 1:
raise ValueError(f"records_per_group 必须 >= 1,收到 {records_per_group}")
if sum_type not in ("accumulate", "aggregate", "mark"):
raise ValueError(f"不支持的 sum_type{sum_type!r}")
records: list[dict[str, Any]] = []
seq = 0
for g in range(group_count):
group_key = f"KEY-{chr(65 + g) if g < 26 else g}" # KEY-A, KEY-B, ...
for r in range(records_per_group):
seq += 1
if sum_type == "accumulate":
field_val = (g + 1) * 100 + r + 1
elif sum_type == "aggregate":
field_val = (g + 1) * 100
else: # mark
field_val = f"MARK-{chr(65 + g)}"
records.append({
"KEY": group_key,
"FIELD": field_val,
"GROUP": g + 1,
"SEQ": seq,
})
return records