150 lines
5.5 KiB
Python
150 lines
5.5 KiB
Python
import shutil, time
|
|
from pathlib import Path
|
|
from data.field_tree import FieldTree
|
|
from data.test_case import TestSuite, SparkConfig
|
|
from data.diff_result import VerificationRun, FieldResult
|
|
from runners.runner import Runner
|
|
from runners.native_java_runner import NativeJavaRunner
|
|
from runners.spark_java_runner import SparkJavaRunner
|
|
from runners.cobol_runner import CobolRunner
|
|
from runners.data_writer import DataWriter
|
|
from agents.agent1_parser import Agent1Parser
|
|
from agents.agent2_data import Agent2Data
|
|
from agents.llm import LLMClient
|
|
from comparator.aligner import align_records
|
|
from comparator.field_compare import compare_field
|
|
from comparator.normalizer import Normalizer
|
|
from comparator.cobol_binary_reader import CobolBinaryReader
|
|
from report.generator import ReportGenerator
|
|
from storage.bundle import TestDataBundle
|
|
from config import Config
|
|
|
|
|
|
def run_pipeline(cfg: Config, copybook_path: str, cobol_src: str,
|
|
java_src: str, mapping_path: str) -> VerificationRun:
|
|
start = time.time()
|
|
vr = VerificationRun(program=Path(java_src).stem, runner=cfg.runner_mode)
|
|
|
|
# Step 1: read COPYBOOK
|
|
copybook_text = Path(copybook_path).read_text()
|
|
|
|
# Step 2: Agent 1 parse
|
|
llm = LLMClient(model=cfg.llm_model, timeout=cfg.llm_timeout,
|
|
cache_dir=cfg.llm_cache_dir)
|
|
parser = Agent1Parser(llm)
|
|
try:
|
|
tree = parser.parse(copybook_text)
|
|
vr.llm_cost += 0.002
|
|
except Exception:
|
|
vr.status = "BLOCKED"
|
|
vr.exit_code = 2
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
# Step 3: Agent 2 test data
|
|
designer = Agent2Data(llm)
|
|
suite = designer.design(tree, cfg.coverage_default,
|
|
spark_mode=(cfg.runner_mode == "spark"))
|
|
vr.llm_cost += 0.002
|
|
if vr.llm_cost > cfg.max_llm_cost:
|
|
vr.status = "BLOCKED"
|
|
vr.exit_code = 3
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
# Step 4: write test data
|
|
bundle = TestDataBundle(base_path=Path("test-data-bundle"), format="json")
|
|
bundle.ensure_dirs()
|
|
writer = DataWriter()
|
|
writer.write_cobol_binary(suite.test_cases, bundle.cobol_input())
|
|
if cfg.runner_mode == "spark":
|
|
sc = suite.spark_config or SparkConfig(num_records=cfg.num_records)
|
|
writer.write_spark_json(suite.test_cases, sc, bundle.spark_input_dir())
|
|
else:
|
|
writer.write_native_json(suite.test_cases, bundle.native_input())
|
|
|
|
# Step 5: COBOL compile + run
|
|
cobol = CobolRunner()
|
|
build = cobol.compile(cobol_src, cfg.dialect)
|
|
if not build.success:
|
|
vr.status = "BLOCKED"
|
|
vr.exit_code = 2
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
cobol_out = Path("cobol_output.bin")
|
|
cobol_run = cobol.run(build.artifact_path, str(bundle.cobol_input()), str(cobol_out))
|
|
if not cobol_run.success:
|
|
vr.status = "ERROR"
|
|
vr.exit_code = 3
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
# Step 6: Java/Spark compile + run
|
|
native_available = shutil.which("java") is not None
|
|
if not native_available:
|
|
vr.status = "BLOCKED"
|
|
vr.exit_code = 2
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
runner: Runner = (SparkJavaRunner(master_url=cfg.spark_master)
|
|
if cfg.runner_mode == "spark" else NativeJavaRunner())
|
|
jbuild = runner.compile(java_src)
|
|
if not jbuild.success:
|
|
vr.status = "BLOCKED"
|
|
vr.exit_code = 2
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
native_input = str(bundle.native_input())
|
|
if cfg.runner_mode == "spark":
|
|
native_input = str(bundle.spark_input_dir())
|
|
jrun = runner.run(jbuild.artifact_path, native_input, "java_output")
|
|
|
|
# Step 7: compare
|
|
reader = CobolBinaryReader()
|
|
cobol_records = reader.read(str(cobol_out), tree)
|
|
if len(cobol_records) == 0 and len(jrun.records) == 0:
|
|
vr.status = "PASS"
|
|
vr.duration_s = time.time() - start
|
|
return vr
|
|
|
|
aligned = align_records(cobol_records, jrun.records, key_field="CUST-ID")
|
|
field_results = []
|
|
for c_rec, j_rec, status in aligned:
|
|
if status != "MATCHED":
|
|
field_results.append(FieldResult(
|
|
field_name="unknown",
|
|
status="NOT_SET" if status == "MISSING_IN_SPARK" else "EXTRA"))
|
|
continue
|
|
for key in c_rec:
|
|
if key == "CUST-ID":
|
|
continue
|
|
cv = str(c_rec.get(key, ""))
|
|
jv = str(j_rec.get(key, ""))
|
|
ft = "decimal"
|
|
mapped = tree.get_by_name(key)
|
|
if mapped and mapped.usage != "COMP-3":
|
|
ft = "string" if mapped.usage == "DISPLAY" else "decimal"
|
|
fr = compare_field(key, cv, jv, ft, tolerance=cfg.tolerance)
|
|
field_results.append(fr)
|
|
|
|
mismatches = sum(1 for f in field_results if f.status in ("MISMATCH", "NOT_SET"))
|
|
vr.status = "PASS" if mismatches == 0 else "MISMATCH"
|
|
vr.exit_code = 0 if mismatches == 0 else 1
|
|
vr.fields_matched = len(field_results) - mismatches
|
|
vr.fields_mismatched = mismatches
|
|
vr.field_results = field_results
|
|
vr.duration_s = time.time() - start
|
|
|
|
# Step 8: report
|
|
report_dir = Path(f"reports/{vr.program}") / vr.timestamp
|
|
report_dir.mkdir(parents=True, exist_ok=True)
|
|
gen = ReportGenerator()
|
|
gen.generate_json(vr, report_dir / "result.json")
|
|
gen.generate_html(vr, report_dir / "report.html")
|
|
gen.generate_machine_json(vr, report_dir / "machine.json")
|
|
vr.report_path = str(report_dir)
|
|
|
|
return vr
|