浏览代码

Add branch predictor

sindre-ex2
父节点
当前提交
9c8fc2bf21
共有 2 个文件被更改,包括 120 次插入14 次删除
  1. +98
    -0
      src/main/scala/BranchPredictor.scala
  2. +22
    -14
      src/main/scala/CPU.scala

+ 98
- 0
src/main/scala/BranchPredictor.scala 查看文件

@@ -0,0 +1,98 @@
package FiveStage

import chisel3._
import chisel3.util._


/**
* A two bit branch predictor. When given an address of an instruction it returns
* the predicted address of the next instruction. This works for all instructions.
* If it isn't in the prediction table, the following address is returned.
*/
class BranchPredictor(size: Int) extends Module {
val io = IO(new Bundle {
val PC = Input(UInt(32.W))
val nextAddr = Output(UInt(32.W))

// Set these after executing a branch instruction
// to update the prediction table
val update = Input(new Bundle {
val PC = UInt(32.W)
val branchAddr = UInt(32.W)
val branchTaken = Bool()
val enable = Bool()
})

val misjump = Output(Bool())
})

def getIndex(addr: UInt): UInt = {
// A chisel version of getTag
val bitsLeft = addr.getWidth - (log2Ceil(size) + 2)
val bitsRight = addr.getWidth - log2Ceil(size)
(addr << bitsLeft) >> bitsRight
}

class PredictionBundle extends Bundle {
val state = UInt(2.W)
val addr = UInt(32.W)
val branchAddr = UInt(32.W)
}

val table = Mem(size, new PredictionBundle)

// Whether the previous prediction was wrong. Is set on update.
// The CPU should clear the barriers when this is true.
val misjump = WireInit(false.B)
io.misjump := misjump

/* Predict the next address */

when (misjump) {
// We predicted incorrectly. Use the correct address now
io.nextAddr := Mux(io.update.branchTaken, io.update.branchAddr, io.update.PC + 4.U)
}.otherwise {
val entry = table(getIndex(io.PC))
when (entry.addr === io.PC && entry.state >= 2.U) {
// Take the branch
io.nextAddr := entry.branchAddr
}.otherwise {
// The next address is by default the following instruction
io.nextAddr := io.PC + 4.U
}
}

/* Update the prediction */

when (io.update.enable) {
val entry = table(getIndex(io.update.PC))

when (entry.addr === io.update.PC && entry.branchAddr === io.update.branchAddr) {
// Call a misjump if we predicted a jump when we shouldn't, or the other way around
misjump := entry.state >= 2.U ^ io.update.branchTaken

// Update the state
when (io.update.branchTaken) {
switch (entry.state) {
is (0.U) { entry.state := 1.U }
is (1.U) { entry.state := 3.U }
is (2.U) { entry.state := 3.U }
}
}.otherwise {
switch (entry.state) {
is (1.U) { entry.state := 0.U }
is (2.U) { entry.state := 0.U }
is (3.U) { entry.state := 2.U }
}
}
}.otherwise {
// Add a new entry to the table
entry.addr := io.update.PC
entry.branchAddr := io.update.branchAddr
entry.state := Mux(io.update.branchTaken, 2.U, 1.U)

// We only call a misjump if we jump now, or if we jumped to the wrong address on the prediction.
misjump := io.update.branchTaken || (entry.addr === io.update.PC && entry.state >= 2.U)
}
}
}

+ 22
- 14
src/main/scala/CPU.scala 查看文件

@@ -48,11 +48,6 @@ class CPU extends MultiIOModule {
testHarness.currentPC := IF.testHarness.PC


val NORMAL = 0.U
val LOAD_FREEZE = 1.U
val state = Reg(UInt(4.W))
state := NORMAL

val freeze = Wire(Bool())
freeze := false.B
IFBarrier.freeze := freeze
@@ -78,9 +73,22 @@ class CPU extends MultiIOModule {
forwarder.WB := WBBarrier.out
forwarder.writeback := writeback

val predictor = Module(new BranchPredictor(128)).io
predictor.update.enable := false.B
predictor.update.PC := 0.U
predictor.update.branchAddr := 0.U
predictor.update.branchTaken := false.B

when (predictor.misjump) {
// Clear the barriers on a misjump
IDBarrier.clear := true.B
EXBarrier.clear := true.B
}

// Stage 1
IFBarrier.in := IF.io
IF.io.addr := IFBarrier.out.PC + 4.U
predictor.PC := IFBarrier.out.PC
IF.io.addr := predictor.nextAddr

// Stage 2
ID.io.instruction := IFBarrier.out.instruction
@@ -94,7 +102,7 @@ class CPU extends MultiIOModule {
IDBarrier.in.ALUop := ID.io.ALUop

// Stage 3
when (forwarder.loadFreeze && state =/= LOAD_FREEZE) {
when (forwarder.loadFreeze && !RegNext(freeze)) {
// Freeze the IF and ID barriers, repeating the instruction.
// EX is cleared, so the instruction isn't computed and written twice.
freeze := true.B
@@ -102,7 +110,6 @@ class CPU extends MultiIOModule {
WBBarrier.freeze := false.B
EXBarrier.freeze := false.B
EXBarrier.clear := true.B
state := LOAD_FREEZE
}

EX.io.PC := IDBarrier.out.PC
@@ -121,6 +128,13 @@ class CPU extends MultiIOModule {
EXBarrier.in.branch := EX.io.branch

// Stage 4
when (EXBarrier.out.controlSignals.jump || EXBarrier.out.controlSignals.branch) {
predictor.update.enable := true.B
predictor.update.PC := EXBarrier.out.PC
predictor.update.branchAddr := EXBarrier.out.result
predictor.update.branchTaken := EXBarrier.out.branch
}

MEM.io.dataIn := forwarder.memWrite
MEM.io.writeEnable := EXBarrier.out.controlSignals.memWrite
MEM.io.dataAddress := EXBarrier.out.result
@@ -133,12 +147,6 @@ class CPU extends MultiIOModule {
MEMBarrier.in.dataOut := MEM.io.dataOut

// Stage 5
when (EXBarrier.out.branch) {
IF.io.addr := EXBarrier.out.result
IDBarrier.clear := true.B
EXBarrier.clear := true.B
}

ID.io.writeEnable := MEMBarrier.out.controlSignals.regWrite
ID.io.writeAddr := MEMBarrier.out.instruction.registerRd
ID.io.writeData := writeback


正在加载...
取消
保存