Files
cobol-java-v3/cobol_testgen/read.py
hangshuo652 7fb9304212 merge local cobol_testgen improvements into v3 shared modules
- cond.py: SQLCODE/SQLSTATE handling, alphanumeric >/< boundary fix
- output.py: termination tracking, db_input support, _is_field_assigned filter
- coverage.py: mark_from_gcov, THRU support, KeyError protection
- gcov.py: new file (dependency for coverage.py)
- grammar.lark: multi-segment PIC support
- read.py: SQL INCLUDE resolution, DECLARE TABLE parsing, * comment fix
- core.py: SQL parsing, blocked_names, keyword list
- design.py: multi-sentinel, THRU ranges, PERFORM VARYING last iteration
- __init__.py: local main() + v3 API functions, guarded imports

All 6 ZAN programs verified passing through v3 pipeline
2026-06-23 22:38:17 +08:00

676 lines
22 KiB
Python

"""Preprocessor + COPYBOOK + DATA DIVISION parse + PIC"""
import re
import logging
from pathlib import Path
from lark import Lark, Transformer, v_args
logger = logging.getLogger(__name__)
from .models import FieldDef, PicInfo
# 鈹€鈹€ Preprocessor 鈹€鈹€
def _is_fixed_format(source: str) -> bool:
if re.search(r'>>SOURCE\s+FORMAT\s+IS\s+FREE', source, re.IGNORECASE):
return False
if re.search(r'>>SOURCE\s+FORMAT\s+IS\s+FIXED', source, re.IGNORECASE):
return True
lines = [l for l in source.splitlines() if l.strip()]
fixed_hits = 0
free_hits = 0
for line in lines[:10]:
if len(line) >= 72:
free_hits += 1
elif len(line) >= 7 and line[6] in ('*', '/', '-', 'D'):
fixed_hits += 1
return fixed_hits >= free_hits if (fixed_hits + free_hits) > 0 else True
def preprocess(source: str) -> str:
# COPY 预处理:展开或移除 COPY 语句
# Lark 语法不支持 COPY(这是预处理指令),必须在解析前处理
source = resolve_copybooks(source, '.')
# Strip EXEC ... END-EXEC blocks (CICS/SQL) before Lark parsing
source = re.sub(
r'EXEC\s+(?:CICS|SQL)\b.*?END-EXEC\.?',
'',
source, flags=re.IGNORECASE | re.DOTALL
)
# Strip commas from VALUE clauses (VALUE 'A', 'B', 'C' → VALUE 'A' 'B' 'C')
def _strip_value_commas(m):
return re.sub(r'\s*,\s*', ' ', m.group(0))
source = re.sub(r'VALUE\s+[^.\n]+', _strip_value_commas, source, flags=re.IGNORECASE)
# Strip ALL from VALUE ALL (VALUE ALL '*.' → VALUE '*.')
source = re.sub(r'\bVALUE\s+ALL\b', 'VALUE', source, flags=re.IGNORECASE)
# Collapse &-concatenated VALUE continuation lines
# COBOL uses & to split long literals across lines:
# "............................" &
# "............................"
# Match: (quote/X'...') + " &" + newline + (quote/X'...')
source = re.sub(
r'([Xx]?["\'])\s*&\s*\n\s*([Xx]?["\'])',
lambda m: m.group(1) + m.group(2),
source
)
# Remove trailing & at end of lines (standalone continuation markers)
source = re.sub(r'&(?=[^"\']*$)', '', source, flags=re.MULTILINE)
# Convert PIC decimal dots to V (implied decimal) for Lark compatibility
# PIC Z(9)9.99. → PIC Z(9)9V99. (only within PIC clause before DOT)
source = re.sub(
r'(PIC\s+)([A-Z0-9(),\-*/V\$]+)\.(\d+)',
r'\1\2V\3',
source, flags=re.IGNORECASE
)
fixed = _is_fixed_format(source)
lines = []
for raw_line in source.splitlines():
line = raw_line.rstrip()
if not line:
lines.append('')
continue
if fixed:
if len(line) >= 7 and line[6] in ('*', '/'):
continue
if len(line) >= 7 and line[6] == '-':
if lines:
lines[-1] = lines[-1] + ' ' + line[7:].lstrip()
continue
if len(line) >= 7 and line[6].upper() == 'D':
continue
content = line[6:] if len(line) >= 7 else line
if content.strip().startswith('*'):
continue
else:
comment_pos = line.find('*>')
if comment_pos >= 0:
line = line[:comment_pos]
line = line.strip()
if not line:
continue
# Strip bare * comment lines in free format (after *> removal)
if line.startswith('*') and not line.startswith('*>'):
continue
content = line
lines.append(re.sub(r'\s+FALSE\s+[^\s.]+', '', content.upper()))
# Ensure DATA DIVISION lines with PIC/VALUE but no trailing DOT get one
# (handles COBOL programs where the period on a PIC clause is optional/omitted)
fixed_lines = []
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.endswith('.'):
# Lines inside DATA DIVISION that have PIC or VALUE but no DOT
if re.search(r'\b(PIC|VALUE|REDEFINES|OCCURS|USAGE)\b', stripped, re.IGNORECASE):
# Only fix if the NEXT line also looks like a data_item (level_num)
if i + 1 < len(lines) and re.match(r'^\s*(0[1-9]|[0-4][0-9]|49|66|77|88)\s', lines[i + 1]):
line = line.rstrip() + ' .'
fixed_lines.append(line)
return '\n'.join(fixed_lines)
def extract_data_division(source: str) -> str:
m = re.search(r'DATA\s+DIVISION\s*\.', source)
if not m:
return ''
start = m.end()
end_m = re.search(r'PROCEDURE\s+DIVISION', source[start:])
if end_m:
end = start + end_m.start()
else:
end = len(source)
return source[start:end].strip()
def extract_procedure_division(source: str) -> str:
m = re.search(r'PROCEDURE\s+DIVISION', source)
if not m:
return ''
return source[m.start():].strip()
# 鈹€鈹€ COPYBOOK Resolution 鈹€鈹€
_COPYBOOK_EXTENSIONS = ['.cpy', '.cbl', '.cpb', '']
def resolve_copybooks(source: str, source_dir: str, _recursion_depth: int = 0,
extra_search_paths: list[str] = None) -> str:
"""Find COPY statements and replace with copybook content.
Searches from source_dir first, then extra_search_paths.
"""
_RE_COPY = re.compile(
r"^\s*COPY\s+(\w[\w-]*|\"[^\"]*\"|\'[^\']*\')(?:\s+REPLACING\s+(.+?))?\s*\.?\s*$",
re.IGNORECASE
)
_RE_PAIR = re.compile(r"==(.+?)==\s+BY\s+==(.+?)==", re.IGNORECASE)
search_dirs = [source_dir] + (extra_search_paths or [])
lines = source.split('\n')
result = []
for line in lines:
m = _RE_COPY.match(line)
if m:
raw_name = m.group(1)
name = raw_name.strip('"').strip("'").upper()
found = None
for sd in search_dirs:
for ext in _COPYBOOK_EXTENSIONS:
p = Path(sd, name + ext)
if p.exists():
found = p
break
if found:
break
if found:
if _recursion_depth > 10:
logger.warning(f"COPY circular dependency detected for {name}, skipping")
continue
cb = found.read_text(encoding='utf-8')
# Recursively resolve nested COPY inside the copybook
cb = resolve_copybooks(cb, source_dir, _recursion_depth + 1)
if m.group(2):
pairs = _RE_PAIR.findall(m.group(2))
for old, new in pairs:
cb = re.sub(
re.escape(old.strip()), new.strip(),
cb, flags=re.IGNORECASE
)
# 展开 COPYBOOK 内容,不添加注释行(避免 Lark 在 FD 块内看到注释)
result.append(cb)
else:
# COPY 未找到时完全跳过(预处理指令,Lark 不应处理)
# 该行可能在 FD/SD 块内,保留会破坏 Lark 解析
pass
else:
result.append(line)
return '\n'.join(result)
# ── EXEC SQL INCLUDE Resolution ──
_RE_SQL_INC = re.compile(
r'EXEC\s+SQL\s+INCLUDE\s+(\w[\w-]*)\s+END-EXEC\.',
re.IGNORECASE | re.DOTALL
)
_BUILTIN_SQLCA = """\
01 SQLCA.
05 SQLCAID PIC X(8).
05 SQLCABC PIC S9(9) COMP.
05 SQLCODE PIC S9(9) COMP.
05 SQLERRM.
10 SQLERRML PIC S9(4) COMP.
10 SQLERRMC PIC X(70).
05 SQLERRP PIC X(8).
05 SQLERRD OCCURS 6 TIMES PIC S9(9) COMP.
05 SQLWARN.
10 SQLWARN0 PIC X.
10 SQLWARN1 PIC X.
10 SQLWARN2 PIC X.
10 SQLWARN3 PIC X.
10 SQLWARN4 PIC X.
10 SQLWARN5 PIC X.
10 SQLWARN6 PIC X.
10 SQLWARN7 PIC X.
05 SQLSTATE PIC X(5).
"""
def resolve_sql_includes(source: str, source_dir: str) -> str:
"""Resolve EXEC SQL INCLUDE name END-EXEC. like COPY. Injects built-in SQLCA if not found."""
def _resolve_one(m):
name = m.group(1).upper()
for ext in ('', '.cpy', '.CPY', '.cbl', '.CBL'):
p = Path(source_dir) / f"{name}{ext}"
if p.exists():
return p.read_text(encoding='utf-8')
if name == 'SQLCA':
return _BUILTIN_SQLCA
logger.warning(f"SQL INCLUDE {name} not found, injecting as comment")
return f" * SQL INCLUDE {name} NOT RESOLVED\n"
while True:
new_source = _RE_SQL_INC.sub(_resolve_one, source)
if new_source == source:
break
source = new_source
return source
_RE_SQL_BLOCK = re.compile(
r'EXEC\s+SQL\s+(.*?)\s+END-EXEC\.?',
re.IGNORECASE | re.DOTALL
)
_RE_DECLARE_TABLE = re.compile(
r'EXEC\s+SQL\s+DECLARE\s+(\w[\w-]*)\s+TABLE\s*\((.*?)\)\s+END-EXEC\.?',
re.IGNORECASE | re.DOTALL
)
def strip_exec_sql_from_data_div(source: str) -> tuple:
"""Strip EXEC SQL blocks from DATA DIVISION. Returns (cleaned_source, declared_columns)."""
declared_columns = {}
def _repl(m):
full = m.group(0)
dm = _RE_DECLARE_TABLE.match(full)
if dm:
table_name = dm.group(1).upper()
col_text = dm.group(2)
cols = _parse_declare_table_columns(col_text)
declared_columns[table_name] = cols
return f" *> DECLARE {table_name} TABLE ({len(cols)} cols)\n"
return " *> SKIPPED EXEC SQL\n"
cleaned = _RE_SQL_BLOCK.sub(_repl, source)
return cleaned, declared_columns
def _parse_declare_table_columns(col_text: str) -> list[dict]:
"""Parse 'CUST_ID CHAR(5) NOT NULL, BALANCE PIC 9(6)' into column list."""
cols = []
for part in re.split(r',\s*', col_text):
part = part.strip()
if not part:
continue
m = re.match(
r'(\w[\w-]*)\s+(CHAR\s*\(\s*(\d+)\s*\)'
r'|VARCHAR\s*\(\s*(\d+)\s*\)'
r'|INTEGER|SMALLINT'
r'|DECIMAL\s*\(\s*(\d+)\s*(?:,\s*(\d+))?\s*\)'
r'|DATE'
r'|PIC\s+([\w().]+))'
r'(?:\s+NOT\s+NULL|\s+NULL)?',
part, re.IGNORECASE
)
if m:
name = m.group(1).upper()
if m.group(3):
col_type = {'db_type': 'CHAR', 'size': int(m.group(3))}
elif m.group(4):
col_type = {'db_type': 'VARCHAR', 'size': int(m.group(4))}
elif m.group(2).upper() == 'INTEGER':
col_type = {'db_type': 'INTEGER'}
elif m.group(2).upper() == 'SMALLINT':
col_type = {'db_type': 'SMALLINT'}
elif m.group(5):
prec = int(m.group(5)) if m.group(5) else 0
scale = int(m.group(6)) if m.group(6) else 0
col_type = {'db_type': 'DECIMAL', 'precision': prec, 'scale': scale}
elif m.group(2).upper() == 'DATE':
col_type = {'db_type': 'DATE'}
elif m.group(7):
col_type = {'db_type': 'PIC', 'pic': m.group(7).upper()}
else:
col_type = {'db_type': 'CHAR', 'size': 1}
cols.append({'name': name, **col_type})
return cols
# 鈹€鈹€ Lark Grammar 鈹€鈹€
_GRAMMAR_CACHE = None
def _get_grammar() -> str:
global _GRAMMAR_CACHE
if _GRAMMAR_CACHE is None:
lark_path = Path(__file__).parent / 'grammar.lark'
_GRAMMAR_CACHE = lark_path.read_text(encoding='utf-8')
return _GRAMMAR_CACHE
# 鈹€鈹€ Data Transformer 鈹€鈹€
@v_args(inline=True)
class DataTransformer(Transformer):
def __init__(self):
super().__init__()
self.fields = []
self._last_parent = None
self._pending = []
def start(self, *items):
for f in self._pending:
f['section'] = f.get('section', 'WORKING-STORAGE')
self.fields.append(f)
self._pending = []
return self.fields
def file_section(self, *args):
for f in self._pending:
f['section'] = 'FILE'
self.fields.append(f)
self._pending = []
return None
def working_storage(self, *args):
for f in self._pending:
f['section'] = 'WORKING-STORAGE'
self.fields.append(f)
self._pending = []
return None
def linkage(self, *args):
for f in self._pending:
f['section'] = 'LINKAGE'
self.fields.append(f)
self._pending = []
return None
def data_item(self, level_num, name, *clauses):
level = int(str(level_num))
name = str(name)
is_filler = (name.upper() == 'FILLER')
pic = None
value = None
values = None
redefines = None
usage = None
occurs_count = 0
occurs_depending = None
for c in clauses:
if isinstance(c, dict):
if 'pic' in c:
pic = c['pic']
if 'value' in c:
value = c['value']
if 'values' in c:
values = c['values']
if 'redefines' in c:
redefines = c['redefines']
if 'usage' in c:
usage = c['usage']
if 'occurs' in c:
occurs_count = c['occurs']
if 'depends' in c:
occurs_depending = c['depends']
base = {
'level': level,
'name': name,
'pic': pic if pic else None,
'value': value,
'values': values,
'is_filler': is_filler,
'redefines': redefines,
'usage': usage,
'occurs': occurs_count,
'occurs_depending': occurs_depending,
}
if pic is not None:
self._pending.append(base)
self._last_parent = name
elif level == 88 and value is not None:
base.update({
'pic': None,
'value': value.strip("'").strip('"'),
'values': [v.strip("'").strip('"') for v in values] if values else None,
'is_88': True,
'parent': self._last_parent or '',
})
self._pending.append(base)
else:
# 组项目(无 PIC,有下级字段)
self._pending.append(base)
self._last_parent = name
return None
def clause(self, *args):
# ?????????? dict??????? token
result = {}
for a in args:
if isinstance(a, dict):
result.update(a)
elif isinstance(a, str) and a.upper() in (
'COMP', 'COMP-3', 'COMP-5', 'BINARY', 'PACKED-DECIMAL', 'DISPLAY',
):
result['usage'] = a.upper()
return result if result else None
def pic_clause(self, *args):
return {'pic': str(args[-1])}
def usage_clause(self, token):
return {'usage': str(token)}
def value_clause(self, *args):
values = []
for a in args:
if isinstance(a, str) and a.upper() in ('VALUE', 'IS'):
continue
val = str(a).strip("'").strip('"')
values.append(val)
return {'value': values[0], 'values': values} if values else {'value': None}
def value_literal(self, *args):
if args:
return str(args[-1])
return ''
def occurs_clause(self, *args):
result = {'occurs': int(args[0])}
if len(args) >= 2:
result['depends'] = str(args[1])
return result
def redefines_clause(self, *args):
return {'redefines': str(args[-1])}
def level_num(self, token):
return token
def NAME(self, token):
return str(token)
def PICTURE_STRING(self, token):
return str(token)
def INT(self, token):
return int(token)
# 鈹€鈹€ PIC Parser 鈹€鈹€
def _expand_pic(s: str) -> str:
result = ''
i = 0
while i < len(s):
if s[i] == '(':
j = s.find(')', i)
if j > i + 1:
count = int(s[i + 1:j])
if result:
result += result[-1] * (count - 1)
i = j + 1
continue
result += s[i]
i += 1
return result
def parse_pic(pic_str: str) -> PicInfo:
info = PicInfo()
s = pic_str.upper().strip()
if not s:
return info
if s.startswith('S'):
info.signed = True
s = s[1:]
expanded = _expand_pic(s)
if expanded[0] == '9':
info.type = 'numeric'
if 'V' in expanded:
parts = expanded.split('V')
info.digits = parts[0].count('9')
info.decimal = parts[1].count('9')
else:
info.digits = expanded.count('9')
info.decimal = 0
elif expanded[0] == 'X':
info.type = 'alphanumeric'
info.length = len(expanded)
elif expanded[0] == 'A':
info.type = 'alphabetic'
info.length = len(expanded)
elif expanded[0] in ('Z', '*', '$', '+', '-'):
info.type = 'numeric-edited'
info.digits = expanded.count('9')
if 'V' in expanded:
info.decimal = expanded.split('V')[1].count('9')
elif '.' in expanded:
info.decimal = expanded.split('.')[1].count('9')
info.length = len(expanded)
elif expanded.endswith('CR') or expanded.endswith('DB'):
info.type = 'numeric-edited'
stripped = expanded[:-2]
info.digits = stripped.count('9')
if 'V' in stripped:
info.decimal = stripped.split('V')[1].count('9')
elif '.' in stripped:
info.decimal = stripped.split('.')[1].count('9')
info.length = len(expanded)
else:
info.type = 'alphanumeric'
info.length = len(expanded)
return info
# 鈹€鈹€ DATA DIVISION 鍏ュ彛 鈹€鈹€
def parse_data_division(data_div_text: str) -> list[FieldDef]:
"""??DATA DIVISION???FieldDef????PIC???"""
grammar = _get_grammar()
parser = Lark(grammar, parser='earley', lexer='dynamic')
tree = parser.parse(data_div_text)
transformer = DataTransformer()
raw = transformer.transform(tree)
result = []
for r in raw:
pic = r.get('pic', '')
info = parse_pic(pic) if pic else None
f = FieldDef(
name=r['name'],
level=r['level'],
pic=pic,
pic_info=info,
is_filler=r.get('is_filler', False),
occurs_count=r.get('occurs', 0),
occurs_depending=r.get('occurs_depending'),
redefines=r.get('redefines'),
usage=r.get('usage'),
value=r.get('value'),
values=r.get('values'),
is_88=r.get('is_88', False),
parent=r.get('parent'),
section=r.get('section'),
)
result.append(f)
return result
# 鈹€鈹€ FILE-CONTROL / FILE SECTION / OPEN 瑙f瀽 鈹€鈹€
def parse_file_control(source: str) -> dict:
"""Parse FILE-CONTROL paragraph.
Returns dict:
{filename: {"assign": str, "organization": str, "recording_mode": str}}
"""
m = re.search(r'FILE-CONTROL\.(.*?)(?=DATA\s+DIVISION|\Z)', source, re.DOTALL | re.IGNORECASE)
if not m:
return {}
fc = m.group(1)
result = {}
for sel_m in re.finditer(
r'SELECT\s+(\w[\w-]*)\s+[^.]*?\bASSIGN\s+TO\s+'
r'(?:(["\'])(.*?)\2|(\w[\w-]*))'
r'[^.]*\.',
fc, re.IGNORECASE
):
name = sel_m.group(1).upper()
if sel_m.group(2):
assign_to = sel_m.group(3).upper()
else:
assign_to = sel_m.group(4).upper()
clause = sel_m.group(0)
org_m = re.search(r'ORGANIZATION\s+(LINE\s+)?SEQUENTIAL', clause, re.IGNORECASE)
if org_m and org_m.group(1):
org = 'LINE SEQUENTIAL'
elif org_m:
org = 'SEQUENTIAL'
else:
org = 'SEQUENTIAL'
result[name] = {'assign': assign_to, 'organization': org, 'recording_mode': 'F'}
# Extract RECORDING MODE from FD blocks in FILE SECTION
fd_sec_m = re.search(r'FILE\s+SECTION\.(.*?)(?=WORKING-STORAGE\s+SECTION|LINKAGE\s+SECTION|\Z)',
source, re.DOTALL | re.IGNORECASE)
if fd_sec_m:
fs = fd_sec_m.group(1)
for block in re.split(r'\n\s*(?=FD\s+)', fs.strip()):
fd_m = re.match(r'FD\s+(\w[\w-]*)', block, re.IGNORECASE)
if not fd_m:
continue
fd_name = fd_m.group(1).upper()
if fd_name in result:
rm_m = re.search(r'RECORDING\s+MODE\s+IS\s+(\w)', block, re.IGNORECASE)
if rm_m:
result[fd_name]['recording_mode'] = rm_m.group(1).upper()
return result
def parse_file_section(source: str) -> dict:
"""?? FILE SECTION??? {?????: [01?????...]}"""
m = re.search(r'FILE\s+SECTION\.(.*?)(?=WORKING-STORAGE\s+SECTION|LINKAGE\s+SECTION|\Z)',
source, re.DOTALL | re.IGNORECASE)
if not m:
return {}
fs = m.group(1)
result = {}
# FD 和 SD 条目
fd_blocks = re.split(r'\n\s*(?=(?:FD|SD)\s+)', fs.strip())
for block in fd_blocks:
m = re.match(r'(FD|SD)\s+(\w[\w-]*)', block, re.IGNORECASE)
if not m:
continue
name = m.group(2).upper()
recs = re.findall(r'^\s*0{0,1}1\s+(\w[\w-]*)', block, re.MULTILINE)
result[name] = [r.upper() for r in recs]
return result
def scan_open_statements(source: str) -> dict:
"""?? OPEN ????? {?????: 'INPUT'|'OUTPUT'|'I-O'}"""
dirs = {}
for m in re.finditer(
r'OPEN\s+((?:INPUT|OUTPUT|I-O)\s+[\w\s-]+'
r'(?:\s+(?:INPUT|OUTPUT|I-O)\s+[\w\s-]+)*)',
source, re.IGNORECASE
):
full = m.group(1)
full = re.sub(r'\s+', ' ', full)
tokens = re.split(r'\s+(?=(?:INPUT|OUTPUT|I-O)\s)', full)
for seg in tokens:
seg = seg.strip()
if not seg:
continue
seg_m = re.match(r'(INPUT|OUTPUT|I-O)\s+([\w -]+)', seg, re.IGNORECASE)
if seg_m:
direction = seg_m.group(1).upper()
for fname in re.findall(r'\w[\w-]*', seg_m.group(2)):
dirs[fname.upper()] = direction
return dirs