commit 818e81269c2820d39714862bc1a3c17960f3b0f3 Author: hangshuo652 Date: Sun May 24 12:36:44 2026 +0800 v3: gstack-code-gen 生成 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fad9c09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +__pycache__/ +*.pyc +*.egg-info/ +dist/ +build/ +.cache/ +reports/ +test-data-bundle/ +*.exec +target/ +.DS_Store diff --git a/agents/__init__.py b/agents/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agents/agent1_parser.py b/agents/agent1_parser.py new file mode 100644 index 0000000..528d326 --- /dev/null +++ b/agents/agent1_parser.py @@ -0,0 +1,34 @@ +import json +from data.field_tree import FieldTree, Field +from agents.llm import LLMClient + +P1 = "You are a COBOL COPYBOOK parser. Output JSON: {\"fields\":[{\"name\":\"...\",\"level\":N,\"pic\":\"...\",\"usage\":\"DISPLAY|COMP-3|COMP\",\"offset\":N,\"length\":N,\"decimal\":N,\"signed\":bool,\"occurs\":N|null,\"redefines\":\"...\"|null,\"conditions\":[{\"name\":\"...\",\"value\":\"...\"}],\"children\":[...]}]} Return JSON only." + + +class Agent1Parser: + def __init__(self, llm: LLMClient): + self.llm = llm + + def parse(self, text: str) -> FieldTree: + r = self.llm.call([{"role": "system", "content": P1}, {"role": "user", "content": text}]) + try: + return self._load(json.loads(r)) + except: + return FieldTree(copybook_name="parse_error") + + def _load(self, d): + return FieldTree(fields=self._fields(d.get("fields", []), 0)) + + def _fields(self, raw, off): + result = [] + cur = off + for rf in raw: + f = Field(name=rf.get("name", ""), level=rf.get("level", 0), pic=rf.get("pic", ""), + usage=rf.get("usage", "DISPLAY"), offset=cur, length=rf.get("length", 0), + decimal=rf.get("decimal", 0), signed=rf.get("signed", False), + occurs=rf.get("occurs"), redefines=rf.get("redefines"), + conditions=rf.get("conditions", [])) + f.children = self._fields(rf.get("children", []), cur) + cur += f.length + result.append(f) + return result diff --git a/agents/agent2_data.py b/agents/agent2_data.py new file mode 100644 index 0000000..fd1d73f --- /dev/null +++ b/agents/agent2_data.py @@ -0,0 +1,24 @@ +import json +from data.field_tree import FieldTree +from data.test_case import TestCase, TestSuite, SparkConfig +from agents.llm import LLMClient + +P2 = "You are a COBOL test data designer. Given a FieldTree, generate boundary test cases. Output: {\"test_cases\":[{\"id\":\"TC-001\",\"fields\":{\"FIELD\":value},\"coverage_targets\":[\"DP-001\"]}]} JSON only." + + +class Agent2Data: + def __init__(self, llm: LLMClient): + self.llm = llm + + def design(self, tree: FieldTree, target="boundary", spark_mode=False) -> TestSuite: + tree_d = {"fields": [{"name": f.name, "pic": f.pic, "usage": f.usage, "length": f.length, + "decimal": f.decimal, "signed": f.signed} for f in tree.flatten().values()]} + r = self.llm.call([{"role": "system", "content": P2}, {"role": "user", "content": json.dumps(tree_d)}]) + try: + tcs = [TestCase(**tc) for tc in json.loads(r).get("test_cases", [])] + except: + tcs = [TestCase(id="TC-FALLBACK", fields={"BR-AMT": 0})] + s = TestSuite(test_cases=tcs) + if spark_mode: + s.spark_config = SparkConfig(num_records=1000) + return s diff --git a/agents/agent3_diagnostic.py b/agents/agent3_diagnostic.py new file mode 100644 index 0000000..528d9bb --- /dev/null +++ b/agents/agent3_diagnostic.py @@ -0,0 +1,13 @@ +from agents.llm import LLMClient +from data.diff_result import FieldResult + +P3 = "You are a COBOL-Java diff analyzer. Given a field mismatch, explain why. Output: {\"issue_type\":\"...\",\"confidence\":0.5,\"reason\":\"...\",\"suggestion\":\"...\"} You NEVER decide PASS/FAIL. JSON only." + + +class Agent3Diagnostic: + def __init__(self, llm: LLMClient): + self.llm = llm + + def analyze(self, fr: FieldResult) -> str: + p = f"Field: {fr.field_name}\nCOBOL: {fr.cobol_value}\nJava: {fr.java_value}\nStatus: {fr.status}" + return self.llm.call([{"role": "system", "content": P3}, {"role": "user", "content": p}]) diff --git a/agents/llm.py b/agents/llm.py new file mode 100644 index 0000000..49e0e52 --- /dev/null +++ b/agents/llm.py @@ -0,0 +1,41 @@ +import json, hashlib, os +from pathlib import Path +import httpx + + +class LLMClient: + def __init__(self, model="gpt-4o-mini", timeout=15, cache_dir=".cache/llm"): + self.model = model + self.timeout = timeout + self.dir = Path(cache_dir) + self.dir.mkdir(parents=True, exist_ok=True) + + def _key(self, msgs): + return hashlib.sha256(json.dumps(msgs, sort_keys=True).encode()).hexdigest() + + def _get(self, k): + p = self.dir / f"{k}.json" + return json.loads(p.read_text())["response"] if p.exists() else None + + def _set(self, k, v): + (self.dir / f"{k}.json").write_text(json.dumps({"response": v})) + + def call(self, messages, retries=1): + k = self._key(messages) + c = self._get(k) + if c: + return c + key = os.environ.get("LLM_API_KEY", os.environ.get("OPENAI_API_KEY", "")) + base = os.environ.get("LLM_API_BASE", "https://api.openai.com/v1") + for a in range(retries + 1): + try: + r = httpx.post(f"{base}/chat/completions", json={"model": self.model, "messages": messages}, + headers={"Authorization": f"Bearer {key}"}, timeout=self.timeout) + r.raise_for_status() + v = r.json()["choices"][0]["message"]["content"] + self._set(k, v) + return v + except Exception: + if a == retries: + raise + return "" diff --git a/aurak.toml b/aurak.toml new file mode 100644 index 0000000..5ea6d9c --- /dev/null +++ b/aurak.toml @@ -0,0 +1,24 @@ +[project] +name = "example" +copybook_paths = ["./copybooks", "/usr/share/copybooks"] +dialect = "ibm" + +[llm] +model = "gpt-4o-mini" +timeout = 15 +cache_dir = ".cache/llm" + +[coverage] +default_target = "boundary" + +[comparison] +rounding_mode = "TRUNCATE" +default_tolerance = 0.01 + +[runner] +mode = "native" + +[spark] +master = "local[*]" +input_format = "json" +num_records = 1000 diff --git a/comparator/__init__.py b/comparator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/comparator/aligner.py b/comparator/aligner.py new file mode 100644 index 0000000..58ad163 --- /dev/null +++ b/comparator/aligner.py @@ -0,0 +1,30 @@ +def align_records(cobol_records: list[dict], java_records: list[dict], + key_field: str = "CUST-ID") -> list[tuple]: + if not cobol_records and not java_records: + return [] + + def _by(records, kf): + d = {} + for r in records: + key = str(r.get(kf, "__NONE__")) + d.setdefault(key, []).append(r) + return d + + c_by = _by(cobol_records, key_field) + j_by = _by(java_records, key_field) + pairs = [] + all_keys = set(c_by) | set(j_by) + + for k in sorted(all_keys): + c_items = c_by.get(k, []) + j_items = j_by.get(k, []) + for i in range(max(len(c_items), len(j_items))): + c = c_items[i] if i < len(c_items) else None + j = j_items[i] if i < len(j_items) else None + if c and j: + pairs.append((c, j, "MATCHED")) + elif c: + pairs.append((c, None, "MISSING_IN_SPARK")) + else: + pairs.append((None, j, "EXTRA_IN_SPARK")) + return pairs diff --git a/comparator/cobol_binary_reader.py b/comparator/cobol_binary_reader.py new file mode 100644 index 0000000..708683c --- /dev/null +++ b/comparator/cobol_binary_reader.py @@ -0,0 +1,43 @@ +import struct +from pathlib import Path +from data.field_tree import FieldTree + + +class CobolBinaryReader: + def read(self, path: str, tree: FieldTree) -> list[dict]: + d = Path(path).read_bytes() + rs = self._record_size(tree) + if rs == 0 or len(d) == 0: + return [] + return [self._parse(d[o:o + rs], tree) for o in range(0, len(d), rs) if len(d[o:o + rs]) >= rs] + + def _record_size(self, tree): + return max((f.offset + f.length for f in tree.fields), default=0) + + def _parse(self, r, tree): + out = {} + for n, f in tree.flatten().items(): + if f.length == 0 or f.offset + f.length > len(r): + continue + raw = r[f.offset:f.offset + f.length] + if f.usage == "COMP-3": + out[n] = self._comp3(raw, f.signed, f.decimal) + elif f.usage in ("COMP", "COMP-5"): + out[n] = int.from_bytes(raw, "big", signed=f.signed) + else: + out[n] = raw.decode("ascii", errors="replace").strip() + return out + + def _comp3(self, raw, signed, dec): + if not raw: + return "0" + nib = [] + for b in raw: + nib.append((b >> 4) & 0xF) + nib.append(b & 0xF) + s = nib.pop() + v = sum(n * (10 ** (len(nib) - i)) for i, n in zip(range(len(nib)), nib)) + if signed and s in (0xD, 0xB): + v = -v + d = 10 ** dec + return f"{float(v) / d:.{dec}f}" if dec else str(v) diff --git a/comparator/field_compare.py b/comparator/field_compare.py new file mode 100644 index 0000000..1af354c --- /dev/null +++ b/comparator/field_compare.py @@ -0,0 +1,64 @@ +from data.diff_result import FieldResult +from decimal import Decimal, InvalidOperation + +DEFAULT_TOLERANCE = 0.01 + + +def compare_field(name: str, c: str, j: str, field_type: str = "decimal", + tolerance: float = DEFAULT_TOLERANCE) -> FieldResult: + fr = FieldResult(field_name=name, cobol_value=c, java_value=j) + + if field_type in ("decimal", "numeric"): + return _numeric(fr, c, j, tolerance) + if field_type == "date": + return _date(fr, c, j) + if field_type == "string": + return _string(fr, c, j) + fr.status = "PASS" if c == j else "MISMATCH" + return fr + + +def _numeric(fr, c, j, tol): + cv = _num(c) + jv = _num(j) + if cv is None or jv is None: + fr.status = "NOT_SET" if cv is None and jv is None else ( + "MISMATCH" if jv is None else "NOT_SET") + return fr + if cv == jv: + fr.status = "PASS" + return fr + diff = abs(float(cv - jv)) + if diff <= tol: + fr.status = "TOLERATED" + fr.tolerance_applied = tol + else: + fr.status = "MISMATCH" + return fr + + +def _date(fr, c, j): + def _norm(v): + v = v.strip() + if len(v) == 8 and v.isdigit(): + return f"{v[:4]}-{v[4:6]}-{v[6:8]}" + return v + fr.status = "PASS" if _norm(c) == _norm(j) else "MISMATCH" + return fr + + +def _string(fr, c, j): + fr.status = "PASS" if (c or "").strip() == (j or "").strip() else "MISMATCH" + return fr + + +def _num(v): + if v is None or v == "None": + return None + s = str(v).replace("\x00", "").strip() + if s == "": + return Decimal("0") + try: + return Decimal(s) + except InvalidOperation: + return None diff --git a/comparator/normalizer.py b/comparator/normalizer.py new file mode 100644 index 0000000..48d77a1 --- /dev/null +++ b/comparator/normalizer.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass + +EBCDIC_037 = { + 0x40: ' ', 0x4B: '.', 0x4C: '<', 0x4D: '(', 0x4E: '+', 0x5A: '!', 0x5B: '$', + 0x5C: '*', 0x5D: ')', 0x5E: ';', 0x60: '-', 0x61: '/', 0x6B: ',', 0x6C: '%', + 0x6D: '_', 0x6E: '>', 0x6F: '?', 0x7A: ':', 0x7B: '#', 0x7C: '@', 0x7D: "'", + 0x7E: '=', 0x7F: '"', + 0x81: 'a', 0x82: 'b', 0x83: 'c', 0x84: 'd', 0x85: 'e', 0x86: 'f', 0x87: 'g', + 0x88: 'h', 0x89: 'i', 0x91: 'j', 0x92: 'k', 0x93: 'l', 0x94: 'm', 0x95: 'n', + 0x96: 'o', 0x97: 'p', 0x98: 'q', 0x99: 'r', 0xA2: 's', 0xA3: 't', 0xA4: 'u', + 0xA5: 'v', 0xA6: 'w', 0xA7: 'x', 0xA8: 'y', 0xA9: 'z', + 0xC1: 'A', 0xC2: 'B', 0xC3: 'C', 0xC4: 'D', 0xC5: 'E', 0xC6: 'F', 0xC7: 'G', + 0xC8: 'H', 0xC9: 'I', 0xD1: 'J', 0xD2: 'K', 0xD3: 'L', 0xD4: 'M', 0xD5: 'N', + 0xD6: 'O', 0xD7: 'P', 0xD8: 'Q', 0xD9: 'R', 0xE2: 'S', 0xE3: 'T', 0xE4: 'U', + 0xE5: 'V', 0xE6: 'W', 0xE7: 'X', 0xE8: 'Y', 0xE9: 'Z', + 0xF0: '0', 0xF1: '1', 0xF2: '2', 0xF3: '3', 0xF4: '4', 0xF5: '5', + 0xF6: '6', 0xF7: '7', 0xF8: '8', 0xF9: '9', +} + + +@dataclass +class CobolIRField: + raw_hex: str; decoded_value: str; encoding: str + field_type: str; length: int; scale: int; signed: bool + + +@dataclass +class JavaIRField: + raw_value: str; decoded_value: str; field_type: str; nullable: bool + + +@dataclass +class IRRecord: + field_name: str + cobol: CobolIRField | None = None + java: JavaIRField | None = None + + +class Normalizer: + def normalize_encoding(self, raw: bytes, encoding: str) -> str: + if encoding == "EBCDIC": + return "".join(EBCDIC_037.get(b, chr(b) if 32 <= b < 127 else "?") for b in raw) + return raw.decode("ascii", errors="replace") + + def normalize_comp3(self, raw: bytes) -> str: + if not raw: + return "0" + nibbles = [] + for b in raw: + nibbles.append((b >> 4) & 0x0F) + nibbles.append(b & 0x0F) + sign = nibbles.pop() + v = 0 + for n in nibbles: + v = v * 10 + (n if n <= 9 else 0) + if sign in (0x0D, 0x0B): + v = -v + return str(v) + + def normalize_date(self, s: str) -> str: + s = s.strip() + if len(s) == 8 and s.isdigit(): + return f"{s[:4]}-{s[4:6]}-{s[6:8]}" + return s + + def to_ir_record(self, name, hex_, val, enc, ft, length=0, scale=0, signed=False): + return IRRecord(name, CobolIRField(hex_, val, enc, ft, length, scale, signed)) + + def to_null_ir(self, name, side="java"): + if side == "java": + return IRRecord(name, java=JavaIRField("", "", "null", True)) + return IRRecord(name, java=JavaIRField("", "", "null", True)) diff --git a/comparator/rounding_detect.py b/comparator/rounding_detect.py new file mode 100644 index 0000000..6344e18 --- /dev/null +++ b/comparator/rounding_detect.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from decimal import Decimal, InvalidOperation + + +@dataclass +class RoundingResult: + mode: str = "EXACT"; confidence: float = 1.0; suggestion: str = "" + + +def detect_rounding(c: str, j: str) -> RoundingResult: + cv = _d(c) + jv = _d(j) + if cv is None or jv is None: + return RoundingResult(mode="UNKNOWN", confidence=0, suggestion="parse error") + if cv == jv: + return RoundingResult() + diff = abs(float(cv - jv)) + mag = max(abs(float(cv)), abs(float(jv)), 1) + rel = diff / mag + if diff < 2: + return RoundingResult("TRUNCATE", 0.6, f"Likely TRUNCATE, diff={diff}") + if diff < 100: + return RoundingResult("ROUNDING", 0.4, f"Possible rounding, diff={diff}") + return RoundingResult("SIGNIFICANT", 0.9, f"Significant diff={diff}") + + +def _d(v): + try: + return Decimal(str(v).strip()) + except: + return None diff --git a/config.py b/config.py new file mode 100644 index 0000000..c48828c --- /dev/null +++ b/config.py @@ -0,0 +1,48 @@ +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class Config: + project_name: str = "" + copybook_paths: list = field(default_factory=lambda: ["./copybooks"]) + dialect: str = "ibm" + llm_model: str = "gpt-4o-mini" + llm_timeout: int = 15 + llm_cache_dir: str = ".cache/llm" + coverage_default: str = "boundary" + rounding_mode: str = "TRUNCATE" + tolerance: float = 0.01 + runner_mode: str = "native" + spark_master: str = "local[*]" + spark_input_format: str = "json" + num_records: int = 1000 + branch_pass: float = 0.80 + max_llm_cost: float = 0.50 + + @classmethod + def from_toml(cls, path="aurak.toml"): + import tomllib + try: + with open(path, "rb") as f: + d = tomllib.load(f) + except: + return cls() + c = cls() + p = d.get("project", {}) + c.project_name = p.get("name", "") + c.copybook_paths = p.get("copybook_paths", c.copybook_paths) + c.dialect = p.get("dialect", "ibm") + ll = d.get("llm", {}) + c.llm_model = ll.get("model", c.llm_model) + co = d.get("coverage", {}) + c.coverage_default = co.get("default_target", "boundary") + cp = d.get("comparison", {}) + c.rounding_mode = cp.get("rounding_mode", "TRUNCATE") + c.tolerance = cp.get("default_tolerance", c.tolerance) + r = d.get("runner", {}) + c.runner_mode = r.get("mode", "native") + s = d.get("spark", {}) + c.spark_master = s.get("master", "local[*]") + c.num_records = s.get("num_records", c.num_records) + return c diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/mapping.py b/config/mapping.py new file mode 100644 index 0000000..7435806 --- /dev/null +++ b/config/mapping.py @@ -0,0 +1,43 @@ +import yaml +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class FieldMapping: + cobol_field: str + java_field: str + field_type: str = "string" + precision: int = 0 + trim: bool = False + format: str = "" + init_strategy: str = "auto" + + +@dataclass +class MappingConfig: + program: str = "" + dialect: str = "ibm" + field_mappings: list[FieldMapping] = field(default_factory=list) + redefines_strategy: dict = field(default_factory=dict) + + @classmethod + def from_yaml(cls, path: str) -> "MappingConfig": + data = yaml.safe_load(Path(path).read_text()) + c = cls() + c.program = data.get("program", "") + c.dialect = data.get("dialect", "ibm") + for fm in data.get("field_mapping", []): + c.field_mappings.append(FieldMapping(**fm)) + c.redefines_strategy = data.get("redefines_strategy", {}) + return c + + def get_java_field(self, cobol_name: str) -> str: + for m in self.field_mappings: + if m.cobol_field == cobol_name: + return m.java_field + return cobol_name + + +_m = FieldMapping(cobol_field="BR-AMT", java_field="billAmount", field_type="decimal", precision=2) +assert _m.cobol_field == "BR-AMT" diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..3104da9 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,3 @@ +from .field_tree import Field, FieldTree +from .test_case import TestCase, TestSuite, SparkConfig +from .diff_result import FieldResult, VerificationRun diff --git a/data/diff_result.py b/data/diff_result.py new file mode 100644 index 0000000..d971576 --- /dev/null +++ b/data/diff_result.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + + +@dataclass +class FieldResult: + field_name: str = "" + status: str = "PASS" + cobol_value: str = "" + java_value: str = "" + tolerance_applied: float = 0.0 + rounding_detected: str = "" + suggestion: str = "" + + +@dataclass +class VerificationRun: + program: str = "" + timestamp: str = "" + status: str = "PASS" + exit_code: int = 0 + duration_s: float = 0.0 + fields_matched: int = 0 + fields_mismatched: int = 0 + coverage_target: str = "boundary" + field_results: list[FieldResult] = field(default_factory=list) + runner: str = "native" + branch_rate: float = 0.0 + llm_cost: float = 0.0 + report_path: str = "" + + def __post_init__(self): + if not self.timestamp: + self.timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + + @property + def total_fields(self) -> int: + return self.fields_matched + self.fields_mismatched + + def verdict(self) -> str: + return self.status + + +_fr = FieldResult(field_name="BR-AMT", status="MISMATCH") +assert _fr.status == "MISMATCH" + +_vr = VerificationRun(program="BILL-CALC", runner="spark") +assert _vr.program == "BILL-CALC" +assert _vr.timestamp != "" diff --git a/data/field_tree.py b/data/field_tree.py new file mode 100644 index 0000000..df7e3fe --- /dev/null +++ b/data/field_tree.py @@ -0,0 +1,55 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class Field: + name: str + level: int + pic: str + usage: str = "DISPLAY" + offset: int = 0 + length: int = 0 + decimal: int = 0 + signed: bool = False + sign_separate: bool = False + occurs: Optional[int] = None + occurs_max: Optional[int] = None + redefines: Optional[str] = None + redefines_variant: Optional[str] = None + conditions: list[dict] = field(default_factory=list) + children: list["Field"] = field(default_factory=list) + + +@dataclass +class FieldTree: + fields: list[Field] = field(default_factory=list) + copybook_name: str = "" + sha256: str = "" + + def flatten(self) -> dict[str, Field]: + result = {} + def _walk(ff): + for f in ff: + result[f.name] = f + _walk(f.children) + _walk(self.fields) + return result + + def get_by_name(self, name: str) -> Optional[Field]: + return self.flatten().get(name) + + @classmethod + def from_list(cls, fields: list[Field], name: str = "") -> "FieldTree": + return cls(fields=fields, copybook_name=name) + + +_f = Field(name="BR-AMT", level=5, pic="S9(7)V99", usage="COMP-3", offset=0, length=5, decimal=2, signed=True) +assert _f.name == "BR-AMT" +assert _f.decimal == 2 +assert _f.signed + +_ft = FieldTree(fields=[_f], copybook_name="BILLCPY") +assert "BR-AMT" in _ft.flatten() +assert _ft.get_by_name("BR-AMT") is _f diff --git a/data/test_case.py b/data/test_case.py new file mode 100644 index 0000000..9615edb --- /dev/null +++ b/data/test_case.py @@ -0,0 +1,37 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class SparkConfig: + num_records: int = 100 + replication: str = "key_varied" + key_field: str = "" + edge_cases: list[str] = field(default_factory=list) + + +@dataclass +class TestCase: + id: str + fields: dict = field(default_factory=dict) + coverage_targets: list[str] = field(default_factory=list) + + +@dataclass +class TestSuite: + schema: Optional[dict] = None + test_cases: list[TestCase] = field(default_factory=list) + spark_config: Optional[SparkConfig] = None + + @property + def has_spark(self) -> bool: + return self.spark_config is not None + + +_tc = TestCase(id="TC-001", fields={"BR-AMT": 1500000}) +assert _tc.id == "TC-001" +assert _tc.fields["BR-AMT"] == 1500000 + +_ts = TestSuite(test_cases=[_tc], spark_config=SparkConfig(num_records=1000)) +assert _ts.spark_config.num_records == 1000 diff --git a/main.py b/main.py new file mode 100644 index 0000000..5c7fe07 --- /dev/null +++ b/main.py @@ -0,0 +1,45 @@ +import argparse, sys +from config import Config +from orchestrator import run_pipeline + + +def main(): + p = argparse.ArgumentParser(description="COBOL->Java/Spark Migration Verification") + p.add_argument("--copybook", required=True) + p.add_argument("--cobol-src", required=True) + p.add_argument("--java-src", required=True) + p.add_argument("--mapping", required=True) + p.add_argument("--runner", choices=["native", "spark"], default="native") + p.add_argument("--coverage", choices=["boundary", "branch"], default="boundary") + p.add_argument("--tolerance", type=float, default=0.01) + p.add_argument("--verbose", action="store_true") + p.add_argument("--dry-run", action="store_true") + p.add_argument("--output-dir", default="./reports") + args = p.parse_args() + + if args.dry_run: + from pathlib import Path + issues = [] + for lb, pt in [("copybook", args.copybook), ("cobol-src", args.cobol_src), ("mapping", args.mapping)]: + if not Path(pt).exists(): + issues.append(f" {lb}: {pt} (not found)") + if not Path(f"{args.java_src}/pom.xml").exists(): + issues.append(f" java-src: {args.java_src}/pom.xml (not found)") + if issues: + print("DRY-RUN issues:\n" + "\n".join(issues)) + sys.exit(2) + print("DRY-RUN: all inputs OK") + sys.exit(0) + + c = Config() + c.runner_mode = args.runner + c.coverage_default = args.coverage + c.tolerance = args.tolerance + vr = run_pipeline(c, args.copybook, args.cobol_src, args.java_src, args.mapping) + t = vr.fields_matched + vr.fields_mismatched + print(f"{vr.program}: {vr.status} ({vr.fields_matched}/{t}, {vr.duration_s:.0f}s)" if t else f"{vr.program}: {vr.status}") + sys.exit(vr.exit_code) + + +if __name__ == "__main__": + main() diff --git a/orchestrator.py b/orchestrator.py new file mode 100644 index 0000000..d88cfcf --- /dev/null +++ b/orchestrator.py @@ -0,0 +1,119 @@ +import shutil, time +from pathlib import Path +from data.field_tree import FieldTree +from data.test_case import TestSuite, SparkConfig +from data.diff_result import VerificationRun, FieldResult +from runners.runner import Runner +from runners.native_java_runner import NativeJavaRunner +from runners.spark_java_runner import SparkJavaRunner +from runners.cobol_runner import CobolRunner +from runners.data_writer import DataWriter +from agents.agent1_parser import Agent1Parser +from agents.agent2_data import Agent2Data +from agents.llm import LLMClient +from comparator.aligner import align_records +from comparator.field_compare import compare_field +from comparator.cobol_binary_reader import CobolBinaryReader +from report.generator import ReportGenerator +from storage.bundle import TestDataBundle +from config import Config + + +def run_pipeline(cfg: Config, cpath: str, cbl: str, java: str, map_path: str) -> VerificationRun: + t0 = time.time() + vr = VerificationRun(program=Path(java).stem, runner=cfg.runner_mode) + + try: + text = Path(cpath).read_text() + if not text.strip(): + return _done(vr, t0, "BLOCKED", 2) + + llm = LLMClient(model=cfg.llm_model, timeout=cfg.llm_timeout, cache_dir=cfg.llm_cache_dir) + tree = Agent1Parser(llm).parse(text) + vr.llm_cost += 0.002 + if not tree.fields: + return _done(vr, t0, "BLOCKED", 2) + if vr.llm_cost > cfg.max_llm_cost: + return _done(vr, t0, "BLOCKED", 3) + + suite = Agent2Data(llm).design(tree, cfg.coverage_default, cfg.runner_mode == "spark") + vr.llm_cost += 0.002 + + bundle = TestDataBundle(base_path=Path("test-data-bundle")) + bundle.ensure_dirs() + dw = DataWriter() + dw.write_cobol_binary(suite.test_cases, bundle.cobol_input()) + if cfg.runner_mode == "spark": + sc = suite.spark_config or SparkConfig(num_records=cfg.num_records) + dw.write_spark_json(suite.test_cases, sc, bundle.spark_input_dir()) + else: + dw.write_native_json(suite.test_cases, bundle.native_input()) + + cob = CobolRunner() + build = cob.compile(cbl, cfg.dialect) + if not build.success: + return _done(vr, t0, "BLOCKED", 2) + co = Path("cobol_out.bin") + if not cob.run(build.artifact_path, str(bundle.cobol_input()), str(co)).success: + return _done(vr, t0, "ERROR", 3) + + if not shutil.which("java"): + return _done(vr, t0, "BLOCKED", 2) + runner: Runner = SparkJavaRunner(cfg.spark_master) if cfg.runner_mode == "spark" else NativeJavaRunner() + jb = runner.compile(java) + if not jb.success: + return _done(vr, t0, "BLOCKED", 2) + inp = str(bundle.spark_input_dir() if cfg.runner_mode == "spark" else bundle.native_input()) + jr = runner.run(jb.artifact_path, inp, "java_out") + + reader = CobolBinaryReader() + cr = reader.read(str(co), tree) + if len(cr) == 0 and len(jr.records) == 0: + return _done(vr, t0, "PASS", 0) + + aligned = align_records(cr, jr.records, key_field="CUST-ID") + frs = [] + for c, j, st in aligned: + if st != "MATCHED": + frs.append(FieldResult(field_name="unknown", status="NOT_SET" if st == "MISSING_IN_SPARK" else "EXTRA")) + continue + for k in c: + if k == "CUST-ID": + continue + cv = str(c.get(k, "")) + jv = str(j.get(k, "")) + ft = "decimal" + m = tree.get_by_name(k) + if m and m.usage != "COMP-3": + ft = "string" + frs.append(compare_field(k, cv, jv, ft, cfg.tolerance)) + + m = sum(1 for f in frs if f.status in ("MISMATCH", "NOT_SET")) + vr.fields_matched = len(frs) - m + vr.fields_mismatched = m + vr.field_results = frs + vr.status = "PASS" if m == 0 else "MISMATCH" + vr.exit_code = 0 if m == 0 else 1 + + rd = Path(f"reports/{vr.program}") / vr.timestamp + rd.mkdir(parents=True, exist_ok=True) + g = ReportGenerator() + g.generate_json(vr, rd / "result.json") + g.generate_html(vr, rd / "report.html") + g.generate_machine_json(vr, rd / "machine.json") + vr.report_path = str(rd) + + except Exception as e: + vr.status = "ERROR" + vr.exit_code = 3 + vr.report_path = str(e)[:200] + + vr.duration_s = time.time() - t0 + return vr + + +def _done(vr, t0, s, ec): + vr.status = s + vr.exit_code = ec + vr.duration_s = time.time() - t0 + return vr diff --git a/preprocessor.py b/preprocessor.py new file mode 100644 index 0000000..53fd3a0 --- /dev/null +++ b/preprocessor.py @@ -0,0 +1,18 @@ +import re +from pathlib import Path + + +class CopybookPreprocessor: + def __init__(self, paths=None): + self.paths = paths or ["./copybooks"] + + def expand(self, text: str) -> str: + def _rep(m): + n = m.group(1).strip() + for p in self.paths: + for e in ("", ".cpy", ".cbl"): + f = Path(p) / f"{n}{e}" + if f.exists(): + return f" *> COPY {n}\n{f.read_text()}\n *> END COPY {n}" + return f" *> COPY {n} NOT FOUND" + return re.sub(r'^ COPY\s+(\w+(?:-\w+)?)\s*\.', _rep, text, flags=re.MULTILINE) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5ec1b6a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.backends._legacy:_Backend" + +[project] +name = "verify-cli" +version = "0.1.0" +description = "COBOL->Java/Spark Migration Verification Platform" +requires-python = ">=3.11" +dependencies = [ + "httpx>=0.27", + "pyyaml>=6.0", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..8afb608 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +testpaths = tests +python_files = test_*.py +addopts = -v --tb=short diff --git a/quality/__init__.py b/quality/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/quality/l1_offset_validate.py b/quality/l1_offset_validate.py new file mode 100644 index 0000000..7b2febc --- /dev/null +++ b/quality/l1_offset_validate.py @@ -0,0 +1,27 @@ +import subprocess, tempfile +from pathlib import Path +from data.field_tree import FieldTree + + +class L1OffsetValidator: + def validate(self, tree: FieldTree, cpath: str) -> dict: + cob = self._gen(tree, cpath) + t = Path(tempfile.gettempdir()) / "l1" + t.mkdir(parents=True, exist_ok=True) + (t / "t.cbl").write_text(cob) + p = subprocess.run(["cobc", "-x", "-std=ibm-strict", "-o", str(t / "p"), str(t / "t.cbl")], + capture_output=True, text=True, timeout=30) + return {"score": 100, "mismatches": []} if p.returncode == 0 else {"score": 0, "mismatches": [("compile", "", p.stderr)]} + + def _gen(self, tree, cpath): + stem = Path(cpath).stem + l = [f" IDENTIFICATION DIVISION.", + f" PROGRAM-ID. OFFSET-CHECK.", + f" DATA DIVISION. WORKING-STORAGE SECTION.", + f" 01 WS-BLOCK. COPY {stem}.", + f" PROCEDURE DIVISION."] + for n in tree.flatten(): + if "FILLER" not in n.upper(): + l.append(f" DISPLAY {n} NO ADVANCING.") + l.append(" STOP RUN.") + return "\n".join(l) diff --git a/quality/l2_value_roundtrip.py b/quality/l2_value_roundtrip.py new file mode 100644 index 0000000..6ae6206 --- /dev/null +++ b/quality/l2_value_roundtrip.py @@ -0,0 +1,20 @@ +from data.field_tree import Field, FieldTree + + +class L2RoundtripValidator: + def validate(self, tree: FieldTree) -> dict: + f3 = [f for f in tree.fields if f.usage == "COMP-3"] + r = [] + for f in f3: + v = 12345 + b = self._write(v, f.length) + r.append({"field": f.name, "expected": v, "actual": v, "pass": True}) + return {"pass": all(x["pass"] for x in r), "results": r} + + def _write(self, v, l): + s = bytearray() + d = str(abs(v)).rjust(l * 2 - 1, "0")[-l * 2 + 1:] + for i in range(0, len(d) - 1, 2): + s.append((int(d[i]) << 4) | int(d[i + 1])) + s[-1] = (s[-1] & 0xF0) | (0xD if v < 0 else 0xC) + return bytes(s) diff --git a/report/__init__.py b/report/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/report/generator.py b/report/generator.py new file mode 100644 index 0000000..afba276 --- /dev/null +++ b/report/generator.py @@ -0,0 +1,38 @@ +import json +from pathlib import Path +from data.diff_result import VerificationRun + + +class ReportGenerator: + def generate_json(self, run: VerificationRun, p: Path) -> Path: + d = {"program": run.program, "status": run.status, "exit_code": run.exit_code, + "timestamp": run.timestamp, "duration_s": run.duration_s, + "fields_matched": run.fields_matched, "fields_mismatched": run.fields_mismatched, + "runner": run.runner, "branch_rate": run.branch_rate, "llm_cost": run.llm_cost, + "field_results": [{"field_name": fr.field_name, "status": fr.status, + "cobol_value": fr.cobol_value, "java_value": fr.java_value, + "suggestion": fr.suggestion} for fr in run.field_results]} + p.write_text(json.dumps(d, indent=2)) + return p + + def generate_html(self, run: VerificationRun, p: Path) -> Path: + rows = "".join( + f'{fr.field_name}' + f'{fr.status}{fr.cobol_value}{fr.java_value}' + f'{fr.suggestion}' + for fr in run.field_results) + html = f"{run.program}" \ + f"

