merge local cobol_testgen improvements into v3 shared modules

- cond.py: SQLCODE/SQLSTATE handling, alphanumeric >/< boundary fix
- output.py: termination tracking, db_input support, _is_field_assigned filter
- coverage.py: mark_from_gcov, THRU support, KeyError protection
- gcov.py: new file (dependency for coverage.py)
- grammar.lark: multi-segment PIC support
- read.py: SQL INCLUDE resolution, DECLARE TABLE parsing, * comment fix
- core.py: SQL parsing, blocked_names, keyword list
- design.py: multi-sentinel, THRU ranges, PERFORM VARYING last iteration
- __init__.py: local main() + v3 API functions, guarded imports

All 6 ZAN programs verified passing through v3 pipeline
This commit is contained in:
hangshuo652
2026-06-23 22:38:17 +08:00
parent e5ab3baa46
commit 7fb9304212
9 changed files with 1595 additions and 326 deletions
+370 -84
View File
@@ -1,14 +1,14 @@
"""COBOL Test Data Generator — 模块化版入口 """COBOL Test Data Generator — 模块化版入口
from __future__ import annotations
公开 API: 公开 API:
extract_structure() — 解析 COBOL 控制流 → dict extract_structure() — 解析 COBOL 控制流 → dict
generate_data() — 生成测试数据 → list[dict] generate_data() — 生成测试数据 → list[dict]
incremental_supplement — 差分补充数据 → list[dict] incremental_supplement — 差分补充数据 → list[dict]
check_coverage() — 覆盖率报告 → dict
""" """
import os
import sys import sys
import json
import re import re
import logging import logging
from datetime import datetime from datetime import datetime
@@ -16,25 +16,45 @@ from pathlib import Path
# ── 配置(必须放在本地模块导入之前,避免循环导入) ── # ── 配置(必须放在本地模块导入之前,避免循环导入) ──
CONFIG = {} CONFIG = {
'abend_programs': ['SUB03END'],
}
from .read import preprocess, extract_data_division, extract_procedure_division from .read import preprocess, extract_data_division, extract_procedure_division
from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements, parse_file_control from .read import resolve_copybooks, parse_data_division, parse_file_section, scan_open_statements
from .core import classify_field_roles, _init_child_names from .read import parse_file_control, resolve_sql_includes, strip_exec_sql_from_data_div
from .pipeline_bridge import build_branch_tree_fallback from .core import build_branch_tree, classify_field_roles, _init_child_names, sql_register_virtual_fields, _find_multi_write_fds
from .cond import parse_single_condition, is_field, collect_leaves from .cond import parse_single_condition, is_field, collect_leaves
from .design_mcdc import enum_paths, _filter_stop from .pipeline_bridge import build_branch_tree_fallback
from .design import generate_records from .design_mcdc import enum_paths as mcdc_enum_paths, _filter_stop
from .design import enum_paths, generate_records, get_term_type, extend_abend_programs
from .output import output_json, output_input_files from .output import output_json, output_input_files
from .coverage import run_coverage, generate_coverage_index, check_coverage from .coverage import run_coverage, generate_coverage_index
from japanese_data import generate_fullwidth_text, generate_halfwidth_katakana, generate_wareki_date from japanese_data import generate_fullwidth_text, generate_halfwidth_katakana, generate_wareki_date
try:
from .runner import run_and_compare, run_all, GroupInfo, GroupResult
_HAVE_RUNNER = True
except ImportError:
_HAVE_RUNNER = False
try:
from .gcov import run_gcov
_HAVE_GCOV = True
except ImportError:
_HAVE_GCOV = False
try:
from .to_sql import collect_sql_meta, build_db_input
_HAVE_TOSQL = True
except ImportError:
_HAVE_TOSQL = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
n__all__ = [ __all__ = [
"extract_structure", "extract_structure",
"generate_data", "generate_data",
"incremental_supplement", "incremental_supplement",
"check_coverage",
"CONFIG", "CONFIG",
"generate_fullwidth_text", "generate_fullwidth_text",
"generate_halfwidth_katakana", "generate_halfwidth_katakana",
@@ -107,6 +127,149 @@ def expand_occurs(fields):
return result return result
# ── PREV 连锁 ──
def _constraint_in(cons, field, op, value, want):
for c in cons:
if len(c) == 4 and c[0] == field and c[1] == op and c[2] == value and c[3] == want:
return True
return False
def _inc_str(s, length):
try:
return str(int(s) + 1).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 9Zz\xff':
c[i] = chr(ord(c[i]) + 1)
break
if c[i] == ' ':
c[i] = '0'
break
if c[i] == '9':
c[i] = '0'
return ''.join(c)
def _dec_str(s, length):
try:
n = max(0, int(s) - 1)
return str(n).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 0Aa\x00':
c[i] = chr(ord(c[i]) - 1)
break
if c[i] == ' ':
break
if c[i] == '0':
c[i] = '9'
return ''.join(c)
def _field_length(fname, fields):
for f in fields:
if f['name'] == fname:
pi = f.get('pic_info', {})
return pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0) or 1
return 1
def _chain_prev(records, path_infos, fields, fd_fields, field_to_fd, open_dir):
"""跨记录 PREV 连锁。修改 records 使批次执行的路径与实际比较一致。
每个路径 k-1 的约束(PREV OP CURRENT)对应批次中 loop iter k-1 的实际比较:
PREV = records[prev_src].R01 (程序内部保持的前值)
CURRENT = records[k].R01 (当前读入值)
本函数调整 records[k] 的字段以保证交叉记录比较满足路径约束。
"""
N = len(records)
if N < 2:
return
key_fields = []
time_start_field = None
time_end_field = None
for fname in records[0]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[0]:
if 'EMP-ID' in fname or 'APPL-DATE' in fname:
key_fields.append(fname)
if 'END-TIME' in fname:
time_end_field = fname
if 'START-TIME' in fname:
time_start_field = fname
prev_src = 0
for k in range(1, N):
if k - 1 >= len(path_infos):
break
cons = path_infos[k - 1][0]
is_same_key = all(
_constraint_in(cons, f'WRK-PREV-{fn[3:]}', '=', fn, True)
for fn in key_fields
) if key_fields else False
is_overlap = is_same_key and time_end_field and time_start_field and \
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, True)
is_normal = is_same_key and time_end_field and time_start_field and \
(_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '<=', time_start_field, True) or
_constraint_in(cons, f'WRK-PREV-{time_end_field[3:]}', '>', time_start_field, False))
for fname in records[prev_src]:
if fname.startswith('R01') and not fname.startswith('R01INNREC'):
base = fname[3:]
prev_name = 'WRK-PREV-' + base
if prev_name in records[k]:
records[k][prev_name] = records[prev_src][fname]
if is_same_key:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
records[k][kf] = records[prev_src][kf]
if is_normal and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end >= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _inc_str(prev_end, length)
if is_overlap and time_end_field and time_start_field:
prev_end = records[prev_src].get(time_end_field, '')
curr_start = records[k].get(time_start_field, '')
if prev_end <= curr_start:
length = _field_length(time_start_field, fields)
records[k][time_start_field] = _dec_str(prev_end, length) if prev_end else '0' * length
else:
for kf in key_fields:
if kf in records[k] and kf in records[prev_src]:
if records[k][kf] == records[prev_src][kf]:
length = _field_length(kf, fields)
records[k][kf] = _inc_str(str(records[k][kf]), length)
records[k]['_w02_path'] = is_same_key and time_end_field and time_start_field and not is_overlap
records[k]['_overlap_path'] = is_overlap
for fn in list(records[k].keys()):
if fn.startswith('R01') and not fn.startswith('R01INNREC'):
wfn = 'W01' + fn[3:]
if wfn in records[k]:
records[k][wfn] = records[k][fn]
if is_overlap:
pass
else:
prev_src = k
# ── 入口 ── # ── 入口 ──
def main(): def main():
@@ -116,7 +279,32 @@ def main():
args = sys.argv[1:] args = sys.argv[1:]
# 分离 cobol 文件与输出目录 do_run = False
gcov_mode = False
temp_dir = None
if '--run' in args:
do_run = True
args.remove('--run')
if '--gcov' in args:
gcov_mode = True
args.remove('--gcov')
i = 0
while i < len(args):
if args[i] == '--temp-dir':
if i + 1 < len(args):
temp_dir = args[i + 1]
args.pop(i + 1)
args.pop(i)
else:
args.pop(i)
break
elif args[i].startswith('--temp-dir='):
temp_dir = args[i].split('=', 1)[1]
args.pop(i)
break
else:
i += 1
cobol_files = [] cobol_files = []
outdir = None outdir = None
for a in args: for a in args:
@@ -133,13 +321,13 @@ def main():
if outdir is None: if outdir is None:
outdir = cobol_files[0].parent outdir = cobol_files[0].parent
# 配置全局 Logger
outdir.mkdir(parents=True, exist_ok=True) outdir.mkdir(parents=True, exist_ok=True)
log_path = outdir / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log" (outdir / 'logs').mkdir(parents=True, exist_ok=True)
log_path = outdir / 'logs' / f"cobol_testgen_{datetime.now():%Y%m%d_%H%M%S}.log"
fh = logging.FileHandler(log_path, encoding="utf-8", mode="w") fh = logging.FileHandler(log_path, encoding="utf-8", mode="w")
fh.setLevel(logging.DEBUG) fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter( fh.setFormatter(logging.Formatter(
"%(asctime)s [%(levelname)s] %(name)s: %(message)s" "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)) ))
sh = logging.StreamHandler() sh = logging.StreamHandler()
sh.setLevel(logging.INFO) sh.setLevel(logging.INFO)
@@ -157,12 +345,20 @@ def main():
continue continue
source = filepath.read_text(encoding='utf-8') source = filepath.read_text(encoding='utf-8')
source = resolve_copybooks(source, str(filepath.parent)) source = resolve_copybooks(
source,
str(filepath.parent),
extra_search_paths=[str(filepath.parent / '..' / 'cpy')],
)
source = resolve_sql_includes(source, str(filepath.parent))
preprocessed = preprocess(source) preprocessed = preprocess(source)
file_sec = parse_file_section(preprocessed) file_sec = parse_file_section(preprocessed)
# DATA DIVISION解析
data_div = extract_data_division(preprocessed) data_div = extract_data_division(preprocessed)
if data_div:
data_div, declared_columns = strip_exec_sql_from_data_div(data_div)
else:
declared_columns = {}
if not data_div: if not data_div:
logger.error(f"错误:{filepath.name} 中没有 DATA DIVISION。") logger.error(f"错误:{filepath.name} 中没有 DATA DIVISION。")
continue continue
@@ -172,7 +368,6 @@ def main():
logger.error(f"错误:{filepath.name} 中没有找到含 PIC 的字段。") logger.error(f"错误:{filepath.name} 中没有找到含 PIC 的字段。")
continue continue
# FieldDef → dict
fields_dict = [] fields_dict = []
parent_pic = {} parent_pic = {}
filler_counter = 0 filler_counter = 0
@@ -206,7 +401,6 @@ def main():
if f.is_88: if f.is_88:
entry['is_88'] = True entry['is_88'] = True
entry['parent'] = f.parent entry['parent'] = f.parent
# Copy parent's pic_info for value generation
if f.parent and f.parent in parent_pic: if f.parent and f.parent in parent_pic:
entry['pic_info'] = dict(parent_pic[f.parent]) entry['pic_info'] = dict(parent_pic[f.parent])
else: else:
@@ -215,7 +409,8 @@ def main():
fields_dict = expand_occurs(fields_dict) fields_dict = expand_occurs(fields_dict)
# Build FD→children 和 field→FD 映射 sql_register_virtual_fields(fields_dict)
fd_fields = {} fd_fields = {}
field_to_fd = {} field_to_fd = {}
if file_sec: if file_sec:
@@ -245,13 +440,12 @@ def main():
pic_display = str(f.get('pic', '')) if f.get('pic') else ('88-level' if f.get('is_88') else '') pic_display = str(f.get('pic', '')) if f.get('pic') else ('88-level' if f.get('is_88') else '')
logger.info(f"{f['level']:<6} {f['name']:<25} {pic_display:<15} {t:<12} {l:<5}") logger.info(f"{f['level']:<6} {f['name']:<25} {pic_display:<15} {t:<12} {l:<5}")
# PROCEDURE DIVISION解析
proc_div = extract_procedure_division(preprocessed) proc_div = extract_procedure_division(preprocessed)
branch_paths = [] branch_paths = []
assignments = {} assignments = {}
if proc_div: if proc_div:
branch_tree, assignments = build_branch_tree_fallback(proc_div, fields_dict) branch_tree, assignments = build_branch_tree(proc_div, fields_dict, full_source=preprocessed)
roles = classify_field_roles(branch_tree, assignments, fields_dict, roles = classify_field_roles(branch_tree, assignments, fields_dict,
source=preprocessed, proc_text=proc_div) source=preprocessed, proc_text=proc_div)
@@ -261,12 +455,32 @@ def main():
continue continue
logger.info(f" {f['name']:<30} {roles.get(f['name'], '?')}") logger.info(f" {f['name']:<30} {roles.get(f['name'], '?')}")
abend_list = CONFIG.get('abend_programs', [])
if abend_list:
extend_abend_programs(abend_list)
branch_paths_with_assigns = enum_paths(branch_tree, fields_dict) branch_paths_with_assigns = enum_paths(branch_tree, fields_dict)
branch_paths_with_assigns = [ path_infos = []
(_filter_stop(c), a) for c, a in branch_paths_with_assigns for c, a in branch_paths_with_assigns:
] filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
def _is_skip(cons):
eq1_true = 0
other = 0
for c in cons:
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
eq1_true += 1
else:
other += 1
return eq1_true > 0 and other == 0
before = len(path_infos)
path_infos = [p for p in path_infos if not _is_skip(p[0])]
after = len(path_infos)
logger.info(f" SKIP 过滤: {before} -> {after} 条路径(预期减少 1")
# OPEN 方向解析
open_dir = scan_open_statements(proc_div) if proc_div else {} open_dir = scan_open_statements(proc_div) if proc_div else {}
if proc_div: if proc_div:
@@ -284,26 +498,104 @@ def main():
else: else:
logger.warning("\n没有找到 PROCEDURE DIVISION。") logger.warning("\n没有找到 PROCEDURE DIVISION。")
branch_paths_with_assigns = [([], {})] branch_paths_with_assigns = [([], {})]
path_infos = [([], {}, 'normal')]
roles = {f['name']: 'unused' for f in fields_dict} roles = {f['name']: 'unused' for f in fields_dict}
# 覆盖率报告(传入原始源文本用于行号定位) records, _, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
cov_prefix = str(outdir / filepath.stem)
index_relpath = 'coverage/index.html'
cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict,
source, cov_prefix, index_relpath=index_relpath)
records, kept_path_cons = generate_records(branch_paths_with_assigns, fields_dict, assignments, file_sec=file_sec) def _is_eof_path(cons):
last_eq1_true = -1
for i, c in enumerate(cons):
if len(c) == 4 and c[0] == 'WRK-R01EOF':
val = str(c[2]).strip("'\"")
if val == '1' and c[1] == '=' and c[3]:
last_eq1_true = i
if last_eq1_true < 0:
return False
for i in range(last_eq1_true + 1, len(cons)):
if len(cons[i]) == 4 and cons[i][0] == 'WRK-R01EOF':
return False
return True
eof_mask = [_is_eof_path(c) for c, a, t in path_infos]
eof_count = sum(eof_mask)
if eof_count:
term_types = ['eof' if e else t for e, t in zip(eof_mask, term_types)]
logger.info(f" EOF 路径: {eof_count} 条(将单独执行)")
# 输出 JSON(完整文件) multi_write_fds = _find_multi_write_fds(branch_tree, field_to_fd) if proc_div and branch_tree else set()
outpath = outdir / (filepath.stem + '.json') if multi_write_fds:
logger.info(f" 检测到多 WRITE FD: {', '.join(sorted(multi_write_fds))}")
_chain_prev(records, path_infos, fields_dict, fd_fields, field_to_fd, open_dir)
if _HAVE_TOSQL:
sql_meta = collect_sql_meta(assignments, declared_columns)
db_input = build_db_input(
branch_paths_with_assigns, fields_dict, assignments, sql_meta, declared_columns,
records=records,
)
else:
db_input = None
(outdir / 'json').mkdir(parents=True, exist_ok=True)
outpath = outdir / 'json' / (filepath.stem + '.json')
output_json(records, outpath, roles, output_json(records, outpath, roles,
fd_fields=fd_fields, field_to_fd=field_to_fd, fd_fields=fd_fields, field_to_fd=field_to_fd,
open_dir=open_dir, open_dir=open_dir,
path_cons_list=kept_path_cons) term_types=term_types,
db_input=db_input if db_input else None,
data_fields=fields_dict)
# 输出入力 JSON(按 FD 拆分) output_input_files(records, outdir / 'input', filepath.stem, roles,
output_input_files(records, outdir, filepath.stem, roles, fd_fields, field_to_fd, open_dir,
fd_fields, field_to_fd, open_dir) term_types=term_types)
gcov_data = None
if gcov_mode and proc_div and _HAVE_GCOV:
select_info = parse_file_control(preprocessed)
_temp = temp_dir or str(outdir / '.gcov_cache')
source_dir = str(filepath.parent)
expected_records: list[dict] = [{}] * len(records)
if file_sec and os.path.exists(outpath):
with open(outpath, encoding='utf-8') as f:
full_json = json.load(f)
json_records = full_json.get('records', [])
for i in range(len(records)):
exp = {}
if i < len(json_records):
json_rec = json_records[i]
for fd_name in file_sec:
eo = json_rec.get('expected_output', {})
if fd_name in eo:
exp.update(eo[fd_name])
expected_records[i] = exp
group_results = run_all(
filepath.stem, str(outdir), _temp,
fields_dict, fd_fields, select_info, open_dir,
term_types, records, expected_records=expected_records,
source_dir=source_dir, path_infos=path_infos,
multi_write_fds=multi_write_fds,
)
gcov_data = run_gcov(filepath.stem, _temp)
passed = sum(1 for r in group_results if r.passed)
total = len(group_results)
logger.info(f"\n 执行验证: {passed}/{total} 组通过")
if passed < total:
for r in group_results:
if not r.passed and r.details:
fails = [d for d in r.details if not d.match][:3]
for d in fails:
logger.warning(f" [{r.name}] {d.field}: "
f"期望={d.expected!r}, 实际={d.actual!r}")
if do_run and proc_div and _HAVE_RUNNER:
select_info = parse_file_control(preprocessed)
run_and_compare(
filepath.stem, str(outdir), fields_dict,
fd_fields, select_info, open_dir,
term_types, records,
)
logger.info(f"\n输出:{outpath}{len(records)} 条记录)") logger.info(f"\n输出:{outpath}{len(records)} 条记录)")
logger.debug(f"\n记录明细:") logger.debug(f"\n记录明细:")
@@ -315,11 +607,17 @@ def main():
vals.append(f"{marker}{f['name']}={rec.get(f['name'], '?')}") vals.append(f"{marker}{f['name']}={rec.get(f['name'], '?')}")
logger.debug(f" 记录 {i}: {' | '.join(vals)}") logger.debug(f" 记录 {i}: {' | '.join(vals)}")
(outdir / 'coverage').mkdir(parents=True, exist_ok=True)
cov_prefix = str(outdir / 'coverage' / filepath.stem)
index_relpath = 'index.html'
cov_result = run_coverage(branch_tree, branch_paths_with_assigns, fields_dict,
source, cov_prefix, index_relpath=index_relpath,
gcov_data=gcov_data)
programs.append(cov_result) programs.append(cov_result)
# 生成覆盖率总括索引页
if programs: if programs:
generate_coverage_index(programs, outdir) generate_coverage_index(programs, outdir / 'coverage')
logger.info(f"\n覆盖率总览:{outdir / 'coverage' / 'index.html'}") logger.info(f"\n覆盖率总览:{outdir / 'coverage' / 'index.html'}")
@@ -429,18 +727,14 @@ def extract_structure(cobol_source: str) -> dict:
if m: if m:
paragraphs.add(m.group(1)) paragraphs.add(m.group(1))
# ── 新增字段: select_files ──
select_files = parse_file_control(preprocessed) select_files = parse_file_control(preprocessed)
# ── 新增字段: open_directions_detail (与 open_directions 一致) ──
open_directions_detail = open_dir open_directions_detail = open_dir
# ── 新增字段: has_divide / has_inspect / has_string ──
has_divide = bool(re.search(r'\bDIVIDE\b', cobol_source.upper())) has_divide = bool(re.search(r'\bDIVIDE\b', cobol_source.upper()))
has_inspect = bool(re.search(r'\bINSPECT\b', cobol_source.upper())) has_inspect = bool(re.search(r'\bINSPECT\b', cobol_source.upper()))
has_string = bool(re.search(r'\bSTRING\b', cobol_source.upper())) has_string = bool(re.search(r'\bSTRING\b', cobol_source.upper()))
# ── 新增字段: divide_constants ──
divide_constants = [] divide_constants = []
if has_divide and proc_div: if has_divide and proc_div:
for dm in re.finditer(r'\bDIVIDE\s+([\d.]+)\b', proc_div, re.IGNORECASE): for dm in re.finditer(r'\bDIVIDE\s+([\d.]+)\b', proc_div, re.IGNORECASE):
@@ -450,7 +744,6 @@ def extract_structure(cobol_source: str) -> dict:
except ValueError: except ValueError:
pass pass
# ── 新增字段: perform_patterns ──
perform_patterns = [] perform_patterns = []
def _walk_performs(node): def _walk_performs(node):
@@ -478,7 +771,6 @@ def extract_structure(cobol_source: str) -> dict:
if branch_tree: if branch_tree:
_walk_performs(branch_tree) _walk_performs(branch_tree)
# ── 新增字段: main_loop ──
main_loop = None main_loop = None
def _find_main_loop(node, depth=0): def _find_main_loop(node, depth=0):
@@ -533,7 +825,6 @@ def extract_structure(cobol_source: str) -> dict:
if branch_tree: if branch_tree:
_find_main_loop(branch_tree) _find_main_loop(branch_tree)
# ── 新增字段: if_types ──
if_types = {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0} if_types = {"total": 0, "comparison": 0, "equality": 0, "compound": 0, "nested_depth": 0}
def _walk_if_types(node, depth=0): def _walk_if_types(node, depth=0):
@@ -543,7 +834,6 @@ def extract_structure(cobol_source: str) -> dict:
ct = node.cond_tree ct = node.cond_tree
if ct: if ct:
leaves = collect_leaves(ct) leaves = collect_leaves(ct)
# Check compound: cond_tree is CondAnd or CondOr (not just CondLeaf)
if isinstance(ct, (CondAnd, CondOr)): if isinstance(ct, (CondAnd, CondOr)):
if_types["compound"] += 1 if_types["compound"] += 1
for leaf in leaves: for leaf in leaves:
@@ -566,7 +856,6 @@ def extract_structure(cobol_source: str) -> dict:
if branch_tree: if branch_tree:
_walk_if_types(branch_tree) _walk_if_types(branch_tree)
# ── 新增字段: variable_patterns ──
variable_patterns = { variable_patterns = {
"has_prev_key": False, "has_prev_key": False,
"has_accumulator": False, "has_accumulator": False,
@@ -597,14 +886,12 @@ def extract_structure(cobol_source: str) -> dict:
if re.search(r'[-_]W\b|[-_]WORK\b|[-_]WK\b|^WS-W[0O]\w', name, re.IGNORECASE): if re.search(r'[-_]W\b|[-_]WORK\b|[-_]WK\b|^WS-W[0O]\w', name, re.IGNORECASE):
variable_patterns["has_work"] = True variable_patterns["has_work"] = True
# ── 新增字段: open_pattern ──
open_pattern = "sequential" open_pattern = "sequential"
if proc_div: if proc_div:
proc_upper = proc_div.upper() proc_upper = proc_div.upper()
open_positions = [m.start() for m in re.finditer(r'\bOPEN\b', proc_upper)] open_positions = [m.start() for m in re.finditer(r'\bOPEN\b', proc_upper)]
close_positions = [m.start() for m in re.finditer(r'\bCLOSE\b', proc_upper)] close_positions = [m.start() for m in re.finditer(r'\bCLOSE\b', proc_upper)]
if open_positions and close_positions: if open_positions and close_positions:
# Check OPEN ... CLOSE ... OPEN sequence
for i, opos in enumerate(open_positions): for i, opos in enumerate(open_positions):
for cpos in close_positions: for cpos in close_positions:
if cpos > opos: if cpos > opos:
@@ -618,30 +905,29 @@ def extract_structure(cobol_source: str) -> dict:
break break
return { return {
"paragraphs": sorted(paragraphs) if paragraphs else [], "paragraphs": sorted(paragraphs) if paragraphs else [],
"decision_points": decision_points, "decision_points": decision_points,
"branch_tree": branch_tree, "branch_tree": branch_tree,
"file_count": len(file_sec) if file_sec else 0, "file_count": len(file_sec) if file_sec else 0,
"open_directions": open_dir, "open_directions": open_dir,
"has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points), "has_search_all": any('SEARCH' in str(dp.get('label', '')) for dp in decision_points),
"has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points), "has_evaluate": any(dp['kind'] == 'EVALUATE' for dp in decision_points),
"has_call": 'CALL' in cobol_source.upper(), "has_call": 'CALL' in cobol_source.upper(),
"has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points), "has_break": any('KEY' in str(dp.get('label', '')).upper() for dp in decision_points),
"total_branches": total_branches, "total_branches": total_branches,
"total_paragraphs": len(paragraphs), "total_paragraphs": len(paragraphs),
"branch_tree_obj": branch_tree, "branch_tree_obj": branch_tree,
# ── 新增 8 类结构特征 ── "select_files": select_files,
"select_files": select_files, "open_directions_detail": open_directions_detail,
"open_directions_detail": open_directions_detail, "has_divide": has_divide,
"has_divide": has_divide, "divide_constants": divide_constants,
"divide_constants": divide_constants, "has_inspect": has_inspect,
"has_inspect": has_inspect, "has_string": has_string,
"has_string": has_string, "perform_patterns": perform_patterns,
"perform_patterns": perform_patterns, "main_loop": main_loop,
"main_loop": main_loop, "if_types": if_types,
"if_types": if_types, "variable_patterns": variable_patterns,
"variable_patterns": variable_patterns, "open_pattern": open_pattern,
"open_pattern": open_pattern,
} }
@@ -693,11 +979,12 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]:
file_sec = parse_file_section(preprocessed) file_sec = parse_file_section(preprocessed)
branch_paths = enum_paths(branch_tree, fields_dict) branch_paths_unfiltered = mcdc_enum_paths(branch_tree, fields_dict)
branch_paths = [(_filter_stop(c), a) for c, a in branch_paths] path_infos = []
for c, a in branch_paths_unfiltered:
filtered_c, term = get_term_type(c)
path_infos.append((filtered_c, a, term))
# Filter: remove constraints whose field doesn't exist in fields_dict.
# Resolve OF-qualified names and subscripts for matching.
_fdict_names = {f['name'] for f in fields_dict} _fdict_names = {f['name'] for f in fields_dict}
def _resolve_field(fn: str) -> str: def _resolve_field(fn: str) -> str:
ufn = fn.upper() ufn = fn.upper()
@@ -708,7 +995,7 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]:
return m.group(1) return m.group(1)
return fn return fn
filtered_paths = [] filtered_paths = []
for cons_list, asgn in branch_paths: for cons_list, asgn, term in path_infos:
clean = [] clean = []
for c in cons_list: for c in cons_list:
if len(c) >= 4: if len(c) >= 4:
@@ -718,12 +1005,11 @@ def generate_data(cobol_source: str, structure: dict = None) -> list[dict]:
clean.append(tuple(c)) clean.append(tuple(c))
else: else:
clean.append(c) clean.append(c)
filtered_paths.append((clean, asgn)) filtered_paths.append((clean, asgn, term))
branch_paths = filtered_paths path_infos = filtered_paths
records, kept_paths = generate_records(branch_paths, fields_dict, assignments, file_sec=file_sec) records, kept_paths, term_types = generate_records(path_infos, fields_dict, assignments, file_sec=file_sec)
# Cross-file KEY alignment for matching programs
if records: if records:
import re as _re import re as _re
proc_upper = (proc_div or "").upper() proc_upper = (proc_div or "").upper()
+45 -3
View File
@@ -44,12 +44,34 @@ def parse_single_condition(text, fields=None):
- Bare: WS-EOF → (WS-EOF, '=', 'Y') - Bare: WS-EOF → (WS-EOF, '=', 'Y')
- NOT bare: NOT WS-EOF → (WS-EOF, '<>', 'Y') - NOT bare: NOT WS-EOF → (WS-EOF, '<>', 'Y')
- NOT arith: A+B NOT = C → ('A+B', '<>', 'C') - NOT arith: A+B NOT = C → ('A+B', '<>', 'C')
- SQLCODE: SQLCODE = 100 → ('SQLCODE', '=', '100')
- SQLSTATE: SQLSTATE <> '02000' → ('SQLSTATE', '<>', '02000')
Returns None for compound (AND/OR) conditions. Returns None for compound (AND/OR) conditions.
""" """
if ' AND ' in text or ' OR ' in text: if ' AND ' in text or ' OR ' in text:
return None return None
text = text.strip() text = text.strip()
field_name = text.split()[0] if text else ''
# SQLCODE special handling
if field_name.upper() == 'SQLCODE':
text_upper = text.upper()
if 'GREATER THAN 0' in text_upper or 'GREATER THAN ZERO' in text_upper:
return ('SQLCODE', '>', '0')
if 'LESS THAN 0' in text_upper:
return ('SQLCODE', '<', '0')
if '= 100' in text_upper:
return ('SQLCODE', '=', '100')
if 'NOT = 100' in text_upper:
return ('SQLCODE', '<>', '100')
# SQLSTATE special handling
if field_name.upper() == 'SQLSTATE':
normalized_sql = re.sub(r'\bNOT\s*=', '<>', text, flags=re.IGNORECASE)
m = re.match(r"SQLSTATE\s*(>=|<=|<>|>|<|=)\s*['\"]?(.+?)['\"]?\s*$", normalized_sql, re.IGNORECASE)
if m:
return ('SQLSTATE', m.group(1), m.group(2).strip().strip("'\""))
# Resolve 88-level condition names # Resolve 88-level condition names
if fields: if fields:
@@ -62,9 +84,9 @@ def parse_single_condition(text, fields=None):
# Bare NOT field reference (no operator): NOT WS-EOF → WS-EOF <> 'Y' # Bare NOT field reference (no operator): NOT WS-EOF → WS-EOF <> 'Y'
if text.upper().startswith('NOT ') and not re.search(r'(>=|<=|<>|>|<|=)', text): if text.upper().startswith('NOT ') and not re.search(r'(>=|<=|<>|>|<|=)', text):
field_name = text[4:].strip() fn = text[4:].strip()
if re.match(r'^[A-Z][A-Z0-9-]*(?:\([^)]*\))?$', field_name, re.IGNORECASE): if re.match(r'^[A-Z][A-Z0-9-]*(?:\([^)]*\))?$', fn, re.IGNORECASE):
return (field_name, '<>', 'Y') return (fn, '<>', 'Y')
# Normalize COBOL NOT-operators: X NOT = Y → X <> Y # Normalize COBOL NOT-operators: X NOT = Y → X <> Y
normalized = text normalized = text
@@ -292,11 +314,31 @@ def satisfying_value(field_info: dict, operator: str, value, want_true: bool) ->
elif operator in ('<>', '!='): elif operator in ('<>', '!='):
other = chr(65 + (ord(base_chr) - 64) % 26) other = chr(65 + (ord(base_chr) - 64) % 26)
return other.ljust(length, other) return other.ljust(length, other)
elif operator == '>':
sv = str(value)[:length].ljust(length)
chars = list(sv)
last = chars[-1]
if last not in '9Zz':
chars[-1] = chr(ord(last) + 1)
return ''.join(chars)
elif operator == '<':
sv = str(value)[:length].ljust(length)
chars = list(sv)
last = chars[-1]
if last == ' ':
pass
elif last in '0Aa':
chars[-1] = ' '
else:
chars[-1] = chr(ord(last) - 1)
return ''.join(chars)
else: else:
if operator in ('=', '=='): if operator in ('=', '=='):
other = chr(65 + (ord(base_chr) - 64) % 26) other = chr(65 + (ord(base_chr) - 64) % 26)
return other.ljust(length, other) return other.ljust(length, other)
elif operator in ('<>', '!='): elif operator in ('<>', '!='):
return base_chr.ljust(length, base_chr) return base_chr.ljust(length, base_chr)
elif operator in ('>', '<'):
return str(value)[:length].ljust(length)
return '0'.zfill(total) return '0'.zfill(total)
+315 -73
View File
@@ -15,16 +15,29 @@ _COBOL_SCOPE_ENDERS = {
'END-SEARCH', 'END-SEARCH',
'ELSE', 'WHEN', 'OTHER', 'ELSE', 'WHEN', 'OTHER',
} }
_COBOL_KEYWORDS = {
'GOBACK', 'EXIT', 'STOP', 'CONTINUE',
'ACCEPT', 'DISPLAY', 'MOVE', 'COMPUTE', 'INITIALIZE',
'ADD', 'SUBTRACT', 'MULTIPLY', 'DIVIDE',
'STRING', 'UNSTRING', 'SET', 'INSPECT',
'OPEN', 'CLOSE', 'READ', 'WRITE', 'REWRITE', 'DELETE', 'START',
'PERFORM', 'CALL', 'IF', 'EVALUATE', 'SEARCH', 'SORT', 'MERGE',
'COMMIT', 'ROLLBACK', 'GO',
}
def scan_paragraphs(raw_lines): def scan_paragraphs(raw_lines, blocked_names=None):
paragraphs = {} paragraphs = {}
i = 0 i = 0
blocked = set()
if blocked_names:
for n in blocked_names:
blocked.add(n.upper())
while i < len(raw_lines): while i < len(raw_lines):
line = raw_lines[i].strip() line = raw_lines[i].strip()
m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', line) m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', line)
sec_m = re.match(r'^([A-Z][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE) sec_m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE)
if m and m.group(1) not in _COBOL_SCOPE_ENDERS: if m and m.group(1) not in _COBOL_SCOPE_ENDERS and m.group(1) not in _COBOL_KEYWORDS and m.group(1) not in blocked:
name = m.group(1) name = m.group(1)
elif sec_m: elif sec_m:
name = sec_m.group(1).upper() name = sec_m.group(1).upper()
@@ -36,9 +49,9 @@ def scan_paragraphs(raw_lines):
while j < len(raw_lines): while j < len(raw_lines):
nline = raw_lines[j].strip() nline = raw_lines[j].strip()
nm = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', nline) nm = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', nline)
if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS: if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS and nm.group(1) not in _COBOL_KEYWORDS and nm.group(1) not in blocked:
break break
if re.match(r'^[A-Z][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE): if re.match(r'^[A-Z0-9][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE):
break break
j += 1 j += 1
paragraphs[name] = (start, j - 1) paragraphs[name] = (start, j - 1)
@@ -46,9 +59,47 @@ def scan_paragraphs(raw_lines):
return paragraphs return paragraphs
def build_branch_tree(proc_text, fields=None): def sql_register_virtual_fields(fields_dict: list[dict]) -> list[dict]:
"""Inject SQLCODE, SQLSTATE as virtual fields if not already present."""
virtual = []
if not any(f['name'] == 'SQLCODE' for f in fields_dict):
virtual.append({
'name': 'SQLCODE',
'level': 77, 'pic': 'S9(9)',
'pic_info': {'type': 'numeric', 'digits': 9, 'decimal': 0,
'length': 4, 'signed': True},
'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None,
'usage': 'COMP', 'occurs': 0, 'occurs_depending': None,
'value': None, 'values': None,
})
if not any(f['name'] == 'SQLSTATE' for f in fields_dict):
virtual.append({
'name': 'SQLSTATE',
'level': 77, 'pic': 'X(5)',
'pic_info': {'type': 'alphanumeric', 'length': 5},
'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None,
'usage': 'DISPLAY', 'occurs': 0, 'occurs_depending': None,
'value': None, 'values': None,
})
fields_dict.extend(virtual)
return fields_dict
def build_branch_tree(proc_text, fields=None, full_source=None):
raw_lines = proc_text.split('\n') raw_lines = proc_text.split('\n')
paragraphs = scan_paragraphs(raw_lines) # Collect data names (FD names, record names, field names) to block paragraph detection
blocked_names = set()
if fields:
for f in fields:
if isinstance(f, dict):
blocked_names.add(f['name'].upper())
else:
blocked_names.add(f.name.upper())
# Extract FD names from full source if available (includes DATA DIVISION)
src = full_source or proc_text
for m in re.finditer(r'\bFD\s+(\w[\w-]*)\b', src, re.IGNORECASE):
blocked_names.add(m.group(1).upper())
paragraphs = scan_paragraphs(raw_lines, blocked_names=blocked_names)
first_para_name = None first_para_name = None
first_para_idx = None first_para_idx = None
@@ -169,6 +220,13 @@ class _BrParser:
if m_search: if m_search:
seq.add(self._parse_search(m_search)) seq.add(self._parse_search(m_search))
continue continue
m_exec = re.match(r'^EXEC\s+SQL\s*$', line, re.IGNORECASE)
if m_exec:
sql_block = self._parse_sql_block()
assign_node = self._parse_sql(sql_block)
if assign_node:
seq.add(assign_node)
continue
m = re.match(r'^INITIALIZE\s+', line) m = re.match(r'^INITIALIZE\s+', line)
if m: if m:
init_seq = self._parse_initialize() init_seq = self._parse_initialize()
@@ -192,7 +250,7 @@ class _BrParser:
seq.add(self._parse_call()) seq.add(self._parse_call())
continue continue
m = re.match( m = re.match(
r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS))?\s*$', r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS|SYSIN|COMMAND-LINE|SYSERR|SYSOUT|ENVIRONMENT-NAME|ENVIRONMENT-VALUE))?\s*$',
line, re.IGNORECASE line, re.IGNORECASE
) )
if m: if m:
@@ -211,21 +269,11 @@ class _BrParser:
seq.add(Assign(tgt, info)) seq.add(Assign(tgt, info))
self.advance() self.advance()
# 跳过 READ 语句剩余行(AT END / NOT AT END / END-READ # 跳过 READ 语句剩余行(AT END / NOT AT END / END-READ
# 遇到新的语句关键词时停止,避免贪婪吞咽后续内容
_stmt_boundary = re.compile(
r'^(IF |EVALUATE |PERFORM |SEARCH |INITIALIZE |STRING |'
r'UNSTRING |CALL |ACCEPT |READ |WRITE |REWRITE |SET |'
r'INSPECT |MOVE |COMPUTE |ADD |SUBTRACT |MULTIPLY |DIVIDE |'
r'GO\s+TO |GOBACK |STOP\s+RUN|EXIT\s|CLOSE |OPEN |DISPLAY |'
r'DELETE |START |'
r'END-IF|END-PERFORM|END-EVALUATE|END-READ)', re.IGNORECASE)
while self.pos < len(self.lines): while self.pos < len(self.lines):
cl = self.clean() cl = self.clean()
if cl in ('END-READ', 'END-READ.'): if cl in ('END-READ', 'END-READ.'):
self.advance() self.advance()
break break
if _stmt_boundary.match(cl):
break
self.advance() self.advance()
continue continue
m_set_false = re.match(r'^SET\s+(\w[\w-]*)\s+TO\s+FALSE\s*$', line, re.IGNORECASE) m_set_false = re.match(r'^SET\s+(\w[\w-]*)\s+TO\s+FALSE\s*$', line, re.IGNORECASE)
@@ -366,7 +414,30 @@ class _BrParser:
else: else:
tgt_key = tgt_base tgt_key = tgt_base
src_clean = raw_src.strip("'").strip('"') src_clean = raw_src.strip("'").strip('"')
is_field_name = self.fields and any(f['name'] == src_clean for f in self.fields) # 检测引用修饰 FIELD(start:length)
rm = re.match(r'^(\w[\w-]*)\(\s*(\d+)\s*:\s*(\d+)\s*\)$', src_clean, re.IGNORECASE)
if rm:
base_src = rm.group(1)
refmod_start = int(rm.group(2))
refmod_length = int(rm.group(3))
is_field_name = self.fields and any(
(f['name'] if isinstance(f, dict) else f.name) == base_src
for f in self.fields
)
if is_field_name:
info = {
'type': 'move',
'source_vars': [base_src],
'refmod_start': refmod_start,
'refmod_length': refmod_length,
}
else:
info = {'type': 'move_literal', 'literal': src_clean}
else:
is_field_name = self.fields and any(
(f['name'] if isinstance(f, dict) else f.name) == src_clean
for f in self.fields
)
if is_field_name: if is_field_name:
info = {'type': 'move', 'source_vars': [src_clean]} info = {'type': 'move', 'source_vars': [src_clean]}
else: else:
@@ -648,40 +719,11 @@ class _BrParser:
line = self.clean() line = self.clean()
m = re.match(r'^IF\s+(.+?)(?:THEN)?\s*$', line) m = re.match(r'^IF\s+(.+?)(?:THEN)?\s*$', line)
cond_text = m.group(1).strip() cond_text = m.group(1).strip()
# Truncate at COBOL statement keywords (single-line IF body after condition)
_stmt_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|'
r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|'
r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH)\b')
_stmt_starts = re.compile(_stmt_pat, re.IGNORECASE)
rest = "" # remaining text after condition truncation (single-line IF body)
sm = _stmt_starts.search(cond_text)
if sm:
rest = cond_text[sm.start():]
cond_text = cond_text[:sm.start()]
self.advance() self.advance()
if rest:
rest = rest.strip()
if rest.endswith('.'):
rest = rest[:-1]
# Split on ELSE but keep ELSE as its own line for parse_seq boundary
else_parts = re.split(r'(\s+ELSE\s+)', rest, maxsplit=1, flags=re.IGNORECASE)
parts = [p.strip() for p in else_parts if p.strip()]
insert_parts = []
for p in parts:
if p.upper() == 'ELSE':
insert_parts.append('ELSE')
else:
insert_parts.append(p if '.' in p else p + '.')
for part in reversed(insert_parts):
self.lines.insert(self.pos, part)
# Join continuation lines (multi-line IF conditions) # Join continuation lines (multi-line IF conditions)
_cont_keywords = (r'THEN|ELSE|END-IF|MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|'
r'DIVIDE|STRING|UNSTRING|INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|'
r'READ|WRITE|REWRITE|DELETE|START|INSPECT|SET|IF|GO\b|EXIT\b|'
r'STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH')
while self.pos < len(self.lines): while self.pos < len(self.lines):
peek = self.clean() peek = self.clean()
if re.match(r'^(' + _cont_keywords + r')', peek, re.IGNORECASE): if re.match(r'^(THEN|ELSE|END-IF|EXEC|MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT)', peek, re.IGNORECASE):
break break
if peek.endswith('.'): if peek.endswith('.'):
cond_text += ' ' + peek.rstrip('.') cond_text += ' ' + peek.rstrip('.')
@@ -697,16 +739,8 @@ class _BrParser:
node = BrIf(cond_text) node = BrIf(cond_text)
node.cond_tree = parse_compound_condition(node.condition, self.fields) node.cond_tree = parse_compound_condition(node.condition, self.fields)
node.true_seq = self.parse_seq(['ELSE', 'END-IF']) node.true_seq = self.parse_seq(['ELSE', 'END-IF'])
clean = self.clean() if self.clean() == 'ELSE':
if clean.startswith('ELSE'): self.advance()
self.advance() # consume ELSE keyword
rest = clean[4:].strip() if len(clean) > 4 else ''
# ELSE IF → reinsert IF statement as next line for recursive parse
if rest.upper().startswith('IF '):
self.lines.insert(self.pos, rest)
elif rest:
# Regular ELSE body text on same line as ELSE: reinsert
self.lines.insert(self.pos, rest if '.' in rest else rest + '.')
node.false_seq = self.parse_seq(['END-IF']) node.false_seq = self.parse_seq(['END-IF'])
if self.clean() == 'END-IF': if self.clean() == 'END-IF':
self.advance() self.advance()
@@ -728,13 +762,6 @@ class _BrParser:
m = re.match(r'^WHEN\s+(.+?)\s*$', line) m = re.match(r'^WHEN\s+(.+?)\s*$', line)
if m: if m:
raw_val = m.group(1).strip().strip("'").strip('"') raw_val = m.group(1).strip().strip("'").strip('"')
# Truncate at COBOL statement keywords (single-line WHEN body after condition)
_eval_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|'
r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|'
r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\b|GOBACK|CLOSE|OPEN|SEARCH)\b')
_eval_stmt = re.search(_eval_pat, raw_val, re.IGNORECASE)
if _eval_stmt:
raw_val = raw_val[:_eval_stmt.start()]
self.advance() self.advance()
# Capture multi-line WHEN conditions (AND/OR continuation) # Capture multi-line WHEN conditions (AND/OR continuation)
while self.pos < len(self.lines): while self.pos < len(self.lines):
@@ -848,6 +875,14 @@ class _BrParser:
if um: if um:
condition = um.group(1).strip() condition = um.group(1).strip()
self.advance() self.advance()
# Join continuation lines (AND/OR on next lines)
while self.pos < len(self.lines):
peek = self.clean()
if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE):
condition += ' ' + peek
self.advance()
else:
break
break break
break break
if from_val and by_val and condition: if from_val and by_val and condition:
@@ -894,6 +929,30 @@ class _BrParser:
m = re.match(r'^PERFORM\s+(\w[\w-]*)\s*$', line) m = re.match(r'^PERFORM\s+(\w[\w-]*)\s*$', line)
if m: if m:
target = m.group(1).strip() target = m.group(1).strip()
save_pos = self.pos
condition = None
self.advance()
while self.pos < len(self.lines):
nxt = self.clean()
um = re.match(r'^UNTIL\s+(.+)$', nxt)
if um:
condition = um.group(1).strip()
self.advance()
# Join continuation lines (AND/OR on next lines)
while self.pos < len(self.lines):
peek = self.clean()
if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE):
condition += ' ' + peek
self.advance()
else:
break
break
break
if condition:
node = BrPerform('para_until', target=target, condition=condition)
self._inline_perform(node, target)
return node
self.pos = save_pos
node = BrPerform('para', target=target) node = BrPerform('para', target=target)
self.advance() self.advance()
self._inline_perform(node, target) self._inline_perform(node, target)
@@ -962,12 +1021,18 @@ class _BrParser:
parts = [self.clean()] parts = [self.clean()]
self.advance() self.advance()
while self.pos < len(self.lines): while self.pos < len(self.lines):
peek = self.peek()
cl = self.clean() cl = self.clean()
if cl == 'END-STRING': if cl == 'END-STRING':
self.advance() self.advance()
break break
# Stop when a new COBOL statement keyword is encountered
if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-STRING)', peek, re.IGNORECASE):
break
parts.append(cl) parts.append(cl)
self.advance() self.advance()
if peek.rstrip().endswith('.'):
break
full = ' '.join(parts) full = ' '.join(parts)
m = re.match(r'^STRING\s+(.+)\s+INTO\s+(\w[\w-]*)\s*$', full, re.IGNORECASE | re.DOTALL) m = re.match(r'^STRING\s+(.+)\s+INTO\s+(\w[\w-]*)\s*$', full, re.IGNORECASE | re.DOTALL)
if not m: if not m:
@@ -985,12 +1050,17 @@ class _BrParser:
parts = [self.clean()] parts = [self.clean()]
self.advance() self.advance()
while self.pos < len(self.lines): while self.pos < len(self.lines):
peek = self.peek()
cl = self.clean() cl = self.clean()
if cl == 'END-UNSTRING': if cl == 'END-UNSTRING':
self.advance() self.advance()
break break
if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-UNSTRING)', peek, re.IGNORECASE):
break
parts.append(cl) parts.append(cl)
self.advance() self.advance()
if peek.rstrip().endswith('.'):
break
full = ' '.join(parts) full = ' '.join(parts)
m = re.match(r'^UNSTRING\s+(.+?)\s+INTO\s+(.+?)\s*$', full, re.IGNORECASE | re.DOTALL) m = re.match(r'^UNSTRING\s+(.+?)\s+INTO\s+(.+?)\s*$', full, re.IGNORECASE | re.DOTALL)
if not m: if not m:
@@ -1088,6 +1158,75 @@ class _BrParser:
self.advance() self.advance()
return Assign(tgt, info) return Assign(tgt, info)
# ── EXEC SQL parsing ──
_RE_SELECT_INTO = re.compile(
r'SELECT\s+(.*?)\s+INTO\s+(:\w[\w-]*(?:\s*,\s*:\w[\w-]*(?::\w[\w-]*)?)*)'
r'\s+FROM\s+(\w[\w-]*)',
re.IGNORECASE
)
_RE_WHERE = re.compile(r'\bWHERE\b\s+(.*)', re.IGNORECASE)
def _parse_sql_block(self) -> str:
"""Consume lines from EXEC SQL until END-EXEC. Returns SQL text."""
texts = []
self.advance()
while self.pos < len(self.lines):
line = self.lines[self.pos].rstrip('.')
m = re.match(r'(.*?)END-EXEC\.?\s*$', line, re.IGNORECASE)
if m:
before = m.group(1).strip()
if before:
texts.append(before)
self.advance()
break
texts.append(line)
self.advance()
result = ' '.join(texts)
result = re.sub(r'\s+', ' ', result)
return result
def _parse_sql(self, sql_text: str):
"""Parse SQL text from EXEC SQL block. Returns Assign node or None."""
m = self._RE_SELECT_INTO.search(sql_text)
if not m:
return None
select_list = m.group(1).strip()
into_raw = m.group(2).strip()
from_table = m.group(3).strip().upper()
remaining = sql_text[m.end():].strip()
# Parse INTO variables (handle indicator vars: :host:indicator)
into_vars = []
for v in re.split(r'\s*,\s*', into_raw):
v = v.strip().lstrip(':')
parts = v.split(':')
into_vars.append(parts[0].upper())
if len(parts) > 1:
into_vars.append(parts[1].upper())
# Extract WHERE clause
where_clause = ''
wm = self._RE_WHERE.search(remaining)
if wm:
where_clause = wm.group(1).strip()
info = {
'type': 'exec_sql_select',
'table': from_table,
'select_list': select_list,
'into_vars': into_vars,
'where': where_clause,
'sql_text': sql_text,
}
for var in into_vars:
self.assignments.setdefault(var, []).append(info)
return Assign(into_vars[0], info)
# ── 工具函数 ── # ── 工具函数 ──
@@ -1141,8 +1280,6 @@ def trace_to_root(field_name, assignments, fields, path_assign=None):
asgn = asgn_list asgn = asgn_list
else: else:
asgn_list = assignments[var] asgn_list = assignments[var]
if not asgn_list:
break
asgn = asgn_list[-1] asgn = asgn_list[-1]
if isinstance(asgn_list, list): if isinstance(asgn_list, list):
for a in reversed(asgn_list): for a in reversed(asgn_list):
@@ -1152,6 +1289,8 @@ def trace_to_root(field_name, assignments, fields, path_assign=None):
asgn = a asgn = a
break break
chain.append((var, asgn)) chain.append((var, asgn))
if asgn.get('type') in ('unstring_split',):
break
if not asgn.get('source_vars'): if not asgn.get('source_vars'):
break break
sv = asgn['source_vars'] sv = asgn['source_vars']
@@ -1332,8 +1471,36 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
src = asgn['source_vars'][0] src = asgn['source_vars'][0]
resolved_tgt = _resolve_subscript(tgt, rec) resolved_tgt = _resolve_subscript(tgt, rec)
resolved_src = _resolve_subscript(src, rec) resolved_src = _resolve_subscript(src, rec)
if resolved_src in rec: tgt_children = _init_child_names(resolved_tgt, fields)
rec[resolved_tgt] = rec[resolved_src] if tgt_children:
# Group MOVE: propagate to child fields by position
src_children = _init_child_names(resolved_src, fields)
if src_children:
src_str = ''.join(str(rec.get(c, '')) for c in src_children)
elif resolved_src in rec:
src_str = str(rec[resolved_src])
else:
src_str = ''
if src_str:
rec[resolved_tgt] = src_str
pos = 0
for tgt_c in tgt_children:
child_len = 0
for f in fields:
if f['name'] == tgt_c:
pi = f.get('pic_info', {})
child_len = pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0)
break
if child_len > 0:
rec[tgt_c] = src_str[pos:pos + child_len] if pos < len(src_str) else ('0' if child_len else '')
pos += child_len
elif resolved_src in rec:
src_val = str(rec[resolved_src])
if asgn.get('refmod_start') and asgn.get('refmod_length'):
start = asgn['refmod_start'] - 1
end = start + asgn['refmod_length']
src_val = src_val[start:end]
rec[resolved_tgt] = src_val
# Pass 2: literal MOVE # Pass 2: literal MOVE
for tgt, asgn in flat_list: for tgt, asgn in flat_list:
@@ -1439,9 +1606,7 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
resolved_tgt = _resolve_subscript(tgt, rec) resolved_tgt = _resolve_subscript(tgt, rec)
if resolved_tgt not in rec: if resolved_tgt not in rec:
continue continue
inspect_src = asgn.get('tgt', tgt) src_val = str(rec[resolved_tgt])
resolved_src = _resolve_subscript(inspect_src, rec)
src_val = str(rec.get(resolved_src, ''))
for op_type, params in asgn.get('sub_ops', []): for op_type, params in asgn.get('sub_ops', []):
if op_type == 'tally': if op_type == 'tally':
cv = params['count_var'].upper() cv = params['count_var'].upper()
@@ -1495,6 +1660,10 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
src_var = asgn.get('source_vars', [None])[0] src_var = asgn.get('source_vars', [None])[0]
resolved_src = _resolve_subscript(src_var, rec) if src_var else None resolved_src = _resolve_subscript(src_var, rec) if src_var else None
idx = asgn.get('index', 0) idx = asgn.get('index', 0)
if resolved_src and resolved_src not in rec:
children = _init_child_names(resolved_src, fields)
if children:
resolved_src = children[0]
if resolved_src and resolved_src in rec: if resolved_src and resolved_src in rec:
src_val = str(rec[resolved_src]) src_val = str(rec[resolved_src])
ftype = pi.get('type', 'unknown') ftype = pi.get('type', 'unknown')
@@ -1556,6 +1725,23 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
else: else:
rec[resolved_tgt] = val.ljust(length)[:length] if length else val rec[resolved_tgt] = val.ljust(length)[:length] if length else val
# Pass 9: EXEC SQL SELECT INTO
for tgt, asgn in flat_list:
if asgn.get('type') == 'exec_sql_select':
resolved_tgt = _resolve_subscript(tgt, rec)
if resolved_tgt not in rec:
continue
src_val = rec.get(resolved_tgt, '')
pi = pi_map.get(resolved_tgt, {})
if pi.get('type') == 'numeric':
total = pi.get('digits', 0) + pi.get('decimal', 0)
if total > 0:
rec[resolved_tgt] = str(src_val).zfill(total)
elif pi.get('type') in ('alphanumeric', 'alphabetic'):
length = pi.get('length', 0)
if length > 0:
rec[resolved_tgt] = str(src_val).ljust(length)[:length]
# Pass 8: SET var TO TRUE (88-level) # Pass 8: SET var TO TRUE (88-level)
for tgt, asgn in flat_list: for tgt, asgn in flat_list:
if asgn['type'] == 'set_true': if asgn['type'] == 'set_true':
@@ -1649,6 +1835,13 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None)
elif atype == 'write_from': elif atype == 'write_from':
if tgt_base in counts: if tgt_base in counts:
counts[tgt_base]['read'] += 1 counts[tgt_base]['read'] += 1
elif atype == 'exec_sql_select':
if tgt_base in counts:
counts[tgt_base]['write'] += 1
for v in node.source_info.get('into_vars', []):
v_base = _basename(v)
if v_base in counts:
counts[v_base]['write'] += 1
elif atype == 'set_true': elif atype == 'set_true':
if tgt_base in counts: if tgt_base in counts:
counts[tgt_base]['write'] += 1 counts[tgt_base]['write'] += 1
@@ -1705,3 +1898,52 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None)
if name not in result: if name not in result:
result[name] = role result[name] = role
return result return result
# ── 多 WRITE 检测 ──
def _collect_write_fds(node, fds_set, field_to_fd):
"""Recursively collect output FD names from WRITE Assigns."""
if isinstance(node, Assign):
st = node.source_info.get('type', '')
if st in ('write_bare', 'write_from'):
fname = node.target
if fname in field_to_fd:
fds_set.add(field_to_fd[fname])
elif isinstance(node, BrSeq):
for c in node.children:
_collect_write_fds(c, fds_set, field_to_fd)
elif isinstance(node, BrIf):
_collect_write_fds(node.true_seq, fds_set, field_to_fd)
_collect_write_fds(node.false_seq, fds_set, field_to_fd)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_collect_write_fds(seq, fds_set, field_to_fd)
_collect_write_fds(node.other_seq, fds_set, field_to_fd)
elif isinstance(node, BrPerform):
_collect_write_fds(node.body_seq, fds_set, field_to_fd)
elif isinstance(node, BrSearch):
_collect_write_fds(node.at_end_seq, fds_set, field_to_fd)
for _, seq in node.when_list:
_collect_write_fds(seq, fds_set, field_to_fd)
def _find_multi_write_fds(tree, field_to_fd):
"""返回在 INIT 段(主循环前)和循环内部都有 WRITE 的 FD 名集合。
主循环 = 顶层 BrSeq 中最后一个 UNTIL 型 BrPerform(包含 para_until)。
"""
if not isinstance(tree, BrSeq):
return set()
main_loop_idx = -1
for i, child in enumerate(tree.children):
if isinstance(child, BrPerform) and child.perf_type in ('until', 'para_until', 'varying', 'para_varying'):
main_loop_idx = i
if main_loop_idx < 0:
return set()
pre_write = set()
for child in tree.children[:main_loop_idx]:
_collect_write_fds(child, pre_write, field_to_fd)
loop_write = set()
_collect_write_fds(tree.children[main_loop_idx], loop_write, field_to_fd)
return pre_write & loop_write
+46 -57
View File
@@ -8,6 +8,7 @@ from pathlib import Path
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .models import BrSeq, BrIf, BrEval, BrPerform, BrSearch, CondLeaf from .models import BrSeq, BrIf, BrEval, BrPerform, BrSearch, CondLeaf
from .cond import parse_single_condition, parse_compound_condition, is_field, collect_leaves, evaluate_tree from .cond import parse_single_condition, parse_compound_condition, is_field, collect_leaves, evaluate_tree
from .gcov import mark_from_gcov
# ── 数据模型 ── # ── 数据模型 ──
@@ -190,11 +191,14 @@ def _mark_if(dp, cons):
if _match_leaf(c, leaf): if _match_leaf(c, leaf):
assignment[leaf] = c[3] assignment[leaf] = c[3]
break break
if len(assignment) == len(dp.cond_leaves): if assignment:
try:
if evaluate_tree(dp.cond_tree, assignment): if evaluate_tree(dp.cond_tree, assignment):
dp.active_branches.add('T') dp.active_branches.add('T')
else: else:
dp.active_branches.add('F') dp.active_branches.add('F')
except KeyError:
pass
else: else:
matched = 0 matched = 0
for leaf in dp.leaves: for leaf in dp.leaves:
@@ -253,6 +257,15 @@ def _mark_eval(dp, cons, fields=None):
dp.active_branches.add(name) dp.active_branches.add(name)
elif c[0] == dp.label and c[1] == 'not_in': elif c[0] == dp.label and c[1] == 'not_in':
dp.active_branches.add('OTHER') dp.active_branches.add('OTHER')
thru_lows = {c[2] for c in cons if c[0] == dp.label and c[1] == '>=' and c[3]}
thru_highs = {c[2] for c in cons if c[0] == dp.label and c[1] == '<=' and c[3]}
if thru_lows or thru_highs:
for when_val, _ in dp.when_list:
thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(when_val), re.IGNORECASE)
if thru_m and thru_m.group(1) in thru_lows and thru_m.group(2) in thru_highs:
name = f"WHEN {when_val}"
if name in dp.branch_names:
dp.active_branches.add(name)
def _mark_search(dp, cons, fields=None): def _mark_search(dp, cons, fields=None):
@@ -309,11 +322,14 @@ def _mark_perform(dp, cons):
if _match_leaf(c, leaf): if _match_leaf(c, leaf):
assignment[leaf] = c[3] assignment[leaf] = c[3]
break break
if len(assignment) == len(dp.cond_leaves): if assignment:
try:
if evaluate_tree(dp.cond_tree, assignment): if evaluate_tree(dp.cond_tree, assignment):
dp.active_branches.add('Skip') dp.active_branches.add('Skip')
else: else:
dp.active_branches.add('Enter') dp.active_branches.add('Enter')
except KeyError:
pass
else: else:
for c in cons: for c in cons:
if c[0] == dp.label or any(c[0] == f for f in _get_fields_in_cond(dp.label)): if c[0] == dp.label or any(c[0] == f for f in _get_fields_in_cond(dp.label)):
@@ -330,7 +346,6 @@ def _get_fields_in_cond(cond_text):
# ── 行号定位(基于原始源文本)── # ── 行号定位(基于原始源文本)──
def locate_decision_lines(decision_points, raw_source): def locate_decision_lines(decision_points, raw_source):
"""在原始源文本中搜索每个决策点的近似行号"""
lines = raw_source.upper().splitlines() lines = raw_source.upper().splitlines()
for dp in decision_points: for dp in decision_points:
patterns = _build_search_patterns(dp) patterns = _build_search_patterns(dp)
@@ -344,7 +359,6 @@ def locate_decision_lines(decision_points, raw_source):
def _normalize(text): def _normalize(text):
"""标准化条件文本用于比较:去多余空白、标准化引号"""
t = re.sub(r'\s+', ' ', text).strip() t = re.sub(r'\s+', ' ', text).strip()
t = t.replace('"', "'") t = t.replace('"', "'")
return t return t
@@ -360,14 +374,13 @@ def _build_search_patterns(dp):
texts.append((r'\bUNTIL\b', dp.condition if hasattr(dp, 'condition') else dp.label texts.append((r'\bUNTIL\b', dp.condition if hasattr(dp, 'condition') else dp.label
if dp.label else '')) if dp.label else ''))
else: else:
return [r'$^'] # 永不匹配 return [r'$^']
patterns = [] patterns = []
for keyword, condition in texts: for keyword, condition in texts:
if not condition: if not condition:
continue continue
norm_cond = _normalize(condition) norm_cond = _normalize(condition)
# 转义正则特殊字符,但保留空格(替换为\s+)
esc = re.escape(norm_cond) esc = re.escape(norm_cond)
esc = esc.replace(r'\ ', r'\s+') esc = esc.replace(r'\ ', r'\s+')
esc = esc.replace(r'\'', r"['\"]") esc = esc.replace(r'\'', r"['\"]")
@@ -411,7 +424,6 @@ _DETAIL_HTML = '''<!DOCTYPE html>
}} }}
.section h2 {{ font-size: 16px; font-weight: 600; color: #1a237e; margin-bottom: 16px; padding-bottom: 8px; border-bottom: 2px solid #e8eaf6; }} .section h2 {{ font-size: 16px; font-weight: 600; color: #1a237e; margin-bottom: 16px; padding-bottom: 8px; border-bottom: 2px solid #e8eaf6; }}
/* 统计卡片行 */
.stats-row {{ display: flex; gap: 16px; flex-wrap: wrap; }} .stats-row {{ display: flex; gap: 16px; flex-wrap: wrap; }}
.stat-card {{ .stat-card {{
flex: 1; min-width: 140px; background: #f5f7fa; border-radius: 8px; padding: 14px 18px; flex: 1; min-width: 140px; background: #f5f7fa; border-radius: 8px; padding: 14px 18px;
@@ -430,7 +442,6 @@ _DETAIL_HTML = '''<!DOCTYPE html>
.dot-red {{ background: #ffcdd2; }} .dot-red {{ background: #ffcdd2; }}
.dot-amber {{ background: #fff9c4; }} .dot-amber {{ background: #fff9c4; }}
/* 进度条 */
.prog-bar-detail {{ .prog-bar-detail {{
width: 100%; height: 12px; border-radius: 6px; background: #ffcdd2; overflow: hidden; margin: 10px 0 6px 0; width: 100%; height: 12px; border-radius: 6px; background: #ffcdd2; overflow: hidden; margin: 10px 0 6px 0;
}} }}
@@ -440,20 +451,17 @@ _DETAIL_HTML = '''<!DOCTYPE html>
.prog-fill-detail.amber {{ background: linear-gradient(90deg, #ffca28, #ff8f00); }} .prog-fill-detail.amber {{ background: linear-gradient(90deg, #ffca28, #ff8f00); }}
.prog-fill-detail.red {{ background: linear-gradient(90deg, #ef5350, #ff1744); }} .prog-fill-detail.red {{ background: linear-gradient(90deg, #ef5350, #ff1744); }}
/* 表格 */
table {{ width: 100%; border-collapse: collapse; table-layout: fixed; }} table {{ width: 100%; border-collapse: collapse; table-layout: fixed; }}
th, td {{ padding: 10px 14px; text-align: left; border-bottom: 1px solid #eceff1; word-break: break-all; }} th, td {{ padding: 10px 14px; text-align: left; border-bottom: 1px solid #eceff1; word-break: break-all; }}
th {{ background: #f5f7fa; font-weight: 600; font-size: 12px; color: #78909c; text-transform: uppercase; letter-spacing: 0.5px; }} th {{ background: #f5f7fa; font-weight: 600; font-size: 12px; color: #78909c; text-transform: uppercase; letter-spacing: 0.5px; }}
tbody tr:hover {{ background: #e8eaf6; }} tbody tr:hover {{ background: #e8eaf6; }}
tbody tr:last-child td {{ border-bottom: none; }} tbody tr:last-child td {{ border-bottom: none; }}
/* 决策表列宽 */
.dp-table th:nth-child(1), .dp-table td:nth-child(1) {{ width: 50px; }} .dp-table th:nth-child(1), .dp-table td:nth-child(1) {{ width: 50px; }}
.dp-table th:nth-child(2), .dp-table td:nth-child(2) {{ width: 70px; }} .dp-table th:nth-child(2), .dp-table td:nth-child(2) {{ width: 70px; }}
.dp-table th:nth-child(3), .dp-table td:nth-child(3) {{ width: 50px; }} .dp-table th:nth-child(3), .dp-table td:nth-child(3) {{ width: 50px; }}
.dp-table th:nth-child(5), .dp-table td:nth-child(5) {{ width: 160px; }} .dp-table th:nth-child(5), .dp-table td:nth-child(5) {{ width: 160px; }}
/* 叶条件表列宽 */
.leaf-table th:nth-child(1), .leaf-table td:nth-child(1) {{ width: 110px; }} .leaf-table th:nth-child(1), .leaf-table td:nth-child(1) {{ width: 110px; }}
.leaf-table th:nth-child(2), .leaf-table td:nth-child(2) {{ width: 60px; }} .leaf-table th:nth-child(2), .leaf-table td:nth-child(2) {{ width: 60px; }}
.leaf-table th:nth-child(4), .leaf-table td:nth-child(4), .leaf-table th:nth-child(4), .leaf-table td:nth-child(4),
@@ -468,7 +476,6 @@ _DETAIL_HTML = '''<!DOCTYPE html>
.cond-ok {{ color: #00c853; }} .cond-ok {{ color: #00c853; }}
.cond-miss {{ color: #ff5252; }} .cond-miss {{ color: #ff5252; }}
/* 源码 */
.source-section {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }} .source-section {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }}
.source-line {{ display: flex; padding: 1px 0; }} .source-line {{ display: flex; padding: 1px 0; }}
.source-line:hover {{ background: #f5f5f5; }} .source-line:hover {{ background: #f5f5f5; }}
@@ -534,20 +541,22 @@ _DETAIL_HTML = '''<!DOCTYPE html>
{source_section} {source_section}
{source_note}
</div> </div>
</body> </body>
</html>''' </html>'''
def generate_html_report(decision_points, leaf_stats, source_lines, outpath, def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
filename='', index_relpath=None, covered_lines=None): filename='', index_relpath=None, covered_lines=None,
source_note=''):
title = f"覆盖率报告 — {filename}" if filename else "覆盖率报告" title = f"覆盖率报告 — {filename}" if filename else "覆盖率报告"
total_branches = sum(len(dp.branch_names) for dp in decision_points) total_branches = sum(len(dp.branch_names) for dp in decision_points)
covered_branches = sum(len(dp.active_branches) for dp in decision_points) covered_branches = sum(len(dp.active_branches) for dp in decision_points)
implied_branches = sum(len(dp.implied_branches) for dp in decision_points) implied_branches = sum(len(dp.implied_branches) for dp in decision_points)
if covered_lines: if covered_lines:
# 无分支程序:隐式 100%
total_branches = max(total_branches, 1) total_branches = max(total_branches, 1)
covered_branches = max(covered_branches, 1) covered_branches = max(covered_branches, 1)
@@ -555,15 +564,13 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
covered_leaves = (sum(1 for l in leaf_stats if l.covered_true) + covered_leaves = (sum(1 for l in leaf_stats if l.covered_true) +
sum(1 for l in leaf_stats if l.covered_false)) sum(1 for l in leaf_stats if l.covered_false))
# 计算数值 is_implicit = bool(covered_lines)
is_implicit = bool(covered_lines) # 无分支程序,隐式 100%
dec_pct_val = (covered_branches / total_branches * 100) if total_branches else 0 dec_pct_val = (covered_branches / total_branches * 100) if total_branches else 0
dec_pct_text = "100%" if is_implicit else (f"{dec_pct_val:.1f}%" if total_branches else "") dec_pct_text = "100%" if is_implicit else (f"{dec_pct_val:.1f}%" if total_branches else "")
dec_frac = "全部覆盖" if is_implicit else (f"{covered_branches}/{total_branches}" if total_branches else "") dec_frac = "全部覆盖" if is_implicit else (f"{covered_branches}/{total_branches}" if total_branches else "")
cond_frac = f"{covered_leaves}/{total_leaves}" if total_leaves else "" cond_frac = f"{covered_leaves}/{total_leaves}" if total_leaves else ""
implied_text = f'+{implied_branches - covered_branches} 推断)' if implied_branches > covered_branches else '' implied_text = f'+{implied_branches - covered_branches} 推断)' if implied_branches > covered_branches else ''
# 颜色
if is_implicit or not total_branches or dec_pct_val >= 100: if is_implicit or not total_branches or dec_pct_val >= 100:
dec_val_cls = 'val-green' dec_val_cls = 'val-green'
bar_cls = '' bar_cls = ''
@@ -581,7 +588,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
else: else:
cond_val_cls = 'val-red' cond_val_cls = 'val-red'
# 决策点表格
if decision_points: if decision_points:
dp_rows = [] dp_rows = []
for dp in decision_points: for dp in decision_points:
@@ -608,7 +614,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
else: else:
decision_table = '' decision_table = ''
# 叶条件表格
if leaf_stats: if leaf_stats:
leaf_rows = [] leaf_rows = []
for leaf in leaf_stats: for leaf in leaf_stats:
@@ -627,7 +632,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
else: else:
leaf_table = '' leaf_table = ''
# 源码标注
if source_lines: if source_lines:
line_cov = {} line_cov = {}
for dp in decision_points: for dp in decision_points:
@@ -643,7 +647,6 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
else: else:
line_cov[dp.source_line].append('hl-amber') line_cov[dp.source_line].append('hl-amber')
# 无分支程序:所有 PD 行标记为已覆盖
if covered_lines: if covered_lines:
for ln in covered_lines: for ln in covered_lines:
line_cov.setdefault(ln, []).append('hl-green') line_cov.setdefault(ln, []).append('hl-green')
@@ -677,6 +680,7 @@ def generate_html_report(decision_points, leaf_stats, source_lines, outpath,
leaf_table=leaf_table, leaf_table=leaf_table,
source_section=source_section, source_section=source_section,
dp_count_text=('' if is_implicit else str(len(decision_points))), dp_count_text=('' if is_implicit else str(len(decision_points))),
source_note=source_note,
) )
outpath = Path(outpath) outpath = Path(outpath)
@@ -699,7 +703,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
background: #f0f2f5; color: #37474f; font-size: 14px; line-height: 1.6; background: #f0f2f5; color: #37474f; font-size: 14px; line-height: 1.6;
}} }}
/* 顶栏 */
.topbar {{ .topbar {{
background: linear-gradient(135deg, #1a237e, #283593); background: linear-gradient(135deg, #1a237e, #283593);
color: #fff; padding: 18px 32px; color: #fff; padding: 18px 32px;
@@ -711,7 +714,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.container {{ max-width: 1200px; margin: 0 auto; padding: 28px 24px; }} .container {{ max-width: 1200px; margin: 0 auto; padding: 28px 24px; }}
/* 统计卡片 */
.cards {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 16px; margin-bottom: 28px; }} .cards {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 16px; margin-bottom: 28px; }}
.card {{ .card {{
background: #fff; border-radius: 10px; padding: 20px 22px; background: #fff; border-radius: 10px; padding: 20px 22px;
@@ -725,7 +727,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.num-red {{ color: #ff1744; }} .num-red {{ color: #ff1744; }}
.num-blue {{ color: #1a237e; }} .num-blue {{ color: #1a237e; }}
/* 图表行 */
.charts-row {{ .charts-row {{
display: flex; gap: 32px; justify-content: center; flex-wrap: wrap; display: flex; gap: 32px; justify-content: center; flex-wrap: wrap;
background: #fff; border-radius: 10px; padding: 28px 20px; background: #fff; border-radius: 10px; padding: 28px 20px;
@@ -744,7 +745,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.legend .dot-red {{ background: #ff5252; }} .legend .dot-red {{ background: #ff5252; }}
.legend .dot-amber {{ background: #ffd740; }} .legend .dot-amber {{ background: #ffd740; }}
/* 工具栏 */
.toolbar {{ .toolbar {{
display: flex; justify-content: space-between; align-items: center; display: flex; justify-content: space-between; align-items: center;
margin-bottom: 14px; flex-wrap: wrap; gap: 10px; margin-bottom: 14px; flex-wrap: wrap; gap: 10px;
@@ -764,7 +764,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.toolbar .sort-btn:hover {{ background: #eceff1; }} .toolbar .sort-btn:hover {{ background: #eceff1; }}
.toolbar .sort-btn.active {{ background: #e8eaf6; border-color: #3f51b5; color: #1a237e; font-weight: 500; }} .toolbar .sort-btn.active {{ background: #e8eaf6; border-color: #3f51b5; color: #1a237e; font-weight: 500; }}
/* 表格 */
.table-wrap {{ .table-wrap {{
background: #fff; border-radius: 10px; overflow: hidden; background: #fff; border-radius: 10px; overflow: hidden;
box-shadow: 0 1px 4px rgba(0,0,0,0.06); box-shadow: 0 1px 4px rgba(0,0,0,0.06);
@@ -789,7 +788,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.prog-name a {{ color: #283593; text-decoration: none; }} .prog-name a {{ color: #283593; text-decoration: none; }}
.prog-name a:hover {{ text-decoration: underline; color: #1a237e; }} .prog-name a:hover {{ text-decoration: underline; color: #1a237e; }}
/* 进度条 */
.prog-wrap {{ .prog-wrap {{
display: inline-flex; align-items: center; gap: 10px; width: 100%; display: inline-flex; align-items: center; gap: 10px; width: 100%;
}} }}
@@ -812,7 +810,6 @@ _INDEX_HTML = '''<!DOCTYPE html>
.prog-fill.full {{ border-radius: 10px; }} .prog-fill.full {{ border-radius: 10px; }}
.prog-text {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; white-space: nowrap; min-width: 48px; }} .prog-text {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; white-space: nowrap; min-width: 48px; }}
/* 状态徽标 */
.badge {{ .badge {{
display: inline-block; padding: 3px 10px; border-radius: 12px; display: inline-block; padding: 3px 10px; border-radius: 12px;
font-size: 12px; font-weight: 600; letter-spacing: 0.3px; font-size: 12px; font-weight: 600; letter-spacing: 0.3px;
@@ -821,10 +818,8 @@ _INDEX_HTML = '''<!DOCTYPE html>
.badge-warn {{ background: #fff8e1; color: #e65100; }} .badge-warn {{ background: #fff8e1; color: #e65100; }}
.badge-fail {{ background: #ffebee; color: #c62828; }} .badge-fail {{ background: #ffebee; color: #c62828; }}
/* 条件覆盖列 */
.cond-cell {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }} .cond-cell {{ font-family: "Cascadia Code","Fira Code","JetBrains Mono",Consolas,monospace; font-size: 13px; }}
/* 响应式 */
@media (max-width: 680px) {{ @media (max-width: 680px) {{
.topbar {{ flex-direction: column; align-items: flex-start; gap: 6px; padding: 14px 18px; }} .topbar {{ flex-direction: column; align-items: flex-start; gap: 6px; padding: 14px 18px; }}
.container {{ padding: 16px 12px; }} .container {{ padding: 16px 12px; }}
@@ -968,7 +963,6 @@ function filterTable() {{
def _ring_svg(pct, color_stops): def _ring_svg(pct, color_stops):
"""生成 SVG 圆环 HTML。pct: 0-100 浮点数。"""
r = 54 r = 54
circ = 2 * 3.14159265 * r circ = 2 * 3.14159265 * r
offset = circ * (1 - pct / 100) if pct > 0 else circ offset = circ * (1 - pct / 100) if pct > 0 else circ
@@ -995,7 +989,6 @@ def _ring_svg(pct, color_stops):
def generate_coverage_index(programs, outdir): def generate_coverage_index(programs, outdir):
"""生成覆盖率总括索引页。"""
from datetime import datetime from datetime import datetime
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M') timestamp = datetime.now().strftime('%Y-%m-%d %H:%M')
@@ -1038,7 +1031,6 @@ def generate_coverage_index(programs, outdir):
cond_text = f"{cc}/{tc}" if tc else "" cond_text = f"{cc}/{tc}" if tc else ""
bar_pct = int(pct_dec) bar_pct = int(pct_dec)
# 进度条颜色
if imp or pct_dec >= 100: if imp or pct_dec >= 100:
bar_cls = '' bar_cls = ''
elif pct_dec >= 80: elif pct_dec >= 80:
@@ -1046,7 +1038,6 @@ def generate_coverage_index(programs, outdir):
else: else:
bar_cls = ' red' bar_cls = ' red'
# 状态徽标
if tb == 0 or (cb == tb and not (ib > cb)): if tb == 0 or (cb == tb and not (ib > cb)):
badge = '<span class="badge badge-pass">&#10003; 完全</span>' badge = '<span class="badge badge-pass">&#10003; 完全</span>'
elif cb == tb and ib > cb: elif cb == tb and ib > cb:
@@ -1056,7 +1047,6 @@ def generate_coverage_index(programs, outdir):
else: else:
badge = '<span class="badge badge-fail">&#10007; 欠缺</span>' badge = '<span class="badge badge-fail">&#10007; 欠缺</span>'
# 条件覆盖数字颜色
if tc: if tc:
cond_pct = cc / tc * 100 cond_pct = cc / tc * 100
cond_color = 'num-green' if cond_pct == 100 else ('num-amber' if cond_pct >= 80 else 'num-red') cond_color = 'num-green' if cond_pct == 100 else ('num-amber' if cond_pct >= 80 else 'num-red')
@@ -1107,7 +1097,6 @@ def generate_coverage_index(programs, outdir):
# ── PROCEDURE DIVISION 行范围定位(用于无分支程序标记)── # ── PROCEDURE DIVISION 行范围定位(用于无分支程序标记)──
def _find_proc_range(raw_source: str): def _find_proc_range(raw_source: str):
"""返回 PROCEDURE DIVISION 的行范围 (start_line, end_line) 1-indexed,或 None。"""
lines = raw_source.splitlines() lines = raw_source.splitlines()
proc_start = None proc_start = None
for i, line in enumerate(lines): for i, line in enumerate(lines):
@@ -1116,26 +1105,36 @@ def _find_proc_range(raw_source: str):
break break
if proc_start is None: if proc_start is None:
return None return None
# 找下一个 DIVISION 作为结束边界(或文件尾)
for i in range(proc_start, len(lines)): for i in range(proc_start, len(lines)):
if re.search(r'(IDENTIFICATION|DATA|ENVIRONMENT)\s+DIVISION', lines[i].upper()): if re.search(r'(IDENTIFICATION|DATA|ENVIRONMENT)\s+DIVISION', lines[i].upper()):
return (proc_start, i) # 不包含下一个 DIVISION return (proc_start, i)
return (proc_start, len(lines) + 1) return (proc_start, len(lines) + 1)
# ── 接入入口 ── # ── 接入入口 ──
def run_coverage(branch_tree, branch_paths_with_assigns, fields, def run_coverage(branch_tree, branch_paths_with_assigns, fields,
raw_source, output_prefix, index_relpath=None): raw_source, output_prefix, index_relpath=None,
"""完整覆盖率流程:收集 → 标记 → 定位 → 输出。 gcov_data=None):
Returns:
dict: 汇总数据,用于总括页聚合
"""
decision_points, leaf_stats = collect_decision_points(branch_tree, fields) decision_points, leaf_stats = collect_decision_points(branch_tree, fields)
mark_coverage(decision_points, leaf_stats, branch_paths_with_assigns, fields) mark_coverage(decision_points, leaf_stats, branch_paths_with_assigns, fields)
if gcov_data:
mark_from_gcov(decision_points, gcov_data, branch_tree)
for ls in leaf_stats:
ls.covered_true = False
ls.covered_false = False
_source_note = ''
if gcov_data:
_source_note = (
'<div style="margin-top:16px;font-size:12px;color:#90a4ae;'
'text-align:center;border-top:1px solid #eceff1;padding-top:12px;">'
'覆盖率基于 gcov 运行时数据'
'</div>'
)
if raw_source: if raw_source:
locate_decision_lines(decision_points, raw_source) locate_decision_lines(decision_points, raw_source)
@@ -1146,7 +1145,6 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields,
sum(1 for l in leaf_stats if l.covered_false)) sum(1 for l in leaf_stats if l.covered_false))
leaf_total = len(leaf_stats) * 2 leaf_total = len(leaf_stats) * 2
# 无决策点但有路径 → PROCEDURE DIVISION 全部覆盖
covered_lines = set() covered_lines = set()
if total == 0 and branch_paths_with_assigns and raw_source: if total == 0 and branch_paths_with_assigns and raw_source:
proc_range = _find_proc_range(raw_source) proc_range = _find_proc_range(raw_source)
@@ -1161,9 +1159,9 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields,
f"{output_prefix}_coverage.html", f"{output_prefix}_coverage.html",
Path(output_prefix).stem, Path(output_prefix).stem,
index_relpath=index_relpath, index_relpath=index_relpath,
covered_lines=covered_lines) covered_lines=covered_lines,
source_note=_source_note)
# 控制台摘要
if total or leaf_total: if total or leaf_total:
logger.info(f"\n=== 分支覆盖率 ===") logger.info(f"\n=== 分支覆盖率 ===")
if covered_lines and not decision_points: if covered_lines and not decision_points:
@@ -1194,7 +1192,7 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields,
implicit_100 = bool(covered_lines) implicit_100 = bool(covered_lines)
return { return {
'name': Path(output_prefix).stem if output_prefix else '', 'name': Path(output_prefix).stem if output_prefix else '',
'detail_relpath': ('../' + Path(output_prefix).stem + '_coverage.html' 'detail_relpath': (Path(output_prefix).stem + '_coverage.html'
if output_prefix else ''), if output_prefix else ''),
'total_branches': total, 'total_branches': total,
'covered_branches': covered, 'covered_branches': covered,
@@ -1208,15 +1206,6 @@ def run_coverage(branch_tree, branch_paths_with_assigns, fields,
def check_coverage(structure: dict, test_records: list[dict]) -> dict: def check_coverage(structure: dict, test_records: list[dict]) -> dict:
"""报告 COBOL 源码的静态分支结构信息。
注意: 静态分析无法精确判断每条测试数据运行时覆盖了哪些分支。
精确的路径追踪依赖 gcov(Phase 3)。此处仅报告总分支数和记录生成情况。
Returns:
dict with: paragraph_rate, branch_rate, decision_rate, total_branches,
total_paragraphs, records_count, note
"""
total_paragraphs = structure.get("total_paragraphs", 0) total_paragraphs = structure.get("total_paragraphs", 0)
total_branches = structure.get("total_branches", 0) total_branches = structure.get("total_branches", 0)
decision_points = structure.get("decision_points", []) decision_points = structure.get("decision_points", [])
+441 -37
View File
@@ -8,12 +8,52 @@ from .core import trace_to_root, invert_through_chain, propagate_assignments, _b
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_STOP = ('__STOP__', '', None, True) _STOP_EXIT_PERFORM = ('__STOP_EXIT_PERFORM__', '', None, True)
_MAX_PATHS = 500 _STOP_SENTINEL = ('__STOP__', '', None, True)
_ABEND_SENTINEL = ('__ABEND__', '', None, True)
_SENTINELS_ALL = {_STOP_EXIT_PERFORM, _STOP_SENTINEL, _ABEND_SENTINEL}
_ABEND_PROGRAMS = {'ABENDPGM'}
def extend_abend_programs(names: list[str]):
_ABEND_PROGRAMS.update(n.upper() for n in names)
_MAX_PATHS = 10000
def _is_sentinel(c):
return c is _STOP_EXIT_PERFORM or c is _STOP_SENTINEL or c is _ABEND_SENTINEL
def _hashable_cons(cons):
"""将约束列表转为可哈希形式(列表值转tuple)用于签名去重。"""
result = []
for c in cons:
if len(c) == 4:
field, op, val, want = c
if isinstance(val, list):
val = tuple(val)
result.append((field, op, val, want))
else:
result.append(c)
return result
def _filter_stop(cons): def _filter_stop(cons):
return [c for c in cons if c is not _STOP] """Legacy: strip all sentinel markers. 供旧测试代码使用。"""
return [c for c in cons if not _is_sentinel(c)]
def get_term_type(cons):
"""提取终止类型,返回 (filtered_cons, term_type)."""
remaining = []
term = 'normal'
for c in cons:
if c is _ABEND_SENTINEL:
term = 'abend'
elif _is_sentinel(c):
pass
else:
remaining.append(c)
return remaining, term
def _cap_paths(paths): def _cap_paths(paths):
@@ -29,11 +69,11 @@ def _cap_paths_fair(new_active, child_paths):
k = len(child_paths) k = len(child_paths)
if k <= 1: if k <= 1:
return new_active[:_MAX_PATHS] return new_active[:_MAX_PATHS]
# 分离 STOP 路径(不参与组合,直接保留) # 分离 sentinel 路径(不参与组合,直接保留)
stop_paths = [(p, a) for p, a in new_active if any(c is _STOP for c in p)] stop_paths = [(p, a) for p, a in new_active if any(_is_sentinel(c) for c in p)]
combined = [(p, a) for p, a in new_active if not any(c is _STOP for c in p)] combined = [(p, a) for p, a in new_active if not any(_is_sentinel(c) for c in p)]
n_pred = len(combined) // k n_pred = len(combined) // k
result = list(stop_paths) result = []
if n_pred <= 1: if n_pred <= 1:
result.extend(combined[:_MAX_PATHS - len(result)]) result.extend(combined[:_MAX_PATHS - len(result)])
return result[:_MAX_PATHS] return result[:_MAX_PATHS]
@@ -75,24 +115,29 @@ def enum_paths(node, fields):
for child in node.children: for child in node.children:
child_paths = _cap_paths(enum_paths(child, fields)) child_paths = _cap_paths(enum_paths(child, fields))
if not child_paths: if not child_paths:
break continue
new_active = [] new_active = []
covered_sigs = set()
for p_cons, p_assign in paths: for p_cons, p_assign in paths:
if any(c is _STOP for c in p_cons): if any(_is_sentinel(c) for c in p_cons):
new_active.append((p_cons, p_assign)) new_active.append((p_cons, p_assign))
continue continue
for cp_cons, cp_assign in child_paths: for cp_cons, cp_assign in child_paths:
merged_cons = p_cons + list(cp_cons)
sig = frozenset(_hashable_cons(merged_cons))
if sig not in covered_sigs:
covered_sigs.add(sig)
merged = {} merged = {}
for d in (p_assign, cp_assign): for d in (p_assign, cp_assign):
for k, v in d.items(): for k, v in d.items():
merged.setdefault(k, []).extend(v if isinstance(v, list) else [v]) merged.setdefault(k, []).extend(v if isinstance(v, list) else [v])
merged_cons = p_cons + list(cp_cons)
new_active.append((merged_cons, merged)) new_active.append((merged_cons, merged))
if len(new_active) >= _MAX_PATHS: if not new_active:
for pc, pa in paths:
if not any(_is_sentinel(c) for c in pc):
new_active.append((pc, dict(pa)))
break break
if len(new_active) >= _MAX_PATHS: paths = new_active
break
paths = _cap_paths_fair(new_active, child_paths)
return paths return paths
elif isinstance(node, BrIf): elif isinstance(node, BrIf):
@@ -186,6 +231,14 @@ def enum_paths(node, fields):
constraints.append((cond.field, cond.op, cond.value, True)) constraints.append((cond.field, cond.op, cond.value, True))
paths.append((constraints + sp_cons, sp_assign)) paths.append((constraints + sp_cons, sp_assign))
prior_false_sets.append([(cond.field, cond.op, cond.value, False)]) prior_false_sets.append([(cond.field, cond.op, cond.value, False)])
elif cond and isinstance(cond, CondNot) and isinstance(cond.child, CondLeaf) and is_field(cond.child.field, fields):
leaf = cond.child
sub = _cap_paths(enum_paths(seq, fields))
for sp_cons, sp_assign in (sub or [([], {})]):
constraints = [c for pf in prior_false_sets for c in pf]
constraints.append((leaf.field, leaf.op, leaf.value, False))
paths.append((constraints + sp_cons, sp_assign))
prior_false_sets.append([(leaf.field, leaf.op, leaf.value, True)])
elif cond: elif cond:
leaves = collect_leaves(cond) leaves = collect_leaves(cond)
if leaves and all(is_field(l.field, fields) for l in leaves): if leaves and all(is_field(l.field, fields) for l in leaves):
@@ -232,11 +285,34 @@ def enum_paths(node, fields):
paths = [] paths = []
for value, seq in node.when_list: for value, seq in node.when_list:
sub = _cap_paths(enum_paths(seq, fields)) sub = _cap_paths(enum_paths(seq, fields))
thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(value), re.IGNORECASE)
if thru_m and not node.subjects:
low, high = thru_m.group(1), thru_m.group(2)
for sp_cons, sp_assign in (sub or [([], {})]):
paths.append(([(node.subject, '>=', low, True), (node.subject, '<=', high, True)] + sp_cons, sp_assign))
paths.append(([(node.subject, '<=', high, True), (node.subject, '>=', low, True)] + sp_cons, sp_assign))
else:
for sp_cons, sp_assign in (sub or [([], {})]): for sp_cons, sp_assign in (sub or [([], {})]):
paths.append(([(node.subject, '=', value, True)] + sp_cons, sp_assign)) paths.append(([(node.subject, '=', value, True)] + sp_cons, sp_assign))
if node.has_other: if node.has_other:
case_vals = [v for v, _ in node.when_list]
sub = _cap_paths(enum_paths(node.other_seq, fields)) sub = _cap_paths(enum_paths(node.other_seq, fields))
thru_found = False
for v, _ in node.when_list:
thru_m = re.match(r'^(\d+)\s+THRU\s+(\d+)$', str(v), re.IGNORECASE)
if thru_m and not node.subjects:
thru_found = True
low_int, high_int = int(thru_m.group(1)), int(thru_m.group(2))
for sp_cons, sp_assign in (sub or [([], {})]):
a_low = dict(sp_assign)
a_low[node.subject] = [{'type': 'move_literal', 'literal': str(max(0, low_int - 1))}]
low_cons = [(node.subject, 'not_in', [thru_m.group(1), thru_m.group(2)], True)]
paths.append((low_cons + sp_cons, a_low))
a_high = dict(sp_assign)
a_high[node.subject] = [{'type': 'move_literal', 'literal': str(high_int + 1)}]
high_cons = [(node.subject, 'not_in', [thru_m.group(1), thru_m.group(2)], True)]
paths.append((high_cons + sp_cons, a_high))
if not thru_found:
case_vals = [v for v, _ in node.when_list]
for sp_cons, sp_assign in (sub or [([], {})]): for sp_cons, sp_assign in (sub or [([], {})]):
paths.append(([(node.subject, 'not_in', case_vals, True)] + sp_cons, sp_assign)) paths.append(([(node.subject, 'not_in', case_vals, True)] + sp_cons, sp_assign))
return paths return paths
@@ -247,7 +323,10 @@ def enum_paths(node, fields):
elif isinstance(node, BrPerform): elif isinstance(node, BrPerform):
if node.perf_type in ('para', 'thru'): if node.perf_type in ('para', 'thru'):
if node.body_seq: if node.body_seq:
return enum_paths(node.body_seq, fields) paths = enum_paths(node.body_seq, fields)
# EXIT PERFORM 只在 PERFORM 体内有效,剥离后不影响后续 BrSeq 组合
paths = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in paths]
return paths
return [([], {})] return [([], {})]
elif node.perf_type in ('until', 'para_until', 'varying', 'para_varying'): elif node.perf_type in ('until', 'para_until', 'varying', 'para_varying'):
# 尝试单条件(现有逻辑) # 尝试单条件(现有逻辑)
@@ -256,7 +335,9 @@ def enum_paths(node, fields):
field, op, val = parsed field, op, val = parsed
paths = [] paths = []
false_sub = _cap_paths(enum_paths(node.body_seq, fields)) false_sub = _cap_paths(enum_paths(node.body_seq, fields))
false_sub = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in false_sub]
for sp_cons, sp_assign in (false_sub or [([], {})]): for sp_cons, sp_assign in (false_sub or [([], {})]):
body_assign = dict(sp_assign)
# PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径 # PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径
if node.varying_from and node.varying_var: if node.varying_from and node.varying_var:
is_fld = any(f['name'] == node.varying_from for f in fields) if fields else False is_fld = any(f['name'] == node.varying_from for f in fields) if fields else False
@@ -268,6 +349,40 @@ def enum_paths(node, fields):
merged.setdefault(k, []).extend(v if isinstance(v, list) else [v]) merged.setdefault(k, []).extend(v if isinstance(v, list) else [v])
sp_assign = merged sp_assign = merged
paths.append(([(field, op, val, False)] + sp_cons, sp_assign)) paths.append(([(field, op, val, False)] + sp_cons, sp_assign))
# PERFORM VARYING: 末次迭代路径(下标=MAX)
if node.varying_from and node.varying_var and op in ('>', '>=', '<', '<=', '='):
try:
if op == '>':
max_val = int(val)
elif op == '>=':
max_val = int(val) - 1
elif op == '<':
max_val = int(val)
elif op == '<=':
max_val = int(val) + 1
elif op == '=':
by_str = str(node.varying_by or '1')
if by_str.lstrip('-').isdigit() and int(by_str) < 0:
max_val = int(val) + 1
else:
max_val = int(val) - 1
from_val = int(node.varying_from)
by_str = str(node.varying_by or '1')
if by_str.lstrip('-').isdigit() and int(by_str) < 0:
ok = max_val <= from_val
else:
ok = max_val >= from_val
if ok:
max_asgn = {'type': 'move_literal', 'literal': str(max_val)}
max_assign = {node.varying_var: [max_asgn]}
merged_max = {}
for d in (max_assign, body_assign):
for k, v in d.items():
merged_max.setdefault(k, []).extend(v if isinstance(v, list) else [v])
the_cons = [(field, op, val, False)]
paths.append((the_cons + sp_cons, merged_max))
except (ValueError, TypeError):
pass
paths.append(([(field, op, val, True)], {})) paths.append(([(field, op, val, True)], {}))
return paths return paths
# 尝试复合条件(AND/OR # 尝试复合条件(AND/OR
@@ -279,6 +394,7 @@ def enum_paths(node, fields):
if sets: if sets:
paths = [] paths = []
false_sub = _cap_paths(enum_paths(node.body_seq, fields)) false_sub = _cap_paths(enum_paths(node.body_seq, fields))
false_sub = [([c for c in cons if c is not _STOP_EXIT_PERFORM], a) for cons, a in false_sub]
for sp_cons, sp_assign in (false_sub or [([], {})]): for sp_cons, sp_assign in (false_sub or [([], {})]):
# PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径 # PERFORM VARYING: 将 FROM 值作为 MOVE 赋值加入 Enter 路径
if node.varying_from and node.varying_var: if node.varying_from and node.varying_var:
@@ -301,14 +417,18 @@ def enum_paths(node, fields):
return [([], {})] return [([], {})]
elif isinstance(node, CallNode): elif isinstance(node, CallNode):
if node.program_name in _ABEND_PROGRAMS:
return [([_ABEND_SENTINEL], {})]
return [([], {})] return [([], {})]
elif isinstance(node, ExitNode): elif isinstance(node, ExitNode):
return [([_STOP], {})] if node.exit_type == 'PERFORM':
return [([_STOP_EXIT_PERFORM], {})]
return [([_STOP_SENTINEL], {})]
elif isinstance(node, GoTo): elif isinstance(node, GoTo):
paths = enum_paths(node.body_seq, fields) paths = enum_paths(node.body_seq, fields)
return [([_STOP] + c, a) for c, a in paths] return [([_STOP_SENTINEL] + c, a) for c, a in paths]
return [([], {})] return [([], {})]
@@ -335,7 +455,7 @@ def seq_date(seq_num: int) -> str:
def _is_date_field(name: str) -> bool: def _is_date_field(name: str) -> bool:
patterns = [r'DATE', r'YYMMDD', r'YYYYMM', r'YEAR', r'MONTH', r'DAY'] patterns = [r'DATE', r'YYMMDD', r'YYYYMM']
for p in patterns: for p in patterns:
if re.search(p, name.upper()): if re.search(p, name.upper()):
return True return True
@@ -401,13 +521,12 @@ def _children_of(group_name: str, fields: list) -> list:
def _make_numeric_value(idx: int, record_num: int, total_digits: int) -> str: def _make_numeric_value(idx: int, record_num: int, total_digits: int) -> str:
max_val = 10 ** total_digits - 1 max_val = 10 ** total_digits
for step in (100, 10, 1): for step in (100, 10, 1):
val = idx * step + record_num val = idx * step + record_num
if val < 10 ** total_digits: if val < max_val:
return str(min(val, max_val)).zfill(total_digits) return str(val).zfill(total_digits)
return str(min(record_num, max_val)).zfill(total_digits) return str(record_num % max_val).zfill(total_digits)
return str(record_num).zfill(total_digits)
def _make_alpha_value(idx: int, record_num: int, length: int) -> str: def _make_alpha_value(idx: int, record_num: int, length: int) -> str:
@@ -548,6 +667,16 @@ def _check_constraint_satisfied(rec, field_name, operator, value, want_true, fie
return eq == want_true return eq == want_true
elif operator == '<>': elif operator == '<>':
return (not eq) == want_true return (not eq) == want_true
elif operator in ('>', '<', '>=', '<='):
if operator == '>':
ok = s_val > s_target
elif operator == '<':
ok = s_val < s_target
elif operator == '>=':
ok = s_val >= s_target
elif operator == '<=':
ok = s_val <= s_target
return ok == want_true
return True return True
return False return False
@@ -625,6 +754,95 @@ def _apply_arith_constraint(rec, field_name, operator, value, want_true, fields)
rec[right_field] = pick rec[right_field] = pick
def _inc_str(s, length):
s = str(s).strip()
try:
r = str(int(s) + 1).zfill(length)
return r if len(r) <= length else '9' * length
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 9Zz\xff':
c[i] = chr(ord(c[i]) + 1)
break
if c[i] == ' ':
c[i] = '0'
break
if c[i] == '9':
c[i] = '0'
elif c[i] == 'Z':
c[i] = 'A'
elif c[i] == 'z':
c[i] = 'a'
return ''.join(c)
def _dec_str(s, length):
s = str(s).strip()
try:
n = max(0, int(s) - 1)
return str(n).zfill(length)
except ValueError:
c = list(str(s).ljust(length)[:length])
for i in range(len(c) - 1, -1, -1):
if c[i] not in ' 0Aa\x00':
c[i] = chr(ord(c[i]) - 1)
break
if c[i] == ' ':
break
if c[i] == '0':
c[i] = '9'
elif c[i] == 'A':
c[i] = ' '
elif c[i] == 'a':
c[i] = ' '
return ''.join(c)
def _reconcile_unstring_fields(rec, left_field, operator, right_field, want_true,
fields, left_chain, assignments, path_assign):
right_root, right_chain = trace_to_root(right_field, assignments, fields, path_assign)
if right_root not in rec:
logger.debug(f"字段间比较协调:右侧根 {right_root} 不在 rec,跳过")
return
all_entries = (left_chain or []) + (right_chain or [])
for _, asgn in all_entries:
if asgn.get('type') not in ('move', 'unstring_split'):
logger.debug(f"字段间比较协调:链含非 MOVE 类型 {asgn.get('type')},跳过")
return
left_val = str(rec.get(left_field, ''))
if not left_val.strip():
logger.debug(f"字段间比较协调:左侧 {left_field} 无值,跳过")
return
length = 0
for f in fields:
if f['name'] == right_root:
length = f.get('pic_info', {}).get('length', 0)
break
if length == 0:
length = len(left_val)
if operator in ('>=', '<='):
if want_true:
right_val = left_val
else:
right_val = _inc_str(left_val, length) if operator == '>=' else _dec_str(left_val, length)
elif operator in ('>', '<'):
if want_true:
right_val = _dec_str(left_val, length) if operator == '>' else _inc_str(left_val, length)
else:
right_val = left_val
elif operator == '=':
right_val = left_val if want_true else _inc_str(left_val, length)
elif operator == '<>':
right_val = _inc_str(left_val, length) if want_true else left_val
else:
return
rec[right_root] = right_val[:length] if right_val else right_val
logger.debug(f"字段间比较协调:{left_field}={left_val} {operator} {right_field} -> {right_root}={rec[right_root]} (want={want_true})")
def apply_constraint(rec, field_name, operator, value, want_true, fields, assignments=None, path_assign=None): def apply_constraint(rec, field_name, operator, value, want_true, fields, assignments=None, path_assign=None):
# 标准化字段名:去除括号内空格(WS-CELL ( 1, 1 ) → WS-CELL(1,1) # 标准化字段名:去除括号内空格(WS-CELL ( 1, 1 ) → WS-CELL(1,1)
field_name = re.sub(r'\s*([(),])\s*', r'\1', field_name) field_name = re.sub(r'\s*([(),])\s*', r'\1', field_name)
@@ -659,6 +877,7 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign
apply_constraint(rec, parent_name, operator, value, want_true, fields, assignments, path_assign) apply_constraint(rec, parent_name, operator, value, want_true, fields, assignments, path_assign)
return return
break break
chain = None
if assignments: if assignments:
root_var, chain = trace_to_root(field_name, assignments, fields, path_assign) root_var, chain = trace_to_root(field_name, assignments, fields, path_assign)
if root_var != field_name: if root_var != field_name:
@@ -666,8 +885,41 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign
if any(f['name'] == new_field_name for f in fields): if any(f['name'] == new_field_name for f in fields):
field_name, operator, value = new_field_name, new_op, new_val field_name, operator, value = new_field_name, new_op, new_val
# 字段间比较:在 satisfied check 前解析/处理
if any(f['name'] == value for f in fields):
resolved_literal = None
for f in fields:
if f['name'] == value and f.get('value') is not None:
resolved_literal = str(f['value']).strip("'").strip('"')
break
if resolved_literal is not None:
value = resolved_literal
elif chain is not None and assignments:
_reconcile_unstring_fields(rec, field_name, operator, value, want_true,
fields, chain, assignments, path_assign)
return
elif re.search(r'[+\-*/]', field_name):
_apply_arith_constraint(rec, field_name, operator, value, want_true, fields)
return
else:
logger.debug(f"字段间比较约束跳过:{field_name} {operator} {value}")
return
# 如果当前值已满足该约束,跳过覆盖(保持先前约束的一致性) # 如果当前值已满足该约束,跳过覆盖(保持先前约束的一致性)
# 但零值时强制使用边界值(非 0/非 min)
if _check_constraint_satisfied(rec, field_name, operator, value, want_true, fields): if _check_constraint_satisfied(rec, field_name, operator, value, want_true, fields):
cur = str(rec.get(field_name, '')).strip('0')
if (cur == '' or cur == '.') and (
(operator in ('>', '>=') and not want_true) or
(operator in ('<', '<=') and want_true)
):
for f in fields:
if f['name'] == field_name:
pi = f.get('pic_info', {})
if pi.get('type') == 'numeric':
val = satisfying_value(pi, operator, value, want_true)
rec[field_name] = val
return
return return
if operator == 'not_in': if operator == 'not_in':
@@ -687,13 +939,6 @@ def apply_constraint(rec, field_name, operator, value, want_true, fields, assign
rec[field_name] = str(n).zfill(pi.get('digits', 0) + pi.get('decimal', 0)) rec[field_name] = str(n).zfill(pi.get('digits', 0) + pi.get('decimal', 0))
return return
return return
# 字段间比较(值侧也是字段名)
if any(f['name'] == value for f in fields):
if re.search(r'[+\-*/]', field_name):
_apply_arith_constraint(rec, field_name, operator, value, want_true, fields)
else:
logger.debug(f"字段间比较约束跳过:{field_name} {operator} {value}")
return
for f in fields: for f in fields:
if f['name'] == field_name: if f['name'] == field_name:
pi = f.get('pic_info', {}) pi = f.get('pic_info', {})
@@ -738,6 +983,31 @@ def sync_redefined_fields(rec, fields):
def apply_occurs_depending(rec, fields): def apply_occurs_depending(rec, fields):
"""根据 OCCURS DEPENDING ON 变量的当前值,清零超范围的下标字段。""" """根据 OCCURS DEPENDING ON 变量的当前值,清零超范围的下标字段。"""
# Phase 1: 将零值的 DEPENDING ON 变量设为最大下标
dep_max = {}
for f in fields:
dep_var = f.get('occurs_depending')
if not dep_var:
continue
m = re.search(r'\((\d+)\)$', f['name'])
if m:
sub = int(m.group(1))
if sub > dep_max.get(dep_var, 0):
dep_max[dep_var] = sub
for dep_var, max_sub in dep_max.items():
try:
cur_val = int(float(str(rec.get(dep_var, '0'))))
except (ValueError, TypeError):
cur_val = 0
if cur_val == 0:
for f in fields:
if f['name'] == dep_var:
pi = f.get('pic_info', {})
digits = pi.get('digits', 0) + pi.get('decimal', 0)
if digits > 0:
rec[dep_var] = str(max_sub).zfill(digits)
break
# Phase 2: 清零超范围的下标字段
for f in fields: for f in fields:
dep_var = f.get('occurs_depending') dep_var = f.get('occurs_depending')
if not dep_var: if not dep_var:
@@ -805,6 +1075,9 @@ def _enum_search_paths(node, fields):
base = re.sub(r'\s*\(.*?\)\s*$', '', cond_tree.field) base = re.sub(r'\s*\(.*?\)\s*$', '', cond_tree.field)
matching_val = cond_tree.value matching_val = cond_tree.value
elem_key = f'{base}({i + 1})' elem_key = f'{base}({i + 1})'
if any(f['name'] == matching_val for f in fields):
extra_assign[elem_key] = [{'type': 'move', 'source_vars': [matching_val]}]
else:
extra_assign[elem_key] = [{'type': 'move_literal', 'literal': matching_val}] extra_assign[elem_key] = [{'type': 'move_literal', 'literal': matching_val}]
non_match = _non_match_for(cond_tree, fields) or ' ' non_match = _non_match_for(cond_tree, fields) or ' '
for j in range(i): for j in range(i):
@@ -815,6 +1088,9 @@ def _enum_search_paths(node, fields):
merged_assign = dict(extra_assign) merged_assign = dict(extra_assign)
for k, v in sp_assign.items(): for k, v in sp_assign.items():
merged_assign.setdefault(k, []).extend(v if isinstance(v, list) else [v]) merged_assign.setdefault(k, []).extend(v if isinstance(v, list) else [v])
if cond_tree and isinstance(cond_tree, CondLeaf):
paths.append(([(elem_key, cond_tree.op, matching_val, True)] + sp_cons, merged_assign))
else:
paths.append((sp_cons, merged_assign)) paths.append((sp_cons, merged_assign))
if node.has_at_end: if node.has_at_end:
@@ -837,16 +1113,20 @@ def _enum_search_paths(node, fields):
return paths return paths
def generate_records(branch_paths_with_assigns, data_fields, base_assignments=None, file_sec=None): def generate_records(path_infos, data_fields, base_assignments=None, file_sec=None):
"""生成测试数据记录。 """生成测试数据记录。
branch_paths_with_assigns: list of (constraints, path_assignments). path_infos: list of (constraints, path_assignments) 或 (constraints, path_assignments, term_type).
base_assignments: 全局 assignments dict (用于 trace_to_root). base_assignments: 全局 assignments dict (用于 trace_to_root).
返回: (records, kept_path_cons) — kept_path_cons 是与 records 一一对应的约束。 返回: (records, kept_path_cons, term_types).
""" """
# 自动兼容旧 2-tuple 格式
if path_infos and len(path_infos[0]) == 2:
path_infos = [(c, a, 'normal') for c, a in path_infos]
records = [] records = []
kept_path_cons = [] kept_path_cons = []
if branch_paths_with_assigns: term_types = []
for seq, (path_cons, path_assign) in enumerate(branch_paths_with_assigns, start=1): if path_infos:
for seq, (path_cons, path_assign, term_type) in enumerate(path_infos, start=1):
path_cons = _filter_stop(path_cons) path_cons = _filter_stop(path_cons)
rec = make_base_record(seq, data_fields) rec = make_base_record(seq, data_fields)
# Pass A: 先传播赋值(MOVE/COMPUTE/READ INTO 等),模拟到决策点前的程序状态 # Pass A: 先传播赋值(MOVE/COMPUTE/READ INTO 等),模拟到决策点前的程序状态
@@ -869,6 +1149,26 @@ def generate_records(branch_paths_with_assigns, data_fields, base_assignments=No
if not _check_constraint_satisfied(rec, root_var, new_op, new_val, want, data_fields): if not _check_constraint_satisfied(rec, root_var, new_op, new_val, want, data_fields):
skip_impossible = True skip_impossible = True
break break
elif field in rec:
asgn_val = path_assign.get(field)
if asgn_val is not None:
asgn_list = asgn_val if isinstance(asgn_val, list) else [asgn_val]
if asgn_list and asgn_list[-1]['type'] == 'move_literal':
cur_val = str(rec.get(field, ''))
if cur_val != '':
pi = next((f.get('pic_info', {}) for f in data_fields if f['name'] == field), {})
if pi.get('type') == 'numeric':
try:
nv = int(float(cur_val))
tv = int(float(str(val)))
ops = {'>': lambda a,b: a > b, '<': lambda a,b: a < b, '=': lambda a,b: a == b, '<>': lambda a,b: a != b, '>=': lambda a,b: a >= b, '<=': lambda a,b: a <= b}
if op in ops:
satisfied = ops[op](nv, tv) == want
if not satisfied:
skip_impossible = True
break
except (ValueError, TypeError):
pass
if skip_impossible: if skip_impossible:
continue continue
# Pass B: 约束覆盖(确保决策条件满足,覆盖 MOVE 带来的值) # Pass B: 约束覆盖(确保决策条件满足,覆盖 MOVE 带来的值)
@@ -886,17 +1186,121 @@ def generate_records(branch_paths_with_assigns, data_fields, base_assignments=No
forward[tgt] = filtered forward[tgt] = filtered
if forward: if forward:
propagate_assignments(rec, forward, data_fields, file_sec=file_sec) propagate_assignments(rec, forward, data_fields, file_sec=file_sec)
# Pass B.75: COMPUTE 重算(约束修改了 COMPUTE 源字段的值)
if isinstance(path_assign, dict):
compute_only = {}
for tgt, asgn_val in path_assign.items():
asgn_list = asgn_val if isinstance(asgn_val, list) else [asgn_val]
filtered = [a for a in asgn_list if a['type'] == 'compute']
if filtered:
compute_only[tgt] = filtered
if compute_only:
propagate_assignments(rec, compute_only, data_fields, file_sec=file_sec)
# Pass B.8: UNSTRING source reconstruction (targets → source)
if base_assignments:
_reconstruct_unstring_sources(rec, base_assignments, data_fields)
# Pass C: 同步 REDEFINES(确保共享存储一致) # Pass C: 同步 REDEFINES(确保共享存储一致)
sync_redefined_fields(rec, data_fields) sync_redefined_fields(rec, data_fields)
# Pass D: OCCURS DEPENDING ON — 清零超范围的下标字段 # Pass D: OCCURS DEPENDING ON — 清零超范围的下标字段
apply_occurs_depending(rec, data_fields) apply_occurs_depending(rec, data_fields)
# Pass E: PIC 长度约束 — 模拟 COBOL 截断语义
for f in data_fields:
name = f['name']
if name in rec and not f.get('is_88') and not f.get('is_filler'):
pi = f.get('pic_info', {})
ftype = pi.get('type', 'unknown')
val = str(rec[name])
if ftype == 'numeric':
total = pi.get('digits', 0) + pi.get('decimal', 0)
if total > 0 and len(val) > total:
rec[name] = val[-total:].zfill(total)
elif ftype in ('alphanumeric', 'alphabetic'):
length = pi.get('length', 0)
if length > 0 and len(val) > length:
rec[name] = val[:length]
records.append(rec) records.append(rec)
kept_path_cons.append(path_cons) kept_path_cons.append(path_cons)
term_types.append(term_type)
# Track which fields were explicitly assigned in this path
if isinstance(path_assign, dict):
rec['_assigned_fields'] = set(path_assign.keys())
else:
rec['_assigned_fields'] = set()
if not records: if not records:
rec = make_base_record(1, data_fields) rec = make_base_record(1, data_fields)
if base_assignments: if base_assignments:
propagate_assignments(rec, base_assignments, data_fields, file_sec=file_sec) propagate_assignments(rec, base_assignments, data_fields, file_sec=file_sec)
if base_assignments:
_reconstruct_unstring_sources(rec, base_assignments, data_fields)
rec['_assigned_fields'] = set()
records.append(rec) records.append(rec)
kept_path_cons.append([]) kept_path_cons.append([])
return records, kept_path_cons term_types.append('normal')
return records, kept_path_cons, term_types
def _reconstruct_unstring_sources(rec, base_assignments, data_fields):
"""Build UNSTRING source field value from comma-separated target values.
After constraints determine target field values, construct the source
string so the COBOL UNSTRING can correctly parse it.
"""
groups = {}
for tgt, asgn_list in base_assignments.items():
for asgn in asgn_list:
if asgn.get('type') == 'unstring_split' and asgn.get('source_vars'):
src = asgn['source_vars'][0]
idx = asgn.get('index', 0)
groups.setdefault(src, []).append((idx, tgt))
for src_var, targets in groups.items():
targets.sort(key=lambda x: x[0])
# Resolve group→child name if source not directly in rec
resolved_src = src_var
if resolved_src not in rec:
grp_level = None
found = False
for f in data_fields:
if not found and f['name'] == resolved_src:
grp_level = f.get('level', 0)
found = True
continue
if found:
if f.get('level', 0) <= grp_level or f.get('level') == 77:
break
if f.get('pic'):
resolved_src = f['name']
break
if resolved_src not in rec:
continue
csv_parts = []
for idx, tgt in targets:
val = rec.get(tgt, '')
csv_parts.append(val if val is not None else '')
csv_value = ','.join(csv_parts)
src_len = 0
for f in data_fields:
if f['name'] == resolved_src:
pi = f.get('pic_info', {})
if pi:
src_len = pi.get('length', 0)
break
if src_len > 0:
csv_value = csv_value.ljust(src_len)[:src_len]
rec[resolved_src] = csv_value
# Also sync to child fields (group→elementary) for FD output consistency
if resolved_src == src_var:
grp_level = None
found = False
for f in data_fields:
if not found and f['name'] == resolved_src:
grp_level = f.get('level', 0)
found = True
continue
if found:
if f.get('level', 0) <= grp_level or f.get('level') == 77:
break
if f.get('pic'):
rec[f['name']] = csv_value
break
+119
View File
@@ -0,0 +1,119 @@
"""gcov 覆盖率数据解析和分支标记"""
import re
import logging
import subprocess
from pathlib import Path
logger = logging.getLogger(__name__)
def parse_cbl_gcov(gcov_path: str) -> dict[int, int]:
"""解析 .cbl.gcov 文件,返回 {COBOL行号: 执行次数}。
gcov 行格式:
#####: 6: 源码行 → 未执行(0 次)
75*: 12: 源码行 → 执行 75 次
1*: 14: 源码行 → 执行 1 次
-: 17: 源码行 → 不可执行(注释/声明行,跳过)
"""
counts = {}
with open(gcov_path, encoding='utf-8') as f:
for line in f:
m = re.match(r'^\s*(#####|\d+\*?|-):\s*(\d+):', line)
if not m:
continue
count_str = m.group(1)
lineno = int(m.group(2))
if count_str == '#####':
counts[lineno] = 0
elif count_str == '-':
continue
else:
counts[lineno] = int(count_str.rstrip('*'))
return counts
def run_gcov(program_name: str, work_dir: str) -> dict[int, int]:
"""在 work_dir 中通过 WSL 执行 gcov 并解析 COBOL 行计数。
Args:
program_name: 程序名(不含扩展名),如 "ALLCMDS"
work_dir: 包含 .gcda/.gcno 的目录(Windows 路径)
Returns:
{COBOL行号: 执行次数} 字典。失败时返回空 dict。
"""
wsl_work = _wsl_path(work_dir)
cmd = ['wsl', 'sh', '-c', f'cd {wsl_work} && gcov {program_name}.c']
result = subprocess.run(
cmd,
capture_output=True, text=True,
encoding='utf-8', errors='replace',
timeout=30,
)
if result.returncode != 0:
logger.warning(f"gcov 失败 (exit={result.returncode}): {result.stderr.strip()}")
return {}
cbl_gcov = Path(work_dir) / f'{program_name}.cbl.gcov'
if not cbl_gcov.exists():
logger.warning(f"gcov 输出不存在: {cbl_gcov}")
return {}
gcov_data = parse_cbl_gcov(str(cbl_gcov))
logger.info(f"gcov 解析: {len(gcov_data)} 行, "
f"{sum(1 for v in gcov_data.values() if v > 0)} 行已执行")
return gcov_data
def _wsl_path(windows_path: str) -> str:
path = Path(windows_path).resolve()
drive = path.drive.lower().rstrip(':')
rest = str(path.relative_to(path.anchor)).replace('\\', '/')
return f'/mnt/{drive}/{rest}'
def mark_from_gcov(decision_points: list, gcov_data: dict[int, int],
branch_tree) -> None:
"""用 gcov 行执行计数推断决策点分支覆盖,直接修改 decision_points 的 active_branches。
推断规则(简化版,先覆盖主要场景):
IF (条件行 L):
- 条件行 L 在 gcov 中 count == 0 → 不可到达,不标记
- 条件行 L 在 gcov 中 count > 0 → 标记 T 和 F 都覆盖
EVALUATE:
- subject 行 count > 0 → 标记所有 WHEN 为已覆盖
PERFORM UNTIL (条件行 L):
- count == 1 → 条件初始即为真,循环体未进入 → Skip 覆盖
- count > 1 → 循环体至少进入一次 → Enter 覆盖
- Skip 总视为覆盖(无论进入与否,最终都会跳出)
"""
for dp in decision_points:
ln = dp.source_line
if ln <= 0 or ln not in gcov_data:
continue
count = gcov_data.get(ln)
if count is None:
continue
if dp.kind == 'IF':
if count == 0:
continue
dp.active_branches.add('T')
dp.active_branches.add('F')
elif dp.kind == 'EVALUATE':
if count == 0:
continue
for bn in dp.branch_names:
dp.active_branches.add(bn)
elif dp.kind == 'PERFORM':
if count > 1:
dp.active_branches.add('Enter')
dp.active_branches.add('Skip')
+1 -1
View File
@@ -13,7 +13,7 @@ clause: pic_clause | value_clause | occurs_clause | redefines_clause | usage_cla
| "JUSTIFIED" "RIGHT"? | "JUSTIFIED" "RIGHT"?
| "BLANK" "WHEN" "ZERO" | "BLANK" "WHEN" "ZERO"
| "GLOBAL" | "EXTERNAL" | "GLOBAL" | "EXTERNAL"
pic_clause: "PIC" "IS"? PICTURE_STRING pic_clause: "PIC" "IS"? PICTURE_STRING ("." PICTURE_STRING)*
value_clause: "VALUE" "IS"? value_literal+ value_clause: "VALUE" "IS"? value_literal+
value_literal: INT | SIGNED_NUMBER | STRING | SQSTRING value_literal: INT | SIGNED_NUMBER | STRING | SQSTRING
| "ZERO" | "ZEROS" | "ZEROES" | "ZERO" | "ZEROS" | "ZEROES"
+67 -24
View File
@@ -23,27 +23,68 @@ def _scenario_text(path_cons):
return ', '.join(parts) return ', '.join(parts)
def _write_json(entries, outpath):
if not entries:
return
outpath.parent.mkdir(parents=True, exist_ok=True)
with open(outpath, 'w', encoding='utf-8') as f:
json.dump(entries, f, ensure_ascii=False, indent=2)
def _is_field_assigned(fname, assigned_set, fields, fd_fields_lookup):
if not assigned_set:
return False
if fname in assigned_set:
return True
level_map = {}
name_order = []
for f in fields:
fn = f['name']
lv = f.get('level', 77)
level_map[fn] = lv
name_order.append((lv, fn))
flv = level_map.get(fname, 77)
ancestor = None
for lv, fn in name_order:
if fn == fname:
break
if lv < flv:
ancestor = fn
if ancestor and ancestor in assigned_set:
return True
return False
def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None, def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None,
open_dir=None, path_cons_list=None): open_dir=None, term_types=None, db_input=None, data_fields=None):
outpath.parent.mkdir(parents=True, exist_ok=True) outpath.parent.mkdir(parents=True, exist_ok=True)
if not roles: if not roles:
out = []
for i, rec in enumerate(records):
entry = dict(rec)
entry['termination'] = (term_types or ['normal'] * len(records))[i]
out.append(entry)
obj = {'program': outpath.stem, 'records': out}
if db_input:
obj['db_input'] = db_input
with open(outpath, 'w', encoding='utf-8') as f: with open(outpath, 'w', encoding='utf-8') as f:
json.dump(records, f, ensure_ascii=False, indent=2) json.dump(obj, f, ensure_ascii=False, indent=2)
return return
# FD direction lookup term_types = term_types or ['normal'] * len(records)
out = [] out = []
for i, rec in enumerate(records): for i, rec in enumerate(records):
inp = {} inp = {}
out_exp = {} out_exp = {}
ws = {} ws = {}
# Group by FD
if fd_fields and field_to_fd: if fd_fields and field_to_fd:
for fd_name, fds_set in fd_fields.items(): for fd_name, fds_set in fd_fields.items():
direction = (open_dir or {}).get(fd_name, '') direction = (open_dir or {}).get(fd_name, '')
inp_block = {} inp_block = {}
out_block = {} out_block = {}
assigned_set = rec.get('_assigned_fields', set())
for fname in fds_set: for fname in fds_set:
if fname not in rec: if fname not in rec:
continue continue
@@ -52,13 +93,13 @@ def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None,
if direction in ('INPUT', 'I-O') and r in ('input', 'inout'): if direction in ('INPUT', 'I-O') and r in ('input', 'inout'):
inp_block[fname] = val inp_block[fname] = val
if direction in ('OUTPUT', 'I-O') and r in ('output', 'inout'): if direction in ('OUTPUT', 'I-O') and r in ('output', 'inout'):
if _is_field_assigned(fname, assigned_set, data_fields or [], fd_fields):
out_block[fname] = val out_block[fname] = val
if inp_block: if inp_block:
inp[fd_name] = inp_block inp[fd_name] = inp_block
if out_block: if out_block:
out_exp[fd_name] = out_block out_exp[fd_name] = out_block
# Working-storage: not belonging to any FD
for name, val in rec.items(): for name, val in rec.items():
if not field_to_fd or name not in field_to_fd: if not field_to_fd or name not in field_to_fd:
ws[name] = val ws[name] = val
@@ -66,25 +107,21 @@ def output_json(records, outpath, roles=None, fd_fields=None, field_to_fd=None,
entry = { entry = {
'input': inp, 'input': inp,
'expected_output': out_exp, 'expected_output': out_exp,
'working_storage': ws, 'working_storage': {k: v for k, v in ws.items() if k != '_assigned_fields'},
'termination': term_types[i] if i < len(term_types) else 'normal',
} }
if path_cons_list and i < len(path_cons_list):
text = _scenario_text(path_cons_list[i])
if text:
entry['scenario'] = text
out.append(entry) out.append(entry)
with open(outpath, 'w', encoding='utf-8') as f: obj = {'program': outpath.stem, 'records': out}
json.dump(out, f, ensure_ascii=False, indent=2) if db_input:
obj['db_input'] = db_input
_write_json(obj, outpath)
def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, open_dir): def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, open_dir,
"""按 FD 名拆分出力入力 JSON 文件。 term_types=None):
每个 INPUT / I-O 方向 FD 生成一个文件:{stem}_{fd_name}.json term_types = term_types or ['normal'] * len(records)
内容为路径数 × 记录,每条只含该 FD 的入力字段值。
"""
input_fds = {} input_fds = {}
for fd_name, fds_set in fd_fields.items(): for fd_name, fds_set in fd_fields.items():
direction = (open_dir or {}).get(fd_name, '') direction = (open_dir or {}).get(fd_name, '')
@@ -101,9 +138,11 @@ def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, ope
outdir.mkdir(parents=True, exist_ok=True) outdir.mkdir(parents=True, exist_ok=True)
for fd_name, fds_set in input_fds.items(): for fd_name, fds_set in input_fds.items():
fd_records = [] normals = []
abends = []
direction = (open_dir or {}).get(fd_name, '') direction = (open_dir or {}).get(fd_name, '')
for rec in records: for i, rec in enumerate(records):
term = term_types[i] if i < len(term_types) else 'normal'
fd_rec = {} fd_rec = {}
for fname in fds_set: for fname in fds_set:
r = roles.get(fname, 'unused') r = roles.get(fname, 'unused')
@@ -111,8 +150,12 @@ def output_input_files(records, outdir, stem, roles, fd_fields, field_to_fd, ope
if fname in rec: if fname in rec:
fd_rec[fname] = rec[fname] fd_rec[fname] = rec[fname]
if fd_rec: if fd_rec:
fd_records.append(fd_rec) if term == 'abend':
abends.append(fd_rec)
else:
normals.append(fd_rec)
outpath = outdir / f'{stem}_{fd_name}.json' if normals:
with open(outpath, 'w', encoding='utf-8') as f: _write_json(normals, outdir / f'{stem}_{fd_name}.json')
json.dump(fd_records, f, ensure_ascii=False, indent=2) if abends:
_write_json(abends, outdir / f'{stem}_abend_{fd_name}.json')
+166 -22
View File
@@ -1,9 +1,12 @@
"""??????? + COPYBOOK + DATA DIVISION?? + PIC""" """Preprocessor + COPYBOOK + DATA DIVISION parse + PIC"""
import re import re
import logging
from pathlib import Path from pathlib import Path
from lark import Lark, Transformer, v_args from lark import Lark, Transformer, v_args
logger = logging.getLogger(__name__)
from .models import FieldDef, PicInfo from .models import FieldDef, PicInfo
@@ -85,6 +88,8 @@ def preprocess(source: str) -> str:
if len(line) >= 7 and line[6].upper() == 'D': if len(line) >= 7 and line[6].upper() == 'D':
continue continue
content = line[6:] if len(line) >= 7 else line content = line[6:] if len(line) >= 7 else line
if content.strip().startswith('*'):
continue
else: else:
comment_pos = line.find('*>') comment_pos = line.find('*>')
if comment_pos >= 0: if comment_pos >= 0:
@@ -192,6 +197,125 @@ def resolve_copybooks(source: str, source_dir: str, _recursion_depth: int = 0,
return '\n'.join(result) return '\n'.join(result)
# ── EXEC SQL INCLUDE Resolution ──
_RE_SQL_INC = re.compile(
r'EXEC\s+SQL\s+INCLUDE\s+(\w[\w-]*)\s+END-EXEC\.',
re.IGNORECASE | re.DOTALL
)
_BUILTIN_SQLCA = """\
01 SQLCA.
05 SQLCAID PIC X(8).
05 SQLCABC PIC S9(9) COMP.
05 SQLCODE PIC S9(9) COMP.
05 SQLERRM.
10 SQLERRML PIC S9(4) COMP.
10 SQLERRMC PIC X(70).
05 SQLERRP PIC X(8).
05 SQLERRD OCCURS 6 TIMES PIC S9(9) COMP.
05 SQLWARN.
10 SQLWARN0 PIC X.
10 SQLWARN1 PIC X.
10 SQLWARN2 PIC X.
10 SQLWARN3 PIC X.
10 SQLWARN4 PIC X.
10 SQLWARN5 PIC X.
10 SQLWARN6 PIC X.
10 SQLWARN7 PIC X.
05 SQLSTATE PIC X(5).
"""
def resolve_sql_includes(source: str, source_dir: str) -> str:
"""Resolve EXEC SQL INCLUDE name END-EXEC. like COPY. Injects built-in SQLCA if not found."""
def _resolve_one(m):
name = m.group(1).upper()
for ext in ('', '.cpy', '.CPY', '.cbl', '.CBL'):
p = Path(source_dir) / f"{name}{ext}"
if p.exists():
return p.read_text(encoding='utf-8')
if name == 'SQLCA':
return _BUILTIN_SQLCA
logger.warning(f"SQL INCLUDE {name} not found, injecting as comment")
return f" * SQL INCLUDE {name} NOT RESOLVED\n"
while True:
new_source = _RE_SQL_INC.sub(_resolve_one, source)
if new_source == source:
break
source = new_source
return source
_RE_SQL_BLOCK = re.compile(
r'EXEC\s+SQL\s+(.*?)\s+END-EXEC\.?',
re.IGNORECASE | re.DOTALL
)
_RE_DECLARE_TABLE = re.compile(
r'EXEC\s+SQL\s+DECLARE\s+(\w[\w-]*)\s+TABLE\s*\((.*?)\)\s+END-EXEC\.?',
re.IGNORECASE | re.DOTALL
)
def strip_exec_sql_from_data_div(source: str) -> tuple:
"""Strip EXEC SQL blocks from DATA DIVISION. Returns (cleaned_source, declared_columns)."""
declared_columns = {}
def _repl(m):
full = m.group(0)
dm = _RE_DECLARE_TABLE.match(full)
if dm:
table_name = dm.group(1).upper()
col_text = dm.group(2)
cols = _parse_declare_table_columns(col_text)
declared_columns[table_name] = cols
return f" *> DECLARE {table_name} TABLE ({len(cols)} cols)\n"
return " *> SKIPPED EXEC SQL\n"
cleaned = _RE_SQL_BLOCK.sub(_repl, source)
return cleaned, declared_columns
def _parse_declare_table_columns(col_text: str) -> list[dict]:
"""Parse 'CUST_ID CHAR(5) NOT NULL, BALANCE PIC 9(6)' into column list."""
cols = []
for part in re.split(r',\s*', col_text):
part = part.strip()
if not part:
continue
m = re.match(
r'(\w[\w-]*)\s+(CHAR\s*\(\s*(\d+)\s*\)'
r'|VARCHAR\s*\(\s*(\d+)\s*\)'
r'|INTEGER|SMALLINT'
r'|DECIMAL\s*\(\s*(\d+)\s*(?:,\s*(\d+))?\s*\)'
r'|DATE'
r'|PIC\s+([\w().]+))'
r'(?:\s+NOT\s+NULL|\s+NULL)?',
part, re.IGNORECASE
)
if m:
name = m.group(1).upper()
if m.group(3):
col_type = {'db_type': 'CHAR', 'size': int(m.group(3))}
elif m.group(4):
col_type = {'db_type': 'VARCHAR', 'size': int(m.group(4))}
elif m.group(2).upper() == 'INTEGER':
col_type = {'db_type': 'INTEGER'}
elif m.group(2).upper() == 'SMALLINT':
col_type = {'db_type': 'SMALLINT'}
elif m.group(5):
prec = int(m.group(5)) if m.group(5) else 0
scale = int(m.group(6)) if m.group(6) else 0
col_type = {'db_type': 'DECIMAL', 'precision': prec, 'scale': scale}
elif m.group(2).upper() == 'DATE':
col_type = {'db_type': 'DATE'}
elif m.group(7):
col_type = {'db_type': 'PIC', 'pic': m.group(7).upper()}
else:
col_type = {'db_type': 'CHAR', 'size': 1}
cols.append({'name': name, **col_type})
return cols
# 鈹€鈹€ Lark Grammar 鈹€鈹€ # 鈹€鈹€ Lark Grammar 鈹€鈹€
_GRAMMAR_CACHE = None _GRAMMAR_CACHE = None
@@ -464,7 +588,7 @@ def parse_file_control(source: str) -> dict:
"""Parse FILE-CONTROL paragraph. """Parse FILE-CONTROL paragraph.
Returns dict: Returns dict:
{filename: {"assign_to": str, "organization": str | None}} {filename: {"assign": str, "organization": str, "recording_mode": str}}
""" """
m = re.search(r'FILE-CONTROL\.(.*?)(?=DATA\s+DIVISION|\Z)', source, re.DOTALL | re.IGNORECASE) m = re.search(r'FILE-CONTROL\.(.*?)(?=DATA\s+DIVISION|\Z)', source, re.DOTALL | re.IGNORECASE)
if not m: if not m:
@@ -472,21 +596,39 @@ def parse_file_control(source: str) -> dict:
fc = m.group(1) fc = m.group(1)
result = {} result = {}
for sel_m in re.finditer( for sel_m in re.finditer(
r'SELECT\s+(\w[\w-]*)\s+[^.]*?\bASSIGN\s+TO\s+(["\'])(.*?)\2', r'SELECT\s+(\w[\w-]*)\s+[^.]*?\bASSIGN\s+TO\s+'
r'(?:(["\'])(.*?)\2|(\w[\w-]*))'
r'[^.]*\.',
fc, re.IGNORECASE fc, re.IGNORECASE
): ):
fname = sel_m.group(1).upper() name = sel_m.group(1).upper()
if sel_m.group(2):
assign_to = sel_m.group(3).upper() assign_to = sel_m.group(3).upper()
# Extract ORGANIZATION clause within this SELECT statement else:
org_m = re.search( assign_to = sel_m.group(4).upper()
r'ORGANIZATION\s+(?:IS\s+)?(\w[\w-]*)', clause = sel_m.group(0)
sel_m.group(0), re.IGNORECASE org_m = re.search(r'ORGANIZATION\s+(LINE\s+)?SEQUENTIAL', clause, re.IGNORECASE)
) if org_m and org_m.group(1):
org = org_m.group(1).upper() if org_m else None org = 'LINE SEQUENTIAL'
result[fname] = { elif org_m:
"assign_to": assign_to, org = 'SEQUENTIAL'
"organization": org, else:
} org = 'SEQUENTIAL'
result[name] = {'assign': assign_to, 'organization': org, 'recording_mode': 'F'}
# Extract RECORDING MODE from FD blocks in FILE SECTION
fd_sec_m = re.search(r'FILE\s+SECTION\.(.*?)(?=WORKING-STORAGE\s+SECTION|LINKAGE\s+SECTION|\Z)',
source, re.DOTALL | re.IGNORECASE)
if fd_sec_m:
fs = fd_sec_m.group(1)
for block in re.split(r'\n\s*(?=FD\s+)', fs.strip()):
fd_m = re.match(r'FD\s+(\w[\w-]*)', block, re.IGNORECASE)
if not fd_m:
continue
fd_name = fd_m.group(1).upper()
if fd_name in result:
rm_m = re.search(r'RECORDING\s+MODE\s+IS\s+(\w)', block, re.IGNORECASE)
if rm_m:
result[fd_name]['recording_mode'] = rm_m.group(1).upper()
return result return result
@@ -499,14 +641,12 @@ def parse_file_section(source: str) -> dict:
fs = m.group(1) fs = m.group(1)
result = {} result = {}
# FD 和 SD 条目 # FD 和 SD 条目
blocks = re.split(r'\n\s*(?=(?:FD|SD)\s+)', fs.strip()) fd_blocks = re.split(r'\n\s*(?=(?:FD|SD)\s+)', fs.strip())
for block in blocks: for block in fd_blocks:
m = re.match(r'(FD|SD)\s+(\w[\w-]*)', block, re.IGNORECASE) m = re.match(r'(FD|SD)\s+(\w[\w-]*)', block, re.IGNORECASE)
if not m: if not m:
continue continue
entry_type = m.group(1).upper() # "FD" or "SD"
name = m.group(2).upper() name = m.group(2).upper()
# 找 01 层记录
recs = re.findall(r'^\s*0{0,1}1\s+(\w[\w-]*)', block, re.MULTILINE) recs = re.findall(r'^\s*0{0,1}1\s+(\w[\w-]*)', block, re.MULTILINE)
result[name] = [r.upper() for r in recs] result[name] = [r.upper() for r in recs]
return result return result
@@ -521,11 +661,15 @@ def scan_open_statements(source: str) -> dict:
source, re.IGNORECASE source, re.IGNORECASE
): ):
full = m.group(1) full = m.group(1)
for seg_m in re.finditer( full = re.sub(r'\s+', ' ', full)
r'(INPUT|OUTPUT|I-O)\s+([\w\s-]+)', full, re.IGNORECASE tokens = re.split(r'\s+(?=(?:INPUT|OUTPUT|I-O)\s)', full)
): for seg in tokens:
seg = seg.strip()
if not seg:
continue
seg_m = re.match(r'(INPUT|OUTPUT|I-O)\s+([\w -]+)', seg, re.IGNORECASE)
if seg_m:
direction = seg_m.group(1).upper() direction = seg_m.group(1).upper()
for fname in re.findall(r'\w[\w-]*', seg_m.group(2)): for fname in re.findall(r'\w[\w-]*', seg_m.group(2)):
if fname.upper() not in ('INPUT', 'OUTPUT', 'I-O'):
dirs[fname.upper()] = direction dirs[fname.upper()] = direction
return dirs return dirs