From 7fb930421231b1f161e077e167a72952981612ce Mon Sep 17 00:00:00 2001 From: hangshuo652 Date: Tue, 23 Jun 2026 22:38:17 +0800 Subject: [PATCH] merge local cobol_testgen improvements into v3 shared modules - cond.py: SQLCODE/SQLSTATE handling, alphanumeric >/< boundary fix - output.py: termination tracking, db_input support, _is_field_assigned filter - coverage.py: mark_from_gcov, THRU support, KeyError protection - gcov.py: new file (dependency for coverage.py) - grammar.lark: multi-segment PIC support - read.py: SQL INCLUDE resolution, DECLARE TABLE parsing, * comment fix - core.py: SQL parsing, blocked_names, keyword list - design.py: multi-sentinel, THRU ranges, PERFORM VARYING last iteration - __init__.py: local main() + v3 API functions, guarded imports All 6 ZAN programs verified passing through v3 pipeline --- cobol_testgen/__init__.py | 454 ++++++++++++++++++++++++++------- cobol_testgen/cond.py | 48 +++- cobol_testgen/core.py | 394 +++++++++++++++++++++++------ cobol_testgen/coverage.py | 119 ++++----- cobol_testgen/design.py | 498 +++++++++++++++++++++++++++++++++---- cobol_testgen/gcov.py | 119 +++++++++ cobol_testgen/grammar.lark | 2 +- cobol_testgen/output.py | 93 +++++-- cobol_testgen/read.py | 194 +++++++++++++-- 9 files changed, 1595 insertions(+), 326 deletions(-) create mode 100644 cobol_testgen/gcov.py diff --git a/cobol_testgen/__init__.py b/cobol_testgen/__init__.py index e893b83..ffb6e0f 100644 --- a/cobol_testgen/__init__.py +++ b/cobol_testgen/__init__.py @@ -1,14 +1,14 @@ """COBOL Test Data Generator — 模块化版入口 -from __future__ import annotations 公开 API: extract_structure() — 解析 COBOL 控制流 → dict generate_data() — 生成测试数据 → list[dict] incremental_supplement — 差分补充数据 → list[dict] - check_coverage() — 覆盖率报告 → dict """ +import os import sys +import json import re import logging from datetime import datetime @@ -16,25 +16,45 @@ from pathlib import Path # ── 配置(必须放在本地模块导入之前,避免循环导入) ── -CONFIG = {} +CONFIG = { + 'abend_programs': ['SUB03END'], +} from .read import preprocess, extract_data_division, extract_procedure_division -from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements, parse_file_control -from .core import classify_field_roles, _init_child_names -from .pipeline_bridge import build_branch_tree_fallback +from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements +from .read import parse_file_control, resolve_sql_includes, strip_exec_sql_from_data_div +from .core import build_branch_tree, classify_field_roles, _init_child_names, sql_register_virtual_fields, _find_multi_write_fds from .cond import parse_single_condition, is_field, collect_leaves -from .design_mcdc import enum_paths, _filter_stop -from .design import generate_records +from .pipeline_bridge import build_branch_tree_fallback +from .design_mcdc import enum_paths as mcdc_enum_paths, _filter_stop +from .design import enum_paths, generate_records, get_term_type, extend_abend_programs from .output import output_json, output_input_files -from .coverage import run_coverage, generate_coverage_index, check_coverage +from .coverage import run_coverage, generate_coverage_index from japanese_data import generate_fullwidth_text, generate_halfwidth_katakana, generate_wareki_date +try: + from .runner import run_and_compare, run_all, GroupInfo, GroupResult + _HAVE_RUNNER = True +except ImportError: + _HAVE_RUNNER = False + +try: + from .gcov import run_gcov + _HAVE_GCOV = True +except ImportError: + _HAVE_GCOV = False + +try: + from .to_sql import collect_sql_meta, build_db_input + _HAVE_TOSQL = True +except ImportError: + _HAVE_TOSQL = False + logger = logging.getLogger(__name__) -n__all__ = [ +__all__ = [ "extract_structure", "generate_data", "incremental_supplement", - "check_coverage", "CONFIG", "generate_fullwidth_text", "generate_halfwidth_katakana", @@ -107,6 +127,149 @@ def expand_occurs(fields): return result +# ── PREV 连锁 ── + + +def _constraint_in(cons, field, op, value, want): + for c in cons: + if len(c) == 4 and c[0] == field and c[1] == op and c[2] == value and c[3] == want: + return True + return False + + +def _inc_str(s, length): + try: + return str(int(s) + 1).zfill(length) + except ValueError: + c = list(str(s).ljust(length)[:length]) + for i in range(len(c) - 1, -1, -1): + if c[i] not in ' 9Zz\xff': + c[i] = chr(ord(c[i]) + 1) + break + if c[i] == ' ': + c[i] = '0' + break + if c[i] == '9': + c[i] = '0' + return ''.join(c) + + +def _dec_str(s, length): + try: + n = max(0, int(s) - 1) + return str(n).zfill(length) + except ValueError: + c = list(str(s).ljust(length)[:length]) + for i in range(len(c) - 1, -1, -1): + if c[i] not in ' 0Aa\x00': + c[i] = chr(ord(c[i]) - 1) + break + if c[i] == ' ': + break + if c[i] == '0': + c[i] = '9' + return ''.join(c) + + +def _field_length(fname, fields): + for f in fields: + if f['name'] == fname: + pi = f.get('pic_info', {}) + return pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0) or 1 + return 1 + + +def _chain_prev(records, path_infos, fields, fd_fields, field_to_fd, open_dir): + """跨记录 PREV 连锁。修改 records 使批次执行的路径与实际比较一致。 + + 每个路径 k-1 的约束(PREV OP CURRENT)对应批次中 loop iter k-1 的实际比较: + PREV = records[prev_src].R01 (程序内部保持的前值) + CURRENT = records[k].R01 (当前读入值) + 本函数调整 records[k] 的字段以保证交叉记录比较满足路径约束。 + """ + N = len(records) + if N < 2: + return + + key_fields = [] + time_start_field = None + time_end_field = None + for fname in records[0]: + if fname.startswith('R01') and not fname.startswith('R01INNREC'): + base = fname[3:] + prev_name = 'WRK-PREV-' + base + if prev_name in records[0]: + if 'EMP-ID' in fname or 'APPL-DATE' in fname: + key_fields.append(fname) + if 'END-TIME' in fname: + time_end_field = fname + if 'START-TIME' in fname: + time_start_field = fname + + prev_src = 0 + for k in range(1, N): + if k - 1 >= len(path_infos): + break + cons = path_infos[k - 1][0] + + is_same_key = all( + _constraint_in(cons, f'WRK-PREV-{fn[3:]}', '=', fn, True) + for fn in key_fields + ) if key_fields else False + is_overlap = is_same_key and time_end_field and time_start_field and \ + _constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, True) + is_normal = is_same_key and time_end_field and time_start_field and \ + (_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '<=', time_start_field, True) or + _constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, False)) + + for fname in records[prev_src]: + if fname.startswith('R01') and not fname.startswith('R01INNREC'): + base = fname[3:] + prev_name = 'WRK-PREV-' + base + if prev_name in records[k]: + records[k][prev_name] = records[prev_src][fname] + + if is_same_key: + for kf in key_fields: + if kf in records[k] and kf in records[prev_src]: + records[k][kf] = records[prev_src][kf] + + if is_normal and time_end_field and time_start_field: + prev_end = records[prev_src].get(time_end_field, '') + curr_start = records[k].get(time_start_field, '') + if prev_end >= curr_start: + length = _field_length(time_start_field, fields) + records[k][time_start_field] = _inc_str(prev_end, length) + + if is_overlap and time_end_field and time_start_field: + prev_end = records[prev_src].get(time_end_field, '') + curr_start = records[k].get(time_start_field, '') + if prev_end <= curr_start: + length = _field_length(time_start_field, fields) + records[k][time_start_field] = _dec_str(prev_end, length) if prev_end else '0' * length + + else: + for kf in key_fields: + if kf in records[k] and kf in records[prev_src]: + if records[k][kf] == records[prev_src][kf]: + length = _field_length(kf, fields) + records[k][kf] = _inc_str(str(records[k][kf]), length) + + records[k]['_w02_path'] = is_same_key and time_end_field and time_start_field and not is_overlap + records[k]['_overlap_path'] = is_overlap + + for fn in list(records[k].keys()): + if fn.startswith('R01') and not fn.startswith('R01INNREC'): + wfn = 'W01' + fn[3:] + if wfn in records[k]: + records[k][wfn] = records[k][fn] + + if is_overlap: + pass + else: + prev_src = k + + # ── 入口 ── def main(): @@ -116,7 +279,32 @@ def main(): args = sys.argv[1:] - # 分离 cobol 文件与输出目录 + do_run = False + gcov_mode = False + temp_dir = None + if '--run' in args: + do_run = True + args.remove('--run') + if '--gcov' in args: + gcov_mode = True + args.remove('--gcov') + i = 0 + while i < len(args): + if args[i] == '--temp-dir': + if i + 1 < len(args): + temp_dir = args[i + 1] + args.pop(i + 1) + args.pop(i) + else: + args.pop(i) + break + elif args[i].startswith('--temp-dir='): + temp_dir = args[i].split('=', 1)[1] + args.pop(i) + break + else: + i += 1 + cobol_files = [] outdir = None for a in args: @@ -133,13 +321,13 @@ def main(): if outdir is None: outdir = cobol_files[0].parent - # 配置全局 Logger outdir.mkdir(parents=True, exist_ok=True) - log_path = outdir / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log" + (outdir / 'logs').mkdir(parents=True, exist_ok=True) + log_path = outdir / 'logs' / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log" fh = logging.FileHandler(log_path, encoding="utf-8", mode="w") fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter( -"%(asctime)s [%(levelname)s] %(name)s: %(message)s" + "%(asctime)s [%(levelname)s] %(name)s: %(message)s" )) sh = logging.StreamHandler() sh.setLevel(logging.INFO) @@ -157,12 +345,20 @@ def main(): continue source = filepath.read_text(encoding='utf-8') - source = resolve_copybooks(source, str(filepath.parent)) + source = resolve_copybooks( + source, + str(filepath.parent), + extra_search_paths=[str(filepath.parent / '..' / 'cpy')], + ) + source = resolve_sql_includes(source, str(filepath.parent)) preprocessed = preprocess(source) file_sec = parse_file_section(preprocessed) - # DATA DIVISION解析 data_div = extract_data_division(preprocessed) + if data_div: + data_div, declared_columns = strip_exec_sql_from_data_div(data_div) + else: + declared_columns = {} if not data_div: logger.error(f"错误:{filepath.name} 中没有 DATA DIVISION。") continue @@ -172,7 +368,6 @@ def main(): logger.error(f"错误:{filepath.name} 中没有找到含 PIC 的字段。") continue - # FieldDef → dict fields_dict = [] parent_pic = {} filler_counter = 0 @@ -206,7 +401,6 @@ def main(): if f.is_88: entry['is_88'] = True entry['parent'] = f.parent - # Copy parent's pic_info for value generation if f.parent and f.parent in parent_pic: entry['pic_info'] = dict(parent_pic[f.parent]) else: @@ -215,7 +409,8 @@ def main(): fields_dict = expand_occurs(fields_dict) - # Build FD→children 和 field→FD 映射 + sql_register_virtual_fields(fields_dict) + fd_fields = {} field_to_fd = {} if file_sec: @@ -245,13 +440,12 @@ def main(): pic_display = str(f.get('pic', '')) if f.get('pic') else ('88-level' if f.get('is_88') else '') logger.info(f"{f['level']:<6} {f['name']:<25} {pic_display:<15} {t:<12} {l:<5}") - # PROCEDURE DIVISION解析 proc_div = extract_procedure_division(preprocessed) branch_paths = [] assignments = {} if proc_div: - branch_tree, assignments = build_branch_tree_fallback(proc_div, fields_dict) + branch_tree, assignments = build_branch_tree(proc_div, fields_dict, full_source=preprocessed) roles = classify_field_roles(branch_tree, assignments, fields_dict, source=preprocessed, proc_text=proc_div) @@ -261,12 +455,32 @@ def main(): continue logger.info(f" {f['name']:<30} {roles.get(f['name'], '?')}") + abend_list = CONFIG.get('abend_programs', []) + if abend_list: + extend_abend_programs(abend_list) branch_paths_with_assigns = enum_paths(branch_tree, fields_dict) - branch_paths_with_assigns = [ - (_filter_stop(c), a) for c, a in branch_paths_with_assigns - ] + path_infos = [] + for c, a in branch_paths_with_assigns: + filtered_c, term = get_term_type(c) + path_infos.append((filtered_c, a, term)) + + def _is_skip(cons): + eq1_true = 0 + other = 0 + for c in cons: + if len(c) == 4 and c[0] == 'WRK-R01EOF': + val = str(c[2]).strip("'\"") + if val == '1' and c[1] == '=' and c[3]: + eq1_true += 1 + else: + other += 1 + return eq1_true > 0 and other == 0 + + before = len(path_infos) + path_infos = [p for p in path_infos if not _is_skip(p[0])] + after = len(path_infos) + logger.info(f" SKIP 过滤: {before} -> {after} 条路径(预期减少 1)") - # OPEN 方向解析 open_dir = scan_open_statements(proc_div) if proc_div else {} if proc_div: @@ -284,26 +498,104 @@ def main(): else: logger.warning("\n没有找到 PROCEDURE DIVISION。") branch_paths_with_assigns = [([], {})] + path_infos = [([], {}, 'normal')] roles = {f['name']: 'unused' for f in fields_dict} - # 覆盖率报告(传入原始源文本用于行号定位) - cov_prefix = str(outdir / filepath.stem) - index_relpath = 'coverage/index.html' - cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict, - source, cov_prefix, index_relpath=index_relpath) + records, _, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec) - records, kept_path_cons = generate_records(branch_paths_with_assigns, fields_dict, assignments, file_sec=file_sec) + def _is_eof_path(cons): + last_eq1_true = -1 + for i, c in enumerate(cons): + if len(c) == 4 and c[0] == 'WRK-R01EOF': + val = str(c[2]).strip("'\"") + if val == '1' and c[1] == '=' and c[3]: + last_eq1_true = i + if last_eq1_true < 0: + return False + for i in range(last_eq1_true + 1, len(cons)): + if len(cons[i]) == 4 and cons[i][0] == 'WRK-R01EOF': + return False + return True + eof_mask = [_is_eof_path(c) for c, a, t in path_infos] + eof_count = sum(eof_mask) + if eof_count: + term_types = ['eof' if e else t for e, t in zip(eof_mask, term_types)] + logger.info(f" EOF 路径: {eof_count} 条(将单独执行)") - # 输出 JSON(完整文件) - outpath = outdir / (filepath.stem + '.json') + multi_write_fds = _find_multi_write_fds(branch_tree, field_to_fd) if proc_div and branch_tree else set() + if multi_write_fds: + logger.info(f" 检测到多 WRITE FD: {', '.join(sorted(multi_write_fds))}") + _chain_prev(records, path_infos, fields_dict, fd_fields, field_to_fd, open_dir) + + if _HAVE_TOSQL: + sql_meta = collect_sql_meta(assignments, declared_columns) + db_input = build_db_input( + branch_paths_with_assigns, fields_dict, assignments, sql_meta, declared_columns, + records=records, + ) + else: + db_input = None + + (outdir / 'json').mkdir(parents=True, exist_ok=True) + outpath = outdir / 'json' / (filepath.stem + '.json') output_json(records, outpath, roles, fd_fields=fd_fields, field_to_fd=field_to_fd, open_dir=open_dir, - path_cons_list=kept_path_cons) + term_types=term_types, + db_input=db_input if db_input else None, + data_fields=fields_dict) - # 输出入力 JSON(按 FD 拆分) - output_input_files(records, outdir, filepath.stem, roles, - fd_fields, field_to_fd, open_dir) + output_input_files(records, outdir / 'input', filepath.stem, roles, + fd_fields, field_to_fd, open_dir, + term_types=term_types) + + gcov_data = None + if gcov_mode and proc_div and _HAVE_GCOV: + select_info = parse_file_control(preprocessed) + _temp = temp_dir or str(outdir / '.gcov_cache') + source_dir = str(filepath.parent) + expected_records: list[dict] = [{}] * len(records) + if file_sec and os.path.exists(outpath): + with open(outpath, encoding='utf-8') as f: + full_json = json.load(f) + json_records = full_json.get('records', []) + for i in range(len(records)): + exp = {} + if i < len(json_records): + json_rec = json_records[i] + for fd_name in file_sec: + eo = json_rec.get('expected_output', {}) + if fd_name in eo: + exp.update(eo[fd_name]) + expected_records[i] = exp + + group_results = run_all( + filepath.stem, str(outdir), _temp, + fields_dict, fd_fields, select_info, open_dir, + term_types, records, expected_records=expected_records, + source_dir=source_dir, path_infos=path_infos, + multi_write_fds=multi_write_fds, + ) + gcov_data = run_gcov(filepath.stem, _temp) + + passed = sum(1 for r in group_results if r.passed) + total = len(group_results) + logger.info(f"\n 执行验证: {passed}/{total} 组通过") + if passed < total: + for r in group_results: + if not r.passed and r.details: + fails = [d for d in r.details if not d.match][:3] + for d in fails: + logger.warning(f" [{r.name}] {d.field}: " + f"期望={d.expected!r}, 实际={d.actual!r}") + + if do_run and proc_div and _HAVE_RUNNER: + select_info = parse_file_control(preprocessed) + run_and_compare( + filepath.stem, str(outdir), fields_dict, + fd_fields, select_info, open_dir, + term_types, records, + ) logger.info(f"\n输出:{outpath}({len(records)} 条记录)") logger.debug(f"\n记录明细:") @@ -315,11 +607,17 @@ def main(): vals.append(f"{marker}{f['name']}={rec.get(f['name'], '?')}") logger.debug(f" 记录 {i}: {' | '.join(vals)}") + (outdir / 'coverage').mkdir(parents=True, exist_ok=True) + cov_prefix = str(outdir / 'coverage' / filepath.stem) + index_relpath = 'index.html' + cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict, + source, cov_prefix, index_relpath=index_relpath, + gcov_data=gcov_data) + programs.append(cov_result) - # 生成覆盖率总括索引页 if programs: - generate_coverage_index(programs, outdir) + generate_coverage_index(programs, outdir / 'coverage') logger.info(f"\n覆盖率总览:{outdir / 'coverage' / 'index.html'}") @@ -429,18 +727,14 @@ def extract_structure(cobol_source: str) -> dict: if m: paragraphs.add(m.group(1)) - # ── 新增字段: select_files ── select_files = parse_file_control(preprocessed) - # ── 新增字段: open_directions_detail (与 open_directions 一致) ── open_directions_detail = open_dir - # ── 新增字段: has_divide / has_inspect / has_string ── has_divide = bool(re.search(r'\bDIVIDE\b', cobol_source.upper())) has_inspect = bool(re.search(r'\bINSPECT\b', cobol_source.upper())) has_string = bool(re.search(r'\bSTRING\b', cobol_source.upper())) - # ── 新增字段: divide_constants ── divide_constants = [] if has_divide and proc_div: for dm in re.finditer(r'\bDIVIDE\s+([\d.]+)\b', proc_div, re.IGNORECASE): @@ -450,7 +744,6 @@ def extract_structure(cobol_source: str) -> dict: except ValueError: pass - # ── 新增字段: perform_patterns ── perform_patterns = [] def _walk_performs(node): @@ -478,7 +771,6 @@ def extract_structure(cobol_source: str) -> dict: if branch_tree: _walk_performs(branch_tree) - # ── 新增字段: main_loop ── main_loop = None def _find_main_loop(node, depth=0): @@ -533,7 +825,6 @@ def extract_structure(cobol_source: str) -> dict: if branch_tree: _find_main_loop(branch_tree) - # ── 新增字段: if_types ── if_types = {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0} def _walk_if_types(node, depth=0): @@ -543,7 +834,6 @@ def extract_structure(cobol_source: str) -> dict: ct = node.cond_tree if ct: leaves = collect_leaves(ct) - # Check compound: cond_tree is CondAnd or CondOr (not just CondLeaf) if isinstance(ct, (CondAnd, CondOr)): if_types["compound"] += 1 for leaf in leaves: @@ -566,7 +856,6 @@ def extract_structure(cobol_source: str) -> dict: if branch_tree: _walk_if_types(branch_tree) - # ── 新增字段: variable_patterns ── variable_patterns = { "has_prev_key": False, "has_accumulator": False, @@ -597,14 +886,12 @@ def extract_structure(cobol_source: str) -> dict: if re.search(r'[-_]W\b|[-_]WORK\b|[-_]WK\b|^WS-W[0O]\w', name, re.IGNORECASE): variable_patterns["has_work"] = True - # ── 新增字段: open_pattern ── open_pattern = "sequential" if proc_div: proc_upper = proc_div.upper() open_positions = [m.start() for m in re.finditer(r'\bOPEN\b', proc_upper)] close_positions = [m.start() for m in re.finditer(r'\bCLOSE\b', proc_upper)] if open_positions and close_positions: - # Check OPEN ... CLOSE ... OPEN sequence for i, opos in enumerate(open_positions): for cpos in close_positions: if cpos > opos: @@ -618,30 +905,29 @@ def extract_structure(cobol_source: str) -> dict: break return { -"paragraphs": sorted(paragraphs) if paragraphs else [], -"decision_points": decision_points, -"branch_tree": branch_tree, -"file_count": len(file_sec) if file_sec else 0, -"open_directions": open_dir, -"has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points), -"has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points), -"has_call": 'CALL' in cobol_source.upper(), -"has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points), -"total_branches": total_branches, -"total_paragraphs": len(paragraphs), -"branch_tree_obj": branch_tree, -# ── 新增 8 类结构特征 ── -"select_files": select_files, -"open_directions_detail": open_directions_detail, -"has_divide": has_divide, -"divide_constants": divide_constants, -"has_inspect": has_inspect, -"has_string": has_string, -"perform_patterns": perform_patterns, -"main_loop": main_loop, -"if_types": if_types, -"variable_patterns": variable_patterns, -"open_pattern": open_pattern, + "paragraphs": sorted(paragraphs) if paragraphs else [], + "decision_points": decision_points, + "branch_tree": branch_tree, + "file_count": len(file_sec) if file_sec else 0, + "open_directions": open_dir, + "has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points), + "has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points), + "has_call": 'CALL' in cobol_source.upper(), + "has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points), + "total_branches": total_branches, + "total_paragraphs": len(paragraphs), + "branch_tree_obj": branch_tree, + "select_files": select_files, + "open_directions_detail": open_directions_detail, + "has_divide": has_divide, + "divide_constants": divide_constants, + "has_inspect": has_inspect, + "has_string": has_string, + "perform_patterns": perform_patterns, + "main_loop": main_loop, + "if_types": if_types, + "variable_patterns": variable_patterns, + "open_pattern": open_pattern, } @@ -693,11 +979,12 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]: file_sec = parse_file_section(preprocessed) - branch_paths = enum_paths(branch_tree, fields_dict) - branch_paths = [(_filter_stop(c), a) for c, a in branch_paths] + branch_paths_unfiltered = mcdc_enum_paths(branch_tree, fields_dict) + path_infos = [] + for c, a in branch_paths_unfiltered: + filtered_c, term = get_term_type(c) + path_infos.append((filtered_c, a, term)) - # Filter: remove constraints whose field doesn't exist in fields_dict. - # Resolve OF-qualified names and subscripts for matching. _fdict_names = {f['name'] for f in fields_dict} def _resolve_field(fn: str) -> str: ufn = fn.upper() @@ -708,7 +995,7 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]: return m.group(1) return fn filtered_paths = [] - for cons_list, asgn in branch_paths: + for cons_list, asgn, term in path_infos: clean = [] for c in cons_list: if len(c) >= 4: @@ -718,12 +1005,11 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]: clean.append(tuple(c)) else: clean.append(c) - filtered_paths.append((clean, asgn)) - branch_paths = filtered_paths + filtered_paths.append((clean, asgn, term)) + path_infos = filtered_paths - records, kept_paths = generate_records(branch_paths, fields_dict, assignments, file_sec=file_sec) + records, kept_paths, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec) - # Cross-file KEY alignment for matching programs if records: import re as _re proc_upper = (proc_div or "").upper() diff --git a/cobol_testgen/cond.py b/cobol_testgen/cond.py index bb7fb54..9c5c3c6 100644 --- a/cobol_testgen/cond.py +++ b/cobol_testgen/cond.py @@ -44,12 +44,34 @@ def parse_single_condition(text, fields=None): - Bare: WS-EOF → (WS-EOF, '=', 'Y') - NOT bare: NOT WS-EOF → (WS-EOF, '<>', 'Y') - NOT arith: A+B NOT = C → ('A+B', '<>', 'C') + - SQLCODE: SQLCODE = 100 → ('SQLCODE', '=', '100') + - SQLSTATE: SQLSTATE <> '02000' → ('SQLSTATE', '<>', '02000') Returns None for compound (AND/OR) conditions. """ if ' AND ' in text or ' OR ' in text: return None text = text.strip() + field_name = text.split()[0] if text else '' + + # SQLCODE special handling + if field_name.upper() == 'SQLCODE': + text_upper = text.upper() + if 'GREATER THAN 0' in text_upper or 'GREATER THAN ZERO' in text_upper: + return ('SQLCODE', '>', '0') + if 'LESS THAN 0' in text_upper: + return ('SQLCODE', '<', '0') + if '= 100' in text_upper: + return ('SQLCODE', '=', '100') + if 'NOT = 100' in text_upper: + return ('SQLCODE', '<>', '100') + + # SQLSTATE special handling + if field_name.upper() == 'SQLSTATE': + normalized_sql = re.sub(r'\bNOT\s*=', '<>', text, flags=re.IGNORECASE) + m = re.match(r"SQLSTATE\s*(>=|<=|<>|>|<|=)\s*['\"]?(.+?)['\"]?\s*$", normalized_sql, re.IGNORECASE) + if m: + return ('SQLSTATE', m.group(1), m.group(2).strip().strip("'\"")) # Resolve 88-level condition names if fields: @@ -62,9 +84,9 @@ def parse_single_condition(text, fields=None): # Bare NOT field reference (no operator): NOT WS-EOF → WS-EOF <> 'Y' if text.upper().startswith('NOT ') and not re.search(r'(>=|<=|<>|>|<|=)', text): - field_name = text[4:].strip() - if re.match(r'^[A-Z][A-Z0-9-]*(?:\([^)]*\))?$', field_name, re.IGNORECASE): - return (field_name, '<>', 'Y') + fn = text[4:].strip() + if re.match(r'^[A-Z][A-Z0-9-]*(?:\([^)]*\))?$', fn, re.IGNORECASE): + return (fn, '<>', 'Y') # Normalize COBOL NOT-operators: X NOT = Y → X <> Y normalized = text @@ -292,11 +314,31 @@ def satisfying_value(field_info: dict, operator: str, value, want_true: bool) -> elif operator in ('<>', '!='): other = chr(65 + (ord(base_chr) - 64) % 26) return other.ljust(length, other) + elif operator == '>': + sv = str(value)[:length].ljust(length) + chars = list(sv) + last = chars[-1] + if last not in '9Zz': + chars[-1] = chr(ord(last) + 1) + return ''.join(chars) + elif operator == '<': + sv = str(value)[:length].ljust(length) + chars = list(sv) + last = chars[-1] + if last == ' ': + pass + elif last in '0Aa': + chars[-1] = ' ' + else: + chars[-1] = chr(ord(last) - 1) + return ''.join(chars) else: if operator in ('=', '=='): other = chr(65 + (ord(base_chr) - 64) % 26) return other.ljust(length, other) elif operator in ('<>', '!='): return base_chr.ljust(length, base_chr) + elif operator in ('>', '<'): + return str(value)[:length].ljust(length) return '0'.zfill(total) diff --git a/cobol_testgen/core.py b/cobol_testgen/core.py index 3daeffe..c63714c 100644 --- a/cobol_testgen/core.py +++ b/cobol_testgen/core.py @@ -15,16 +15,29 @@ _COBOL_SCOPE_ENDERS = { 'END-SEARCH', 'ELSE', 'WHEN', 'OTHER', } +_COBOL_KEYWORDS = { + 'GOBACK', 'EXIT', 'STOP', 'CONTINUE', + 'ACCEPT', 'DISPLAY', 'MOVE', 'COMPUTE', 'INITIALIZE', + 'ADD', 'SUBTRACT', 'MULTIPLY', 'DIVIDE', + 'STRING', 'UNSTRING', 'SET', 'INSPECT', + 'OPEN', 'CLOSE', 'READ', 'WRITE', 'REWRITE', 'DELETE', 'START', + 'PERFORM', 'CALL', 'IF', 'EVALUATE', 'SEARCH', 'SORT', 'MERGE', + 'COMMIT', 'ROLLBACK', 'GO', +} -def scan_paragraphs(raw_lines): +def scan_paragraphs(raw_lines, blocked_names=None): paragraphs = {} i = 0 + blocked = set() + if blocked_names: + for n in blocked_names: + blocked.add(n.upper()) while i < len(raw_lines): line = raw_lines[i].strip() m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', line) - sec_m = re.match(r'^([A-Z][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE) - if m and m.group(1) not in _COBOL_SCOPE_ENDERS: + sec_m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE) + if m and m.group(1) not in _COBOL_SCOPE_ENDERS and m.group(1) not in _COBOL_KEYWORDS and m.group(1) not in blocked: name = m.group(1) elif sec_m: name = sec_m.group(1).upper() @@ -36,9 +49,9 @@ def scan_paragraphs(raw_lines): while j < len(raw_lines): nline = raw_lines[j].strip() nm = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', nline) - if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS: + if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS and nm.group(1) not in _COBOL_KEYWORDS and nm.group(1) not in blocked: break - if re.match(r'^[A-Z][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE): + if re.match(r'^[A-Z0-9][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE): break j += 1 paragraphs[name] = (start, j - 1) @@ -46,9 +59,47 @@ def scan_paragraphs(raw_lines): return paragraphs -def build_branch_tree(proc_text, fields=None): +def sql_register_virtual_fields(fields_dict: list[dict]) -> list[dict]: + """Inject SQLCODE, SQLSTATE as virtual fields if not already present.""" + virtual = [] + if not any(f['name'] == 'SQLCODE' for f in fields_dict): + virtual.append({ + 'name': 'SQLCODE', + 'level': 77, 'pic': 'S9(9)', + 'pic_info': {'type': 'numeric', 'digits': 9, 'decimal': 0, + 'length': 4, 'signed': True}, + 'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None, + 'usage': 'COMP', 'occurs': 0, 'occurs_depending': None, + 'value': None, 'values': None, + }) + if not any(f['name'] == 'SQLSTATE' for f in fields_dict): + virtual.append({ + 'name': 'SQLSTATE', + 'level': 77, 'pic': 'X(5)', + 'pic_info': {'type': 'alphanumeric', 'length': 5}, + 'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None, + 'usage': 'DISPLAY', 'occurs': 0, 'occurs_depending': None, + 'value': None, 'values': None, + }) + fields_dict.extend(virtual) + return fields_dict + + +def build_branch_tree(proc_text, fields=None, full_source=None): raw_lines = proc_text.split('\n') - paragraphs = scan_paragraphs(raw_lines) + # Collect data names (FD names, record names, field names) to block paragraph detection + blocked_names = set() + if fields: + for f in fields: + if isinstance(f, dict): + blocked_names.add(f['name'].upper()) + else: + blocked_names.add(f.name.upper()) + # Extract FD names from full source if available (includes DATA DIVISION) + src = full_source or proc_text + for m in re.finditer(r'\bFD\s+(\w[\w-]*)\b', src, re.IGNORECASE): + blocked_names.add(m.group(1).upper()) + paragraphs = scan_paragraphs(raw_lines, blocked_names=blocked_names) first_para_name = None first_para_idx = None @@ -169,6 +220,13 @@ class _BrParser: if m_search: seq.add(self._parse_search(m_search)) continue + m_exec = re.match(r'^EXEC\s+SQL\s*$', line, re.IGNORECASE) + if m_exec: + sql_block = self._parse_sql_block() + assign_node = self._parse_sql(sql_block) + if assign_node: + seq.add(assign_node) + continue m = re.match(r'^INITIALIZE\s+', line) if m: init_seq = self._parse_initialize() @@ -192,7 +250,7 @@ class _BrParser: seq.add(self._parse_call()) continue m = re.match( - r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS))?\s*$', + r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS|SYSIN|COMMAND-LINE|SYSERR|SYSOUT|ENVIRONMENT-NAME|ENVIRONMENT-VALUE))?\s*$', line, re.IGNORECASE ) if m: @@ -211,21 +269,11 @@ class _BrParser: seq.add(Assign(tgt, info)) self.advance() # 跳过 READ 语句剩余行(AT END / NOT AT END / END-READ) - # 遇到新的语句关键词时停止,避免贪婪吞咽后续内容 - _stmt_boundary = re.compile( - r'^(IF |EVALUATE |PERFORM |SEARCH |INITIALIZE |STRING |' - r'UNSTRING |CALL |ACCEPT |READ |WRITE |REWRITE |SET |' - r'INSPECT |MOVE |COMPUTE |ADD |SUBTRACT |MULTIPLY |DIVIDE |' - r'GO\s+TO |GOBACK |STOP\s+RUN|EXIT\s|CLOSE |OPEN |DISPLAY |' - r'DELETE |START |' - r'END-IF|END-PERFORM|END-EVALUATE|END-READ)', re.IGNORECASE) while self.pos < len(self.lines): cl = self.clean() if cl in ('END-READ', 'END-READ.'): self.advance() break - if _stmt_boundary.match(cl): - break self.advance() continue m_set_false = re.match(r'^SET\s+(\w[\w-]*)\s+TO\s+FALSE\s*$', line, re.IGNORECASE) @@ -366,11 +414,34 @@ class _BrParser: else: tgt_key = tgt_base src_clean = raw_src.strip("'").strip('"') - is_field_name = self.fields and any(f['name'] == src_clean for f in self.fields) - if is_field_name: - info = {'type': 'move', 'source_vars': [src_clean]} + # 检测引用修饰 FIELD(start:length) + rm = re.match(r'^(\w[\w-]*)\(\s*(\d+)\s*:\s*(\d+)\s*\)$', src_clean, re.IGNORECASE) + if rm: + base_src = rm.group(1) + refmod_start = int(rm.group(2)) + refmod_length = int(rm.group(3)) + is_field_name = self.fields and any( + (f['name'] if isinstance(f, dict) else f.name) == base_src + for f in self.fields + ) + if is_field_name: + info = { + 'type': 'move', + 'source_vars': [base_src], + 'refmod_start': refmod_start, + 'refmod_length': refmod_length, + } + else: + info = {'type': 'move_literal', 'literal': src_clean} else: - info = {'type': 'move_literal', 'literal': src_clean} + is_field_name = self.fields and any( + (f['name'] if isinstance(f, dict) else f.name) == src_clean + for f in self.fields + ) + if is_field_name: + info = {'type': 'move', 'source_vars': [src_clean]} + else: + info = {'type': 'move_literal', 'literal': src_clean} self.assignments.setdefault(tgt_key, []).append(info) return Assign(tgt_key, info) @@ -648,40 +719,11 @@ class _BrParser: line = self.clean() m = re.match(r'^IF\s+(.+?)(?:THEN)?\s*$', line) cond_text = m.group(1).strip() - # Truncate at COBOL statement keywords (single-line IF body after condition) - _stmt_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|' - r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|' - r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH)\b') - _stmt_starts = re.compile(_stmt_pat, re.IGNORECASE) - rest = "" # remaining text after condition truncation (single-line IF body) - sm = _stmt_starts.search(cond_text) - if sm: - rest = cond_text[sm.start():] - cond_text = cond_text[:sm.start()] self.advance() - if rest: - rest = rest.strip() - if rest.endswith('.'): - rest = rest[:-1] - # Split on ELSE but keep ELSE as its own line for parse_seq boundary - else_parts = re.split(r'(\s+ELSE\s+)', rest, maxsplit=1, flags=re.IGNORECASE) - parts = [p.strip() for p in else_parts if p.strip()] - insert_parts = [] - for p in parts: - if p.upper() == 'ELSE': - insert_parts.append('ELSE') - else: - insert_parts.append(p if '.' in p else p + '.') - for part in reversed(insert_parts): - self.lines.insert(self.pos, part) # Join continuation lines (multi-line IF conditions) - _cont_keywords = (r'THEN|ELSE|END-IF|MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|' - r'DIVIDE|STRING|UNSTRING|INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|' - r'READ|WRITE|REWRITE|DELETE|START|INSPECT|SET|IF|GO\b|EXIT\b|' - r'STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH') while self.pos < len(self.lines): peek = self.clean() - if re.match(r'^(' + _cont_keywords + r')', peek, re.IGNORECASE): + if re.match(r'^(THEN|ELSE|END-IF|EXEC|MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT)', peek, re.IGNORECASE): break if peek.endswith('.'): cond_text += ' ' + peek.rstrip('.') @@ -697,16 +739,8 @@ class _BrParser: node = BrIf(cond_text) node.cond_tree = parse_compound_condition(node.condition, self.fields) node.true_seq = self.parse_seq(['ELSE', 'END-IF']) - clean = self.clean() - if clean.startswith('ELSE'): - self.advance() # consume ELSE keyword - rest = clean[4:].strip() if len(clean) > 4 else '' - # ELSE IF → reinsert IF statement as next line for recursive parse - if rest.upper().startswith('IF '): - self.lines.insert(self.pos, rest) - elif rest: - # Regular ELSE body text on same line as ELSE: reinsert - self.lines.insert(self.pos, rest if '.' in rest else rest + '.') + if self.clean() == 'ELSE': + self.advance() node.false_seq = self.parse_seq(['END-IF']) if self.clean() == 'END-IF': self.advance() @@ -728,13 +762,6 @@ class _BrParser: m = re.match(r'^WHEN\s+(.+?)\s*$', line) if m: raw_val = m.group(1).strip().strip("'").strip('"') - # Truncate at COBOL statement keywords (single-line WHEN body after condition) - _eval_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|' - r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|' - r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\b|GOBACK|CLOSE|OPEN|SEARCH)\b') - _eval_stmt = re.search(_eval_pat, raw_val, re.IGNORECASE) - if _eval_stmt: - raw_val = raw_val[:_eval_stmt.start()] self.advance() # Capture multi-line WHEN conditions (AND/OR continuation) while self.pos < len(self.lines): @@ -848,6 +875,14 @@ class _BrParser: if um: condition = um.group(1).strip() self.advance() + # Join continuation lines (AND/OR on next lines) + while self.pos < len(self.lines): + peek = self.clean() + if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE): + condition += ' ' + peek + self.advance() + else: + break break break if from_val and by_val and condition: @@ -894,6 +929,30 @@ class _BrParser: m = re.match(r'^PERFORM\s+(\w[\w-]*)\s*$', line) if m: target = m.group(1).strip() + save_pos = self.pos + condition = None + self.advance() + while self.pos < len(self.lines): + nxt = self.clean() + um = re.match(r'^UNTIL\s+(.+)$', nxt) + if um: + condition = um.group(1).strip() + self.advance() + # Join continuation lines (AND/OR on next lines) + while self.pos < len(self.lines): + peek = self.clean() + if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE): + condition += ' ' + peek + self.advance() + else: + break + break + break + if condition: + node = BrPerform('para_until', target=target, condition=condition) + self._inline_perform(node, target) + return node + self.pos = save_pos node = BrPerform('para', target=target) self.advance() self._inline_perform(node, target) @@ -962,12 +1021,18 @@ class _BrParser: parts = [self.clean()] self.advance() while self.pos < len(self.lines): + peek = self.peek() cl = self.clean() if cl == 'END-STRING': self.advance() break + # Stop when a new COBOL statement keyword is encountered + if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-STRING)', peek, re.IGNORECASE): + break parts.append(cl) self.advance() + if peek.rstrip().endswith('.'): + break full = ' '.join(parts) m = re.match(r'^STRING\s+(.+)\s+INTO\s+(\w[\w-]*)\s*$', full, re.IGNORECASE | re.DOTALL) if not m: @@ -985,12 +1050,17 @@ class _BrParser: parts = [self.clean()] self.advance() while self.pos < len(self.lines): + peek = self.peek() cl = self.clean() if cl == 'END-UNSTRING': self.advance() break + if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-UNSTRING)', peek, re.IGNORECASE): + break parts.append(cl) self.advance() + if peek.rstrip().endswith('.'): + break full = ' '.join(parts) m = re.match(r'^UNSTRING\s+(.+?)\s+INTO\s+(.+?)\s*$', full, re.IGNORECASE | re.DOTALL) if not m: @@ -1088,6 +1158,75 @@ class _BrParser: self.advance() return Assign(tgt, info) + # ── EXEC SQL parsing ── + + _RE_SELECT_INTO = re.compile( + r'SELECT\s+(.*?)\s+INTO\s+(:\w[\w-]*(?:\s*,\s*:\w[\w-]*(?::\w[\w-]*)?)*)' + r'\s+FROM\s+(\w[\w-]*)', + re.IGNORECASE + ) + + _RE_WHERE = re.compile(r'\bWHERE\b\s+(.*)', re.IGNORECASE) + + def _parse_sql_block(self) -> str: + """Consume lines from EXEC SQL until END-EXEC. Returns SQL text.""" + texts = [] + self.advance() + while self.pos < len(self.lines): + line = self.lines[self.pos].rstrip('.') + m = re.match(r'(.*?)END-EXEC\.?\s*$', line, re.IGNORECASE) + if m: + before = m.group(1).strip() + if before: + texts.append(before) + self.advance() + break + texts.append(line) + self.advance() + result = ' '.join(texts) + result = re.sub(r'\s+', ' ', result) + return result + + def _parse_sql(self, sql_text: str): + """Parse SQL text from EXEC SQL block. Returns Assign node or None.""" + m = self._RE_SELECT_INTO.search(sql_text) + if not m: + return None + + select_list = m.group(1).strip() + into_raw = m.group(2).strip() + from_table = m.group(3).strip().upper() + remaining = sql_text[m.end():].strip() + + # Parse INTO variables (handle indicator vars: :host:indicator) + into_vars = [] + for v in re.split(r'\s*,\s*', into_raw): + v = v.strip().lstrip(':') + parts = v.split(':') + into_vars.append(parts[0].upper()) + if len(parts) > 1: + into_vars.append(parts[1].upper()) + + # Extract WHERE clause + where_clause = '' + wm = self._RE_WHERE.search(remaining) + if wm: + where_clause = wm.group(1).strip() + + info = { + 'type': 'exec_sql_select', + 'table': from_table, + 'select_list': select_list, + 'into_vars': into_vars, + 'where': where_clause, + 'sql_text': sql_text, + } + + for var in into_vars: + self.assignments.setdefault(var, []).append(info) + + return Assign(into_vars[0], info) + # ── 工具函数 ── @@ -1141,8 +1280,6 @@ def trace_to_root(field_name, assignments, fields, path_assign=None): asgn = asgn_list else: asgn_list = assignments[var] - if not asgn_list: - break asgn = asgn_list[-1] if isinstance(asgn_list, list): for a in reversed(asgn_list): @@ -1152,6 +1289,8 @@ def trace_to_root(field_name, assignments, fields, path_assign=None): asgn = a break chain.append((var, asgn)) + if asgn.get('type') in ('unstring_split',): + break if not asgn.get('source_vars'): break sv = asgn['source_vars'] @@ -1332,8 +1471,36 @@ def propagate_assignments(rec, assignments, fields, file_sec=None): src = asgn['source_vars'][0] resolved_tgt = _resolve_subscript(tgt, rec) resolved_src = _resolve_subscript(src, rec) - if resolved_src in rec: - rec[resolved_tgt] = rec[resolved_src] + tgt_children = _init_child_names(resolved_tgt, fields) + if tgt_children: + # Group MOVE: propagate to child fields by position + src_children = _init_child_names(resolved_src, fields) + if src_children: + src_str = ''.join(str(rec.get(c, '')) for c in src_children) + elif resolved_src in rec: + src_str = str(rec[resolved_src]) + else: + src_str = '' + if src_str: + rec[resolved_tgt] = src_str + pos = 0 + for tgt_c in tgt_children: + child_len = 0 + for f in fields: + if f['name'] == tgt_c: + pi = f.get('pic_info', {}) + child_len = pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0) + break + if child_len > 0: + rec[tgt_c] = src_str[pos:pos + child_len] if pos < len(src_str) else ('0' if child_len else '') + pos += child_len + elif resolved_src in rec: + src_val = str(rec[resolved_src]) + if asgn.get('refmod_start') and asgn.get('refmod_length'): + start = asgn['refmod_start'] - 1 + end = start + asgn['refmod_length'] + src_val = src_val[start:end] + rec[resolved_tgt] = src_val # Pass 2: literal MOVE for tgt, asgn in flat_list: @@ -1439,9 +1606,7 @@ def propagate_assignments(rec, assignments, fields, file_sec=None): resolved_tgt = _resolve_subscript(tgt, rec) if resolved_tgt not in rec: continue - inspect_src = asgn.get('tgt', tgt) - resolved_src = _resolve_subscript(inspect_src, rec) - src_val = str(rec.get(resolved_src, '')) + src_val = str(rec[resolved_tgt]) for op_type, params in asgn.get('sub_ops', []): if op_type == 'tally': cv = params['count_var'].upper() @@ -1495,6 +1660,10 @@ def propagate_assignments(rec, assignments, fields, file_sec=None): src_var = asgn.get('source_vars', [None])[0] resolved_src = _resolve_subscript(src_var, rec) if src_var else None idx = asgn.get('index', 0) + if resolved_src and resolved_src not in rec: + children = _init_child_names(resolved_src, fields) + if children: + resolved_src = children[0] if resolved_src and resolved_src in rec: src_val = str(rec[resolved_src]) ftype = pi.get('type', 'unknown') @@ -1556,6 +1725,23 @@ def propagate_assignments(rec, assignments, fields, file_sec=None): else: rec[resolved_tgt] = val.ljust(length)[:length] if length else val + # Pass 9: EXEC SQL SELECT INTO + for tgt, asgn in flat_list: + if asgn.get('type') == 'exec_sql_select': + resolved_tgt = _resolve_subscript(tgt, rec) + if resolved_tgt not in rec: + continue + src_val = rec.get(resolved_tgt, '') + pi = pi_map.get(resolved_tgt, {}) + if pi.get('type') == 'numeric': + total = pi.get('digits', 0) + pi.get('decimal', 0) + if total > 0: + rec[resolved_tgt] = str(src_val).zfill(total) + elif pi.get('type') in ('alphanumeric', 'alphabetic'): + length = pi.get('length', 0) + if length > 0: + rec[resolved_tgt] = str(src_val).ljust(length)[:length] + # Pass 8: SET var TO TRUE (88-level) for tgt, asgn in flat_list: if asgn['type'] == 'set_true': @@ -1649,6 +1835,13 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None) elif atype == 'write_from': if tgt_base in counts: counts[tgt_base]['read'] += 1 + elif atype == 'exec_sql_select': + if tgt_base in counts: + counts[tgt_base]['write'] += 1 + for v in node.source_info.get('into_vars', []): + v_base = _basename(v) + if v_base in counts: + counts[v_base]['write'] += 1 elif atype == 'set_true': if tgt_base in counts: counts[tgt_base]['write'] += 1 @@ -1705,3 +1898,52 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None) if name not in result: result[name] = role return result + + +# ── 多 WRITE 检测 ── + + +def _collect_write_fds(node, fds_set, field_to_fd): + """Recursively collect output FD names from WRITE Assigns.""" + if isinstance(node, Assign): + st = node.source_info.get('type', '') + if st in ('write_bare', 'write_from'): + fname = node.target + if fname in field_to_fd: + fds_set.add(field_to_fd[fname]) + elif isinstance(node, BrSeq): + for c in node.children: + _collect_write_fds(c, fds_set, field_to_fd) + elif isinstance(node, BrIf): + _collect_write_fds(node.true_seq, fds_set, field_to_fd) + _collect_write_fds(node.false_seq, fds_set, field_to_fd) + elif isinstance(node, BrEval): + for _, seq in node.when_list: + _collect_write_fds(seq, fds_set, field_to_fd) + _collect_write_fds(node.other_seq, fds_set, field_to_fd) + elif isinstance(node, BrPerform): + _collect_write_fds(node.body_seq, fds_set, field_to_fd) + elif isinstance(node, BrSearch): + _collect_write_fds(node.at_end_seq, fds_set, field_to_fd) + for _, seq in node.when_list: + _collect_write_fds(seq, fds_set, field_to_fd) + + +def _find_multi_write_fds(tree, field_to_fd): + """返回在 INIT 段(主循环前)和循环内部都有 WRITE 的 FD 名集合。 + 主循环 = 顶层 BrSeq 中最后一个 UNTIL 型 BrPerform(包含 para_until)。 + """ + if not isinstance(tree, BrSeq): + return set() + main_loop_idx = -1 + for i, child in enumerate(tree.children): + if isinstance(child, BrPerform) and child.perf_type in ('until', 'para_until', 'varying', 'para_varying'): + main_loop_idx = i + if main_loop_idx < 0: + return set() + pre_write = set() + for child in tree.children[:main_loop_idx]: + _collect_write_fds(child, pre_write, field_to_fd) + loop_write = set() + _collect_write_fds(tree.children[main_loop_idx], loop_write, field_to_fd) + return pre_write & loop_write diff --git a/cobol_testgen/coverage.py b/cobol_testgen/coverage.py index f46d7b4..bdc71af 100644 --- a/cobol_testgen/coverage.py +++ b/cobol_testgen/coverage.py @@ -8,6 +8,7 @@ from pathlib import Path logger = logging.getLogger(__name__) from .models import BrSeq, BrIf, BrEval, BrPerform, BrSearch, CondLeaf from .cond import parse_single_condition, parse_compound_condition, is_field, collect_leaves, evaluate_tree +from .gcov import mark_from_gcov # ── 数据模型 ── @@ -190,11 +191,14 @@ def _mark_if(dp, cons): if _match_leaf(c, leaf): assignment[leaf] = c[3] break - if len(assignment) == len(dp.cond_leaves): - if evaluate_tree(dp.cond_tree, assignment): - dp.active_branches.add('T') - else: - dp.active_branches.add('F') + if assignment: + try: + if evaluate_tree(dp.cond_tree, assignment): + dp.active_branches.add('T') + else: + dp.active_branches.add('F') + except KeyError: + pass else: matched = 0 for leaf in dp.leaves: @@ -253,6 +257,15 @@ def _mark_eval(dp, cons, fields=None): dp.active_branches.add(name) elif c[0] == dp.label and c[1] == 'not_in': dp.active_branches.add('OTHER') + thru_lows = {c[2] for c in cons if c[0] == dp.label and c[1] == '>=' and c[3]} + thru_highs = {c[2] for c in cons if c[0] == dp.label and c[1] == '<=' and c[3]} + if thru_lows or thru_highs: + for when_val, _ in dp.when_list: + thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(when_val), re.IGNORECASE) + if thru_m and thru_m.group(1) in thru_lows and thru_m.group(2) in thru_highs: + name = f"WHEN {when_val}" + if name in dp.branch_names: + dp.active_branches.add(name) def _mark_search(dp, cons, fields=None): @@ -309,11 +322,14 @@ def _mark_perform(dp, cons): if _match_leaf(c, leaf): assignment[leaf] = c[3] break - if len(assignment) == len(dp.cond_leaves): - if evaluate_tree(dp.cond_tree, assignment): - dp.active_branches.add('Skip') - else: - dp.active_branches.add('Enter') + if assignment: + try: + if evaluate_tree(dp.cond_tree, assignment): + dp.active_branches.add('Skip') + else: + dp.active_branches.add('Enter') + except KeyError: + pass else: for c in cons: if c[0] == dp.label or any(c[0] == f for f in _get_fields_in_cond(dp.label)): @@ -330,7 +346,6 @@ def _get_fields_in_cond(cond_text): # ── 行号定位(基于原始源文本)── def locate_decision_lines(decision_points, raw_source): - """在原始源文本中搜索每个决策点的近似行号""" lines = raw_source.upper().splitlines() for dp in decision_points: patterns = _build_search_patterns(dp) @@ -344,7 +359,6 @@ def locate_decision_lines(decision_points, raw_source): def _normalize(text): - """标准化条件文本用于比较:去多余空白、标准化引号""" t = re.sub(r'\s+', ' ', text).strip() t = t.replace('"', "'") return t @@ -360,14 +374,13 @@ def _build_search_patterns(dp): texts.append((r'\bUNTIL\b', dp.condition if hasattr(dp, 'condition') else dp.label if dp.label else '')) else: - return [r'$^'] # 永不匹配 + return [r'$^'] patterns = [] for keyword, condition in texts: if not condition: continue norm_cond = _normalize(condition) - # 转义正则特殊字符,但保留空格(替换为\s+) esc = re.escape(norm_cond) esc = esc.replace(r'\ ', r'\s+') esc = esc.replace(r'\'', r"['\"]") @@ -411,7 +424,6 @@ _DETAIL_HTML = ''' }} .section h2 {{ font-size: 16px; font-weight: 600; color: #1a237e; margin-bottom: 16px; padding-bottom: 8px; border-bottom: 2px solid #e8eaf6; }} - /* 统计卡片行 */ .stats-row {{ display: flex; gap: 16px; flex-wrap: wrap; }} .stat-card {{ flex: 1; min-width: 140px; background: #f5f7fa; border-radius: 8px; padding: 14px 18px; @@ -430,7 +442,6 @@ _DETAIL_HTML = ''' .dot-red {{ background: #ffcdd2; }} .dot-amber {{ background: #fff9c4; }} - /* 进度条 */ .prog-bar-detail {{ width: 100%; height: 12px; border-radius: 6px; background: #ffcdd2; overflow: hidden; margin: 10px 0 6px 0; }} @@ -440,20 +451,17 @@ _DETAIL_HTML = ''' .prog-fill-detail.amber {{ background: linear-gradient(90deg, #ffca28, #ff8f00); }} .prog-fill-detail.red {{ background: linear-gradient(90deg, #ef5350, #ff1744); }} - /* 表格 */ table {{ width: 100%; border-collapse: collapse; table-layout: fixed; }} th, td {{ padding: 10px 14px; text-align: left; border-bottom: 1px solid #eceff1; word-break: break-all; }} th {{ background: #f5f7fa; font-weight: 600; font-size: 12px; color: #78909c; text-transform: uppercase; letter-spacing: 0.5px; }} tbody tr:hover {{ background: #e8eaf6; }} tbody tr:last-child td {{ border-bottom: none; }} - /* 决策表列宽 */ .dp-table th:nth-child(1), .dp-table td:nth-child(1) {{ width: 50px; }} .dp-table th:nth-child(2), .dp-table td:nth-child(2) {{ width: 70px; }} .dp-table th:nth-child(3), .dp-table td:nth-child(3) {{ width: 50px; }} .dp-table th:nth-child(5), .dp-table td:nth-child(5) {{ width: 160px; }} - /* 叶条件表列宽 */ .leaf-table th:nth-child(1), .leaf-table td:nth-child(1) {{ width: 110px; }} .leaf-table th:nth-child(2), .leaf-table td:nth-child(2) {{ width: 60px; }} .leaf-table th:nth-child(4), .leaf-table td:nth-child(4), @@ -468,7 +476,6 @@ _DETAIL_HTML = ''' .cond-ok {{ color: #00c853; }} .cond-miss {{ color: #ff5252; }} - /* 源码 */ .source-section {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }} .source-line {{ display: flex; padding: 1px 0; }} .source-line:hover {{ background: #f5f5f5; }} @@ -534,20 +541,22 @@ _DETAIL_HTML = ''' {source_section} + {source_note} + ''' def generate_html_report(decision_points, leaf_stats, source_lines, outpath, - filename='', index_relpath=None, covered_lines=None): + filename='', index_relpath=None, covered_lines=None, + source_note=''): title = f"覆盖率报告 — {filename}" if filename else "覆盖率报告" total_branches = sum(len(dp.branch_names) for dp in decision_points) covered_branches = sum(len(dp.active_branches) for dp in decision_points) implied_branches = sum(len(dp.implied_branches) for dp in decision_points) if covered_lines: - # 无分支程序:隐式 100% total_branches = max(total_branches, 1) covered_branches = max(covered_branches, 1) @@ -555,15 +564,13 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, covered_leaves = (sum(1 for l in leaf_stats if l.covered_true) + sum(1 for l in leaf_stats if l.covered_false)) - # 计算数值 - is_implicit = bool(covered_lines) # 无分支程序,隐式 100% + is_implicit = bool(covered_lines) dec_pct_val = (covered_branches / total_branches * 100) if total_branches else 0 dec_pct_text = "100% ✓" if is_implicit else (f"{dec_pct_val:.1f}%" if total_branches else "无") dec_frac = "全部覆盖" if is_implicit else (f"{covered_branches}/{total_branches}" if total_branches else "—") cond_frac = f"{covered_leaves}/{total_leaves}" if total_leaves else "—" implied_text = f'(+{implied_branches - covered_branches} 推断)' if implied_branches > covered_branches else '' - # 颜色 if is_implicit or not total_branches or dec_pct_val >= 100: dec_val_cls = 'val-green' bar_cls = '' @@ -581,7 +588,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, else: cond_val_cls = 'val-red' - # 决策点表格 if decision_points: dp_rows = [] for dp in decision_points: @@ -608,7 +614,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, else: decision_table = '' - # 叶条件表格 if leaf_stats: leaf_rows = [] for leaf in leaf_stats: @@ -627,7 +632,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, else: leaf_table = '' - # 源码标注 if source_lines: line_cov = {} for dp in decision_points: @@ -643,7 +647,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, else: line_cov[dp.source_line].append('hl-amber') - # 无分支程序:所有 PD 行标记为已覆盖 if covered_lines: for ln in covered_lines: line_cov.setdefault(ln, []).append('hl-green') @@ -677,6 +680,7 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath, leaf_table=leaf_table, source_section=source_section, dp_count_text=('—' if is_implicit else str(len(decision_points))), + source_note=source_note, ) outpath = Path(outpath) @@ -699,7 +703,6 @@ _INDEX_HTML = ''' background: #f0f2f5; color: #37474f; font-size: 14px; line-height: 1.6; }} - /* 顶栏 */ .topbar {{ background: linear-gradient(135deg, #1a237e, #283593); color: #fff; padding: 18px 32px; @@ -711,7 +714,6 @@ _INDEX_HTML = ''' .container {{ max-width: 1200px; margin: 0 auto; padding: 28px 24px; }} - /* 统计卡片 */ .cards {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 16px; margin-bottom: 28px; }} .card {{ background: #fff; border-radius: 10px; padding: 20px 22px; @@ -725,7 +727,6 @@ _INDEX_HTML = ''' .num-red {{ color: #ff1744; }} .num-blue {{ color: #1a237e; }} - /* 图表行 */ .charts-row {{ display: flex; gap: 32px; justify-content: center; flex-wrap: wrap; background: #fff; border-radius: 10px; padding: 28px 20px; @@ -744,7 +745,6 @@ _INDEX_HTML = ''' .legend .dot-red {{ background: #ff5252; }} .legend .dot-amber {{ background: #ffd740; }} - /* 工具栏 */ .toolbar {{ display: flex; justify-content: space-between; align-items: center; margin-bottom: 14px; flex-wrap: wrap; gap: 10px; @@ -764,7 +764,6 @@ _INDEX_HTML = ''' .toolbar .sort-btn:hover {{ background: #eceff1; }} .toolbar .sort-btn.active {{ background: #e8eaf6; border-color: #3f51b5; color: #1a237e; font-weight: 500; }} - /* 表格 */ .table-wrap {{ background: #fff; border-radius: 10px; overflow: hidden; box-shadow: 0 1px 4px rgba(0,0,0,0.06); @@ -789,7 +788,6 @@ _INDEX_HTML = ''' .prog-name a {{ color: #283593; text-decoration: none; }} .prog-name a:hover {{ text-decoration: underline; color: #1a237e; }} - /* 进度条 */ .prog-wrap {{ display: inline-flex; align-items: center; gap: 10px; width: 100%; }} @@ -812,7 +810,6 @@ _INDEX_HTML = ''' .prog-fill.full {{ border-radius: 10px; }} .prog-text {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; white-space: nowrap; min-width: 48px; }} - /* 状态徽标 */ .badge {{ display: inline-block; padding: 3px 10px; border-radius: 12px; font-size: 12px; font-weight: 600; letter-spacing: 0.3px; @@ -821,10 +818,8 @@ _INDEX_HTML = ''' .badge-warn {{ background: #fff8e1; color: #e65100; }} .badge-fail {{ background: #ffebee; color: #c62828; }} - /* 条件覆盖列 */ .cond-cell {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }} - /* 响应式 */ @media (max-width: 680px) {{ .topbar {{ flex-direction: column; align-items: flex-start; gap: 6px; padding: 14px 18px; }} .container {{ padding: 16px 12px; }} @@ -968,7 +963,6 @@ function filterTable() {{ def _ring_svg(pct, color_stops): - """生成 SVG 圆环 HTML。pct: 0-100 浮点数。""" r = 54 circ = 2 * 3.14159265 * r offset = circ * (1 - pct / 100) if pct > 0 else circ @@ -995,7 +989,6 @@ def _ring_svg(pct, color_stops): def generate_coverage_index(programs, outdir): - """生成覆盖率总括索引页。""" from datetime import datetime timestamp = datetime.now().strftime('%Y-%m-%d %H:%M') @@ -1038,7 +1031,6 @@ def generate_coverage_index(programs, outdir): cond_text = f"{cc}/{tc}" if tc else "—" bar_pct = int(pct_dec) - # 进度条颜色 if imp or pct_dec >= 100: bar_cls = '' elif pct_dec >= 80: @@ -1046,7 +1038,6 @@ def generate_coverage_index(programs, outdir): else: bar_cls = ' red' - # 状态徽标 if tb == 0 or (cb == tb and not (ib > cb)): badge = '✓ 完全' elif cb == tb and ib > cb: @@ -1056,7 +1047,6 @@ def generate_coverage_index(programs, outdir): else: badge = '✗ 欠缺' - # 条件覆盖数字颜色 if tc: cond_pct = cc / tc * 100 cond_color = 'num-green' if cond_pct == 100 else ('num-amber' if cond_pct >= 80 else 'num-red') @@ -1107,7 +1097,6 @@ def generate_coverage_index(programs, outdir): # ── PROCEDURE DIVISION 行范围定位(用于无分支程序标记)── def _find_proc_range(raw_source: str): - """返回 PROCEDURE DIVISION 的行范围 (start_line, end_line) 1-indexed,或 None。""" lines = raw_source.splitlines() proc_start = None for i, line in enumerate(lines): @@ -1116,26 +1105,36 @@ def _find_proc_range(raw_source: str): break if proc_start is None: return None - # 找下一个 DIVISION 作为结束边界(或文件尾) for i in range(proc_start, len(lines)): if re.search(r'(IDENTIFICATION|DATA|ENVIRONMENT)\s+DIVISION', lines[i].upper()): - return (proc_start, i) # 不包含下一个 DIVISION + return (proc_start, i) return (proc_start, len(lines) + 1) # ── 接入入口 ── def run_coverage(branch_tree, branch_paths_with_assigns, fields, - raw_source, output_prefix, index_relpath=None): - """完整覆盖率流程:收集 → 标记 → 定位 → 输出。 - - Returns: - dict: 汇总数据,用于总括页聚合 - """ + raw_source, output_prefix, index_relpath=None, + gcov_data=None): decision_points, leaf_stats = collect_decision_points(branch_tree, fields) mark_coverage(decision_points, leaf_stats, branch_paths_with_assigns, fields) + if gcov_data: + mark_from_gcov(decision_points, gcov_data, branch_tree) + for ls in leaf_stats: + ls.covered_true = False + ls.covered_false = False + + _source_note = '' + if gcov_data: + _source_note = ( + '
' + '覆盖率基于 gcov 运行时数据' + '
' + ) + if raw_source: locate_decision_lines(decision_points, raw_source) @@ -1146,7 +1145,6 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields, sum(1 for l in leaf_stats if l.covered_false)) leaf_total = len(leaf_stats) * 2 - # 无决策点但有路径 → PROCEDURE DIVISION 全部覆盖 covered_lines = set() if total == 0 and branch_paths_with_assigns and raw_source: proc_range = _find_proc_range(raw_source) @@ -1161,9 +1159,9 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields, f"{output_prefix}_coverage.html", Path(output_prefix).stem, index_relpath=index_relpath, - covered_lines=covered_lines) + covered_lines=covered_lines, + source_note=_source_note) - # 控制台摘要 if total or leaf_total: logger.info(f"\n=== 分支覆盖率 ===") if covered_lines and not decision_points: @@ -1194,7 +1192,7 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields, implicit_100 = bool(covered_lines) return { 'name': Path(output_prefix).stem if output_prefix else '', - 'detail_relpath': ('../' + Path(output_prefix).stem + '_coverage.html' + 'detail_relpath': (Path(output_prefix).stem + '_coverage.html' if output_prefix else ''), 'total_branches': total, 'covered_branches': covered, @@ -1208,15 +1206,6 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields, def check_coverage(structure: dict, test_records: list[dict]) -> dict: - """报告 COBOL 源码的静态分支结构信息。 - - 注意: 静态分析无法精确判断每条测试数据运行时覆盖了哪些分支。 - 精确的路径追踪依赖 gcov(Phase 3)。此处仅报告总分支数和记录生成情况。 - - Returns: - dict with: paragraph_rate, branch_rate, decision_rate, total_branches, - total_paragraphs, records_count, note - """ total_paragraphs = structure.get("total_paragraphs", 0) total_branches = structure.get("total_branches", 0) decision_points = structure.get("decision_points", []) diff --git a/cobol_testgen/design.py b/cobol_testgen/design.py index a346b36..b8ccb3a 100644 --- a/cobol_testgen/design.py +++ b/cobol_testgen/design.py @@ -8,12 +8,52 @@ from .core import trace_to_root, invert_through_chain, propagate_assignments, _b logger = logging.getLogger(__name__) -_STOP = ('__STOP__', '', None, True) -_MAX_PATHS = 500 +_STOP_EXIT_PERFORM = ('__STOP_EXIT_PERFORM__', '', None, True) +_STOP_SENTINEL = ('__STOP__', '', None, True) +_ABEND_SENTINEL = ('__ABEND__', '', None, True) +_SENTINELS_ALL = {_STOP_EXIT_PERFORM, _STOP_SENTINEL, _ABEND_SENTINEL} +_ABEND_PROGRAMS = {'ABENDPGM'} + +def extend_abend_programs(names: list[str]): + _ABEND_PROGRAMS.update(n.upper() for n in names) +_MAX_PATHS = 10000 + + +def _is_sentinel(c): + return c is _STOP_EXIT_PERFORM or c is _STOP_SENTINEL or c is _ABEND_SENTINEL + + +def _hashable_cons(cons): + """将约束列表转为可哈希形式(列表值转tuple)用于签名去重。""" + result = [] + for c in cons: + if len(c) == 4: + field, op, val, want = c + if isinstance(val, list): + val = tuple(val) + result.append((field, op, val, want)) + else: + result.append(c) + return result def _filter_stop(cons): - return [c for c in cons if c is not _STOP] + """Legacy: strip all sentinel markers. 供旧测试代码使用。""" + return [c for c in cons if not _is_sentinel(c)] + + +def get_term_type(cons): + """提取终止类型,返回 (filtered_cons, term_type).""" + remaining = [] + term = 'normal' + for c in cons: + if c is _ABEND_SENTINEL: + term = 'abend' + elif _is_sentinel(c): + pass + else: + remaining.append(c) + return remaining, term def _cap_paths(paths): @@ -29,11 +69,11 @@ def _cap_paths_fair(new_active, child_paths): k = len(child_paths) if k <= 1: return new_active[:_MAX_PATHS] - # 分离 STOP 路径(不参与组合,直接保留) - stop_paths = [(p, a) for p, a in new_active if any(c is _STOP for c in p)] - combined = [(p, a) for p, a in new_active if not any(c is _STOP for c in p)] + # 分离 sentinel 路径(不参与组合,直接保留) + stop_paths = [(p, a) for p, a in new_active if any(_is_sentinel(c) for c in p)] + combined = [(p, a) for p, a in new_active if not any(_is_sentinel(c) for c in p)] n_pred = len(combined) // k - result = list(stop_paths) + result = [] if n_pred <= 1: result.extend(combined[:_MAX_PATHS - len(result)]) return result[:_MAX_PATHS] @@ -75,24 +115,29 @@ def enum_paths(node, fields): for child in node.children: child_paths = _cap_paths(enum_paths(child, fields)) if not child_paths: - break + continue new_active = [] + covered_sigs = set() for p_cons, p_assign in paths: - if any(c is _STOP for c in p_cons): + if any(_is_sentinel(c) for c in p_cons): new_active.append((p_cons, p_assign)) continue for cp_cons, cp_assign in child_paths: - merged = {} - for d in (p_assign, cp_assign): - for k, v in d.items(): - merged.setdefault(k, []).extend(v if isinstance(v, list) else [v]) merged_cons = p_cons + list(cp_cons) - new_active.append((merged_cons, merged)) - if len(new_active) >= _MAX_PATHS: + sig = frozenset(_hashable_cons(merged_cons)) + if sig not in covered_sigs: + covered_sigs.add(sig) + merged = {} + for d in (p_assign, cp_assign): + for k, v in d.items(): + merged.setdefault(k, []).extend(v if isinstance(v, list) else [v]) + new_active.append((merged_cons, merged)) + if not new_active: + for pc, pa in paths: + if not any(_is_sentinel(c) for c in pc): + new_active.append((pc, dict(pa))) break - if len(new_active) >= _MAX_PATHS: - break - paths = _cap_paths_fair(new_active, child_paths) + paths = new_active return paths elif isinstance(node, BrIf): @@ -186,6 +231,14 @@ def enum_paths(node, fields): constraints.append((cond.field, cond.op, cond.value, True)) paths.append((constraints + sp_cons, sp_assign)) prior_false_sets.append([(cond.field, cond.op, cond.value, False)]) + elif cond and isinstance(cond, CondNot) and isinstance(cond.child, CondLeaf) and is_field(cond.child.field, fields): + leaf = cond.child + sub = _cap_paths(enum_paths(seq, fields)) + for sp_cons, sp_assign in (sub or [([], {})]): + constraints = [c for pf in prior_false_sets for c in pf] + constraints.append((leaf.field, leaf.op, leaf.value, False)) + paths.append((constraints + sp_cons, sp_assign)) + prior_false_sets.append([(leaf.field, leaf.op, leaf.value, True)]) elif cond: leaves = collect_leaves(cond) if leaves and all(is_field(l.field, fields) for l in leaves): @@ -232,13 +285,36 @@ def enum_paths(node, fields): paths = [] for value, seq in node.when_list: sub = _cap_paths(enum_paths(seq, fields)) - for sp_cons, sp_assign in (sub or [([], {})]): - paths.append(([(node.subject, '=', value, True)] + sp_cons, sp_assign)) + thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(value), re.IGNORECASE) + if thru_m and not node.subjects: + low, high = thru_m.group(1), thru_m.group(2) + for sp_cons, sp_assign in (sub or [([], {})]): + paths.append(([(node.subject, '>=', low, True), (node.subject, '<=', high, True)] + sp_cons, sp_assign)) + paths.append(([(node.subject, '<=', high, True), (node.subject, '>=', low, True)] + sp_cons, sp_assign)) + else: + for sp_cons, sp_assign in (sub or [([], {})]): + paths.append(([(node.subject, '=', value, True)] + sp_cons, sp_assign)) if node.has_other: - case_vals = [v for v, _ in node.when_list] sub = _cap_paths(enum_paths(node.other_seq, fields)) - for sp_cons, sp_assign in (sub or [([], {})]): - paths.append(([(node.subject, 'not_in', case_vals, True)] + sp_cons, sp_assign)) + thru_found = False + for v, _ in node.when_list: + thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(v), re.IGNORECASE) + if thru_m and not node.subjects: + thru_found = True + low_int, high_int = int(thru_m.group(1)), int(thru_m.group(2)) + for sp_cons, sp_assign in (sub or [([], {})]): + a_low = dict(sp_assign) + a_low[node.subject] = [{'type': 'move_literal', 'literal': str(max(0, low_int - 1))}] + low_cons = [(node.subject, 'not_in', [thru_m.group(1), thru_m.group(2)], True)] + paths.append((low_cons + sp_cons, a_low)) + a_high = dict(sp_assign) + a_high[node.subject] = [{'type': 'move_literal', 'literal': str(high_int + 1)}] + high_cons = [(node.subject, 'not_in', [thru_m.group(1), thru_m.group(2)], True)] + paths.append((high_cons + sp_cons, a_high)) + if not thru_found: + case_vals = [v for v, _ in node.when_list] + for sp_cons, sp_assign in (sub or [([], {})]): + paths.append(([(node.subject, 'not_in', case_vals, True)] + sp_cons, sp_assign)) return paths elif isinstance(node, BrSearch): @@ -247,7 +323,10 @@ def enum_paths(node, fields): elif isinstance(node, BrPerform): if node.perf_type in ('para', 'thru'): if node.body_seq: - return enum_paths(node.body_seq, fields) + paths = enum_paths(node.body_seq, fields) + # EXIT PERFORM 只在 PERFORM 体内有效,剥离后不影响后续 BrSeq 组合 + paths = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in paths] + return paths return [([], {})] elif node.perf_type in ('until', 'para_until', 'varying', 'para_varying'): # 尝试单条件(现有逻辑) @@ -256,7 +335,9 @@ def enum_paths(node, fields): field, op, val = parsed paths = [] false_sub = _cap_paths(enum_paths(node.body_seq, fields)) + false_sub = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in false_sub] for sp_cons, sp_assign in (false_sub or [([], {})]): + body_assign = dict(sp_assign) # PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径 if node.varying_from and node.varying_var: is_fld = any(f['name'] == node.varying_from for f in fields) if fields else False @@ -268,6 +349,40 @@ def enum_paths(node, fields): merged.setdefault(k, []).extend(v if isinstance(v, list) else [v]) sp_assign = merged paths.append(([(field, op, val, False)] + sp_cons, sp_assign)) + # PERFORM VARYING: 末次迭代路径(下标=MAX) + if node.varying_from and node.varying_var and op in ('>', '>=', '<', '<=', '='): + try: + if op == '>': + max_val = int(val) + elif op == '>=': + max_val = int(val) - 1 + elif op == '<': + max_val = int(val) + elif op == '<=': + max_val = int(val) + 1 + elif op == '=': + by_str = str(node.varying_by or '1') + if by_str.lstrip('-').isdigit() and int(by_str) < 0: + max_val = int(val) + 1 + else: + max_val = int(val) - 1 + from_val = int(node.varying_from) + by_str = str(node.varying_by or '1') + if by_str.lstrip('-').isdigit() and int(by_str) < 0: + ok = max_val <= from_val + else: + ok = max_val >= from_val + if ok: + max_asgn = {'type': 'move_literal', 'literal': str(max_val)} + max_assign = {node.varying_var: [max_asgn]} + merged_max = {} + for d in (max_assign, body_assign): + for k, v in d.items(): + merged_max.setdefault(k, []).extend(v if isinstance(v, list) else [v]) + the_cons = [(field, op, val, False)] + paths.append((the_cons + sp_cons, merged_max)) + except (ValueError, TypeError): + pass paths.append(([(field, op, val, True)], {})) return paths # 尝试复合条件(AND/OR) @@ -279,6 +394,7 @@ def enum_paths(node, fields): if sets: paths = [] false_sub = _cap_paths(enum_paths(node.body_seq, fields)) + false_sub = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in false_sub] for sp_cons, sp_assign in (false_sub or [([], {})]): # PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径 if node.varying_from and node.varying_var: @@ -301,14 +417,18 @@ def enum_paths(node, fields): return [([], {})] elif isinstance(node, CallNode): + if node.program_name in _ABEND_PROGRAMS: + return [([_ABEND_SENTINEL], {})] return [([], {})] elif isinstance(node, ExitNode): - return [([_STOP], {})] + if node.exit_type == 'PERFORM': + return [([_STOP_EXIT_PERFORM], {})] + return [([_STOP_SENTINEL], {})] elif isinstance(node, GoTo): paths = enum_paths(node.body_seq, fields) - return [([_STOP] + c, a) for c, a in paths] + return [([_STOP_SENTINEL] + c, a) for c, a in paths] return [([], {})] @@ -335,7 +455,7 @@ def seq_date(seq_num: int) -> str: def _is_date_field(name: str) -> bool: - patterns = [r'DATE', r'YYMMDD', r'YYYYMM', r'YEAR', r'MONTH', r'DAY'] + patterns = [r'DATE', r'YYMMDD', r'YYYYMM'] for p in patterns: if re.search(p, name.upper()): return True @@ -401,13 +521,12 @@ def _children_of(group_name: str, fields: list) -> list: def _make_numeric_value(idx: int, record_num: int, total_digits: int) -> str: - max_val = 10 ** total_digits - 1 + max_val = 10 ** total_digits for step in (100, 10, 1): val = idx * step + record_num - if val < 10 ** total_digits: - return str(min(val, max_val)).zfill(total_digits) - return str(min(record_num, max_val)).zfill(total_digits) - return str(record_num).zfill(total_digits) + if val < max_val: + return str(val).zfill(total_digits) + return str(record_num % max_val).zfill(total_digits) def _make_alpha_value(idx: int, record_num: int, length: int) -> str: @@ -548,6 +667,16 @@ def _check_constraint_satisfied(rec, field_name, operator, value, want_true, fie return eq == want_true elif operator == '<>': return (not eq) == want_true + elif operator in ('>', '<', '>=', '<='): + if operator == '>': + ok = s_val > s_target + elif operator == '<': + ok = s_val < s_target + elif operator == '>=': + ok = s_val >= s_target + elif operator == '<=': + ok = s_val <= s_target + return ok == want_true return True return False @@ -625,6 +754,95 @@ def _apply_arith_constraint(rec, field_name, operator, value, want_true, fields) rec[right_field] = pick +def _inc_str(s, length): + s = str(s).strip() + try: + r = str(int(s) + 1).zfill(length) + return r if len(r) <= length else '9' * length + except ValueError: + c = list(str(s).ljust(length)[:length]) + for i in range(len(c) - 1, -1, -1): + if c[i] not in ' 9Zz\xff': + c[i] = chr(ord(c[i]) + 1) + break + if c[i] == ' ': + c[i] = '0' + break + if c[i] == '9': + c[i] = '0' + elif c[i] == 'Z': + c[i] = 'A' + elif c[i] == 'z': + c[i] = 'a' + return ''.join(c) + + +def _dec_str(s, length): + s = str(s).strip() + try: + n = max(0, int(s) - 1) + return str(n).zfill(length) + except ValueError: + c = list(str(s).ljust(length)[:length]) + for i in range(len(c) - 1, -1, -1): + if c[i] not in ' 0Aa\x00': + c[i] = chr(ord(c[i]) - 1) + break + if c[i] == ' ': + break + if c[i] == '0': + c[i] = '9' + elif c[i] == 'A': + c[i] = ' ' + elif c[i] == 'a': + c[i] = ' ' + return ''.join(c) + + +def _reconcile_unstring_fields(rec, left_field, operator, right_field, want_true, + fields, left_chain, assignments, path_assign): + right_root, right_chain = trace_to_root(right_field, assignments, fields, path_assign) + if right_root not in rec: + logger.debug(f"字段间比较协调:右侧根 {right_root} 不在 rec,跳过") + return + all_entries = (left_chain or []) + (right_chain or []) + for _, asgn in all_entries: + if asgn.get('type') not in ('move', 'unstring_split'): + logger.debug(f"字段间比较协调:链含非 MOVE 类型 {asgn.get('type')},跳过") + return + left_val = str(rec.get(left_field, '')) + if not left_val.strip(): + logger.debug(f"字段间比较协调:左侧 {left_field} 无值,跳过") + return + length = 0 + for f in fields: + if f['name'] == right_root: + length = f.get('pic_info', {}).get('length', 0) + break + if length == 0: + length = len(left_val) + + if operator in ('>=', '<='): + if want_true: + right_val = left_val + else: + right_val = _inc_str(left_val, length) if operator == '>=' else _dec_str(left_val, length) + elif operator in ('>', '<'): + if want_true: + right_val = _dec_str(left_val, length) if operator == '>' else _inc_str(left_val, length) + else: + right_val = left_val + elif operator == '=': + right_val = left_val if want_true else _inc_str(left_val, length) + elif operator == '<>': + right_val = _inc_str(left_val, length) if want_true else left_val + else: + return + + rec[right_root] = right_val[:length] if right_val else right_val + logger.debug(f"字段间比较协调:{left_field}={left_val} {operator} {right_field} -> {right_root}={rec[right_root]} (want={want_true})") + + def apply_constraint(rec, field_name, operator, value, want_true, fields, assignments=None, path_assign=None): # 标准化字段名:去除括号内空格(WS-CELL ( 1, 1 ) → WS-CELL(1,1)) field_name = re.sub(r'\s*([(),])\s*', r'\1', field_name) @@ -659,6 +877,7 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign apply_constraint(rec, parent_name, operator, value, want_true, fields, assignments, path_assign) return break + chain = None if assignments: root_var, chain = trace_to_root(field_name, assignments, fields, path_assign) if root_var != field_name: @@ -666,8 +885,41 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign if any(f['name'] == new_field_name for f in fields): field_name, operator, value = new_field_name, new_op, new_val + # 字段间比较:在 satisfied check 前解析/处理 + if any(f['name'] == value for f in fields): + resolved_literal = None + for f in fields: + if f['name'] == value and f.get('value') is not None: + resolved_literal = str(f['value']).strip("'").strip('"') + break + if resolved_literal is not None: + value = resolved_literal + elif chain is not None and assignments: + _reconcile_unstring_fields(rec, field_name, operator, value, want_true, + fields, chain, assignments, path_assign) + return + elif re.search(r'[+\-*/]', field_name): + _apply_arith_constraint(rec, field_name, operator, value, want_true, fields) + return + else: + logger.debug(f"字段间比较约束跳过:{field_name} {operator} {value}") + return + # 如果当前值已满足该约束,跳过覆盖(保持先前约束的一致性) + # 但零值时强制使用边界值(非 0/非 min) if _check_constraint_satisfied(rec, field_name, operator, value, want_true, fields): + cur = str(rec.get(field_name, '')).strip('0') + if (cur == '' or cur == '.') and ( + (operator in ('>', '>=') and not want_true) or + (operator in ('<', '<=') and want_true) + ): + for f in fields: + if f['name'] == field_name: + pi = f.get('pic_info', {}) + if pi.get('type') == 'numeric': + val = satisfying_value(pi, operator, value, want_true) + rec[field_name] = val + return return if operator == 'not_in': @@ -687,13 +939,6 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign rec[field_name] = str(n).zfill(pi.get('digits', 0) + pi.get('decimal', 0)) return return - # 字段间比较(值侧也是字段名) - if any(f['name'] == value for f in fields): - if re.search(r'[+\-*/]', field_name): - _apply_arith_constraint(rec, field_name, operator, value, want_true, fields) - else: - logger.debug(f"字段间比较约束跳过:{field_name} {operator} {value}") - return for f in fields: if f['name'] == field_name: pi = f.get('pic_info', {}) @@ -738,6 +983,31 @@ def sync_redefined_fields(rec, fields): def apply_occurs_depending(rec, fields): """根据 OCCURS DEPENDING ON 变量的当前值,清零超范围的下标字段。""" + # Phase 1: 将零值的 DEPENDING ON 变量设为最大下标 + dep_max = {} + for f in fields: + dep_var = f.get('occurs_depending') + if not dep_var: + continue + m = re.search(r'\((\d+)\)$', f['name']) + if m: + sub = int(m.group(1)) + if sub > dep_max.get(dep_var, 0): + dep_max[dep_var] = sub + for dep_var, max_sub in dep_max.items(): + try: + cur_val = int(float(str(rec.get(dep_var, '0')))) + except (ValueError, TypeError): + cur_val = 0 + if cur_val == 0: + for f in fields: + if f['name'] == dep_var: + pi = f.get('pic_info', {}) + digits = pi.get('digits', 0) + pi.get('decimal', 0) + if digits > 0: + rec[dep_var] = str(max_sub).zfill(digits) + break + # Phase 2: 清零超范围的下标字段 for f in fields: dep_var = f.get('occurs_depending') if not dep_var: @@ -805,7 +1075,10 @@ def _enum_search_paths(node, fields): base = re.sub(r'\s*\(.*?\)\s*$', '', cond_tree.field) matching_val = cond_tree.value elem_key = f'{base}({i + 1})' - extra_assign[elem_key] = [{'type': 'move_literal', 'literal': matching_val}] + if any(f['name'] == matching_val for f in fields): + extra_assign[elem_key] = [{'type': 'move', 'source_vars': [matching_val]}] + else: + extra_assign[elem_key] = [{'type': 'move_literal', 'literal': matching_val}] non_match = _non_match_for(cond_tree, fields) or ' ' for j in range(i): prev_key = f'{base}({j + 1})' @@ -815,7 +1088,10 @@ def _enum_search_paths(node, fields): merged_assign = dict(extra_assign) for k, v in sp_assign.items(): merged_assign.setdefault(k, []).extend(v if isinstance(v, list) else [v]) - paths.append((sp_cons, merged_assign)) + if cond_tree and isinstance(cond_tree, CondLeaf): + paths.append(([(elem_key, cond_tree.op, matching_val, True)] + sp_cons, merged_assign)) + else: + paths.append((sp_cons, merged_assign)) if node.has_at_end: sub = _cap_paths(enum_paths(node.at_end_seq, fields)) @@ -837,16 +1113,20 @@ def _enum_search_paths(node, fields): return paths -def generate_records(branch_paths_with_assigns, data_fields, base_assignments=None, file_sec=None): +def generate_records(path_infos, data_fields, base_assignments=None, file_sec=None): """生成测试数据记录。 - branch_paths_with_assigns: list of (constraints, path_assignments). + path_infos: list of (constraints, path_assignments) 或 (constraints, path_assignments, term_type). base_assignments: 全局 assignments dict (用于 trace_to_root). - 返回: (records, kept_path_cons) — kept_path_cons 是与 records 一一对应的约束。 + 返回: (records, kept_path_cons, term_types). """ + # 自动兼容旧 2-tuple 格式 + if path_infos and len(path_infos[0]) == 2: + path_infos = [(c, a, 'normal') for c, a in path_infos] records = [] kept_path_cons = [] - if branch_paths_with_assigns: - for seq, (path_cons, path_assign) in enumerate(branch_paths_with_assigns, start=1): + term_types = [] + if path_infos: + for seq, (path_cons, path_assign, term_type) in enumerate(path_infos, start=1): path_cons = _filter_stop(path_cons) rec = make_base_record(seq, data_fields) # Pass A: 先传播赋值(MOVE/COMPUTE/READ INTO 等),模拟到决策点前的程序状态 @@ -869,6 +1149,26 @@ def generate_records(branch_paths_with_assigns, data_fields, base_assignments=No if not _check_constraint_satisfied(rec, root_var, new_op, new_val, want, data_fields): skip_impossible = True break + elif field in rec: + asgn_val = path_assign.get(field) + if asgn_val is not None: + asgn_list = asgn_val if isinstance(asgn_val, list) else [asgn_val] + if asgn_list and asgn_list[-1]['type'] == 'move_literal': + cur_val = str(rec.get(field, '')) + if cur_val != '': + pi = next((f.get('pic_info', {}) for f in data_fields if f['name'] == field), {}) + if pi.get('type') == 'numeric': + try: + nv = int(float(cur_val)) + tv = int(float(str(val))) + ops = {'>': lambda a,b: a > b, '<': lambda a,b: a < b, '=': lambda a,b: a == b, '<>': lambda a,b: a != b, '>=': lambda a,b: a >= b, '<=': lambda a,b: a <= b} + if op in ops: + satisfied = ops[op](nv, tv) == want + if not satisfied: + skip_impossible = True + break + except (ValueError, TypeError): + pass if skip_impossible: continue # Pass B: 约束覆盖(确保决策条件满足,覆盖 MOVE 带来的值) @@ -886,17 +1186,121 @@ def generate_records(branch_paths_with_assigns, data_fields, base_assignments=No forward[tgt] = filtered if forward: propagate_assignments(rec, forward, data_fields, file_sec=file_sec) + # Pass B.75: COMPUTE 重算(约束修改了 COMPUTE 源字段的值) + if isinstance(path_assign, dict): + compute_only = {} + for tgt, asgn_val in path_assign.items(): + asgn_list = asgn_val if isinstance(asgn_val, list) else [asgn_val] + filtered = [a for a in asgn_list if a['type'] == 'compute'] + if filtered: + compute_only[tgt] = filtered + if compute_only: + propagate_assignments(rec, compute_only, data_fields, file_sec=file_sec) + # Pass B.8: UNSTRING source reconstruction (targets → source) + if base_assignments: + _reconstruct_unstring_sources(rec, base_assignments, data_fields) # Pass C: 同步 REDEFINES(确保共享存储一致) sync_redefined_fields(rec, data_fields) # Pass D: OCCURS DEPENDING ON — 清零超范围的下标字段 apply_occurs_depending(rec, data_fields) + # Pass E: PIC 长度约束 — 模拟 COBOL 截断语义 + for f in data_fields: + name = f['name'] + if name in rec and not f.get('is_88') and not f.get('is_filler'): + pi = f.get('pic_info', {}) + ftype = pi.get('type', 'unknown') + val = str(rec[name]) + if ftype == 'numeric': + total = pi.get('digits', 0) + pi.get('decimal', 0) + if total > 0 and len(val) > total: + rec[name] = val[-total:].zfill(total) + elif ftype in ('alphanumeric', 'alphabetic'): + length = pi.get('length', 0) + if length > 0 and len(val) > length: + rec[name] = val[:length] + records.append(rec) kept_path_cons.append(path_cons) + term_types.append(term_type) + # Track which fields were explicitly assigned in this path + if isinstance(path_assign, dict): + rec['_assigned_fields'] = set(path_assign.keys()) + else: + rec['_assigned_fields'] = set() if not records: rec = make_base_record(1, data_fields) if base_assignments: propagate_assignments(rec, base_assignments, data_fields, file_sec=file_sec) + if base_assignments: + _reconstruct_unstring_sources(rec, base_assignments, data_fields) + rec['_assigned_fields'] = set() records.append(rec) kept_path_cons.append([]) - return records, kept_path_cons + term_types.append('normal') + return records, kept_path_cons, term_types + + +def _reconstruct_unstring_sources(rec, base_assignments, data_fields): + """Build UNSTRING source field value from comma-separated target values. + After constraints determine target field values, construct the source + string so the COBOL UNSTRING can correctly parse it. + """ + groups = {} + for tgt, asgn_list in base_assignments.items(): + for asgn in asgn_list: + if asgn.get('type') == 'unstring_split' and asgn.get('source_vars'): + src = asgn['source_vars'][0] + idx = asgn.get('index', 0) + groups.setdefault(src, []).append((idx, tgt)) + + for src_var, targets in groups.items(): + targets.sort(key=lambda x: x[0]) + # Resolve group→child name if source not directly in rec + resolved_src = src_var + if resolved_src not in rec: + grp_level = None + found = False + for f in data_fields: + if not found and f['name'] == resolved_src: + grp_level = f.get('level', 0) + found = True + continue + if found: + if f.get('level', 0) <= grp_level or f.get('level') == 77: + break + if f.get('pic'): + resolved_src = f['name'] + break + if resolved_src not in rec: + continue + csv_parts = [] + for idx, tgt in targets: + val = rec.get(tgt, '') + csv_parts.append(val if val is not None else '') + csv_value = ','.join(csv_parts) + src_len = 0 + for f in data_fields: + if f['name'] == resolved_src: + pi = f.get('pic_info', {}) + if pi: + src_len = pi.get('length', 0) + break + if src_len > 0: + csv_value = csv_value.ljust(src_len)[:src_len] + rec[resolved_src] = csv_value + # Also sync to child fields (group→elementary) for FD output consistency + if resolved_src == src_var: + grp_level = None + found = False + for f in data_fields: + if not found and f['name'] == resolved_src: + grp_level = f.get('level', 0) + found = True + continue + if found: + if f.get('level', 0) <= grp_level or f.get('level') == 77: + break + if f.get('pic'): + rec[f['name']] = csv_value + break diff --git a/cobol_testgen/gcov.py b/cobol_testgen/gcov.py new file mode 100644 index 0000000..41572bf --- /dev/null +++ b/cobol_testgen/gcov.py @@ -0,0 +1,119 @@ +"""gcov 覆盖率数据解析和分支标记""" + +import re +import logging +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def parse_cbl_gcov(gcov_path: str) -> dict[int, int]: + """解析 .cbl.gcov 文件,返回 {COBOL行号: 执行次数}。 + + gcov 行格式: + #####: 6: 源码行 → 未执行(0 次) + 75*: 12: 源码行 → 执行 75 次 + 1*: 14: 源码行 → 执行 1 次 + -: 17: 源码行 → 不可执行(注释/声明行,跳过) + """ + counts = {} + with open(gcov_path, encoding='utf-8') as f: + for line in f: + m = re.match(r'^\s*(#####|\d+\*?|-):\s*(\d+):', line) + if not m: + continue + count_str = m.group(1) + lineno = int(m.group(2)) + if count_str == '#####': + counts[lineno] = 0 + elif count_str == '-': + continue + else: + counts[lineno] = int(count_str.rstrip('*')) + return counts + + +def run_gcov(program_name: str, work_dir: str) -> dict[int, int]: + """在 work_dir 中通过 WSL 执行 gcov 并解析 COBOL 行计数。 + + Args: + program_name: 程序名(不含扩展名),如 "ALLCMDS" + work_dir: 包含 .gcda/.gcno 的目录(Windows 路径) + + Returns: + {COBOL行号: 执行次数} 字典。失败时返回空 dict。 + """ + wsl_work = _wsl_path(work_dir) + cmd = ['wsl', 'sh', '-c', f'cd {wsl_work} && gcov {program_name}.c'] + result = subprocess.run( + cmd, + capture_output=True, text=True, + encoding='utf-8', errors='replace', + timeout=30, + ) + if result.returncode != 0: + logger.warning(f"gcov 失败 (exit={result.returncode}): {result.stderr.strip()}") + return {} + + cbl_gcov = Path(work_dir) / f'{program_name}.cbl.gcov' + if not cbl_gcov.exists(): + logger.warning(f"gcov 输出不存在: {cbl_gcov}") + return {} + + gcov_data = parse_cbl_gcov(str(cbl_gcov)) + logger.info(f"gcov 解析: {len(gcov_data)} 行, " + f"{sum(1 for v in gcov_data.values() if v > 0)} 行已执行") + return gcov_data + + +def _wsl_path(windows_path: str) -> str: + path = Path(windows_path).resolve() + drive = path.drive.lower().rstrip(':') + rest = str(path.relative_to(path.anchor)).replace('\\', '/') + return f'/mnt/{drive}/{rest}' + + +def mark_from_gcov(decision_points: list, gcov_data: dict[int, int], + branch_tree) -> None: + """用 gcov 行执行计数推断决策点分支覆盖,直接修改 decision_points 的 active_branches。 + + 推断规则(简化版,先覆盖主要场景): + + IF (条件行 L): + - 条件行 L 在 gcov 中 count == 0 → 不可到达,不标记 + - 条件行 L 在 gcov 中 count > 0 → 标记 T 和 F 都覆盖 + + EVALUATE: + - subject 行 count > 0 → 标记所有 WHEN 为已覆盖 + + PERFORM UNTIL (条件行 L): + - count == 1 → 条件初始即为真,循环体未进入 → Skip 覆盖 + - count > 1 → 循环体至少进入一次 → Enter 覆盖 + - Skip 总视为覆盖(无论进入与否,最终都会跳出) + """ + for dp in decision_points: + ln = dp.source_line + if ln <= 0 or ln not in gcov_data: + continue + + count = gcov_data.get(ln) + if count is None: + continue + + if dp.kind == 'IF': + if count == 0: + continue + dp.active_branches.add('T') + dp.active_branches.add('F') + + elif dp.kind == 'EVALUATE': + if count == 0: + continue + for bn in dp.branch_names: + dp.active_branches.add(bn) + + elif dp.kind == 'PERFORM': + if count > 1: + dp.active_branches.add('Enter') + dp.active_branches.add('Skip') diff --git a/cobol_testgen/grammar.lark b/cobol_testgen/grammar.lark index 2931ad4..6fe4d5e 100644 --- a/cobol_testgen/grammar.lark +++ b/cobol_testgen/grammar.lark @@ -13,7 +13,7 @@ clause: pic_clause | value_clause | occurs_clause | redefines_clause | usage_cla | "JUSTIFIED" "RIGHT"? | "BLANK" "WHEN" "ZERO" | "GLOBAL" | "EXTERNAL" -pic_clause: "PIC" "IS"? PICTURE_STRING +pic_clause: "PIC" "IS"? PICTURE_STRING ("." PICTURE_STRING)* value_clause: "VALUE" "IS"? value_literal+ value_literal: INT | SIGNED_NUMBER | STRING | SQSTRING | "ZERO" | "ZEROS" | "ZEROES" diff --git a/cobol_testgen/output.py b/cobol_testgen/output.py index ef8a5aa..64e1e11 100644 --- a/cobol_testgen/output.py +++ b/cobol_testgen/output.py @@ -23,27 +23,68 @@ def _scenario_text(path_cons): return ', '.join(parts) +def _write_json(entries, outpath): + if not entries: + return + outpath.parent.mkdir(parents=True, exist_ok=True) + with open(outpath, 'w', encoding='utf-8') as f: + json.dump(entries, f, ensure_ascii=False, indent=2) + + +def _is_field_assigned(fname, assigned_set, fields, fd_fields_lookup): + if not assigned_set: + return False + if fname in assigned_set: + return True + level_map = {} + name_order = [] + for f in fields: + fn = f['name'] + lv = f.get('level', 77) + level_map[fn] = lv + name_order.append((lv, fn)) + flv = level_map.get(fname, 77) + ancestor = None + for lv, fn in name_order: + if fn == fname: + break + if lv < flv: + ancestor = fn + if ancestor and ancestor in assigned_set: + return True + return False + + def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None, - open_dir=None, path_cons_list=None): + open_dir=None, term_types=None, db_input=None, data_fields=None): outpath.parent.mkdir(parents=True, exist_ok=True) if not roles: + out = [] + for i, rec in enumerate(records): + entry = dict(rec) + entry['termination'] = (term_types or ['normal'] * len(records))[i] + out.append(entry) + obj = {'program': outpath.stem, 'records': out} + if db_input: + obj['db_input'] = db_input with open(outpath, 'w', encoding='utf-8') as f: - json.dump(records, f, ensure_ascii=False, indent=2) + json.dump(obj, f, ensure_ascii=False, indent=2) return - # FD direction lookup + term_types = term_types or ['normal'] * len(records) + out = [] for i, rec in enumerate(records): inp = {} out_exp = {} ws = {} - # Group by FD if fd_fields and field_to_fd: for fd_name, fds_set in fd_fields.items(): direction = (open_dir or {}).get(fd_name, '') inp_block = {} out_block = {} + assigned_set = rec.get('_assigned_fields', set()) for fname in fds_set: if fname not in rec: continue @@ -52,13 +93,13 @@ def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None, if direction in ('INPUT', 'I-O') and r in ('input', 'inout'): inp_block[fname] = val if direction in ('OUTPUT', 'I-O') and r in ('output', 'inout'): - out_block[fname] = val + if _is_field_assigned(fname, assigned_set, data_fields or [], fd_fields): + out_block[fname] = val if inp_block: inp[fd_name] = inp_block if out_block: out_exp[fd_name] = out_block - # Working-storage: not belonging to any FD for name, val in rec.items(): if not field_to_fd or name not in field_to_fd: ws[name] = val @@ -66,25 +107,21 @@ def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None, entry = { 'input': inp, 'expected_output': out_exp, - 'working_storage': ws, + 'working_storage': {k: v for k, v in ws.items() if k != '_assigned_fields'}, + 'termination': term_types[i] if i < len(term_types) else 'normal', } - if path_cons_list and i < len(path_cons_list): - text = _scenario_text(path_cons_list[i]) - if text: - entry['scenario'] = text - out.append(entry) - with open(outpath, 'w', encoding='utf-8') as f: - json.dump(out, f, ensure_ascii=False, indent=2) + obj = {'program': outpath.stem, 'records': out} + if db_input: + obj['db_input'] = db_input + _write_json(obj, outpath) -def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, open_dir): - """按 FD 名拆分出力入力 JSON 文件。 - 每个 INPUT / I-O 方向 FD 生成一个文件:{stem}_{fd_name}.json - 内容为路径数 × 记录,每条只含该 FD 的入力字段值。 - """ +def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, open_dir, + term_types=None): + term_types = term_types or ['normal'] * len(records) input_fds = {} for fd_name, fds_set in fd_fields.items(): direction = (open_dir or {}).get(fd_name, '') @@ -101,9 +138,11 @@ def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, ope outdir.mkdir(parents=True, exist_ok=True) for fd_name, fds_set in input_fds.items(): - fd_records = [] + normals = [] + abends = [] direction = (open_dir or {}).get(fd_name, '') - for rec in records: + for i, rec in enumerate(records): + term = term_types[i] if i < len(term_types) else 'normal' fd_rec = {} for fname in fds_set: r = roles.get(fname, 'unused') @@ -111,8 +150,12 @@ def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, ope if fname in rec: fd_rec[fname] = rec[fname] if fd_rec: - fd_records.append(fd_rec) + if term == 'abend': + abends.append(fd_rec) + else: + normals.append(fd_rec) - outpath = outdir / f'{stem}_{fd_name}.json' - with open(outpath, 'w', encoding='utf-8') as f: - json.dump(fd_records, f, ensure_ascii=False, indent=2) + if normals: + _write_json(normals, outdir / f'{stem}_{fd_name}.json') + if abends: + _write_json(abends, outdir / f'{stem}_abend_{fd_name}.json') diff --git a/cobol_testgen/read.py b/cobol_testgen/read.py index 9759dcf..6531fd2 100644 --- a/cobol_testgen/read.py +++ b/cobol_testgen/read.py @@ -1,9 +1,12 @@ -"""??????? + COPYBOOK + DATA DIVISION?? + PIC""" +"""Preprocessor + COPYBOOK + DATA DIVISION parse + PIC""" import re +import logging from pathlib import Path from lark import Lark, Transformer, v_args +logger = logging.getLogger(__name__) + from .models import FieldDef, PicInfo @@ -85,6 +88,8 @@ def preprocess(source: str) -> str: if len(line) >= 7 and line[6].upper() == 'D': continue content = line[6:] if len(line) >= 7 else line + if content.strip().startswith('*'): + continue else: comment_pos = line.find('*>') if comment_pos >= 0: @@ -192,6 +197,125 @@ def resolve_copybooks(source: str, source_dir: str, _recursion_depth: int = 0, return '\n'.join(result) +# ── EXEC SQL INCLUDE Resolution ── + +_RE_SQL_INC = re.compile( + r'EXEC\s+SQL\s+INCLUDE\s+(\w[\w-]*)\s+END-EXEC\.', + re.IGNORECASE | re.DOTALL +) + +_BUILTIN_SQLCA = """\ + 01 SQLCA. + 05 SQLCAID PIC X(8). + 05 SQLCABC PIC S9(9) COMP. + 05 SQLCODE PIC S9(9) COMP. + 05 SQLERRM. + 10 SQLERRML PIC S9(4) COMP. + 10 SQLERRMC PIC X(70). + 05 SQLERRP PIC X(8). + 05 SQLERRD OCCURS 6 TIMES PIC S9(9) COMP. + 05 SQLWARN. + 10 SQLWARN0 PIC X. + 10 SQLWARN1 PIC X. + 10 SQLWARN2 PIC X. + 10 SQLWARN3 PIC X. + 10 SQLWARN4 PIC X. + 10 SQLWARN5 PIC X. + 10 SQLWARN6 PIC X. + 10 SQLWARN7 PIC X. + 05 SQLSTATE PIC X(5). +""" + + +def resolve_sql_includes(source: str, source_dir: str) -> str: + """Resolve EXEC SQL INCLUDE name END-EXEC. like COPY. Injects built-in SQLCA if not found.""" + def _resolve_one(m): + name = m.group(1).upper() + for ext in ('', '.cpy', '.CPY', '.cbl', '.CBL'): + p = Path(source_dir) / f"{name}{ext}" + if p.exists(): + return p.read_text(encoding='utf-8') + if name == 'SQLCA': + return _BUILTIN_SQLCA + logger.warning(f"SQL INCLUDE {name} not found, injecting as comment") + return f" * SQL INCLUDE {name} NOT RESOLVED\n" + while True: + new_source = _RE_SQL_INC.sub(_resolve_one, source) + if new_source == source: + break + source = new_source + return source + + +_RE_SQL_BLOCK = re.compile( + r'EXEC\s+SQL\s+(.*?)\s+END-EXEC\.?', + re.IGNORECASE | re.DOTALL +) + +_RE_DECLARE_TABLE = re.compile( + r'EXEC\s+SQL\s+DECLARE\s+(\w[\w-]*)\s+TABLE\s*\((.*?)\)\s+END-EXEC\.?', + re.IGNORECASE | re.DOTALL +) + + +def strip_exec_sql_from_data_div(source: str) -> tuple: + """Strip EXEC SQL blocks from DATA DIVISION. Returns (cleaned_source, declared_columns).""" + declared_columns = {} + def _repl(m): + full = m.group(0) + dm = _RE_DECLARE_TABLE.match(full) + if dm: + table_name = dm.group(1).upper() + col_text = dm.group(2) + cols = _parse_declare_table_columns(col_text) + declared_columns[table_name] = cols + return f" *> DECLARE {table_name} TABLE ({len(cols)} cols)\n" + return " *> SKIPPED EXEC SQL\n" + cleaned = _RE_SQL_BLOCK.sub(_repl, source) + return cleaned, declared_columns + + +def _parse_declare_table_columns(col_text: str) -> list[dict]: + """Parse 'CUST_ID CHAR(5) NOT NULL, BALANCE PIC 9(6)' into column list.""" + cols = [] + for part in re.split(r',\s*', col_text): + part = part.strip() + if not part: + continue + m = re.match( + r'(\w[\w-]*)\s+(CHAR\s*\(\s*(\d+)\s*\)' + r'|VARCHAR\s*\(\s*(\d+)\s*\)' + r'|INTEGER|SMALLINT' + r'|DECIMAL\s*\(\s*(\d+)\s*(?:,\s*(\d+))?\s*\)' + r'|DATE' + r'|PIC\s+([\w().]+))' + r'(?:\s+NOT\s+NULL|\s+NULL)?', + part, re.IGNORECASE + ) + if m: + name = m.group(1).upper() + if m.group(3): + col_type = {'db_type': 'CHAR', 'size': int(m.group(3))} + elif m.group(4): + col_type = {'db_type': 'VARCHAR', 'size': int(m.group(4))} + elif m.group(2).upper() == 'INTEGER': + col_type = {'db_type': 'INTEGER'} + elif m.group(2).upper() == 'SMALLINT': + col_type = {'db_type': 'SMALLINT'} + elif m.group(5): + prec = int(m.group(5)) if m.group(5) else 0 + scale = int(m.group(6)) if m.group(6) else 0 + col_type = {'db_type': 'DECIMAL', 'precision': prec, 'scale': scale} + elif m.group(2).upper() == 'DATE': + col_type = {'db_type': 'DATE'} + elif m.group(7): + col_type = {'db_type': 'PIC', 'pic': m.group(7).upper()} + else: + col_type = {'db_type': 'CHAR', 'size': 1} + cols.append({'name': name, **col_type}) + return cols + + # 鈹€鈹€ Lark Grammar 鈹€鈹€ _GRAMMAR_CACHE = None @@ -464,7 +588,7 @@ def parse_file_control(source: str) -> dict: """Parse FILE-CONTROL paragraph. Returns dict: - {filename: {"assign_to": str, "organization": str | None}} + {filename: {"assign": str, "organization": str, "recording_mode": str}} """ m = re.search(r'FILE-CONTROL\.(.*?)(?=DATA\s+DIVISION|\Z)', source, re.DOTALL | re.IGNORECASE) if not m: @@ -472,21 +596,39 @@ def parse_file_control(source: str) -> dict: fc = m.group(1) result = {} for sel_m in re.finditer( - r'SELECT\s+(\w[\w-]*)\s+[^.]*?\bASSIGN\s+TO\s+(["\'])(.*?)\2', + r'SELECT\s+(\w[\w-]*)\s+[^.]*?\bASSIGN\s+TO\s+' + r'(?:(["\'])(.*?)\2|(\w[\w-]*))' + r'[^.]*\.', fc, re.IGNORECASE ): - fname = sel_m.group(1).upper() - assign_to = sel_m.group(3).upper() - # Extract ORGANIZATION clause within this SELECT statement - org_m = re.search( - r'ORGANIZATION\s+(?:IS\s+)?(\w[\w-]*)', - sel_m.group(0), re.IGNORECASE - ) - org = org_m.group(1).upper() if org_m else None - result[fname] = { - "assign_to": assign_to, - "organization": org, - } + name = sel_m.group(1).upper() + if sel_m.group(2): + assign_to = sel_m.group(3).upper() + else: + assign_to = sel_m.group(4).upper() + clause = sel_m.group(0) + org_m = re.search(r'ORGANIZATION\s+(LINE\s+)?SEQUENTIAL', clause, re.IGNORECASE) + if org_m and org_m.group(1): + org = 'LINE SEQUENTIAL' + elif org_m: + org = 'SEQUENTIAL' + else: + org = 'SEQUENTIAL' + result[name] = {'assign': assign_to, 'organization': org, 'recording_mode': 'F'} + # Extract RECORDING MODE from FD blocks in FILE SECTION + fd_sec_m = re.search(r'FILE\s+SECTION\.(.*?)(?=WORKING-STORAGE\s+SECTION|LINKAGE\s+SECTION|\Z)', + source, re.DOTALL | re.IGNORECASE) + if fd_sec_m: + fs = fd_sec_m.group(1) + for block in re.split(r'\n\s*(?=FD\s+)', fs.strip()): + fd_m = re.match(r'FD\s+(\w[\w-]*)', block, re.IGNORECASE) + if not fd_m: + continue + fd_name = fd_m.group(1).upper() + if fd_name in result: + rm_m = re.search(r'RECORDING\s+MODE\s+IS\s+(\w)', block, re.IGNORECASE) + if rm_m: + result[fd_name]['recording_mode'] = rm_m.group(1).upper() return result @@ -499,14 +641,12 @@ def parse_file_section(source: str) -> dict: fs = m.group(1) result = {} # FD 和 SD 条目 - blocks = re.split(r'\n\s*(?=(?:FD|SD)\s+)', fs.strip()) - for block in blocks: + fd_blocks = re.split(r'\n\s*(?=(?:FD|SD)\s+)', fs.strip()) + for block in fd_blocks: m = re.match(r'(FD|SD)\s+(\w[\w-]*)', block, re.IGNORECASE) if not m: continue - entry_type = m.group(1).upper() # "FD" or "SD" name = m.group(2).upper() - # 找 01 层记录 recs = re.findall(r'^\s*0{0,1}1\s+(\w[\w-]*)', block, re.MULTILINE) result[name] = [r.upper() for r in recs] return result @@ -521,11 +661,15 @@ def scan_open_statements(source: str) -> dict: source, re.IGNORECASE ): full = m.group(1) - for seg_m in re.finditer( - r'(INPUT|OUTPUT|I-O)\s+([\w\s-]+)', full, re.IGNORECASE - ): - direction = seg_m.group(1).upper() - for fname in re.findall(r'\w[\w-]*', seg_m.group(2)): - if fname.upper() not in ('INPUT', 'OUTPUT', 'I-O'): + full = re.sub(r'\s+', ' ', full) + tokens = re.split(r'\s+(?=(?:INPUT|OUTPUT|I-O)\s)', full) + for seg in tokens: + seg = seg.strip() + if not seg: + continue + seg_m = re.match(r'(INPUT|OUTPUT|I-O)\s+([\w -]+)', seg, re.IGNORECASE) + if seg_m: + direction = seg_m.group(1).upper() + for fname in re.findall(r'\w[\w-]*', seg_m.group(2)): dirs[fname.upper()] = direction return dirs