merge local cobol_testgen improvements into v3 shared modules

- cond.py: SQLCODE/SQLSTATE handling, alphanumeric >/< boundary fix
- output.py: termination tracking, db_input support, _is_field_assigned filter
- coverage.py: mark_from_gcov, THRU support, KeyError protection
- gcov.py: new file (dependency for coverage.py)
- grammar.lark: multi-segment PIC support
- read.py: SQL INCLUDE resolution, DECLARE TABLE parsing, * comment fix
- core.py: SQL parsing, blocked_names, keyword list
- design.py: multi-sentinel, THRU ranges, PERFORM VARYING last iteration
- __init__.py: local main() + v3 API functions, guarded imports

All 6 ZAN programs verified passing through v3 pipeline
This commit is contained in:
hangshuo652
2026-06-23 22:38:17 +08:00
parent e5ab3baa46
commit 7fb9304212
9 changed files with 1595 additions and 326 deletions
+318 -76
View File
@@ -15,16 +15,29 @@ _COBOL_SCOPE_ENDERS = {
'END-SEARCH',
'ELSE', 'WHEN', 'OTHER',
}
_COBOL_KEYWORDS = {
'GOBACK', 'EXIT', 'STOP', 'CONTINUE',
'ACCEPT', 'DISPLAY', 'MOVE', 'COMPUTE', 'INITIALIZE',
'ADD', 'SUBTRACT', 'MULTIPLY', 'DIVIDE',
'STRING', 'UNSTRING', 'SET', 'INSPECT',
'OPEN', 'CLOSE', 'READ', 'WRITE', 'REWRITE', 'DELETE', 'START',
'PERFORM', 'CALL', 'IF', 'EVALUATE', 'SEARCH', 'SORT', 'MERGE',
'COMMIT', 'ROLLBACK', 'GO',
}
def scan_paragraphs(raw_lines):
def scan_paragraphs(raw_lines, blocked_names=None):
paragraphs = {}
i = 0
blocked = set()
if blocked_names:
for n in blocked_names:
blocked.add(n.upper())
while i < len(raw_lines):
line = raw_lines[i].strip()
m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', line)
sec_m = re.match(r'^([A-Z][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE)
if m and m.group(1) not in _COBOL_SCOPE_ENDERS:
sec_m = re.match(r'^([A-Z0-9][A-Z0-9-]*)\s+SECTION\.?\s*$', line, re.IGNORECASE)
if m and m.group(1) not in _COBOL_SCOPE_ENDERS and m.group(1) not in _COBOL_KEYWORDS and m.group(1) not in blocked:
name = m.group(1)
elif sec_m:
name = sec_m.group(1).upper()
@@ -36,9 +49,9 @@ def scan_paragraphs(raw_lines):
while j < len(raw_lines):
nline = raw_lines[j].strip()
nm = re.match(r'^([A-Z0-9][A-Z0-9-]*)\.\s*$', nline)
if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS:
if nm and nm.group(1) not in _COBOL_SCOPE_ENDERS and nm.group(1) not in _COBOL_KEYWORDS and nm.group(1) not in blocked:
break
if re.match(r'^[A-Z][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE):
if re.match(r'^[A-Z0-9][A-Z0-9-]*\s+SECTION\.\s*$', nline, re.IGNORECASE):
break
j += 1
paragraphs[name] = (start, j - 1)
@@ -46,9 +59,47 @@ def scan_paragraphs(raw_lines):
return paragraphs
def build_branch_tree(proc_text, fields=None):
def sql_register_virtual_fields(fields_dict: list[dict]) -> list[dict]:
"""Inject SQLCODE, SQLSTATE as virtual fields if not already present."""
virtual = []
if not any(f['name'] == 'SQLCODE' for f in fields_dict):
virtual.append({
'name': 'SQLCODE',
'level': 77, 'pic': 'S9(9)',
'pic_info': {'type': 'numeric', 'digits': 9, 'decimal': 0,
'length': 4, 'signed': True},
'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None,
'usage': 'COMP', 'occurs': 0, 'occurs_depending': None,
'value': None, 'values': None,
})
if not any(f['name'] == 'SQLSTATE' for f in fields_dict):
virtual.append({
'name': 'SQLSTATE',
'level': 77, 'pic': 'X(5)',
'pic_info': {'type': 'alphanumeric', 'length': 5},
'section': 'WORKING-STORAGE', 'is_filler': False, 'redefines': None,
'usage': 'DISPLAY', 'occurs': 0, 'occurs_depending': None,
'value': None, 'values': None,
})
fields_dict.extend(virtual)
return fields_dict
def build_branch_tree(proc_text, fields=None, full_source=None):
raw_lines = proc_text.split('\n')
paragraphs = scan_paragraphs(raw_lines)
# Collect data names (FD names, record names, field names) to block paragraph detection
blocked_names = set()
if fields:
for f in fields:
if isinstance(f, dict):
blocked_names.add(f['name'].upper())
else:
blocked_names.add(f.name.upper())
# Extract FD names from full source if available (includes DATA DIVISION)
src = full_source or proc_text
for m in re.finditer(r'\bFD\s+(\w[\w-]*)\b', src, re.IGNORECASE):
blocked_names.add(m.group(1).upper())
paragraphs = scan_paragraphs(raw_lines, blocked_names=blocked_names)
first_para_name = None
first_para_idx = None
@@ -169,6 +220,13 @@ class _BrParser:
if m_search:
seq.add(self._parse_search(m_search))
continue
m_exec = re.match(r'^EXEC\s+SQL\s*$', line, re.IGNORECASE)
if m_exec:
sql_block = self._parse_sql_block()
assign_node = self._parse_sql(sql_block)
if assign_node:
seq.add(assign_node)
continue
m = re.match(r'^INITIALIZE\s+', line)
if m:
init_seq = self._parse_initialize()
@@ -192,7 +250,7 @@ class _BrParser:
seq.add(self._parse_call())
continue
m = re.match(
r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS))?\s*$',
r'^ACCEPT\s+(\w[\w-]*)(?:\s+FROM\s+(DATE|TIME|DAY|DAY-OF-WEEK|YEAR|YYYYMMDD|HHMMSS|SYSIN|COMMAND-LINE|SYSERR|SYSOUT|ENVIRONMENT-NAME|ENVIRONMENT-VALUE))?\s*$',
line, re.IGNORECASE
)
if m:
@@ -211,21 +269,11 @@ class _BrParser:
seq.add(Assign(tgt, info))
self.advance()
# 跳过 READ 语句剩余行(AT END / NOT AT END / END-READ
# 遇到新的语句关键词时停止,避免贪婪吞咽后续内容
_stmt_boundary = re.compile(
r'^(IF |EVALUATE |PERFORM |SEARCH |INITIALIZE |STRING |'
r'UNSTRING |CALL |ACCEPT |READ |WRITE |REWRITE |SET |'
r'INSPECT |MOVE |COMPUTE |ADD |SUBTRACT |MULTIPLY |DIVIDE |'
r'GO\s+TO |GOBACK |STOP\s+RUN|EXIT\s|CLOSE |OPEN |DISPLAY |'
r'DELETE |START |'
r'END-IF|END-PERFORM|END-EVALUATE|END-READ)', re.IGNORECASE)
while self.pos < len(self.lines):
cl = self.clean()
if cl in ('END-READ', 'END-READ.'):
self.advance()
break
if _stmt_boundary.match(cl):
break
self.advance()
continue
m_set_false = re.match(r'^SET\s+(\w[\w-]*)\s+TO\s+FALSE\s*$', line, re.IGNORECASE)
@@ -366,11 +414,34 @@ class _BrParser:
else:
tgt_key = tgt_base
src_clean = raw_src.strip("'").strip('"')
is_field_name = self.fields and any(f['name'] == src_clean for f in self.fields)
if is_field_name:
info = {'type': 'move', 'source_vars': [src_clean]}
# 检测引用修饰 FIELD(start:length)
rm = re.match(r'^(\w[\w-]*)\(\s*(\d+)\s*:\s*(\d+)\s*\)$', src_clean, re.IGNORECASE)
if rm:
base_src = rm.group(1)
refmod_start = int(rm.group(2))
refmod_length = int(rm.group(3))
is_field_name = self.fields and any(
(f['name'] if isinstance(f, dict) else f.name) == base_src
for f in self.fields
)
if is_field_name:
info = {
'type': 'move',
'source_vars': [base_src],
'refmod_start': refmod_start,
'refmod_length': refmod_length,
}
else:
info = {'type': 'move_literal', 'literal': src_clean}
else:
info = {'type': 'move_literal', 'literal': src_clean}
is_field_name = self.fields and any(
(f['name'] if isinstance(f, dict) else f.name) == src_clean
for f in self.fields
)
if is_field_name:
info = {'type': 'move', 'source_vars': [src_clean]}
else:
info = {'type': 'move_literal', 'literal': src_clean}
self.assignments.setdefault(tgt_key, []).append(info)
return Assign(tgt_key, info)
@@ -648,40 +719,11 @@ class _BrParser:
line = self.clean()
m = re.match(r'^IF\s+(.+?)(?:THEN)?\s*$', line)
cond_text = m.group(1).strip()
# Truncate at COBOL statement keywords (single-line IF body after condition)
_stmt_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|'
r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|'
r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH)\b')
_stmt_starts = re.compile(_stmt_pat, re.IGNORECASE)
rest = "" # remaining text after condition truncation (single-line IF body)
sm = _stmt_starts.search(cond_text)
if sm:
rest = cond_text[sm.start():]
cond_text = cond_text[:sm.start()]
self.advance()
if rest:
rest = rest.strip()
if rest.endswith('.'):
rest = rest[:-1]
# Split on ELSE but keep ELSE as its own line for parse_seq boundary
else_parts = re.split(r'(\s+ELSE\s+)', rest, maxsplit=1, flags=re.IGNORECASE)
parts = [p.strip() for p in else_parts if p.strip()]
insert_parts = []
for p in parts:
if p.upper() == 'ELSE':
insert_parts.append('ELSE')
else:
insert_parts.append(p if '.' in p else p + '.')
for part in reversed(insert_parts):
self.lines.insert(self.pos, part)
# Join continuation lines (multi-line IF conditions)
_cont_keywords = (r'THEN|ELSE|END-IF|MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|'
r'DIVIDE|STRING|UNSTRING|INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|'
r'READ|WRITE|REWRITE|DELETE|START|INSPECT|SET|IF|GO\b|EXIT\b|'
r'STOP\s+RUN|GOBACK|CLOSE|OPEN|SEARCH')
while self.pos < len(self.lines):
peek = self.clean()
if re.match(r'^(' + _cont_keywords + r')', peek, re.IGNORECASE):
if re.match(r'^(THEN|ELSE|END-IF|EXEC|MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT)', peek, re.IGNORECASE):
break
if peek.endswith('.'):
cond_text += ' ' + peek.rstrip('.')
@@ -697,16 +739,8 @@ class _BrParser:
node = BrIf(cond_text)
node.cond_tree = parse_compound_condition(node.condition, self.fields)
node.true_seq = self.parse_seq(['ELSE', 'END-IF'])
clean = self.clean()
if clean.startswith('ELSE'):
self.advance() # consume ELSE keyword
rest = clean[4:].strip() if len(clean) > 4 else ''
# ELSE IF → reinsert IF statement as next line for recursive parse
if rest.upper().startswith('IF '):
self.lines.insert(self.pos, rest)
elif rest:
# Regular ELSE body text on same line as ELSE: reinsert
self.lines.insert(self.pos, rest if '.' in rest else rest + '.')
if self.clean() == 'ELSE':
self.advance()
node.false_seq = self.parse_seq(['END-IF'])
if self.clean() == 'END-IF':
self.advance()
@@ -728,13 +762,6 @@ class _BrParser:
m = re.match(r'^WHEN\s+(.+?)\s*$', line)
if m:
raw_val = m.group(1).strip().strip("'").strip('"')
# Truncate at COBOL statement keywords (single-line WHEN body after condition)
_eval_pat = (r'\s(?:MOVE|DISPLAY|COMPUTE|ADD|SUBTRACT|MULTIPLY|DIVIDE|STRING|UNSTRING|'
r'INITIALIZE|ACCEPT|CALL|PERFORM|EVALUATE|READ|WRITE|REWRITE|DELETE|START|'
r'INSPECT|SET|IF|ELSE|END-IF|GO\b|EXIT\b|STOP\b|GOBACK|CLOSE|OPEN|SEARCH)\b')
_eval_stmt = re.search(_eval_pat, raw_val, re.IGNORECASE)
if _eval_stmt:
raw_val = raw_val[:_eval_stmt.start()]
self.advance()
# Capture multi-line WHEN conditions (AND/OR continuation)
while self.pos < len(self.lines):
@@ -848,6 +875,14 @@ class _BrParser:
if um:
condition = um.group(1).strip()
self.advance()
# Join continuation lines (AND/OR on next lines)
while self.pos < len(self.lines):
peek = self.clean()
if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE):
condition += ' ' + peek
self.advance()
else:
break
break
break
if from_val and by_val and condition:
@@ -894,6 +929,30 @@ class _BrParser:
m = re.match(r'^PERFORM\s+(\w[\w-]*)\s*$', line)
if m:
target = m.group(1).strip()
save_pos = self.pos
condition = None
self.advance()
while self.pos < len(self.lines):
nxt = self.clean()
um = re.match(r'^UNTIL\s+(.+)$', nxt)
if um:
condition = um.group(1).strip()
self.advance()
# Join continuation lines (AND/OR on next lines)
while self.pos < len(self.lines):
peek = self.clean()
if re.match(r'^(AND|OR)\s', peek, re.IGNORECASE):
condition += ' ' + peek
self.advance()
else:
break
break
break
if condition:
node = BrPerform('para_until', target=target, condition=condition)
self._inline_perform(node, target)
return node
self.pos = save_pos
node = BrPerform('para', target=target)
self.advance()
self._inline_perform(node, target)
@@ -962,12 +1021,18 @@ class _BrParser:
parts = [self.clean()]
self.advance()
while self.pos < len(self.lines):
peek = self.peek()
cl = self.clean()
if cl == 'END-STRING':
self.advance()
break
# Stop when a new COBOL statement keyword is encountered
if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-STRING)', peek, re.IGNORECASE):
break
parts.append(cl)
self.advance()
if peek.rstrip().endswith('.'):
break
full = ' '.join(parts)
m = re.match(r'^STRING\s+(.+)\s+INTO\s+(\w[\w-]*)\s*$', full, re.IGNORECASE | re.DOTALL)
if not m:
@@ -985,12 +1050,17 @@ class _BrParser:
parts = [self.clean()]
self.advance()
while self.pos < len(self.lines):
peek = self.peek()
cl = self.clean()
if cl == 'END-UNSTRING':
self.advance()
break
if re.match(r'^(MOVE|IF|PERFORM|EVALUATE|COMPUTE|CALL|STRING|UNSTRING|INITIALIZE|ADD|SUBTRACT|MULTIPLY|DIVIDE|GO\b|EXIT\b|DISPLAY|ACCEPT|STOP|READ|WRITE|REWRITE|DELETE|SET|SEARCH|OPEN|CLOSE|INSPECT|CONTINUE|GOBACK|COMMIT|ROLLBACK|MERGE|SORT|ELSE|END-IF|END-EVALUATE|END-PERFORM|END-READ|END-WRITE|END-UNSTRING)', peek, re.IGNORECASE):
break
parts.append(cl)
self.advance()
if peek.rstrip().endswith('.'):
break
full = ' '.join(parts)
m = re.match(r'^UNSTRING\s+(.+?)\s+INTO\s+(.+?)\s*$', full, re.IGNORECASE | re.DOTALL)
if not m:
@@ -1088,6 +1158,75 @@ class _BrParser:
self.advance()
return Assign(tgt, info)
# ── EXEC SQL parsing ──
_RE_SELECT_INTO = re.compile(
r'SELECT\s+(.*?)\s+INTO\s+(:\w[\w-]*(?:\s*,\s*:\w[\w-]*(?::\w[\w-]*)?)*)'
r'\s+FROM\s+(\w[\w-]*)',
re.IGNORECASE
)
_RE_WHERE = re.compile(r'\bWHERE\b\s+(.*)', re.IGNORECASE)
def _parse_sql_block(self) -> str:
"""Consume lines from EXEC SQL until END-EXEC. Returns SQL text."""
texts = []
self.advance()
while self.pos < len(self.lines):
line = self.lines[self.pos].rstrip('.')
m = re.match(r'(.*?)END-EXEC\.?\s*$', line, re.IGNORECASE)
if m:
before = m.group(1).strip()
if before:
texts.append(before)
self.advance()
break
texts.append(line)
self.advance()
result = ' '.join(texts)
result = re.sub(r'\s+', ' ', result)
return result
def _parse_sql(self, sql_text: str):
"""Parse SQL text from EXEC SQL block. Returns Assign node or None."""
m = self._RE_SELECT_INTO.search(sql_text)
if not m:
return None
select_list = m.group(1).strip()
into_raw = m.group(2).strip()
from_table = m.group(3).strip().upper()
remaining = sql_text[m.end():].strip()
# Parse INTO variables (handle indicator vars: :host:indicator)
into_vars = []
for v in re.split(r'\s*,\s*', into_raw):
v = v.strip().lstrip(':')
parts = v.split(':')
into_vars.append(parts[0].upper())
if len(parts) > 1:
into_vars.append(parts[1].upper())
# Extract WHERE clause
where_clause = ''
wm = self._RE_WHERE.search(remaining)
if wm:
where_clause = wm.group(1).strip()
info = {
'type': 'exec_sql_select',
'table': from_table,
'select_list': select_list,
'into_vars': into_vars,
'where': where_clause,
'sql_text': sql_text,
}
for var in into_vars:
self.assignments.setdefault(var, []).append(info)
return Assign(into_vars[0], info)
# ── 工具函数 ──
@@ -1141,8 +1280,6 @@ def trace_to_root(field_name, assignments, fields, path_assign=None):
asgn = asgn_list
else:
asgn_list = assignments[var]
if not asgn_list:
break
asgn = asgn_list[-1]
if isinstance(asgn_list, list):
for a in reversed(asgn_list):
@@ -1152,6 +1289,8 @@ def trace_to_root(field_name, assignments, fields, path_assign=None):
asgn = a
break
chain.append((var, asgn))
if asgn.get('type') in ('unstring_split',):
break
if not asgn.get('source_vars'):
break
sv = asgn['source_vars']
@@ -1332,8 +1471,36 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
src = asgn['source_vars'][0]
resolved_tgt = _resolve_subscript(tgt, rec)
resolved_src = _resolve_subscript(src, rec)
if resolved_src in rec:
rec[resolved_tgt] = rec[resolved_src]
tgt_children = _init_child_names(resolved_tgt, fields)
if tgt_children:
# Group MOVE: propagate to child fields by position
src_children = _init_child_names(resolved_src, fields)
if src_children:
src_str = ''.join(str(rec.get(c, '')) for c in src_children)
elif resolved_src in rec:
src_str = str(rec[resolved_src])
else:
src_str = ''
if src_str:
rec[resolved_tgt] = src_str
pos = 0
for tgt_c in tgt_children:
child_len = 0
for f in fields:
if f['name'] == tgt_c:
pi = f.get('pic_info', {})
child_len = pi.get('digits', 0) + pi.get('decimal', 0) or pi.get('length', 0)
break
if child_len > 0:
rec[tgt_c] = src_str[pos:pos + child_len] if pos < len(src_str) else ('0' if child_len else '')
pos += child_len
elif resolved_src in rec:
src_val = str(rec[resolved_src])
if asgn.get('refmod_start') and asgn.get('refmod_length'):
start = asgn['refmod_start'] - 1
end = start + asgn['refmod_length']
src_val = src_val[start:end]
rec[resolved_tgt] = src_val
# Pass 2: literal MOVE
for tgt, asgn in flat_list:
@@ -1439,9 +1606,7 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
resolved_tgt = _resolve_subscript(tgt, rec)
if resolved_tgt not in rec:
continue
inspect_src = asgn.get('tgt', tgt)
resolved_src = _resolve_subscript(inspect_src, rec)
src_val = str(rec.get(resolved_src, ''))
src_val = str(rec[resolved_tgt])
for op_type, params in asgn.get('sub_ops', []):
if op_type == 'tally':
cv = params['count_var'].upper()
@@ -1495,6 +1660,10 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
src_var = asgn.get('source_vars', [None])[0]
resolved_src = _resolve_subscript(src_var, rec) if src_var else None
idx = asgn.get('index', 0)
if resolved_src and resolved_src not in rec:
children = _init_child_names(resolved_src, fields)
if children:
resolved_src = children[0]
if resolved_src and resolved_src in rec:
src_val = str(rec[resolved_src])
ftype = pi.get('type', 'unknown')
@@ -1556,6 +1725,23 @@ def propagate_assignments(rec, assignments, fields, file_sec=None):
else:
rec[resolved_tgt] = val.ljust(length)[:length] if length else val
# Pass 9: EXEC SQL SELECT INTO
for tgt, asgn in flat_list:
if asgn.get('type') == 'exec_sql_select':
resolved_tgt = _resolve_subscript(tgt, rec)
if resolved_tgt not in rec:
continue
src_val = rec.get(resolved_tgt, '')
pi = pi_map.get(resolved_tgt, {})
if pi.get('type') == 'numeric':
total = pi.get('digits', 0) + pi.get('decimal', 0)
if total > 0:
rec[resolved_tgt] = str(src_val).zfill(total)
elif pi.get('type') in ('alphanumeric', 'alphabetic'):
length = pi.get('length', 0)
if length > 0:
rec[resolved_tgt] = str(src_val).ljust(length)[:length]
# Pass 8: SET var TO TRUE (88-level)
for tgt, asgn in flat_list:
if asgn['type'] == 'set_true':
@@ -1649,6 +1835,13 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None)
elif atype == 'write_from':
if tgt_base in counts:
counts[tgt_base]['read'] += 1
elif atype == 'exec_sql_select':
if tgt_base in counts:
counts[tgt_base]['write'] += 1
for v in node.source_info.get('into_vars', []):
v_base = _basename(v)
if v_base in counts:
counts[v_base]['write'] += 1
elif atype == 'set_true':
if tgt_base in counts:
counts[tgt_base]['write'] += 1
@@ -1705,3 +1898,52 @@ def classify_field_roles(tree, assignments, fields, source=None, proc_text=None)
if name not in result:
result[name] = role
return result
# ── 多 WRITE 检测 ──
def _collect_write_fds(node, fds_set, field_to_fd):
"""Recursively collect output FD names from WRITE Assigns."""
if isinstance(node, Assign):
st = node.source_info.get('type', '')
if st in ('write_bare', 'write_from'):
fname = node.target
if fname in field_to_fd:
fds_set.add(field_to_fd[fname])
elif isinstance(node, BrSeq):
for c in node.children:
_collect_write_fds(c, fds_set, field_to_fd)
elif isinstance(node, BrIf):
_collect_write_fds(node.true_seq, fds_set, field_to_fd)
_collect_write_fds(node.false_seq, fds_set, field_to_fd)
elif isinstance(node, BrEval):
for _, seq in node.when_list:
_collect_write_fds(seq, fds_set, field_to_fd)
_collect_write_fds(node.other_seq, fds_set, field_to_fd)
elif isinstance(node, BrPerform):
_collect_write_fds(node.body_seq, fds_set, field_to_fd)
elif isinstance(node, BrSearch):
_collect_write_fds(node.at_end_seq, fds_set, field_to_fd)
for _, seq in node.when_list:
_collect_write_fds(seq, fds_set, field_to_fd)
def _find_multi_write_fds(tree, field_to_fd):
"""返回在 INIT 段(主循环前)和循环内部都有 WRITE 的 FD 名集合。
主循环 = 顶层 BrSeq 中最后一个 UNTIL 型 BrPerform(包含 para_until)。
"""
if not isinstance(tree, BrSeq):
return set()
main_loop_idx = -1
for i, child in enumerate(tree.children):
if isinstance(child, BrPerform) and child.perf_type in ('until', 'para_until', 'varying', 'para_varying'):
main_loop_idx = i
if main_loop_idx < 0:
return set()
pre_write = set()
for child in tree.children[:main_loop_idx]:
_collect_write_fds(child, pre_write, field_to_fd)
loop_write = set()
_collect_write_fds(tree.children[main_loop_idx], loop_write, field_to_fd)
return pre_write & loop_write