Files
cobol-java-v3/cobol_testgen/__init__.py
T
NB-076 4a140ff9e5 fix: 真实分支覆盖率99.9% — 条件解析器全面强化
## 修复内容

### parse_single_condition 5项强化 (cond.py)
- 下划线字段名:  加入  字符类
- FUNCTION MOD:  合成字段处理
- 算术表达式优先: 交换标准/算术regex顺序
- 下标剥离:  →
- 空值处理:  →

### 约束通过性 4项修复 (__init__.py)
- 算术表达式直接通过:  不过滤
- 下标基名匹配:  匹配
- 子字段识别:  解析后通过
- _FILE_STATUS 合成字段通过

### EXEC SQL与copybook (__init__.py, read.py)
- generate_data 新增 copybook_dirs 参数
- resolve_sql_includes 集成到数据生成流程
- SQLCA字段在resolve后注入

### _resolve_field 强化 (__init__.py)
- 原逻辑只识别显式  下标
- 新增: OF剥离后检查、基名+后缀匹配
- 保持算术表达式不变

## 最终真实结果
- 43/43程序识别: 3,178 分支
- S15回归: 17/17 PASS
- 100%程序: 41/43
- 剩余2个未覆盖: 变量下标引用 (体系限制)
- 所有覆盖率数字可复现、无假数据

Co-Authored-By: Claude <noreply@anthropic.com>
2026-06-24 23:08:24 +08:00