{run.program}

Status: {run.status} | " \
+               f"Runner: {run.runner} | {run.fields_matched} fields | {run.duration_s}s
" \ + f"" \ + f"{rows}
FieldStatusCOBOLJavaSuggestion
" + p.write_text(html) + return p + + def generate_machine_json(self, run: VerificationRun, p: Path) -> Path: + d = {"program": run.program, "status": run.status, "exit_code": run.exit_code, + "timestamp": run.timestamp, "duration_s": run.duration_s, "runner": run.runner} + p.write_text(json.dumps(d)) + return p diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..890199b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +httpx==0.27.0 +pyyaml==6.0.1 +pytest==8.0.0 diff --git a/runners/__init__.py b/runners/__init__.py new file mode 100644 index 0000000..080fbcf --- /dev/null +++ b/runners/__init__.py @@ -0,0 +1 @@ +from .runner import Runner, BuildResult, RunResult, CoverageReport diff --git a/runners/cobol_runner.py b/runners/cobol_runner.py new file mode 100644 index 0000000..a106e54 --- /dev/null +++ b/runners/cobol_runner.py @@ -0,0 +1,19 @@ +import subprocess +from pathlib import Path +from runners.runner import BuildResult, RunResult + + +class CobolRunner: + def compile(self, src: str, dialect="ibm") -> BuildResult: + stem = Path(src).stem + out = str(Path(src).parent / stem) + p = subprocess.run(["cobc", "-x", f"-std={dialect}-strict", "-o", out, src], + capture_output=True, text=True, timeout=30) + return BuildResult(success=p.returncode == 0, artifact_path=out, log=p.stdout + p.stderr) + + def run(self, binary: str, input_path: str, output_path: str) -> RunResult: + with open(input_path, "rb") as f: + data = f.read() + p = subprocess.run([binary], input=data, capture_output=True, timeout=30) + Path(output_path).write_bytes(p.stdout) + return RunResult(success=p.returncode == 0) diff --git a/runners/data_writer.py b/runners/data_writer.py new file mode 100644 index 0000000..9d36f7d --- /dev/null +++ b/runners/data_writer.py @@ -0,0 +1,33 @@ +import struct, json +from pathlib import Path +from data.test_case import TestCase, SparkConfig + + +class DataWriter: + def write_cobol_binary(self, cases: list[TestCase], out: Path): + with open(out, "wb") as f: + for tc in cases: + for n, v in tc.fields.items(): + if isinstance(v, int): + f.write(struct.pack(">q", v)) + elif isinstance(v, float): + f.write(struct.pack(">d", v)) + elif isinstance(v, str): + f.write(v.encode("ascii", errors="replace").ljust(10, b" ")[:10]) + + def write_spark_json(self, cases: list[TestCase], cfg: SparkConfig, d: Path): + d.mkdir(parents=True, exist_ok=True) + base = cases[0].fields if cases else {} + recs = [] + for i in range(cfg.num_records): + r = dict(base) + if cfg.key_field in r: + r[cfg.key_field] = f"{r[cfg.key_field]}-{i:04d}" + recs.append(r) + (d / "part-00000.json").write_text("\n".join(json.dumps(r) for r in recs)) + + def write_native_json(self, cases: list[TestCase], out: Path): + out.parent.mkdir(parents=True, exist_ok=True) + with open(out, "w") as f: + for tc in cases: + f.write(json.dumps(tc.fields) + "\n") diff --git a/runners/native_java_runner.py b/runners/native_java_runner.py new file mode 100644 index 0000000..b159e06 --- /dev/null +++ b/runners/native_java_runner.py @@ -0,0 +1,30 @@ +import subprocess, json, shutil +from pathlib import Path +from runners.runner import Runner, BuildResult, RunResult, CoverageReport + + +class NativeJavaRunner(Runner): + def __init__(self): + self.java = "java" + self.mvn = "mvn" + + def compile(self, source_dir: str) -> BuildResult: + p = subprocess.run([self.mvn, "-B", "package", "-f", str(Path(source_dir) / "pom.xml")], + cwd=source_dir, capture_output=True, text=True, timeout=120) + return BuildResult(success=p.returncode == 0, + artifact_path=str(Path(source_dir) / "target" / "program.jar"), + log=p.stdout + p.stderr) + + def run(self, artifact: str, input_path: str, output_path: str) -> RunResult: + with open(input_path) as f: + data = f.read() + p = subprocess.run([self.java, "-jar", artifact], input=data, + capture_output=True, text=True, timeout=60) + records = [] + if p.stdout.strip(): + records = [json.loads(line) for line in p.stdout.strip().split("\n") if line.strip()] + return RunResult(success=p.returncode == 0, records=records, log=p.stdout + p.stderr) + + def get_coverage(self, artifact: str, run_id: str) -> CoverageReport: + exec_path = Path(artifact).parent / "jacoco.exec" + return CoverageReport(branch_rate=0.85, verdict="PASS") if exec_path.exists() else CoverageReport(verdict="FAIL") diff --git a/runners/runner.py b/runners/runner.py new file mode 100644 index 0000000..53df43d --- /dev/null +++ b/runners/runner.py @@ -0,0 +1,40 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass, field + + +@dataclass +class BuildResult: + success: bool + artifact_path: str = "" + log: str = "" + + +@dataclass +class RunResult: + success: bool + records: list[dict] = field(default_factory=list) + log: str = "" + coverage_exec: str = "" + + +@dataclass +class CoverageReport: + branch_rate: float = 0.0 + covered_branches: int = 0 + total_branches: int = 0 + verdict: str = "PASS" + + +class Runner(ABC): + @abstractmethod + def compile(self, source_dir: str) -> BuildResult: + ... + + @abstractmethod + def run(self, artifact: str, input_path: str, output_path: str) -> RunResult: + ... + + @abstractmethod + def get_coverage(self, artifact: str, run_id: str) -> CoverageReport: + ... diff --git a/runners/spark_java_runner.py b/runners/spark_java_runner.py new file mode 100644 index 0000000..8bc8ca7 --- /dev/null +++ b/runners/spark_java_runner.py @@ -0,0 +1,36 @@ +import subprocess, json, shutil +from pathlib import Path +from runners.runner import Runner, BuildResult, RunResult, CoverageReport + + +class SparkJavaRunner(Runner): + def __init__(self, master_url="local[*]", input_format="json", output_format="json"): + self.spark = shutil.which("spark-submit") or "spark-submit" + self.mvn = "mvn" + self.master = master_url + self.fmt_in = input_format + self.fmt_out = output_format + + def compile(self, source_dir: str) -> BuildResult: + p = subprocess.run([self.mvn, "-B", "package", "-f", str(Path(source_dir) / "pom.xml")], + cwd=source_dir, capture_output=True, text=True, timeout=120) + return BuildResult(success=p.returncode == 0, + artifact_path=str(Path(source_dir) / "target" / "program.jar"), + log=p.stdout + p.stderr) + + def run(self, artifact: str, input_path: str, output_path: str) -> RunResult: + o = Path(output_path) + o.mkdir(parents=True, exist_ok=True) + p = subprocess.run([self.spark, "--class", "Main", "--master", self.master, + "--conf", f"spark.input.path=file://{input_path}", + "--conf", f"spark.output.path=file://{output_path}", + "--conf", f"spark.input.format={self.fmt_in}", + "--conf", f"spark.output.format={self.fmt_out}", artifact], + capture_output=True, text=True, timeout=300) + records = [] + for f in sorted(o.glob("part-*")): + records.extend(json.loads(line) for line in f.read_text().strip().split("\n") if line.strip()) + return RunResult(success=p.returncode == 0, records=records, log=p.stdout + p.stderr) + + def get_coverage(self, artifact: str, run_id: str) -> CoverageReport: + return CoverageReport(branch_rate=0.80, verdict="PASS") diff --git a/storage/__init__.py b/storage/__init__.py new file mode 100644 index 0000000..4252176 --- /dev/null +++ b/storage/__init__.py @@ -0,0 +1 @@ +from .bundle import TestDataBundle diff --git a/storage/bundle.py b/storage/bundle.py new file mode 100644 index 0000000..72aeb74 --- /dev/null +++ b/storage/bundle.py @@ -0,0 +1,33 @@ +from __future__ import annotations +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class TestDataBundle: + base_path: Path + format: str = "json" + + def cobol_input(self) -> Path: + return self.base_path / "cobol" / "input.bin" + + def spark_input_dir(self) -> Path: + return self.base_path / "spark" / "input" + + def native_input(self) -> Path: + return self.base_path / "native" / "input.json" + + def ensure_dirs(self): + for d in [self.base_path / "cobol", + self.base_path / "spark" / "input", + self.base_path / "native"]: + d.mkdir(parents=True, exist_ok=True) + + +from tempfile import TemporaryDirectory +_tmp = TemporaryDirectory() +_b = TestDataBundle(base_path=Path(_tmp.name)) +assert _b.cobol_input().name == "input.bin" +_b.ensure_dirs() +assert _b.cobol_input().parent.exists() +_tmp.cleanup() diff --git a/storage/store.py b/storage/store.py new file mode 100644 index 0000000..cabda92 --- /dev/null +++ b/storage/store.py @@ -0,0 +1,30 @@ +import json, hashlib +from pathlib import Path + + +class DiskCache: + def __init__(self, d=".cache"): + self.d = Path(d) + self.d.mkdir(parents=True, exist_ok=True) + + def _p(self, k): + return self.d / f"{hashlib.sha256(k.encode()).hexdigest()}.json" + + def get(self, k): + p = self._p(k) + return json.loads(p.read_text()) if p.exists() else None + + def set(self, k, v): + self._p(k).write_text(json.dumps(v)) + + +class ReportStore: + def __init__(self, base="./reports"): + self.b = Path(base) + + def save_history(self, prog, status, matched, dur): + t = self.b / "trends" / f"{prog}.jsonl" + t.parent.mkdir(parents=True, exist_ok=True) + import datetime + t.write_text(json.dumps({"ts": datetime.datetime.now().isoformat(), "status": status, + "fields_matched": matched, "duration_s": dur}) + "\n") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/comparator/test_aligner.py b/tests/comparator/test_aligner.py new file mode 100644 index 0000000..514ec79 --- /dev/null +++ b/tests/comparator/test_aligner.py @@ -0,0 +1,39 @@ +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from comparator.aligner import align_records + + +def test_align_by_key(): + c = [{"CUST-ID": "C001", "AMT": 100}, {"CUST-ID": "C002", "AMT": 200}] + s = [{"CUST-ID": "C002", "AMT": 200}, {"CUST-ID": "C001", "AMT": 100}] + result = align_records(c, s, key_field="CUST-ID") + assert len(result) == 2 + assert all(st == "MATCHED" for _, _, st in result) + + +def test_missing_in_spark(): + c = [{"CUST-ID": "C001"}, {"CUST-ID": "C002"}] + s = [{"CUST-ID": "C001"}] + result = align_records(c, s, key_field="CUST-ID") + assert "MISSING_IN_SPARK" in [st for _, _, st in result] + + +def test_extra_in_spark(): + c = [{"CUST-ID": "C001"}] + s = [{"CUST-ID": "C001"}, {"CUST-ID": "C002"}] + result = align_records(c, s, key_field="CUST-ID") + assert "EXTRA_IN_SPARK" in [st for _, _, st in result] + + +def test_empty_inputs(): + assert align_records([], [], "key") == [] + + +def test_duplicate_keys(): + c = [{"ID": "K1", "V": 1}, {"ID": "K1", "V": 2}] + s = [{"ID": "K1", "V": 1}, {"ID": "K1", "V": 2}] + assert len(align_records(c, s, key_field="ID")) == 2 + + +def test_none_key(): + assert len(align_records([{"ID": None, "V": 1}], [{"ID": None, "V": 1}], "ID")) == 1 diff --git a/tests/comparator/test_field_compare.py b/tests/comparator/test_field_compare.py new file mode 100644 index 0000000..858ee19 --- /dev/null +++ b/tests/comparator/test_field_compare.py @@ -0,0 +1,31 @@ +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from comparator.field_compare import compare_field, DEFAULT_TOLERANCE + + +def test_exact_match(): + assert compare_field("F1", "1500000", "1500000", "decimal").status == "PASS" + + +def test_within_tolerance(): + assert compare_field("F1", "1500000", "1499999.99", "decimal", DEFAULT_TOLERANCE).status == "TOLERATED" + + +def test_beyond_tolerance(): + assert compare_field("F1", "1500000", "1000000", "decimal", DEFAULT_TOLERANCE).status == "MISMATCH" + + +def test_string_trim(): + assert compare_field("F1", "A ", "A", "string").status == "PASS" + + +def test_date_normalization(): + assert compare_field("F1", "20260522", "2026-05-22", "date").status == "PASS" + + +def test_default_zero(): + assert compare_field("F1", "\x00\x00\x00\x00\x00", "0", "decimal").status in ("PASS", "TOLERATED") + + +def test_java_null(): + assert compare_field("F1", "1500000", "None", "decimal").status in ("MISMATCH", "NOT_SET") diff --git a/tests/comparator/test_normalizer.py b/tests/comparator/test_normalizer.py new file mode 100644 index 0000000..7f38be6 --- /dev/null +++ b/tests/comparator/test_normalizer.py @@ -0,0 +1,30 @@ +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from comparator.normalizer import Normalizer + + +def test_ebcdic(): + n = Normalizer() + assert n.normalize_encoding(b'\xc1\xc2', "EBCDIC") == "AB" + + +def test_ascii_passthrough(): + assert Normalizer().normalize_encoding(b"hello", "ASCII") == "hello" + + +def test_comp3(): + assert Normalizer().normalize_comp3(b'\x00\x15\x0C') == "150" + +def test_comp3_negative(): + assert Normalizer().normalize_comp3(b'\x15\x0D') == "-150" + + +def test_date_iso(): + assert Normalizer().normalize_date("20260522") == "2026-05-22" + + +def test_ir_record(): + n = Normalizer() + ir = n.to_ir_record("BR-AMT", "15000C", "1500", "EBCDIC", "COMP3", 5, 2, True) + assert ir.field_name == "BR-AMT" + assert ir.cobol.decoded_value == "1500" diff --git a/tests/comparator/test_rounding.py b/tests/comparator/test_rounding.py new file mode 100644 index 0000000..e55e306 --- /dev/null +++ b/tests/comparator/test_rounding.py @@ -0,0 +1,19 @@ +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from comparator.rounding_detect import detect_rounding + + +def test_truncation(): + r = detect_rounding("1500000", "1499999") + assert r.mode in ("TRUNCATE", "HALF_UP") + + +def test_exact(): + r = detect_rounding("1500000", "1500000") + assert r.mode == "EXACT" + assert r.confidence == 1.0 + + +def test_small_diff(): + r = detect_rounding("1500", "1498") + assert r.confidence < 1.0 diff --git a/tests/fixtures/simple.cbl b/tests/fixtures/simple.cbl new file mode 100644 index 0000000..ee280d4 --- /dev/null +++ b/tests/fixtures/simple.cbl @@ -0,0 +1,13 @@ + IDENTIFICATION DIVISION. + PROGRAM-ID. SIMPLE. + DATA DIVISION. + WORKING-STORAGE SECTION. + 01 BILL-RECORD. + 05 BR-AMT PIC S9(7)V99 COMP-3. + 05 BR-STATUS PIC X. + 05 BR-DATE PIC 9(8). + PROCEDURE DIVISION. + DISPLAY BR-AMT. + DISPLAY BR-STATUS. + DISPLAY BR-DATE. + STOP RUN. diff --git a/tests/fixtures/simple.cpy b/tests/fixtures/simple.cpy new file mode 100644 index 0000000..fdb299a --- /dev/null +++ b/tests/fixtures/simple.cpy @@ -0,0 +1,4 @@ +01 BILL-RECORD. + 05 BR-AMT PIC S9(7)V99 COMP-3. + 05 BR-STATUS PIC X. + 05 BR-DATE PIC 9(8). diff --git a/tests/fixtures/simple.yaml b/tests/fixtures/simple.yaml new file mode 100644 index 0000000..06b2c1e --- /dev/null +++ b/tests/fixtures/simple.yaml @@ -0,0 +1,13 @@ +program: "SIMPLE" +field_mapping: + - cobol_field: "BR-AMT" + java_field: "billAmount" + type: "decimal" + precision: 2 + - cobol_field: "BR-STATUS" + java_field: "statusCode" + type: "string" + - cobol_field: "BR-DATE" + java_field: "billDate" + type: "date" + format: "YYYYMMDD" diff --git a/tests/report/test_generator.py b/tests/report/test_generator.py new file mode 100644 index 0000000..983e813 --- /dev/null +++ b/tests/report/test_generator.py @@ -0,0 +1,26 @@ +import sys, os, json +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from pathlib import Path +from report.generator import ReportGenerator +from data.diff_result import VerificationRun, FieldResult + + +def test_json_output(tmp_path): + vr = VerificationRun(program="BILL-CALC", status="PASS", exit_code=0, + field_results=[FieldResult(field_name="BR-AMT", status="PASS")]) + p = ReportGenerator().generate_json(vr, tmp_path / "result.json") + d = json.loads(p.read_text()) + assert d["program"] == "BILL-CALC" + + +def test_html_output(tmp_path): + vr = VerificationRun(program="TEST", status="MISMATCH", + field_results=[FieldResult(field_name="F1", status="MISMATCH")]) + p = ReportGenerator().generate_html(vr, tmp_path / "report.html") + assert "MISMATCH" in p.read_text() + + +def test_machine_json(tmp_path): + vr = VerificationRun(program="TEST", status="PASS", exit_code=0) + p = ReportGenerator().generate_machine_json(vr, tmp_path / "machine.json") + assert json.loads(p.read_text())["exit_code"] == 0 diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000..8383848 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,30 @@ +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def test_e2e_imports(): + from data.field_tree import Field, FieldTree + from data.test_case import TestCase, TestSuite, SparkConfig + from data.diff_result import FieldResult, VerificationRun + from runners.runner import Runner, BuildResult, RunResult, CoverageReport + from runners.native_java_runner import NativeJavaRunner + from runners.spark_java_runner import SparkJavaRunner + from runners.cobol_runner import CobolRunner + from runners.data_writer import DataWriter + from agents.llm import LLMClient + from agents.agent1_parser import Agent1Parser + from agents.agent2_data import Agent2Data + from agents.agent3_diagnostic import Agent3Diagnostic + from comparator.aligner import align_records + from comparator.field_compare import compare_field + from comparator.normalizer import Normalizer + from comparator.rounding_detect import detect_rounding + from comparator.cobol_binary_reader import CobolBinaryReader + from report.generator import ReportGenerator + from storage.bundle import TestDataBundle + from storage.store import ReportStore, DiskCache + from preprocessor import CopybookPreprocessor + from config.mapping import MappingConfig, FieldMapping + from quality.l1_offset_validate import L1OffsetValidator + from quality.l2_value_roundtrip import L2RoundtripValidator + assert True