提交 f227c0cc 编写于 作者: L Lingrui98

BPU, TAGE: use FakeTage when disable BPD, implement a BaseTage class

上级 160e49bb
......@@ -30,7 +30,7 @@ trait HasXSParameter {
val FetchWidth = 8
val PredictWidth = FetchWidth * 2
val EnableBPU = true
val EnableBPD = true // enable backing predictor(like Tage) in BPUStage3
val EnableBPD = false // enable backing predictor(like Tage) in BPUStage3
val EnableRAS = false
val HistoryLength = 64
val ExtHistoryLength = HistoryLength * 2
......
......@@ -131,17 +131,22 @@ class BPUStage extends XSModule {
// get the last valid inst
// val lastValidPos = MuxCase(0.U, (PredictWidth-1 to 0).map(i => (inLatch.mask(i), i.U)))
val lastValidPos = PriorityMux(Reverse(inLatch.mask), (PredictWidth-1 to 0 by -1).map(i => i.U))
val lastHit = WireInit(false.B)
val lastIsRVC = WireInit(false.B)
// val lastValidPos = WireInit(0.U(log2Up(PredictWidth).W))
// for (i <- 0 until PredictWidth) {
// when (inLatch.mask(i)) { lastValidPos := i.U }
// }
val target = WireInit(0.U(VAddrBits.W))
val targetSrc = VecInit((0 until PredictWidth).map(i => 0.U(VAddrBits.W)))
val target = Mux(taken, targetSrc(jmpIdx), npc(inLatch.pc, PopCount(inLatch.mask)))
io.pred.bits <> DontCare
io.pred.bits.redirect := target =/= inLatch.target
io.pred.bits.taken := taken
io.pred.bits.jmpIdx := jmpIdx
io.pred.bits.hasNotTakenBrs := hasNTBr
io.pred.bits.target := target
io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !lastIsRVC && lastHit
io.out.bits <> DontCare
io.out.bits.pc := inLatch.pc
......@@ -181,7 +186,7 @@ class BPUStage1 extends BPUStage {
.elsewhen(outFire) { predValid := false.B }
.otherwise { predValid := predValid }
io.in.ready := !predValid || io.out.fire() && io.pred.fire() || io.flush
io.out.valid := predValid
// io.out.valid := predValid
// ubtb is accessed with inLatch pc in s1,
// so we use io.in instead of inLatch
......@@ -189,10 +194,10 @@ class BPUStage1 extends BPUStage {
// the read operation is already masked, so we do not need to mask here
takens := VecInit((0 until PredictWidth).map(i => ubtbResp.hits(i) && ubtbResp.takens(i)))
notTakens := VecInit((0 until PredictWidth).map(i => ubtbResp.hits(i) && ubtbResp.notTakens(i)))
target := Mux(taken, ubtbResp.targets(jmpIdx), npc(inLatch.pc, PopCount(inLatch.mask)))
targetSrc := ubtbResp.targets
io.pred.bits.redirect := taken
io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !ubtbResp.is_RVC(lastValidPos) && ubtbResp.hits(lastValidPos)
lastIsRVC := ubtbResp.is_RVC(lastValidPos)
lastHit := ubtbResp.hits(lastValidPos)
// resp and brInfo are from the components,
// so it does not need to be latched
......@@ -207,10 +212,10 @@ class BPUStage2 extends BPUStage {
val bimResp = inLatch.resp.bim
takens := VecInit((0 until PredictWidth).map(i => btbResp.hits(i) && (btbResp.types(i) === BrType.branch && bimResp.ctrs(i)(1) || btbResp.types(i) === BrType.jal)))
notTakens := VecInit((0 until PredictWidth).map(i => btbResp.hits(i) && btbResp.types(i) === BrType.branch && !bimResp.ctrs(i)(1)))
target := Mux(taken, btbResp.targets(jmpIdx), npc(inLatch.pc, PopCount(inLatch.mask)))
targetSrc := btbResp.targets
io.pred.bits.redirect := target =/= inLatch.target
io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !btbResp.isRVC(lastValidPos) && btbResp.hits(lastValidPos)
lastIsRVC := btbResp.isRVC(lastValidPos)
lastHit := btbResp.hits(lastValidPos)
}
class BPUStage3 extends BPUStage {
......@@ -232,11 +237,11 @@ class BPUStage3 extends BPUStage {
val brs = pdMask & Reverse(Cat(pds.map(_.isBr)))
val jals = pdMask & Reverse(Cat(pds.map(_.isJal)))
val jalrs = pdMask & Reverse(Cat(pds.map(_.isJalr)))
val calls = pdMask & Reverse(Cat(pds.map(_.isCall)))
val rets = pdMask & Reverse(Cat(pds.map(_.isRet)))
// val calls = pdMask & Reverse(Cat(pds.map(_.isCall)))
// val rets = pdMask & Reverse(Cat(pds.map(_.isRet)))
val callIdx = PriorityEncoder(calls)
val retIdx = PriorityEncoder(rets)
// val callIdx = PriorityEncoder(calls)
// val retIdx = PriorityEncoder(rets)
val brTakens =
if (EnableBPD) {
......@@ -250,11 +255,12 @@ class BPUStage3 extends BPUStage {
// Whether should we count in branches that are not recorded in btb?
// PS: Currently counted in. Whenever tage does not provide a valid
// taken prediction, the branch is counted as a not taken branch
notTakens := VecInit((0 until PredictWidth).map(i => brs(i) && !tageValidTakens(i)))
target := Mux(taken, inLatch.resp.btb.targets(jmpIdx), npc(inLatch.pc, PopCount(inLatch.mask)))
notTakens := (if (EnableBPD) { VecInit((0 until PredictWidth).map(i => brs(i) && !tageValidTakens(i)))}
else { VecInit((0 until PredictWidth).map(i => brs(i) && bimTakens(i)))})
targetSrc := inLatch.resp.btb.targets
io.pred.bits.redirect := target =/= inLatch.target
io.pred.bits.saveHalfRVI := ((lastValidPos === jmpIdx && taken) || !taken ) && !pds(lastValidPos).isRVC && pdMask(lastValidPos)
lastIsRVC := pds(lastValidPos).isRVC
lastHit := pdMask(lastValidPos)
// Wrap tage resp and tage meta in
// This is ugly
......@@ -275,7 +281,8 @@ trait BranchPredictorComponents extends HasXSParameter {
val ubtb = Module(new MicroBTB)
val btb = Module(new BTB)
val bim = Module(new BIM)
val tage = Module(new Tage)
val tage = (if(EnableBPD) { Module(new Tage) }
else { Module(new FakeTage) })
val preds = Seq(ubtb, btb, bim, tage)
preds.map(_.io := DontCare)
}
......@@ -316,6 +323,8 @@ abstract class BaseBPU extends XSModule with BranchPredictorComponents{
val branchInfo = Decoupled(Vec(PredictWidth, new BranchInfo))
})
def npc(pc: UInt, instCount: UInt) = pc + (instCount << 1.U)
preds.map(_.io.update <> io.inOrderBrInfo)
val s1 = Module(new BPUStage1)
......@@ -413,7 +422,7 @@ class BPU extends BaseBPU {
s1.io.in.valid := io.in.valid
s1.io.in.bits.pc := io.in.bits.pc
s1.io.in.bits.mask := io.in.bits.inMask
s1.io.in.bits.target := DontCare
s1.io.in.bits.target := npc(s1_inLatch.bits.pc, PopCount(s1_inLatch.bits.inMask)) // Deault target npc
s1.io.in.bits.resp := s1_resp_in
s1.io.in.bits.brInfo <> s1_brInfo_in
......
......@@ -258,11 +258,12 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
}
class FakeTAGE extends BasePredictor with HasTageParameter {
abstract class BaseTage extends BasePredictor with HasTageParameter {
class TAGEResp extends Resp {
val takens = Vec(PredictWidth, ValidUndirectioned(Bool()))
val takens = Vec(PredictWidth, Bool())
val hits = Vec(PredictWidth, Bool())
}
class TAGEMeta extends Meta {
class TAGEMeta extends Meta{
}
class FromBIM extends FromOthers {
val ctrs = Vec(PredictWidth, UInt(2.W))
......@@ -275,30 +276,15 @@ class FakeTAGE extends BasePredictor with HasTageParameter {
}
override val io = IO(new TageIO)
}
class FakeTage extends BaseTage {
io.resp <> DontCare
io.meta <> DontCare
}
class Tage extends BasePredictor with HasTageParameter {
class TAGEResp extends Resp {
val takens = Vec(PredictWidth, Bool())
val hits = Vec(PredictWidth, Bool())
}
class TAGEMeta extends Meta{
}
class FromBIM extends FromOthers {
val ctrs = Vec(PredictWidth, UInt(2.W))
}
class TageIO extends DefaultBasePredictorIO {
val resp = Output(new TAGEResp)
val meta = Output(Vec(PredictWidth, new TageMeta))
val bim = Input(new FromBIM)
val s3Fire = Input(Bool())
}
override val io = IO(new TageIO)
class Tage extends BaseTage {
val tables = TableInfo.map {
case (nRows, histLen, tagLen) => {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册