1124 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""COBOL Test Data Generator — 模块化版入口
公开 API:
extract_structure() — 解析 COBOL 控制流 → dict
generate_data() — 生成测试数据 → list[dict]
incremental_supplement — 差分补充数据 → list[dict]
"""
import os
import sys
import json
import re
import logging
from datetime import datetime
from pathlib import Path
# ── 配置(必须放在本地模块导入之前,避免循环导入) ──
CONFIG = {
'abend_programs': ['SUB03END'],
}
from .read import preprocess, extract_data_division, extract_procedure_division
from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements
from .read import parse_file_control, resolve_sql_includes, strip_exec_sql_from_data_div
from .core import build_branch_tree, classify_field_roles, _init_child_names, sql_register_virtual_fields, _find_multi_write_fds
from .cond import parse_single_condition, is_field, collect_leaves
from .pipeline_bridge import build_branch_tree_fallback
from .design_mcdc import enum_paths as mcdc_enum_paths, _filter_stop
from .design import enum_paths, generate_records, get_term_type, extend_abend_programs
from .output import output_json, output_input_files
from .coverage import run_coverage, generate_coverage_index, collect_decision_points, mark_coverage
from japanese_data import generate_fullwidth_text, generate_halfwidth_katakana, generate_wareki_date
try:
from .runner import run_and_compare, run_all, GroupInfo, GroupResult
_HAVE_RUNNER = True
except ImportError:
_HAVE_RUNNER = False
try:
from .gcov import run_gcov
_HAVE_GCOV = True
except ImportError:
_HAVE_GCOV = False
try:
from .to_sql import collect_sql_meta, build_db_input
_HAVE_TOSQL = True
except ImportError:
_HAVE_TOSQL = False
logger = logging.getLogger(__name__)
__all__ = [
"extract_structure",
"generate_data",
"incremental_supplement",
"CONFIG",
"generate_fullwidth_text",
"generate_halfwidth_katakana",
"generate_wareki_date",
]
# ── OCCURS 展开 ──
def _add_subscript(name, occ):
"""追加或扩展下标:WS-CELL → WS-CELL(1), WS-CELL(1) → WS-CELL(1,2)"""
if name.endswith(')'):
return name[:-1] + f',{occ})'
return name + f'({occ})'
def expand_occurs(fields):
"""展开 OCCURS 字段为下标副本。递归处理嵌套 OCCURS。"""
result = []
i = 0
while i < len(fields):
f = fields[i]
if f.get('occurs', 0) > 0 and not f.get('is_88'):
children = []
j = i + 1
while j < len(fields):
child = fields[j]
if child.get('is_88'):
children.append(child)
j += 1
continue
if child['level'] <= f['level'] or child.get('level') == 77:
break
children.append(child)
j += 1
if children:
group = dict(f)
group['occurs'] = 0
result.append(group)
for occ in range(1, f['occurs'] + 1):
for child in children:
copy = dict(child)
if child.get('occurs', 0) == 0:
copy['occurs'] = 0
copy['occurs_depending'] = f.get('occurs_depending')
if child.get('is_88'):
parent = child.get('parent') or f['name']
copy['parent'] = _add_subscript(parent, occ)
copy['name'] = _add_subscript(child['name'], occ)
else:
copy['name'] = _add_subscript(child['name'], occ)
result.append(copy)
else:
for occ in range(1, f['occurs'] + 1):
copy = dict(f)
copy['name'] = _add_subscript(f['name'], occ)
copy['occurs'] = 0
copy['occurs_depending'] = f.get('occurs_depending')
result.append(copy)
i = j
else:
result.append(f)
i += 1
if any(f.get('occurs', 0) > 0 for f in result):
return expand_occurs(result)
return result
# ── PREV 连锁 ──
def _constraint_in(cons, field, op, value, want):
for c in cons:
if len(c) == 4 and c[0] == field and c[1] == op and c[2] == value and c[3] == want:
return True
return False
def _inc_str(s, length):
try:
return str(int(s) + 1).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 9Zz\xff':
c[i] = chr(ord(c[i]) + 1)
break
if c[i] == ' ':
c[i] = '0'
break
if c[i] == '9':
c[i] = '0'
return ''.join(c)
def _dec_str(s, length):
try:
n = max(0, int(s) - 1)
return str(n).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 0Aa\x00':
c[i] = chr(ord(c[i]) - 1)
break
if c[i] == ' ':
break
if c[i] == '0':
c[i] = '9'
return ''.join(c)
def _field_length(fname, fields):
for f in fields:
if f['name'] == fname:
pi = f.get('pic_info', {})
return pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0) or 1
return 1
def _chain_prev(records, path_infos, fields, fd_fields, field_to_fd, open_dir):
"""跨记录 PREV 连锁。修改 records 使批次执行的路径与实际比较一致。
每个路径 k-1 的约束(PREV OP CURRENT)对应批次中 loop iter k-1 的实际比较:
PREV = records[prev_src].R01 (程序内部保持的前值)
CURRENT = records[k].R01 (当前读入值)
本函数调整 records[k] 的字段以保证交叉记录比较满足路径约束。
"""
N = len(records)
if N < 2:
return
key_fields = []
time_start_field = None
time_end_field = None
for fname in records[0]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[0]:
if 'EMP-ID' in fname or 'APPL-DATE' in fname:
key_fields.append(fname)
if 'END-TIME' in fname:
time_end_field = fname
if 'START-TIME' in fname:
time_start_field = fname
prev_src = 0
for k in range(1, N):
if k - 1 >= len(path_infos):
break
cons = path_infos[k - 1][0]
is_same_key = all(
_constraint_in(cons, f'WRK-PREV-{fn[3:]}', '=', fn, True)
for fn in key_fields
) if key_fields else False
is_overlap = is_same_key and time_end_field and time_start_field and \
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, True)
is_normal = is_same_key and time_end_field and time_start_field and \
(_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '<=', time_start_field, True) or
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, False))
for fname in records[prev_src]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[k]:
records[k][prev_name] = records[prev_src][fname]
if is_same_key:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
records[k][kf] = records[prev_src][kf]
if is_normal and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end >= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _inc_str(prev_end, length)
if is_overlap and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end <= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _dec_str(prev_end, length) if prev_end else '0' * length
else:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
if records[k][kf] == records[prev_src][kf]:
length = _field_length(kf, fields)
records[k][kf] = _inc_str(str(records[k][kf]), length)
records[k]['_w02_path'] = is_same_key and time_end_field and time_start_field and not is_overlap
records[k]['_overlap_path'] = is_overlap
for fn in list(records[k].keys()):
if fn.startswith('R01') and not fn.startswith('R01INNREC'):
wfn = 'W01' + fn[3:]
if wfn in records[k]:
records[k][wfn] = records[k][fn]
if is_overlap:
pass
else:
prev_src = k
# ── 入口 ──
def main():
if len(sys.argv) < 2:
print("用法: python -m cobol_testgen <cobol文件1> [cobol文件2 ...] [输出目录]")
sys.exit(1)
args = sys.argv[1:]
do_run = False
gcov_mode = False
temp_dir = None
if '--run' in args:
do_run = True
args.remove('--run')
if '--gcov' in args:
gcov_mode = True
args.remove('--gcov')
i = 0
while i < len(args):
if args[i] == '--temp-dir':
if i + 1 < len(args):
temp_dir = args[i + 1]
args.pop(i + 1)
args.pop(i)
else:
args.pop(i)
break
elif args[i].startswith('--temp-dir='):
temp_dir = args[i].split('=', 1)[1]
args.pop(i)
break
else:
i += 1
cobol_files = []
outdir = None
for a in args:
p = Path(a)
if p.is_dir():
outdir = p
elif p.suffix.upper() in ('.CBL', '.COB', '.CPY'):
cobol_files.append(p)
else:
print(f"警告:跳过未知参数 {a}")
if not cobol_files:
print("错误:未找到任何 COBOL 文件")
sys.exit(1)
if outdir is None:
outdir = cobol_files[0].parent
outdir.mkdir(parents=True, exist_ok=True)
(outdir / 'logs').mkdir(parents=True, exist_ok=True)
log_path = outdir / 'logs' / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log"
fh = logging.FileHandler(log_path, encoding="utf-8", mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
))
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(logging.Formatter("%(message)s"))
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
root_logger.addHandler(fh)
root_logger.addHandler(sh)
programs = []
for filepath in cobol_files:
if not filepath.exists():
logger.error(f"错误:文件不存在 {filepath}")
continue
source = filepath.read_text(encoding='utf-8')
source = resolve_copybooks(
source,
str(filepath.parent),
extra_search_paths=[str(filepath.parent / '..' / 'cpy')],
)
source = resolve_sql_includes(source, str(filepath.parent))
preprocessed = preprocess(source)
file_sec = parse_file_section(preprocessed)
data_div = extract_data_division(preprocessed)
if data_div:
data_div, declared_columns = strip_exec_sql_from_data_div(data_div)
else:
declared_columns = {}
if not data_div:
logger.error(f"错误:{filepath.name} 中没有 DATA DIVISION。")
continue
data_fields = parse_data_division(data_div)
if not data_fields:
logger.error(f"错误:{filepath.name} 中没有找到含 PIC 的字段。")
continue
fields_dict = []
parent_pic = {}
filler_counter = 0
for f in data_fields:
pi = f.pic_info
name = f.name
if name == 'FILLER':
filler_counter += 1
if filler_counter > 1:
name = f'FILLER_{filler_counter}'
entry = {
'name': name,
'level': f.level,
'pic': f.pic,
'pic_info': {
'type': pi.type if pi else 'unknown',
'digits': pi.digits if pi else 0,
'decimal': pi.decimal if pi else 0,
'length': pi.length if pi else 0,
'signed': pi.signed if pi else False,
},
'value': f.value,
'values': f.values,
'section': f.section,
'is_filler': f.is_filler,
'redefines': f.redefines,
'usage': f.usage,
'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
if f.parent and f.parent in parent_pic:
entry['pic_info'] = dict(parent_pic[f.parent])
else:
parent_pic[name] = entry['pic_info']
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
sql_register_virtual_fields(fields_dict)
fd_fields = {}
field_to_fd = {}
if file_sec:
for fd_name, rec_names in file_sec.items():
fds = []
seen = set()
for rec in rec_names:
if rec not in seen:
fds.append(rec)
seen.add(rec)
for child in _init_child_names(rec, fields_dict):
if child not in seen:
fds.append(child)
seen.add(child)
fd_fields[fd_name] = fds
for child in fds:
field_to_fd[child] = fd_name
logger.info(f"\n========== {filepath.name} ==========")
logger.info(f"\n字段列表:")
logger.info(f"{'层级':<6} {'名称':<25} {'PIC':<15} {'类型':<12} {'长度':<5}")
logger.info("-" * 65)
for f in fields_dict:
pi = f['pic_info']
t = pi.get('type', '?')
l = pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0)
pic_display = str(f.get('pic', '')) if f.get('pic') else ('88-level' if f.get('is_88') else '')
logger.info(f"{f['level']:<6} {f['name']:<25} {pic_display:<15} {t:<12} {l:<5}")
proc_div = extract_procedure_division(preprocessed)
branch_paths = []
assignments = {}
if proc_div:
branch_tree, assignments = build_branch_tree(proc_div, fields_dict, full_source=preprocessed)
roles = classify_field_roles(branch_tree, assignments, fields_dict,
source=preprocessed, proc_text=proc_div)
logger.info(f"\n字段角色(输入/输出/出入/未用):")
for f in fields_dict:
if f.get('is_88'):
continue
logger.info(f" {f['name']:<30} {roles.get(f['name'], '?')}")
abend_list = CONFIG.get('abend_programs', [])
if abend_list:
extend_abend_programs(abend_list)
branch_paths_with_assigns = enum_paths(branch_tree, fields_dict)
path_infos = []
for c, a in branch_paths_with_assigns:
filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
def _is_skip(cons):
eq1_true = 0
other = 0
for c in cons:
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
eq1_true += 1
else:
other += 1
return eq1_true > 0 and other == 0
before = len(path_infos)
path_infos = [p for p in path_infos if not _is_skip(p[0])]
after = len(path_infos)
logger.info(f" SKIP 过滤: {before} -> {after} 条路径(预期减少 1")
open_dir = scan_open_statements(proc_div) if proc_div else {}
if proc_div:
logger.info(f"\n分支路径数:{len(branch_paths_with_assigns)}")
for i, (path_cons, _path_assign) in enumerate(branch_paths_with_assigns):
descs = []
for c in path_cons:
if len(c) == 4:
field, op, val, want = c
if op == 'not_in':
descs.append(f"{field} not in {val}")
else:
descs.append(f"{field} {op} {val} ({'T' if want else 'F'})")
logger.debug(f" 路径 {i + 1}: {', '.join(descs)}")
else:
logger.warning("\n没有找到 PROCEDURE DIVISION。")
branch_paths_with_assigns = [([], {})]
path_infos = [([], {}, 'normal')]
roles = {f['name']: 'unused' for f in fields_dict}
records, _, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
def _is_eof_path(cons):
last_eq1_true = -1
for i, c in enumerate(cons):
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
last_eq1_true = i
if last_eq1_true < 0:
return False
for i in range(last_eq1_true + 1, len(cons)):
if len(cons[i]) == 4 and cons[i][0] == 'WRK-R01EOF':
return False
return True
eof_mask = [_is_eof_path(c) for c, a, t in path_infos]
eof_count = sum(eof_mask)
if eof_count:
term_types = ['eof' if e else t for e, t in zip(eof_mask, term_types)]
logger.info(f" EOF 路径: {eof_count} 条(将单独执行)")
multi_write_fds = _find_multi_write_fds(branch_tree, field_to_fd) if proc_div and branch_tree else set()
if multi_write_fds:
logger.info(f" 检测到多 WRITE FD: {', '.join(sorted(multi_write_fds))}")
_chain_prev(records, path_infos, fields_dict, fd_fields, field_to_fd, open_dir)
if _HAVE_TOSQL:
sql_meta = collect_sql_meta(assignments, declared_columns)
db_input = build_db_input(
branch_paths_with_assigns, fields_dict, assignments, sql_meta, declared_columns,
records=records,
)
else:
db_input = None
(outdir / 'json').mkdir(parents=True, exist_ok=True)
outpath = outdir / 'json' / (filepath.stem + '.json')
output_json(records, outpath, roles,
fd_fields=fd_fields, field_to_fd=field_to_fd,
open_dir=open_dir,
term_types=term_types,
db_input=db_input if db_input else None,
data_fields=fields_dict)
output_input_files(records, outdir / 'input', filepath.stem, roles,
fd_fields, field_to_fd, open_dir,
term_types=term_types)
gcov_data = None
if gcov_mode and proc_div and _HAVE_GCOV:
select_info = parse_file_control(preprocessed)
_temp = temp_dir or str(outdir / '.gcov_cache')
source_dir = str(filepath.parent)
expected_records: list[dict] = [{}] * len(records)
if file_sec and os.path.exists(outpath):
with open(outpath, encoding='utf-8') as f:
full_json = json.load(f)
json_records = full_json.get('records', [])
for i in range(len(records)):
exp = {}
if i < len(json_records):
json_rec = json_records[i]
for fd_name in file_sec:
eo = json_rec.get('expected_output', {})
if fd_name in eo:
exp.update(eo[fd_name])
expected_records[i] = exp
group_results = run_all(
filepath.stem, str(outdir), _temp,
fields_dict, fd_fields, select_info, open_dir,
term_types, records, expected_records=expected_records,
source_dir=source_dir, path_infos=path_infos,
multi_write_fds=multi_write_fds,
)
gcov_data = run_gcov(filepath.stem, _temp)
passed = sum(1 for r in group_results if r.passed)
total = len(group_results)
logger.info(f"\n 执行验证: {passed}/{total} 组通过")
if passed < total:
for r in group_results:
if not r.passed and r.details:
fails = [d for d in r.details if not d.match][:3]
for d in fails:
logger.warning(f" [{r.name}] {d.field}: "
f"期望={d.expected!r}, 实际={d.actual!r}")
if do_run and proc_div and _HAVE_RUNNER:
select_info = parse_file_control(preprocessed)
run_and_compare(
filepath.stem, str(outdir), fields_dict,
fd_fields, select_info, open_dir,
term_types, records,
)
logger.info(f"\n输出:{outpath}{len(records)} 条记录)")
logger.debug(f"\n记录明细:")
for i, rec in enumerate(records, 1):
vals = []
for f in fields_dict:
r = roles.get(f['name'], '?')
marker = f"[{r[0].upper()}]" if r != '?' and r != 'unused' else ''
vals.append(f"{marker}{f['name']}={rec.get(f['name'], '?')}")
logger.debug(f" 记录 {i}: {' | '.join(vals)}")
(outdir / 'coverage').mkdir(parents=True, exist_ok=True)
cov_prefix = str(outdir / 'coverage' / filepath.stem)
index_relpath = 'index.html'
cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict,
source, cov_prefix, index_relpath=index_relpath,
gcov_data=gcov_data)
programs.append(cov_result)
if programs:
generate_coverage_index(programs, outdir / 'coverage')
logger.info(f"\n覆盖率总览:{outdir / 'coverage' / 'index.html'}")
# ════════════════════════════════════════════
# Phase 1: 可编程 API(供 orchestrator.py 调用)
# ════════════════════════════════════════════
def extract_structure(cobol_source: str) -> dict:
"""分析 COBOL 源码的结构,返回结构摘要。不生成测试数据,只做静态分析。
Returns:
dict with: paragraphs, decision_points, branch_tree, file_count,
open_directions, has_search_all, has_evaluate,
has_call, has_break, total_branches, total_paragraphs
"""
preprocessed = preprocess(cobol_source)
data_div = extract_data_division(preprocessed)
data_fields = parse_data_division(data_div) if data_div else []
fields_dict = []
for idx, f in enumerate(data_fields):
entry = {
'name': f.name if f.name != 'FILLER' else f'FILLER_{idx + 1}',
'level': f.level, 'pic': f.pic,
'pic_info': {
'type': f.pic_info.type if f.pic_info else 'unknown',
'digits': f.pic_info.digits if f.pic_info else 0,
'decimal': f.pic_info.decimal if f.pic_info else 0,
'length': f.pic_info.length if f.pic_info else 0,
'signed': f.pic_info.signed if f.pic_info else False,
},
'section': f.section, 'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
'redefines': f.redefines, 'usage': f.usage,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
entry['value'] = f.value
entry['values'] = f.values
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
proc_div = extract_procedure_division(preprocessed)
branch_tree = None
assignments = {}
if proc_div:
branch_tree, assignments = build_branch_tree_fallback(proc_div, fields_dict)
file_sec = parse_file_section(preprocessed)
open_dir = scan_open_statements(proc_div) if proc_div else {}
from .models import BrIf, BrEval, BrSeq, BrPerform, BrSearch, Assign, CondAnd, CondOr
decision_points = []
total_branches = 0
def _walk(node, counter):
nonlocal total_branches
if isinstance(node, BrIf):
counter[0] += 1
branches = 2
decision_points.append({
"id": counter[0], "kind": "IF",
"label": str(node.condition)[:80], "branches": branches,
})
total_branches += branches
_walk(node.true_seq, counter)
_walk(node.false_seq, counter)
elif isinstance(node, BrEval):
counter[0] += 1
seen_br = set()
uni_count = 0
for v, _ in node.when_list:
brn = f"WHEN {v}"
if brn not in seen_br:
uni_count += 1
seen_br.add(brn)
n = uni_count + (1 if node.has_other and "OTHER" not in seen_br else 0)
decision_points.append({"id": counter[0], "kind": "EVALUATE",
"label": str(node.subject)[:80], "branches": n})
total_branches += n
for _, seq in node.when_list:
_walk(seq, counter)
_walk(node.other_seq, counter)
elif isinstance(node, BrSeq):
for child in node.children:
_walk(child, counter)
elif isinstance(node, BrPerform):
if node.condition and node.perf_type in ('until', 'para_until', 'varying', 'para_varying'):
counter[0] += 1
decision_points.append({
"id": counter[0], "kind": "PERFORM",
"label": str(node.condition)[:80], "branches": 2,
})
total_branches += 2
_walk(node.body_seq, counter)
elif isinstance(node, BrSearch):
_walk(node.at_end_seq, counter)
for _, seq in node.when_list:
_walk(seq, counter)
if branch_tree:
_walk(branch_tree, [0])
lines = proc_div.split('\n') if proc_div else []
paragraphs = set()
for line in lines:
m = re.match(r'^\s*([A-Z0-9][A-Z0-9-]*)\.\s*$', line.strip())
if m:
paragraphs.add(m.group(1))
select_files = parse_file_control(preprocessed)
open_directions_detail = open_dir
has_divide = bool(re.search(r'\bDIVIDE\b', cobol_source.upper()))
has_inspect = bool(re.search(r'\bINSPECT\b', cobol_source.upper()))
has_string = bool(re.search(r'\bSTRING\b', cobol_source.upper()))
divide_constants = []
if has_divide and proc_div:
for dm in re.finditer(r'\bDIVIDE\s+([\d.]+)\b', proc_div, re.IGNORECASE):
val = dm.group(1)
try:
divide_constants.append(float(val))
except ValueError:
pass
perform_patterns = []
def _walk_performs(node):
if isinstance(node, BrPerform):
entry = {
"type": node.perf_type,
"target": node.target,
"condition": node.condition,
"times": node.times,
"varying_var": node.varying_var,
}
perform_patterns.append(entry)
_walk_performs(node.body_seq)
elif isinstance(node, BrIf):
_walk_performs(node.true_seq)
_walk_performs(node.false_seq)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_walk_performs(seq)
_walk_performs(node.other_seq)
elif isinstance(node, BrSeq):
for c in node.children:
_walk_performs(c)
if branch_tree:
_walk_performs(branch_tree)
main_loop = None
def _find_main_loop(node, depth=0):
nonlocal main_loop
if main_loop is not None:
return
if isinstance(node, BrPerform):
if _perform_has_read(node):
main_loop = {
"type": node.perf_type,
"read_file": _perform_read_file(node),
"has_at_end": False,
}
return
_find_main_loop(node.body_seq, depth + 1)
elif isinstance(node, BrIf):
_find_main_loop(node.true_seq, depth + 1)
_find_main_loop(node.false_seq, depth + 1)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_find_main_loop(seq, depth + 1)
_find_main_loop(node.other_seq, depth + 1)
elif isinstance(node, BrSeq):
for c in node.children:
_find_main_loop(c, depth + 1)
def _perform_has_read(perf_node):
def _walk_seq(seq):
if isinstance(seq, Assign):
if seq.source_info.get('type') == 'read_into':
return True
elif isinstance(seq, BrSeq):
for ch in seq.children:
if _walk_seq(ch):
return True
return False
return _walk_seq(perf_node.body_seq)
def _perform_read_file(perf_node):
def _walk_seq(seq):
if isinstance(seq, Assign):
if seq.source_info.get('type') == 'read_into':
return seq.source_info.get('file', '')
elif isinstance(seq, BrSeq):
for ch in seq.children:
result = _walk_seq(ch)
if result:
return result
return None
return _walk_seq(perf_node.body_seq)
if branch_tree:
_find_main_loop(branch_tree)
if_types = {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0}
def _walk_if_types(node, depth=0):
if isinstance(node, BrIf):
if_types["total"] += 1
if_types["nested_depth"] = max(if_types["nested_depth"], depth)
ct = node.cond_tree
if ct:
leaves = collect_leaves(ct)
if isinstance(ct, (CondAnd, CondOr)):
if_types["compound"] += 1
for leaf in leaves:
if leaf.op in ('>', '<', '>=', '<='):
if_types["comparison"] += 1
elif leaf.op in ('=', '<>'):
if_types["equality"] += 1
_walk_if_types(node.true_seq, depth + 1)
_walk_if_types(node.false_seq, depth + 1)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_walk_if_types(seq, depth + 1)
_walk_if_types(node.other_seq, depth + 1)
elif isinstance(node, BrPerform):
_walk_if_types(node.body_seq, depth + 1)
elif isinstance(node, BrSeq):
for c in node.children:
_walk_if_types(c, depth + 1)
if branch_tree:
_walk_if_types(branch_tree)
variable_patterns = {
"has_prev_key": False,
"has_accumulator": False,
"has_error_flag": False,
"has_switch": False,
"has_index": False,
"has_save_area": False,
"has_counter": False,
"has_work": False,
}
for f in fields_dict:
name = f.get('name', '')
if re.search(r'\bWS-PREV[-_]', name, re.IGNORECASE):
variable_patterns["has_prev_key"] = True
if re.search(r'[-_]CNT\b', name, re.IGNORECASE) or re.search(r'[-_]ACCUM\b', name, re.IGNORECASE):
variable_patterns["has_accumulator"] = True
if re.search(r'[-_]ERR\b', name, re.IGNORECASE) or re.search(r'[-_]ERROR[-_]', name, re.IGNORECASE):
variable_patterns["has_error_flag"] = True
if re.search(r'[-_]SW\b', name, re.IGNORECASE) or re.search(r'[-_]FLAG\b', name, re.IGNORECASE):
variable_patterns["has_switch"] = True
if re.search(r'[-_]IDX\b', name, re.IGNORECASE) or re.search(r'[-_]INDX\b', name, re.IGNORECASE) or re.search(r'[-_]SUB\b', name, re.IGNORECASE):
variable_patterns["has_index"] = True
if re.search(r'[-_]SAVE[-_]', name, re.IGNORECASE) or re.search(r'[-_]HOLD[-_]', name, re.IGNORECASE):
variable_patterns["has_save_area"] = True
if re.search(r'[-_]CNT\b', name, re.IGNORECASE) or re.search(r'[-_]COUNT\b', name, re.IGNORECASE):
variable_patterns["has_counter"] = True
if name.startswith('WS-') and not re.search(r'(?:CNT|ERR|SW|IDX|INDX|SUB|SAVE|HOLD|PREV|ACCUM)', name, re.IGNORECASE):
if re.search(r'[-_]W\b|[-_]WORK\b|[-_]WK\b|^WS-W[0O]\w', name, re.IGNORECASE):
variable_patterns["has_work"] = True
open_pattern = "sequential"
if proc_div:
proc_upper = proc_div.upper()
open_positions = [m.start() for m in re.finditer(r'\bOPEN\b', proc_upper)]
close_positions = [m.start() for m in re.finditer(r'\bCLOSE\b', proc_upper)]
if open_positions and close_positions:
for i, opos in enumerate(open_positions):
for cpos in close_positions:
if cpos > opos:
for opos2 in open_positions:
if opos2 > cpos:
open_pattern = "open-close-open"
break
if open_pattern == "open-close-open":
break
if open_pattern == "open-close-open":
break
return {
"paragraphs": sorted(paragraphs) if paragraphs else [],
"decision_points": decision_points,
"branch_tree": branch_tree,
"file_count": len(file_sec) if file_sec else 0,
"open_directions": open_dir,
"has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points),
"has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points),
"has_call": 'CALL' in cobol_source.upper(),
"has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points),
"total_branches": total_branches,
"total_paragraphs": len(paragraphs),
"branch_tree_obj": branch_tree,
"select_files": select_files,
"open_directions_detail": open_directions_detail,
"has_divide": has_divide,
"divide_constants": divide_constants,
"has_inspect": has_inspect,
"has_string": has_string,
"perform_patterns": perform_patterns,
"main_loop": main_loop,
"if_types": if_types,
"variable_patterns": variable_patterns,
"open_pattern": open_pattern,
}
def generate_data(cobol_source: str, structure: dict = None,
copybook_dirs: list = None) -> list[dict]:
"""根据 COBOL 源码生成覆盖所有路径的测试数据。
Args:
cobol_source: COBOL 程序原始源码文本(未预处理)。
内部会调 preprocess + resolve_copybooks。
如果已预处理过,传进来会因 COPYBOOK 路径丢失导致字段不全。
structure: 可选,如果已调用 extract_structure() 可传入避免重复解析
Returns:
list[dict]: 测试数据记录列表,每条包含所有字段的值
"""
if structure is None:
structure = extract_structure(cobol_source)
branch_tree = structure.get("branch_tree_obj")
if branch_tree is None:
return []
if copybook_dirs:
src_resolved = resolve_copybooks(cobol_source, '.', extra_search_paths=copybook_dirs)
src_resolved = resolve_sql_includes(src_resolved, '.')
preprocessed = preprocess(src_resolved)
else:
# Also try SQL include resolution without copybook
src_sql = resolve_sql_includes(cobol_source, '.')
preprocessed = preprocess(src_sql)
data_div = extract_data_division(preprocessed)
data_fields = parse_data_division(data_div) if data_div else []
fields_dict = []
for f in data_fields:
entry = {
'name': f.name, 'level': f.level, 'pic': f.pic,
'pic_info': {
'type': f.pic_info.type if f.pic_info else 'unknown',
'digits': f.pic_info.digits if f.pic_info else 0,
'decimal': f.pic_info.decimal if f.pic_info else 0,
'length': f.pic_info.length if f.pic_info else 0,
'signed': f.pic_info.signed if f.pic_info else False,
},
'section': f.section, 'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
'value': f.value, 'values': f.values,
'redefines': f.redefines, 'usage': f.usage,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
proc_div = extract_procedure_division(preprocessed)
_, assignments = build_branch_tree_fallback(proc_div, fields_dict)
file_sec = parse_file_section(preprocessed)
branch_paths_unfiltered = mcdc_enum_paths(branch_tree, fields_dict)
path_infos = []
for c, a in branch_paths_unfiltered:
filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
_fdict_names = {f['name'] for f in fields_dict}
def _resolve_field(fn: str) -> str:
if fn.startswith("_"):
return fn
ufn = fn.upper()
if ' OF ' in ufn:
fn = fn.split(' OF ')[0].strip()
if fn in _fdict_names:
return fn
# Check subscript: WS-PLAN-CODE(WS-PLAN-IDX) -> WS-PLAN-CODE
m = re.match(r'^(\w[\w-]*)\s*\(', fn)
if m:
base = m.group(1)
if base in _fdict_names:
return base
# Check if any field in fdict starts with base + "("
if any(f.startswith(base + "(") for f in _fdict_names):
return base
return fn
def _is_arith_expr(fn):
return any(op in fn for op in [' + ', ' - ', ' * ', ' / '])
filtered_paths = []
for cons_list, asgn, term in path_infos:
clean = []
for c in cons_list:
if len(c) >= 4:
fn = _resolve_field(str(c[0]))
if fn in _fdict_names or fn.startswith("_") or _is_arith_expr(str(c[0])) or \
any(f.startswith(fn + "(") for f in _fdict_names):
c = list(c); c[0] = fn
clean.append(tuple(c))
else:
clean.append(c)
filtered_paths.append((clean, asgn, term))
path_infos = filtered_paths
records, kept_paths, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
# ── Coverage marking: which decision branches are actually covered ──
if branch_tree and fields_dict:
try:
dp_list, leaf_stats = collect_decision_points(branch_tree, fields_dict)
cov_paths = [(pi[0], pi[1]) for pi in path_infos if isinstance(pi, (list, tuple)) and len(pi) >= 2]
mark_coverage(dp_list, leaf_stats, cov_paths, fields_dict)
if structure is not None:
structure['coverage'] = {
'decision_points': [{
'id': dp.id, 'kind': dp.kind,
'label': getattr(dp, 'label', '')[:60],
'branches': len(dp.branch_names),
'covered': len(dp.active_branches),
} for dp in dp_list],
'total': sum(len(dp.branch_names) for dp in dp_list),
'covered': sum(len(dp.active_branches) for dp in dp_list),
'pct': sum(len(dp.active_branches) for dp in dp_list) / max(sum(len(dp.branch_names) for dp in dp_list), 1) * 100,
}
except Exception as e:
if structure is not None:
structure['coverage'] = {'error': str(e)[:80]}
if records:
import re as _re
proc_upper = (proc_div or "").upper()
for m in _re.finditer(r'IF\s+(\w[\w-]*)\s*[=<>]\s*(\w[\w-]*)', proc_upper):
lhs, rhs = m.group(1), m.group(2)
lhs_in = any(lhs == f['name'] for f in fields_dict)
rhs_in = any(rhs == f['name'] for f in fields_dict)
if lhs_in and rhs_in and any(lhs in r for r in records) and any(rhs in r for r in records):
half = max(1, len(records) // 2)
for i, rec in enumerate(records):
if lhs in rec and rhs in rec and i < half:
rec[rhs] = rec[lhs]
return records
def incremental_supplement(branch_tree, decision_gaps: list[int]) -> list[dict]:
"""针对未覆盖的决策点,增量生成补充测试数据。
Args:
branch_tree: extract_structure() 返回的 branch_tree 字段
decision_gaps: 未覆盖的决策点 ID 列表,如 [1, 3, 5]
Returns:
list[dict]: 增量测试数据,格式与 generate_data() 兼容
"""
from .models import BrIf, BrEval, BrSeq
target_decisions = set(decision_gaps)
found = []
def _find_decisions(node, counter):
if isinstance(node, BrIf):
counter[0] += 1
if counter[0] in target_decisions:
found.append(("IF", node.condition))
_find_decisions(node.true_seq, counter)
_find_decisions(node.false_seq, counter)
elif isinstance(node, BrEval):
counter[0] += 1
if counter[0] in target_decisions:
found.append(("EVALUATE", node.subject))
for _, seq in node.when_list:
_find_decisions(seq, counter)
_find_decisions(node.other_seq, counter)
elif isinstance(node, BrSeq):
for child in node.children:
_find_decisions(child, counter)
_find_decisions(branch_tree, [0])
supplements = []
for i, (kind, label) in enumerate(found):
supplements.append({
"_dec_id": f"incr_{i}",
"_kind": kind,
"_label": str(label)[:60],
})
return supplements