Files
cobol-java-v3/cobol_testgen/__init__.py
T
hangshuo652 7fb9304212 merge local cobol_testgen improvements into v3 shared modules
- cond.py: SQLCODE/SQLSTATE handling, alphanumeric >/< boundary fix
- output.py: termination tracking, db_input support, _is_field_assigned filter
- coverage.py: mark_from_gcov, THRU support, KeyError protection
- gcov.py: new file (dependency for coverage.py)
- grammar.lark: multi-segment PIC support
- read.py: SQL INCLUDE resolution, DECLARE TABLE parsing, * comment fix
- core.py: SQL parsing, blocked_names, keyword list
- design.py: multi-sentinel, THRU ranges, PERFORM VARYING last iteration
- __init__.py: local main() + v3 API functions, guarded imports

All 6 ZAN programs verified passing through v3 pipeline
2026-06-23 22:38:17 +08:00

1073 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""COBOL Test Data Generator — 模块化版入口
公开 API:
extract_structure() — 解析 COBOL 控制流 → dict
generate_data() — 生成测试数据 → list[dict]
incremental_supplement — 差分补充数据 → list[dict]
"""
import os
import sys
import json
import re
import logging
from datetime import datetime
from pathlib import Path
# ── 配置(必须放在本地模块导入之前,避免循环导入) ──
CONFIG = {
'abend_programs': ['SUB03END'],
}
from .read import preprocess, extract_data_division, extract_procedure_division
from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements
from .read import parse_file_control, resolve_sql_includes, strip_exec_sql_from_data_div
from .core import build_branch_tree, classify_field_roles, _init_child_names, sql_register_virtual_fields, _find_multi_write_fds
from .cond import parse_single_condition, is_field, collect_leaves
from .pipeline_bridge import build_branch_tree_fallback
from .design_mcdc import enum_paths as mcdc_enum_paths, _filter_stop
from .design import enum_paths, generate_records, get_term_type, extend_abend_programs
from .output import output_json, output_input_files
from .coverage import run_coverage, generate_coverage_index
from japanese_data import generate_fullwidth_text, generate_halfwidth_katakana, generate_wareki_date
try:
from .runner import run_and_compare, run_all, GroupInfo, GroupResult
_HAVE_RUNNER = True
except ImportError:
_HAVE_RUNNER = False
try:
from .gcov import run_gcov
_HAVE_GCOV = True
except ImportError:
_HAVE_GCOV = False
try:
from .to_sql import collect_sql_meta, build_db_input
_HAVE_TOSQL = True
except ImportError:
_HAVE_TOSQL = False
logger = logging.getLogger(__name__)
__all__ = [
"extract_structure",
"generate_data",
"incremental_supplement",
"CONFIG",
"generate_fullwidth_text",
"generate_halfwidth_katakana",
"generate_wareki_date",
]
# ── OCCURS 展开 ──
def _add_subscript(name, occ):
"""追加或扩展下标:WS-CELL → WS-CELL(1), WS-CELL(1) → WS-CELL(1,2)"""
if name.endswith(')'):
return name[:-1] + f',{occ})'
return name + f'({occ})'
def expand_occurs(fields):
"""展开 OCCURS 字段为下标副本。递归处理嵌套 OCCURS。"""
result = []
i = 0
while i < len(fields):
f = fields[i]
if f.get('occurs', 0) > 0 and not f.get('is_88'):
children = []
j = i + 1
while j < len(fields):
child = fields[j]
if child.get('is_88'):
children.append(child)
j += 1
continue
if child['level'] <= f['level'] or child.get('level') == 77:
break
children.append(child)
j += 1
if children:
group = dict(f)
group['occurs'] = 0
result.append(group)
for occ in range(1, f['occurs'] + 1):
for child in children:
copy = dict(child)
if child.get('occurs', 0) == 0:
copy['occurs'] = 0
copy['occurs_depending'] = f.get('occurs_depending')
if child.get('is_88'):
parent = child.get('parent') or f['name']
copy['parent'] = _add_subscript(parent, occ)
copy['name'] = _add_subscript(child['name'], occ)
else:
copy['name'] = _add_subscript(child['name'], occ)
result.append(copy)
else:
for occ in range(1, f['occurs'] + 1):
copy = dict(f)
copy['name'] = _add_subscript(f['name'], occ)
copy['occurs'] = 0
copy['occurs_depending'] = f.get('occurs_depending')
result.append(copy)
i = j
else:
result.append(f)
i += 1
if any(f.get('occurs', 0) > 0 for f in result):
return expand_occurs(result)
return result
# ── PREV 连锁 ──
def _constraint_in(cons, field, op, value, want):
for c in cons:
if len(c) == 4 and c[0] == field and c[1] == op and c[2] == value and c[3] == want:
return True
return False
def _inc_str(s, length):
try:
return str(int(s) + 1).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 9Zz\xff':
c[i] = chr(ord(c[i]) + 1)
break
if c[i] == ' ':
c[i] = '0'
break
if c[i] == '9':
c[i] = '0'
return ''.join(c)
def _dec_str(s, length):
try:
n = max(0, int(s) - 1)
return str(n).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 0Aa\x00':
c[i] = chr(ord(c[i]) - 1)
break
if c[i] == ' ':
break
if c[i] == '0':
c[i] = '9'
return ''.join(c)
def _field_length(fname, fields):
for f in fields:
if f['name'] == fname:
pi = f.get('pic_info', {})
return pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0) or 1
return 1
def _chain_prev(records, path_infos, fields, fd_fields, field_to_fd, open_dir):
"""跨记录 PREV 连锁。修改 records 使批次执行的路径与实际比较一致。
每个路径 k-1 的约束(PREV OP CURRENT)对应批次中 loop iter k-1 的实际比较:
PREV = records[prev_src].R01 (程序内部保持的前值)
CURRENT = records[k].R01 (当前读入值)
本函数调整 records[k] 的字段以保证交叉记录比较满足路径约束。
"""
N = len(records)
if N < 2:
return
key_fields = []
time_start_field = None
time_end_field = None
for fname in records[0]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[0]:
if 'EMP-ID' in fname or 'APPL-DATE' in fname:
key_fields.append(fname)
if 'END-TIME' in fname:
time_end_field = fname
if 'START-TIME' in fname:
time_start_field = fname
prev_src = 0
for k in range(1, N):
if k - 1 >= len(path_infos):
break
cons = path_infos[k - 1][0]
is_same_key = all(
_constraint_in(cons, f'WRK-PREV-{fn[3:]}', '=', fn, True)
for fn in key_fields
) if key_fields else False
is_overlap = is_same_key and time_end_field and time_start_field and \
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, True)
is_normal = is_same_key and time_end_field and time_start_field and \
(_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '<=', time_start_field, True) or
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, False))
for fname in records[prev_src]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[k]:
records[k][prev_name] = records[prev_src][fname]
if is_same_key:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
records[k][kf] = records[prev_src][kf]
if is_normal and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end >= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _inc_str(prev_end, length)
if is_overlap and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end <= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _dec_str(prev_end, length) if prev_end else '0' * length
else:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
if records[k][kf] == records[prev_src][kf]:
length = _field_length(kf, fields)
records[k][kf] = _inc_str(str(records[k][kf]), length)
records[k]['_w02_path'] = is_same_key and time_end_field and time_start_field and not is_overlap
records[k]['_overlap_path'] = is_overlap
for fn in list(records[k].keys()):
if fn.startswith('R01') and not fn.startswith('R01INNREC'):
wfn = 'W01' + fn[3:]
if wfn in records[k]:
records[k][wfn] = records[k][fn]
if is_overlap:
pass
else:
prev_src = k
# ── 入口 ──
def main():
if len(sys.argv) < 2:
print("用法: python -m cobol_testgen <cobol文件1> [cobol文件2 ...] [输出目录]")
sys.exit(1)
args = sys.argv[1:]
do_run = False
gcov_mode = False
temp_dir = None
if '--run' in args:
do_run = True
args.remove('--run')
if '--gcov' in args:
gcov_mode = True
args.remove('--gcov')
i = 0
while i < len(args):
if args[i] == '--temp-dir':
if i + 1 < len(args):
temp_dir = args[i + 1]
args.pop(i + 1)
args.pop(i)
else:
args.pop(i)
break
elif args[i].startswith('--temp-dir='):
temp_dir = args[i].split('=', 1)[1]
args.pop(i)
break
else:
i += 1
cobol_files = []
outdir = None
for a in args:
p = Path(a)
if p.is_dir():
outdir = p
elif p.suffix.upper() in ('.CBL', '.COB', '.CPY'):
cobol_files.append(p)
else:
print(f"警告:跳过未知参数 {a}")
if not cobol_files:
print("错误:未找到任何 COBOL 文件")
sys.exit(1)
if outdir is None:
outdir = cobol_files[0].parent
outdir.mkdir(parents=True, exist_ok=True)
(outdir / 'logs').mkdir(parents=True, exist_ok=True)
log_path = outdir / 'logs' / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log"
fh = logging.FileHandler(log_path, encoding="utf-8", mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s"
))
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
sh.setFormatter(logging.Formatter("%(message)s"))
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
root_logger.addHandler(fh)
root_logger.addHandler(sh)
programs = []
for filepath in cobol_files:
if not filepath.exists():
logger.error(f"错误:文件不存在 {filepath}")
continue
source = filepath.read_text(encoding='utf-8')
source = resolve_copybooks(
source,
str(filepath.parent),
extra_search_paths=[str(filepath.parent / '..' / 'cpy')],
)
source = resolve_sql_includes(source, str(filepath.parent))
preprocessed = preprocess(source)
file_sec = parse_file_section(preprocessed)
data_div = extract_data_division(preprocessed)
if data_div:
data_div, declared_columns = strip_exec_sql_from_data_div(data_div)
else:
declared_columns = {}
if not data_div:
logger.error(f"错误:{filepath.name} 中没有 DATA DIVISION。")
continue
data_fields = parse_data_division(data_div)
if not data_fields:
logger.error(f"错误:{filepath.name} 中没有找到含 PIC 的字段。")
continue
fields_dict = []
parent_pic = {}
filler_counter = 0
for f in data_fields:
pi = f.pic_info
name = f.name
if name == 'FILLER':
filler_counter += 1
if filler_counter > 1:
name = f'FILLER_{filler_counter}'
entry = {
'name': name,
'level': f.level,
'pic': f.pic,
'pic_info': {
'type': pi.type if pi else 'unknown',
'digits': pi.digits if pi else 0,
'decimal': pi.decimal if pi else 0,
'length': pi.length if pi else 0,
'signed': pi.signed if pi else False,
},
'value': f.value,
'values': f.values,
'section': f.section,
'is_filler': f.is_filler,
'redefines': f.redefines,
'usage': f.usage,
'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
if f.parent and f.parent in parent_pic:
entry['pic_info'] = dict(parent_pic[f.parent])
else:
parent_pic[name] = entry['pic_info']
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
sql_register_virtual_fields(fields_dict)
fd_fields = {}
field_to_fd = {}
if file_sec:
for fd_name, rec_names in file_sec.items():
fds = []
seen = set()
for rec in rec_names:
if rec not in seen:
fds.append(rec)
seen.add(rec)
for child in _init_child_names(rec, fields_dict):
if child not in seen:
fds.append(child)
seen.add(child)
fd_fields[fd_name] = fds
for child in fds:
field_to_fd[child] = fd_name
logger.info(f"\n========== {filepath.name} ==========")
logger.info(f"\n字段列表:")
logger.info(f"{'层级':<6} {'名称':<25} {'PIC':<15} {'类型':<12} {'长度':<5}")
logger.info("-" * 65)
for f in fields_dict:
pi = f['pic_info']
t = pi.get('type', '?')
l = pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0)
pic_display = str(f.get('pic', '')) if f.get('pic') else ('88-level' if f.get('is_88') else '')
logger.info(f"{f['level']:<6} {f['name']:<25} {pic_display:<15} {t:<12} {l:<5}")
proc_div = extract_procedure_division(preprocessed)
branch_paths = []
assignments = {}
if proc_div:
branch_tree, assignments = build_branch_tree(proc_div, fields_dict, full_source=preprocessed)
roles = classify_field_roles(branch_tree, assignments, fields_dict,
source=preprocessed, proc_text=proc_div)
logger.info(f"\n字段角色(输入/输出/出入/未用):")
for f in fields_dict:
if f.get('is_88'):
continue
logger.info(f" {f['name']:<30} {roles.get(f['name'], '?')}")
abend_list = CONFIG.get('abend_programs', [])
if abend_list:
extend_abend_programs(abend_list)
branch_paths_with_assigns = enum_paths(branch_tree, fields_dict)
path_infos = []
for c, a in branch_paths_with_assigns:
filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
def _is_skip(cons):
eq1_true = 0
other = 0
for c in cons:
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
eq1_true += 1
else:
other += 1
return eq1_true > 0 and other == 0
before = len(path_infos)
path_infos = [p for p in path_infos if not _is_skip(p[0])]
after = len(path_infos)
logger.info(f" SKIP 过滤: {before} -> {after} 条路径(预期减少 1")
open_dir = scan_open_statements(proc_div) if proc_div else {}
if proc_div:
logger.info(f"\n分支路径数:{len(branch_paths_with_assigns)}")
for i, (path_cons, _path_assign) in enumerate(branch_paths_with_assigns):
descs = []
for c in path_cons:
if len(c) == 4:
field, op, val, want = c
if op == 'not_in':
descs.append(f"{field} not in {val}")
else:
descs.append(f"{field} {op} {val} ({'T' if want else 'F'})")
logger.debug(f" 路径 {i + 1}: {', '.join(descs)}")
else:
logger.warning("\n没有找到 PROCEDURE DIVISION。")
branch_paths_with_assigns = [([], {})]
path_infos = [([], {}, 'normal')]
roles = {f['name']: 'unused' for f in fields_dict}
records, _, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
def _is_eof_path(cons):
last_eq1_true = -1
for i, c in enumerate(cons):
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
last_eq1_true = i
if last_eq1_true < 0:
return False
for i in range(last_eq1_true + 1, len(cons)):
if len(cons[i]) == 4 and cons[i][0] == 'WRK-R01EOF':
return False
return True
eof_mask = [_is_eof_path(c) for c, a, t in path_infos]
eof_count = sum(eof_mask)
if eof_count:
term_types = ['eof' if e else t for e, t in zip(eof_mask, term_types)]
logger.info(f" EOF 路径: {eof_count} 条(将单独执行)")
multi_write_fds = _find_multi_write_fds(branch_tree, field_to_fd) if proc_div and branch_tree else set()
if multi_write_fds:
logger.info(f" 检测到多 WRITE FD: {', '.join(sorted(multi_write_fds))}")
_chain_prev(records, path_infos, fields_dict, fd_fields, field_to_fd, open_dir)
if _HAVE_TOSQL:
sql_meta = collect_sql_meta(assignments, declared_columns)
db_input = build_db_input(
branch_paths_with_assigns, fields_dict, assignments, sql_meta, declared_columns,
records=records,
)
else:
db_input = None
(outdir / 'json').mkdir(parents=True, exist_ok=True)
outpath = outdir / 'json' / (filepath.stem + '.json')
output_json(records, outpath, roles,
fd_fields=fd_fields, field_to_fd=field_to_fd,
open_dir=open_dir,
term_types=term_types,
db_input=db_input if db_input else None,
data_fields=fields_dict)
output_input_files(records, outdir / 'input', filepath.stem, roles,
fd_fields, field_to_fd, open_dir,
term_types=term_types)
gcov_data = None
if gcov_mode and proc_div and _HAVE_GCOV:
select_info = parse_file_control(preprocessed)
_temp = temp_dir or str(outdir / '.gcov_cache')
source_dir = str(filepath.parent)
expected_records: list[dict] = [{}] * len(records)
if file_sec and os.path.exists(outpath):
with open(outpath, encoding='utf-8') as f:
full_json = json.load(f)
json_records = full_json.get('records', [])
for i in range(len(records)):
exp = {}
if i < len(json_records):
json_rec = json_records[i]
for fd_name in file_sec:
eo = json_rec.get('expected_output', {})
if fd_name in eo:
exp.update(eo[fd_name])
expected_records[i] = exp
group_results = run_all(
filepath.stem, str(outdir), _temp,
fields_dict, fd_fields, select_info, open_dir,
term_types, records, expected_records=expected_records,
source_dir=source_dir, path_infos=path_infos,
multi_write_fds=multi_write_fds,
)
gcov_data = run_gcov(filepath.stem, _temp)
passed = sum(1 for r in group_results if r.passed)
total = len(group_results)
logger.info(f"\n 执行验证: {passed}/{total} 组通过")
if passed < total:
for r in group_results:
if not r.passed and r.details:
fails = [d for d in r.details if not d.match][:3]
for d in fails:
logger.warning(f" [{r.name}] {d.field}: "
f"期望={d.expected!r}, 实际={d.actual!r}")
if do_run and proc_div and _HAVE_RUNNER:
select_info = parse_file_control(preprocessed)
run_and_compare(
filepath.stem, str(outdir), fields_dict,
fd_fields, select_info, open_dir,
term_types, records,
)
logger.info(f"\n输出:{outpath}{len(records)} 条记录)")
logger.debug(f"\n记录明细:")
for i, rec in enumerate(records, 1):
vals = []
for f in fields_dict:
r = roles.get(f['name'], '?')
marker = f"[{r[0].upper()}]" if r != '?' and r != 'unused' else ''
vals.append(f"{marker}{f['name']}={rec.get(f['name'], '?')}")
logger.debug(f" 记录 {i}: {' | '.join(vals)}")
(outdir / 'coverage').mkdir(parents=True, exist_ok=True)
cov_prefix = str(outdir / 'coverage' / filepath.stem)
index_relpath = 'index.html'
cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict,
source, cov_prefix, index_relpath=index_relpath,
gcov_data=gcov_data)
programs.append(cov_result)
if programs:
generate_coverage_index(programs, outdir / 'coverage')
logger.info(f"\n覆盖率总览:{outdir / 'coverage' / 'index.html'}")
# ════════════════════════════════════════════
# Phase 1: 可编程 API(供 orchestrator.py 调用)
# ════════════════════════════════════════════
def extract_structure(cobol_source: str) -> dict:
"""分析 COBOL 源码的结构,返回结构摘要。不生成测试数据,只做静态分析。
Returns:
dict with: paragraphs, decision_points, branch_tree, file_count,
open_directions, has_search_all, has_evaluate,
has_call, has_break, total_branches, total_paragraphs
"""
preprocessed = preprocess(cobol_source)
data_div = extract_data_division(preprocessed)
data_fields = parse_data_division(data_div) if data_div else []
fields_dict = []
for idx, f in enumerate(data_fields):
entry = {
'name': f.name if f.name != 'FILLER' else f'FILLER_{idx + 1}',
'level': f.level, 'pic': f.pic,
'pic_info': {
'type': f.pic_info.type if f.pic_info else 'unknown',
'digits': f.pic_info.digits if f.pic_info else 0,
'decimal': f.pic_info.decimal if f.pic_info else 0,
'length': f.pic_info.length if f.pic_info else 0,
'signed': f.pic_info.signed if f.pic_info else False,
},
'section': f.section, 'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
'redefines': f.redefines, 'usage': f.usage,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
entry['value'] = f.value
entry['values'] = f.values
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
proc_div = extract_procedure_division(preprocessed)
branch_tree = None
assignments = {}
if proc_div:
branch_tree, assignments = build_branch_tree_fallback(proc_div, fields_dict)
file_sec = parse_file_section(preprocessed)
open_dir = scan_open_statements(proc_div) if proc_div else {}
from .models import BrIf, BrEval, BrSeq, BrPerform, BrSearch, Assign, CondAnd, CondOr
decision_points = []
total_branches = 0
def _walk(node, counter):
nonlocal total_branches
if isinstance(node, BrIf):
counter[0] += 1
branches = 2
decision_points.append({
"id": counter[0], "kind": "IF",
"label": str(node.condition)[:80], "branches": branches,
})
total_branches += branches
_walk(node.true_seq, counter)
_walk(node.false_seq, counter)
elif isinstance(node, BrEval):
counter[0] += 1
n = len(node.when_list) + (1 if node.has_other else 0)
decision_points.append({
"id": counter[0], "kind": "EVALUATE",
"label": str(node.subject)[:80], "branches": n,
})
total_branches += n
for _, seq in node.when_list:
_walk(seq, counter)
_walk(node.other_seq, counter)
elif isinstance(node, BrSeq):
for child in node.children:
_walk(child, counter)
elif isinstance(node, BrPerform):
if node.condition and node.perf_type in ('until', 'para_until', 'varying', 'para_varying'):
counter[0] += 1
decision_points.append({
"id": counter[0], "kind": "PERFORM",
"label": str(node.condition)[:80], "branches": 2,
})
total_branches += 2
_walk(node.body_seq, counter)
elif isinstance(node, BrSearch):
_walk(node.at_end_seq, counter)
for _, seq in node.when_list:
_walk(seq, counter)
if branch_tree:
_walk(branch_tree, [0])
lines = proc_div.split('\n') if proc_div else []
paragraphs = set()
for line in lines:
m = re.match(r'^\s*([A-Z0-9][A-Z0-9-]*)\.\s*$', line.strip())
if m:
paragraphs.add(m.group(1))
select_files = parse_file_control(preprocessed)
open_directions_detail = open_dir
has_divide = bool(re.search(r'\bDIVIDE\b', cobol_source.upper()))
has_inspect = bool(re.search(r'\bINSPECT\b', cobol_source.upper()))
has_string = bool(re.search(r'\bSTRING\b', cobol_source.upper()))
divide_constants = []
if has_divide and proc_div:
for dm in re.finditer(r'\bDIVIDE\s+([\d.]+)\b', proc_div, re.IGNORECASE):
val = dm.group(1)
try:
divide_constants.append(float(val))
except ValueError:
pass
perform_patterns = []
def _walk_performs(node):
if isinstance(node, BrPerform):
entry = {
"type": node.perf_type,
"target": node.target,
"condition": node.condition,
"times": node.times,
"varying_var": node.varying_var,
}
perform_patterns.append(entry)
_walk_performs(node.body_seq)
elif isinstance(node, BrIf):
_walk_performs(node.true_seq)
_walk_performs(node.false_seq)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_walk_performs(seq)
_walk_performs(node.other_seq)
elif isinstance(node, BrSeq):
for c in node.children:
_walk_performs(c)
if branch_tree:
_walk_performs(branch_tree)
main_loop = None
def _find_main_loop(node, depth=0):
nonlocal main_loop
if main_loop is not None:
return
if isinstance(node, BrPerform):
if _perform_has_read(node):
main_loop = {
"type": node.perf_type,
"read_file": _perform_read_file(node),
"has_at_end": False,
}
return
_find_main_loop(node.body_seq, depth + 1)
elif isinstance(node, BrIf):
_find_main_loop(node.true_seq, depth + 1)
_find_main_loop(node.false_seq, depth + 1)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_find_main_loop(seq, depth + 1)
_find_main_loop(node.other_seq, depth + 1)
elif isinstance(node, BrSeq):
for c in node.children:
_find_main_loop(c, depth + 1)
def _perform_has_read(perf_node):
def _walk_seq(seq):
if isinstance(seq, Assign):
if seq.source_info.get('type') == 'read_into':
return True
elif isinstance(seq, BrSeq):
for ch in seq.children:
if _walk_seq(ch):
return True
return False
return _walk_seq(perf_node.body_seq)
def _perform_read_file(perf_node):
def _walk_seq(seq):
if isinstance(seq, Assign):
if seq.source_info.get('type') == 'read_into':
return seq.source_info.get('file', '')
elif isinstance(seq, BrSeq):
for ch in seq.children:
result = _walk_seq(ch)
if result:
return result
return None
return _walk_seq(perf_node.body_seq)
if branch_tree:
_find_main_loop(branch_tree)
if_types = {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0}
def _walk_if_types(node, depth=0):
if isinstance(node, BrIf):
if_types["total"] += 1
if_types["nested_depth"] = max(if_types["nested_depth"], depth)
ct = node.cond_tree
if ct:
leaves = collect_leaves(ct)
if isinstance(ct, (CondAnd, CondOr)):
if_types["compound"] += 1
for leaf in leaves:
if leaf.op in ('>', '<', '>=', '<='):
if_types["comparison"] += 1
elif leaf.op in ('=', '<>'):
if_types["equality"] += 1
_walk_if_types(node.true_seq, depth + 1)
_walk_if_types(node.false_seq, depth + 1)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_walk_if_types(seq, depth + 1)
_walk_if_types(node.other_seq, depth + 1)
elif isinstance(node, BrPerform):
_walk_if_types(node.body_seq, depth + 1)
elif isinstance(node, BrSeq):
for c in node.children:
_walk_if_types(c, depth + 1)
if branch_tree:
_walk_if_types(branch_tree)
variable_patterns = {
"has_prev_key": False,
"has_accumulator": False,
"has_error_flag": False,
"has_switch": False,
"has_index": False,
"has_save_area": False,
"has_counter": False,
"has_work": False,
}
for f in fields_dict:
name = f.get('name', '')
if re.search(r'\bWS-PREV[-_]', name, re.IGNORECASE):
variable_patterns["has_prev_key"] = True
if re.search(r'[-_]CNT\b', name, re.IGNORECASE) or re.search(r'[-_]ACCUM\b', name, re.IGNORECASE):
variable_patterns["has_accumulator"] = True
if re.search(r'[-_]ERR\b', name, re.IGNORECASE) or re.search(r'[-_]ERROR[-_]', name, re.IGNORECASE):
variable_patterns["has_error_flag"] = True
if re.search(r'[-_]SW\b', name, re.IGNORECASE) or re.search(r'[-_]FLAG\b', name, re.IGNORECASE):
variable_patterns["has_switch"] = True
if re.search(r'[-_]IDX\b', name, re.IGNORECASE) or re.search(r'[-_]INDX\b', name, re.IGNORECASE) or re.search(r'[-_]SUB\b', name, re.IGNORECASE):
variable_patterns["has_index"] = True
if re.search(r'[-_]SAVE[-_]', name, re.IGNORECASE) or re.search(r'[-_]HOLD[-_]', name, re.IGNORECASE):
variable_patterns["has_save_area"] = True
if re.search(r'[-_]CNT\b', name, re.IGNORECASE) or re.search(r'[-_]COUNT\b', name, re.IGNORECASE):
variable_patterns["has_counter"] = True
if name.startswith('WS-') and not re.search(r'(?:CNT|ERR|SW|IDX|INDX|SUB|SAVE|HOLD|PREV|ACCUM)', name, re.IGNORECASE):
if re.search(r'[-_]W\b|[-_]WORK\b|[-_]WK\b|^WS-W[0O]\w', name, re.IGNORECASE):
variable_patterns["has_work"] = True
open_pattern = "sequential"
if proc_div:
proc_upper = proc_div.upper()
open_positions = [m.start() for m in re.finditer(r'\bOPEN\b', proc_upper)]
close_positions = [m.start() for m in re.finditer(r'\bCLOSE\b', proc_upper)]
if open_positions and close_positions:
for i, opos in enumerate(open_positions):
for cpos in close_positions:
if cpos > opos:
for opos2 in open_positions:
if opos2 > cpos:
open_pattern = "open-close-open"
break
if open_pattern == "open-close-open":
break
if open_pattern == "open-close-open":
break
return {
"paragraphs": sorted(paragraphs) if paragraphs else [],
"decision_points": decision_points,
"branch_tree": branch_tree,
"file_count": len(file_sec) if file_sec else 0,
"open_directions": open_dir,
"has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points),
"has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points),
"has_call": 'CALL' in cobol_source.upper(),
"has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points),
"total_branches": total_branches,
"total_paragraphs": len(paragraphs),
"branch_tree_obj": branch_tree,
"select_files": select_files,
"open_directions_detail": open_directions_detail,
"has_divide": has_divide,
"divide_constants": divide_constants,
"has_inspect": has_inspect,
"has_string": has_string,
"perform_patterns": perform_patterns,
"main_loop": main_loop,
"if_types": if_types,
"variable_patterns": variable_patterns,
"open_pattern": open_pattern,
}
def generate_data(cobol_source: str, structure: dict = None) -> list[dict]:
"""根据 COBOL 源码生成覆盖所有路径的测试数据。
Args:
cobol_source: COBOL 程序源码文本
structure: 可选,如果已调用 extract_structure() 可传入避免重复解析
Returns:
list[dict]: 测试数据记录列表,每条包含所有字段的值
"""
if structure is None:
structure = extract_structure(cobol_source)
branch_tree = structure.get("branch_tree_obj")
if branch_tree is None:
return []
preprocessed = preprocess(cobol_source)
data_div = extract_data_division(preprocessed)
data_fields = parse_data_division(data_div) if data_div else []
fields_dict = []
for f in data_fields:
entry = {
'name': f.name, 'level': f.level, 'pic': f.pic,
'pic_info': {
'type': f.pic_info.type if f.pic_info else 'unknown',
'digits': f.pic_info.digits if f.pic_info else 0,
'decimal': f.pic_info.decimal if f.pic_info else 0,
'length': f.pic_info.length if f.pic_info else 0,
'signed': f.pic_info.signed if f.pic_info else False,
},
'section': f.section, 'occurs': f.occurs_count,
'occurs_depending': f.occurs_depending,
'value': f.value, 'values': f.values,
'redefines': f.redefines, 'usage': f.usage,
}
if f.is_88:
entry['is_88'] = True
entry['parent'] = f.parent
fields_dict.append(entry)
fields_dict = expand_occurs(fields_dict)
proc_div = extract_procedure_division(preprocessed)
_, assignments = build_branch_tree_fallback(proc_div, fields_dict)
file_sec = parse_file_section(preprocessed)
branch_paths_unfiltered = mcdc_enum_paths(branch_tree, fields_dict)
path_infos = []
for c, a in branch_paths_unfiltered:
filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
_fdict_names = {f['name'] for f in fields_dict}
def _resolve_field(fn: str) -> str:
ufn = fn.upper()
if ' OF ' in ufn:
fn = fn.split(' OF ')[0].strip()
m = re.match(r'^(\w[\w-]*)\s*\(', fn)
if m and m.group(1) in _fdict_names:
return m.group(1)
return fn
filtered_paths = []
for cons_list, asgn, term in path_infos:
clean = []
for c in cons_list:
if len(c) >= 4:
fn = _resolve_field(str(c[0]))
if fn in _fdict_names:
c = list(c); c[0] = fn
clean.append(tuple(c))
else:
clean.append(c)
filtered_paths.append((clean, asgn, term))
path_infos = filtered_paths
records, kept_paths, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
if records:
import re as _re
proc_upper = (proc_div or "").upper()
for m in _re.finditer(r'IF\s+(\w[\w-]*)\s*[=<>]\s*(\w[\w-]*)', proc_upper):
lhs, rhs = m.group(1), m.group(2)
lhs_in = any(lhs == f['name'] for f in fields_dict)
rhs_in = any(rhs == f['name'] for f in fields_dict)
if lhs_in and rhs_in and any(lhs in r for r in records) and any(rhs in r for r in records):
half = max(1, len(records) // 2)
for i, rec in enumerate(records):
if lhs in rec and rhs in rec and i < half:
rec[rhs] = rec[lhs]
return records
def incremental_supplement(branch_tree, decision_gaps: list[int]) -> list[dict]:
"""针对未覆盖的决策点,增量生成补充测试数据。
Args:
branch_tree: extract_structure() 返回的 branch_tree 字段
decision_gaps: 未覆盖的决策点 ID 列表,如 [1, 3, 5]
Returns:
list[dict]: 增量测试数据,格式与 generate_data() 兼容
"""
from .models import BrIf, BrEval, BrSeq
target_decisions = set(decision_gaps)
found = []
def _find_decisions(node, counter):
if isinstance(node, BrIf):
counter[0] += 1
if counter[0] in target_decisions:
found.append(("IF", node.condition))
_find_decisions(node.true_seq, counter)
_find_decisions(node.false_seq, counter)
elif isinstance(node, BrEval):
counter[0] += 1
if counter[0] in target_decisions:
found.append(("EVALUATE", node.subject))
for _, seq in node.when_list:
_find_decisions(seq, counter)
_find_decisions(node.other_seq, counter)
elif isinstance(node, BrSeq):
for child in node.children:
_find_decisions(child, counter)
_find_decisions(branch_tree, [0])
supplements = []
for i, (kind, label) in enumerate(found):
supplements.append({
"_dec_id": f"incr_{i}",
"_kind": kind,
"_label": str(label)[:60],
})
return supplements