From 4b5b8503b7dee6dd595b8ea0890cc7d6fb360c78 Mon Sep 17 00:00:00 2001 From: Sindre Stephansen Date: Mon, 14 Oct 2019 23:24:26 +0200 Subject: [PATCH] Implement jumps --- src/main/scala/ALU.scala | 46 ++++++++++++++++ src/main/scala/CPU.scala | 40 ++++++++++---- src/main/scala/Comparator.scala | 40 ++++++++++++++ src/main/scala/Decoder.scala | 78 +++++++++++++++------------- src/main/scala/Execute.scala | 56 ++++++++------------ src/main/scala/ID.scala | 69 +++++++++++------------- src/main/scala/IF.scala | 11 +++- src/main/scala/ToplevelSignals.scala | 4 ++ 8 files changed, 225 insertions(+), 119 deletions(-) create mode 100644 src/main/scala/ALU.scala create mode 100644 src/main/scala/Comparator.scala diff --git a/src/main/scala/ALU.scala b/src/main/scala/ALU.scala new file mode 100644 index 0000000..c8c29c4 --- /dev/null +++ b/src/main/scala/ALU.scala @@ -0,0 +1,46 @@ +package FiveStage + +import chisel3._ +import chisel3.util._ + + +class ALU extends Module { + val io = IO( + new Bundle { + val ALUop = Input(UInt(4.W)) + val data1 = Input(UInt(32.W)) + val data2 = Input(UInt(32.W)) + + val result = Output(UInt(32.W)) + } + ) + + val data1S = io.data1.asSInt + val data2S = io.data2.asSInt + val data1U = io.data1.asUInt + val data2U = io.data2.asUInt + + val ALUopMap = Array( + ALUOps.ADD -> (data1S + data2S).asUInt, + ALUOps.SUB -> (data1S - data2S).asUInt, + ALUOps.AND -> (data1U & data2U), + ALUOps.OR -> (data1U | data2U), + ALUOps.XOR -> (data1U ^ data2U), + ALUOps.SLT -> (data1S < data2S).asUInt, + ALUOps.SLTU -> (data1U < data2U), + ALUOps.SLL -> (data1S << data2U(4, 0)).asUInt, + ALUOps.SRL -> (data1U >> data2U(4, 0)).asUInt, + ALUOps.SRA -> (data1S >> data2U(4, 0)).asUInt, + ALUOps.COPY_A -> (data1U), + ALUOps.COPY_B -> (data2U), + ALUOps.ADDR -> ((data1S + data2S) & 0xFFFFFFFE.S).asUInt, + ALUOps.LUI -> (data2U << 12.U).asUInt, + ALUOps.AUIPC -> (data1U + (data2U << 12.U)), + ) + + io.result := MuxLookup( + io.ALUop, + 0.U(32.W), + ALUopMap, + ) +} diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index 360fc1c..fa99e76 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -54,30 +54,50 @@ class CPU extends MultiIOModule { /** TODO: Your code here */ + // Stage 1 + + // The IF gets the instruction, but it won't be ready until the next cycle + //printf(p"S1: PC=${Hexadecimal(IF.io.PC)} || ") + // Stage 2 - ID.io.PC := IF.io.PC ID.io.instruction := IF.io.instruction - //printf(p"S2: PC=${IF.io.PC}, Opcode=${IF.io.instruction.opcode}, rd=${IF.io.instruction.registerRd}, rs1=${IF.io.instruction.registerRs1}, rs2=${IF.io.instruction.registerRs2} || ") + //printf(p"S2: Opcode=${IF.io.instruction.opcode}, rd=${IF.io.instruction.registerRd}, rs1=${IF.io.instruction.registerRs1}, rs2=${IF.io.instruction.registerRs2} || ") // Stage 3 + EX.io.PC := ShiftRegister(IF.io.PC, 2) + EX.io.controlSignals := ID.io.controlSignals + EX.io.branchType := ID.io.branchType EX.io.ALUop := ID.io.ALUop - EX.io.data1 := ID.io.data1 - EX.io.data2 := ID.io.data2 - //printf(p"S3: ALUop=${ID.io.ALUop}, data1=${ID.io.data1.asSInt}, data2=${ID.io.data2.asSInt} || ") + EX.io.reg1 := ID.io.reg1 + EX.io.reg2 := ID.io.reg2 + EX.io.imm := ID.io.imm + //printf(p"S3: ALUop=${ID.io.ALUop}, reg1=${ID.io.reg1.asSInt}, reg2=${ID.io.reg2.asSInt}, imm=${ID.io.imm.asSInt} || ") // Stage 4 - MEM.io.dataIn := ShiftRegister(ID.io.data3, 1) + MEM.io.dataIn := ShiftRegister(ID.io.reg2, 1) MEM.io.dataAddress := EX.io.result MEM.io.writeEnable := ShiftRegister(ID.io.controlSignals.memWrite, 1) + + IF.io.jumpEnable := EX.io.branch + IF.io.jumpAddr := EX.io.result + //printf(p"S4: res=${EX.io.result} || ") // Stage 5 - val memOrEx = ShiftRegister(ID.io.controlSignals.memToReg, 2) // From stage 3 val exResult = ShiftRegister(EX.io.result, 1) // From stage 4 ID.io.writeEnable := ShiftRegister(ID.io.controlSignals.regWrite, 2) // From stage 3 - ID.io.writeAddr := ShiftRegister(IF.io.instruction.registerRd, 3) // From stage 2 - ID.io.writeData := Mux(memOrEx, MEM.io.dataOut, exResult) - //printf(p"S5: Mem=${MEM.io.dataOut} Ex=${exResult}") + ID.io.writeAddr := ShiftRegister(IF.io.instruction.registerRd, 3) // From stage 2 + ID.io.writeData := exResult + + when (ShiftRegister(ID.io.controlSignals.memToReg, 2)) { + ID.io.writeData := MEM.io.dataOut + } + + when (ShiftRegister(ID.io.controlSignals.jump, 2)) { + ID.io.writeData := ShiftRegister(IF.io.PC, 4) + 4.U + } + + //printf(p"S5: WB=${ID.io.writeData}") //printf("\n\n") } diff --git a/src/main/scala/Comparator.scala b/src/main/scala/Comparator.scala new file mode 100644 index 0000000..a656046 --- /dev/null +++ b/src/main/scala/Comparator.scala @@ -0,0 +1,40 @@ +package FiveStage +import chisel3._ +import chisel3.util._ + + +class Comparator extends Module { + val io = IO( + new Bundle { + val branchType = Input(UInt(3.W)) + val data1 = Input(UInt(32.W)) + val data2 = Input(UInt(32.W)) + + val result = Output(Bool()) + }) + + io.result := false.B + + //printf(p"Comparator: ${io.result}\n") + + switch (io.branchType) { + is (branchType.beq) { + io.result := io.data1 === io.data2 + } + is (branchType.neq) { + io.result := io.data1 =/= io.data2 + } + is (branchType.gte) { + io.result := io.data1.asSInt >= io.data2.asSInt + } + is (branchType.gteu) { + io.result := io.data1 >= io.data2 + } + is (branchType.lt) { + io.result := io.data1.asSInt < io.data2.asSInt + } + is (branchType.ltu) { + io.result := io.data1 < io.data2 + } + } +} diff --git a/src/main/scala/Decoder.scala b/src/main/scala/Decoder.scala index e069cfa..2e93759 100644 --- a/src/main/scala/Decoder.scala +++ b/src/main/scala/Decoder.scala @@ -20,8 +20,6 @@ class Decoder() extends Module { val controlSignals = Output(new ControlSignals) val branchType = Output(UInt(3.W)) - val op1Select = Output(UInt(1.W)) - val op2Select = Output(UInt(1.W)) val immType = Output(UInt(3.W)) val ALUop = Output(UInt(4.W)) }) @@ -46,30 +44,41 @@ class Decoder() extends Module { */ val opcodeMap: Array[(BitPat, List[UInt])] = Array( - // signal memToReg, regWrite, memRead, memWrite, branch, jump, branchType, Op1Select, Op2Select, ImmSelect, ALUOp - LW -> List(Y, Y, Y, N, N, N, branchType.DC, rs1, imm, ITYPE, ALUOps.ADD), - - SW -> List(N, N, N, Y, N, N, branchType.DC, rs1, imm, STYPE, ALUOps.ADD), - - ADD -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.ADD), - ADDI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ITYPE, ALUOps.ADD), - SUB -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SUB), - AND -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.AND), - ANDI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.AND), - OR -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.OR), - ORI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.OR), - XOR -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.XOR), - XORI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.XOR), - SLT -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLT), - SLTI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.SLT), - SLTU -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLTU), - SLTIU -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.SLTU), - SRA -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRA), - SRAI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.SRA), - SRL -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SRL), - SRLI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.SRL), - SLL -> List(N, Y, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.SLL), - SLLI -> List(N, Y, N, N, N, N, branchType.DC, rs1, imm, ImmFormat.DC, ALUOps.SLL), + // signal memToReg, regWrite, memRead, memWrite, immediate, branch, jump, branchType, ImmSelect, ALUOp + LW -> List(Y, Y, Y, N, Y, N, N, branchType.DC, ITYPE, ALUOps.ADD), + SW -> List(N, N, N, Y, Y, N, N, branchType.DC, STYPE, ALUOps.ADD), + + JAL -> List(N, Y, N, N, Y, N, Y, branchType.jump, JTYPE, ALUOps.ADD), + JALR -> List(N, Y, N, N, Y, N, Y, branchType.jump, ITYPE, ALUOps.ADDR), + BEQ -> List(N, N, N, N, Y, Y, N, branchType.beq, BTYPE, ALUOps.ADD), + BNE -> List(N, N, N, N, Y, Y, N, branchType.neq, BTYPE, ALUOps.ADD), + BLT -> List(N, N, N, N, Y, Y, N, branchType.lt, BTYPE, ALUOps.ADD), + BLTU -> List(N, N, N, N, Y, Y, N, branchType.ltu, BTYPE, ALUOps.ADD), + BGE -> List(N, N, N, N, Y, Y, N, branchType.gte, BTYPE, ALUOps.ADD), + BGEU -> List(N, N, N, N, Y, Y, N, branchType.gteu, BTYPE, ALUOps.ADD), + + LUI -> List(N, Y, N, N, Y, N, N, branchType.DC, ImmFormat.DC, ALUOps.LUI), + AUIPC -> List(N, Y, N, N, Y, N, N, branchType.DC, ImmFormat.DC, ALUOps.AUIPC), + + ADD -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.ADD), + ADDI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.ADD), + SUB -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SUB), + AND -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.AND), + ANDI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.AND), + OR -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.OR), + ORI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.OR), + XOR -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.XOR), + XORI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.XOR), + SLT -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SLT), + SLTI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.SLT), + SLTU -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SLTU), + SLTIU -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.SLTU), + SRA -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SRA), + SRAI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.SRA), + SRL -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SRL), + SRLI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.SRL), + SLL -> List(N, Y, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.SLL), + SLLI -> List(N, Y, N, N, Y, N, N, branchType.DC, ITYPE, ALUOps.SLL), /** TODO: Fill in the blanks @@ -77,7 +86,7 @@ class Decoder() extends Module { ) - val NOP = List(N, N, N, N, N, N, branchType.DC, rs1, rs2, ImmFormat.DC, ALUOps.DC) + val NOP = List(N, N, N, N, N, N, N, branchType.DC, ImmFormat.DC, ALUOps.DC) val decodedControlSignals = ListLookup( io.instruction.asUInt(), @@ -88,12 +97,11 @@ class Decoder() extends Module { io.controlSignals.regWrite := decodedControlSignals(1) io.controlSignals.memRead := decodedControlSignals(2) io.controlSignals.memWrite := decodedControlSignals(3) - io.controlSignals.branch := decodedControlSignals(4) - io.controlSignals.jump := decodedControlSignals(5) - - io.branchType := decodedControlSignals(6) - io.op1Select := decodedControlSignals(7) - io.op2Select := decodedControlSignals(8) - io.immType := decodedControlSignals(9) - io.ALUop := decodedControlSignals(10) + io.controlSignals.immediate := decodedControlSignals(4) + io.controlSignals.branch := decodedControlSignals(5) + io.controlSignals.jump := decodedControlSignals(6) + + io.branchType := decodedControlSignals(7) + io.immType := decodedControlSignals(8) + io.ALUop := decodedControlSignals(9) } diff --git a/src/main/scala/Execute.scala b/src/main/scala/Execute.scala index 8fcf37f..3b2bcf6 100644 --- a/src/main/scala/Execute.scala +++ b/src/main/scala/Execute.scala @@ -2,45 +2,35 @@ package FiveStage import chisel3._ import chisel3.util._ + class Execute extends Module { val io = IO( new Bundle { + val controlSignals = Input(new ControlSignals) + val branchType = Input(UInt(3.W)) val ALUop = Input(UInt(4.W)) - val data1 = Input(UInt(32.W)) - val data2 = Input(UInt(32.W)) + val PC = Input(UInt(32.W)) + val reg1 = Input(UInt(32.W)) + val reg2 = Input(UInt(32.W)) + val imm = Input(UInt(32.W)) val result = Output(UInt(32.W)) + val branch = Output(Bool()) }) - val result = RegInit(UInt(32.W), 0.U) - - val data1S = io.data1.asSInt - val data2S = io.data2.asSInt - val data1U = io.data1.asUInt - val data2U = io.data2.asUInt - - val ALUopMap = Array( - ALUOps.ADD -> (data1S + data2S).asUInt, - ALUOps.SUB -> (data1S - data2S).asUInt, - ALUOps.AND -> (data1U & data2U), - ALUOps.OR -> (data1U | data2U), - ALUOps.XOR -> (data1U ^ data2U), - // TODO: SLT. Set GPR? - ALUOps.SLT -> (data1S < data2S).asUInt, - ALUOps.SLTU -> (data1U < data2U), - ALUOps.SLL -> (data1S << data2U(4, 0)).asUInt, - ALUOps.SRL -> (data1U >> data2U).asUInt, - ALUOps.SRA -> (data1S >> data2U).asUInt, // TODO: SRA sign-extend? - ALUOps.COPY_A -> (data1U), - ALUOps.COPY_B -> (data2U), - - ) - - result := MuxLookup( - io.ALUop, - 0.U(32.W), - ALUopMap, - ) - - io.result := result + val alu = Module(new ALU).io + val comparator = Module(new Comparator).io + + val usePC = io.controlSignals.branch | (io.controlSignals.jump && (io.ALUop =/= ALUOps.ADDR)) | (io.ALUop === ALUOps.AUIPC) + + alu.ALUop := io.ALUop + alu.data1 := Mux(usePC, io.PC, io.reg1) + alu.data2 := Mux(io.controlSignals.immediate, io.imm, io.reg2) + + comparator.branchType := io.branchType + comparator.data1 := io.reg1 + comparator.data2 := io.reg2 + + io.result := ShiftRegister(alu.result, 1) + io.branch := ShiftRegister(io.controlSignals.jump | (comparator.result & io.controlSignals.branch), 1) } diff --git a/src/main/scala/ID.scala b/src/main/scala/ID.scala index dd27e6e..6394a61 100644 --- a/src/main/scala/ID.scala +++ b/src/main/scala/ID.scala @@ -21,16 +21,16 @@ class InstructionDecode extends MultiIOModule { /** * TODO: Your code here. */ - val PC = Input(UInt(32.W)) val instruction = Input(new Instruction) val writeEnable = Input(UInt(32.W)) val writeAddr = Input(UInt(32.W)) val writeData = Input(UInt(32.W)) val controlSignals = Output(new ControlSignals) - val data1 = Output(UInt(32.W)) - val data2 = Output(UInt(32.W)) - val data3 = Output(UInt(32.W)) + val branchType = Output(UInt(3.W)) + val reg1 = Output(UInt(32.W)) + val reg2 = Output(UInt(32.W)) + val imm = Output(UInt(32.W)) val ALUop = Output(UInt(4.W)) } ) @@ -38,10 +38,11 @@ class InstructionDecode extends MultiIOModule { val registers = Module(new Registers) val decoder = Module(new Decoder).io - val data1 = RegInit(UInt(32.W), 0.U) - val data2 = RegInit(UInt(32.W), 0.U) - val data3 = RegInit(UInt(32.W), 0.U) + val reg1 = RegInit(UInt(32.W), 0.U) + val reg2 = RegInit(UInt(32.W), 0.U) + val imm = RegInit(UInt(32.W), 0.U) val ALUop = RegInit(UInt(4.W), 0.U) + val branchType = RegInit(UInt(3.W), 0.U) val controlSignals = Reg(new ControlSignals) @@ -68,41 +69,29 @@ class InstructionDecode extends MultiIOModule { io.ALUop := ALUop controlSignals := decoder.controlSignals io.controlSignals := controlSignals + branchType := decoder.branchType + io.branchType := branchType - io.data1 := data1 - io.data2 := data2 - io.data3 := data3 + reg1 := registers.io.readData1 + reg2 := registers.io.readData2 - data3 := registers.io.readData2 + io.reg1 := reg1 + io.reg2 := reg2 + io.imm := imm - switch (decoder.op1Select) { - is (Op1Select.rs1) { - data1 := registers.io.readData1 - } - is (Op1Select.PC) { - data1 := io.PC - } - } - switch (decoder.op2Select) { - is (Op2Select.rs2) { - data2 := registers.io.readData2 - } - is (Op2Select.imm) { - val immTypeMap = Array( - ImmFormat.ITYPE -> io.instruction.immediateIType, - ImmFormat.STYPE -> io.instruction.immediateSType, - ImmFormat.BTYPE -> io.instruction.immediateBType, - ImmFormat.UTYPE -> io.instruction.immediateUType, - ImmFormat.JTYPE -> io.instruction.immediateJType, - ImmFormat.SHAMT -> 0.S, // TODO: Implement SHAMT - ) - - data2 := MuxLookup( - decoder.immType, - 0.S, - immTypeMap, - ).pad(32).asUInt - } - } + val immTypeMap = Array( + ImmFormat.ITYPE -> io.instruction.immediateIType, + ImmFormat.STYPE -> io.instruction.immediateSType, + ImmFormat.BTYPE -> io.instruction.immediateBType, + ImmFormat.UTYPE -> io.instruction.immediateUType, + ImmFormat.JTYPE -> io.instruction.immediateJType, + ImmFormat.SHAMT -> 0.S, // TODO: Implement SHAMT + ) + + imm := MuxLookup( + decoder.immType, + 0.S, + immTypeMap, + ).pad(32).asUInt } diff --git a/src/main/scala/IF.scala b/src/main/scala/IF.scala index 22ede75..b545461 100644 --- a/src/main/scala/IF.scala +++ b/src/main/scala/IF.scala @@ -25,6 +25,9 @@ class InstructionFetch extends MultiIOModule { new Bundle { val PC = Output(UInt()) val instruction = Output(new Instruction) + + val jumpEnable = Input(Bool()) + val jumpAddr = Input(UInt(32.W)) }) val IMEM = Module(new IMEM) @@ -43,9 +46,15 @@ class InstructionFetch extends MultiIOModule { * * You should expand on or rewrite the code below. */ + + val addr = Mux(io.jumpEnable, io.jumpAddr, PC + 4.U) + when (io.jumpEnable) { + //printf(p"Jump to ${Hexadecimal(addr)}\n") + } + IMEM.io.instructionAddress := PC - PC := PC + 4.U + PC := addr io.PC := PC io.instruction := IMEM.io.instruction.asTypeOf(new Instruction) diff --git a/src/main/scala/ToplevelSignals.scala b/src/main/scala/ToplevelSignals.scala index fe31613..ecf5d14 100644 --- a/src/main/scala/ToplevelSignals.scala +++ b/src/main/scala/ToplevelSignals.scala @@ -43,6 +43,7 @@ class ControlSignals extends Bundle(){ val regWrite = Bool() val memRead = Bool() val memWrite = Bool() + val immediate = Bool() val branch = Bool() val jump = Bool() } @@ -121,6 +122,9 @@ object ALUOps { val SRA = 9.U(4.W) val COPY_A = 10.U(4.W) val COPY_B = 11.U(4.W) + val ADDR = 12.U(4.W) + val LUI = 13.U(4.W) + val AUIPC = 14.U(4.W) val DC = 15.U(4.W) }