From 9c8fc2bf211813c23a078d69a8a7ad9112c3c408 Mon Sep 17 00:00:00 2001 From: Sindre Stephansen Date: Wed, 27 Nov 2019 16:34:40 +0100 Subject: [PATCH] Add branch predictor --- src/main/scala/BranchPredictor.scala | 98 ++++++++++++++++++++++++++++ src/main/scala/CPU.scala | 36 ++++++---- 2 files changed, 120 insertions(+), 14 deletions(-) create mode 100644 src/main/scala/BranchPredictor.scala diff --git a/src/main/scala/BranchPredictor.scala b/src/main/scala/BranchPredictor.scala new file mode 100644 index 0000000..3b47379 --- /dev/null +++ b/src/main/scala/BranchPredictor.scala @@ -0,0 +1,98 @@ +package FiveStage + +import chisel3._ +import chisel3.util._ + + +/** + * A two bit branch predictor. When given an address of an instruction it returns + * the predicted address of the next instruction. This works for all instructions. + * If it isn't in the prediction table, the following address is returned. + */ +class BranchPredictor(size: Int) extends Module { + val io = IO(new Bundle { + val PC = Input(UInt(32.W)) + val nextAddr = Output(UInt(32.W)) + + // Set these after executing a branch instruction + // to update the prediction table + val update = Input(new Bundle { + val PC = UInt(32.W) + val branchAddr = UInt(32.W) + val branchTaken = Bool() + val enable = Bool() + }) + + val misjump = Output(Bool()) + }) + + def getIndex(addr: UInt): UInt = { + // A chisel version of getTag + val bitsLeft = addr.getWidth - (log2Ceil(size) + 2) + val bitsRight = addr.getWidth - log2Ceil(size) + (addr << bitsLeft) >> bitsRight + } + + class PredictionBundle extends Bundle { + val state = UInt(2.W) + val addr = UInt(32.W) + val branchAddr = UInt(32.W) + } + + val table = Mem(size, new PredictionBundle) + + // Whether the previous prediction was wrong. Is set on update. + // The CPU should clear the barriers when this is true. + val misjump = WireInit(false.B) + io.misjump := misjump + + /* Predict the next address */ + + when (misjump) { + // We predicted incorrectly. Use the correct address now + io.nextAddr := Mux(io.update.branchTaken, io.update.branchAddr, io.update.PC + 4.U) + }.otherwise { + val entry = table(getIndex(io.PC)) + when (entry.addr === io.PC && entry.state >= 2.U) { + // Take the branch + io.nextAddr := entry.branchAddr + }.otherwise { + // The next address is by default the following instruction + io.nextAddr := io.PC + 4.U + } + } + + /* Update the prediction */ + + when (io.update.enable) { + val entry = table(getIndex(io.update.PC)) + + when (entry.addr === io.update.PC && entry.branchAddr === io.update.branchAddr) { + // Call a misjump if we predicted a jump when we shouldn't, or the other way around + misjump := entry.state >= 2.U ^ io.update.branchTaken + + // Update the state + when (io.update.branchTaken) { + switch (entry.state) { + is (0.U) { entry.state := 1.U } + is (1.U) { entry.state := 3.U } + is (2.U) { entry.state := 3.U } + } + }.otherwise { + switch (entry.state) { + is (1.U) { entry.state := 0.U } + is (2.U) { entry.state := 0.U } + is (3.U) { entry.state := 2.U } + } + } + }.otherwise { + // Add a new entry to the table + entry.addr := io.update.PC + entry.branchAddr := io.update.branchAddr + entry.state := Mux(io.update.branchTaken, 2.U, 1.U) + + // We only call a misjump if we jump now, or if we jumped to the wrong address on the prediction. + misjump := io.update.branchTaken || (entry.addr === io.update.PC && entry.state >= 2.U) + } + } +} diff --git a/src/main/scala/CPU.scala b/src/main/scala/CPU.scala index b459e13..cfdb8c4 100644 --- a/src/main/scala/CPU.scala +++ b/src/main/scala/CPU.scala @@ -48,11 +48,6 @@ class CPU extends MultiIOModule { testHarness.currentPC := IF.testHarness.PC - val NORMAL = 0.U - val LOAD_FREEZE = 1.U - val state = Reg(UInt(4.W)) - state := NORMAL - val freeze = Wire(Bool()) freeze := false.B IFBarrier.freeze := freeze @@ -78,9 +73,22 @@ class CPU extends MultiIOModule { forwarder.WB := WBBarrier.out forwarder.writeback := writeback + val predictor = Module(new BranchPredictor(128)).io + predictor.update.enable := false.B + predictor.update.PC := 0.U + predictor.update.branchAddr := 0.U + predictor.update.branchTaken := false.B + + when (predictor.misjump) { + // Clear the barriers on a misjump + IDBarrier.clear := true.B + EXBarrier.clear := true.B + } + // Stage 1 IFBarrier.in := IF.io - IF.io.addr := IFBarrier.out.PC + 4.U + predictor.PC := IFBarrier.out.PC + IF.io.addr := predictor.nextAddr // Stage 2 ID.io.instruction := IFBarrier.out.instruction @@ -94,7 +102,7 @@ class CPU extends MultiIOModule { IDBarrier.in.ALUop := ID.io.ALUop // Stage 3 - when (forwarder.loadFreeze && state =/= LOAD_FREEZE) { + when (forwarder.loadFreeze && !RegNext(freeze)) { // Freeze the IF and ID barriers, repeating the instruction. // EX is cleared, so the instruction isn't computed and written twice. freeze := true.B @@ -102,7 +110,6 @@ class CPU extends MultiIOModule { WBBarrier.freeze := false.B EXBarrier.freeze := false.B EXBarrier.clear := true.B - state := LOAD_FREEZE } EX.io.PC := IDBarrier.out.PC @@ -121,6 +128,13 @@ class CPU extends MultiIOModule { EXBarrier.in.branch := EX.io.branch // Stage 4 + when (EXBarrier.out.controlSignals.jump || EXBarrier.out.controlSignals.branch) { + predictor.update.enable := true.B + predictor.update.PC := EXBarrier.out.PC + predictor.update.branchAddr := EXBarrier.out.result + predictor.update.branchTaken := EXBarrier.out.branch + } + MEM.io.dataIn := forwarder.memWrite MEM.io.writeEnable := EXBarrier.out.controlSignals.memWrite MEM.io.dataAddress := EXBarrier.out.result @@ -133,12 +147,6 @@ class CPU extends MultiIOModule { MEMBarrier.in.dataOut := MEM.io.dataOut // Stage 5 - when (EXBarrier.out.branch) { - IF.io.addr := EXBarrier.out.result - IDBarrier.clear := true.B - EXBarrier.clear := true.B - } - ID.io.writeEnable := MEMBarrier.out.controlSignals.regWrite ID.io.writeAddr := MEMBarrier.out.instruction.registerRd ID.io.writeData := writeback