diff --git a/src/simulator/machine.rs b/src/simulator/machine.rs index fdb9b0f..d1fcd3a 100644 --- a/src/simulator/machine.rs +++ b/src/simulator/machine.rs @@ -385,24 +385,24 @@ impl Machine { /// Executes RISC-V Integer Register-Immediate Instructions on the machine fn opi_instruction(&mut self, inst: Instruction) -> Result<(), MachineError> { - let mut compute = |operation: &dyn Fn (i64, i64) -> i64| { - self.int_reg.set_reg(inst.rd, operation(self.int_reg.get_reg(inst.rs1), inst.imm12_I_signed as i64)); + let rs1 = self.int_reg.get_reg(inst.rs1); + let imm12 = inst.imm12_I_signed as i64; + let shamt = inst.shamt as i64; + let mut compute = |operation: &dyn Fn (i64, i64) -> i64, a, b| { + self.int_reg.set_reg(inst.rd, operation(a, b)); Ok(()) }; match inst.funct3 { - RISCV_OPI_ADDI => compute(&std::ops::Add::add), - RISCV_OPI_SLTI => compute(&|a, b| { (a < b) as i64 }), - RISCV_OPI_XORI => compute(&|a, b| { a ^ b }), - RISCV_OPI_ORI => compute(&|a, b| { a | b }), - RISCV_OPI_ANDI => compute(&|a, b| { a & b }), - RISCV_OPI_SLLI => compute(&|a, b| { a << b }), - RISCV_OPI_SRI => { - if inst.funct7_smaller == RISCV_OPI_SRI_SRLI { - self.int_reg.set_reg(inst.rd, (self.int_reg.get_reg(inst.rs1) >> inst.shamt) & self.shiftmask[inst.shamt as usize] as i64) - } else { // SRAI - self.int_reg.set_reg(inst.rd, self.int_reg.get_reg(inst.rs1) >> inst.shamt) - } - Ok(()) + RISCV_OPI_ADDI => compute(&std::ops::Add::add, rs1, imm12), + RISCV_OPI_SLTI => compute(&|a, b| (a < b) as i64, rs1, imm12), + RISCV_OPI_XORI => compute(&core::ops::BitXor::bitxor, rs1, imm12), + RISCV_OPI_ORI => compute(&core::ops::BitOr::bitor, rs1, imm12), + RISCV_OPI_ANDI => compute(&core::ops::BitAnd::bitand, rs1, imm12), + RISCV_OPI_SLLI => compute(&core::ops::Shl::shl, rs1, imm12), + RISCV_OPI_SRI => if inst.funct7_smaller == RISCV_OPI_SRI_SRLI { + compute(&|a, b| { (a >> b) & self.shiftmask[inst.shamt as usize] as i64 }, rs1, shamt) + } else { + compute(&|a, b| { a >> b }, rs1, shamt) } _ => Err(MachineError::new(format!("In OPI switch case, this should never happen... Instr was %x\n {}", inst.value).as_str())) }