提交 879035f6 编写于 作者: Z zoujr

BPU: Remove is_* in BranchPrediction

上级 733e7516
......@@ -193,7 +193,7 @@ class FakePredictor(implicit p: Parameters) extends BasePredictor {
}
class BpuToFtqIO(implicit p: Parameters) extends XSBundle {
val resp = DecoupledIO(new BPUToFtqBundle())
val resp = DecoupledIO(new BpuToFtqBundle())
}
class PredictorIO(implicit p: Parameters) extends XSBundle {
......@@ -258,7 +258,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst {
}
// when(toFtq_fire) {
// final_gh := s3_gh.update(io.bpu_to_ftq.resp.bits.preds.is_br.reduce(_||_) && !io.bpu_to_ftq.resp.bits.preds.taken,
// final_gh := s3_gh.update(io.bpu_to_ftq.resp.bits.ftb_entry.brValids.reduce(_||_) && !io.bpu_to_ftq.resp.bits.preds.taken,
// io.bpu_to_ftq.resp.bits.preds.taken)
// }
......@@ -321,7 +321,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst {
predictors.io.s3_fire := s3_fire
io.bpu_to_ftq.resp.valid := s3_valid && !io.ftq_to_bpu.redirect.valid
io.bpu_to_ftq.resp.bits := BPUToFtqBundle(predictors.io.out.resp.s3)
io.bpu_to_ftq.resp.bits := BpuToFtqBundle(predictors.io.out.resp.s3)
io.bpu_to_ftq.resp.bits.meta := predictors.io.out.s3_meta
io.bpu_to_ftq.resp.bits.ghist := s3_ghist
......@@ -331,14 +331,14 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst {
// History manage
// s1
val s1_sawNTBr = Mux(resp.s1.preds.hit,
resp.s1.preds.is_br.zip(resp.s1.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
resp.s1.ftb_entry.brValids.zip(resp.s1.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
false.B)
val s1_takenOnBr = resp.s1.preds.real_br_taken_mask.asUInt =/= 0.U
val s1_takenOnBr = resp.s1.real_br_taken_mask.asUInt =/= 0.U
val s1_predicted_ghist = s1_ghist.update(s1_sawNTBr, s1_takenOnBr)
XSDebug(p"s1_sawNTBR=${s1_sawNTBr}, resp.s1.hit=${resp.s1.preds.hit}, is_br=${Binary(resp.s1.preds.is_br.asUInt)}, taken_mask=${Binary(resp.s1.preds.taken_mask.asUInt)}\n")
XSDebug(p"s1_takenOnBr=$s1_takenOnBr, real_taken_mask=${Binary(resp.s1.preds.real_taken_mask.asUInt)}\n")
XSDebug(p"s1_sawNTBR=${s1_sawNTBr}, resp.s1.hit=${resp.s1.preds.hit}, is_br=${Binary(resp.s1.ftb_entry.brValids.asUInt)}, taken_mask=${Binary(resp.s1.preds.taken_mask.asUInt)}\n")
XSDebug(p"s1_takenOnBr=$s1_takenOnBr, real_taken_mask=${Binary(resp.s1.real_taken_mask.asUInt)}\n")
XSDebug(p"s1_predicted_ghist=${Binary(s1_predicted_ghist.asUInt)}\n")
// when(s1_valid) {
// s0_ghist := s1_predicted_ghist
......@@ -351,9 +351,9 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst {
// s2
val s2_sawNTBr = Mux(resp.s2.preds.hit,
resp.s2.preds.is_br.zip(resp.s2.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
resp.s2.ftb_entry.brValids.zip(resp.s2.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
false.B)
val s2_takenOnBr = resp.s2.preds.real_br_taken_mask.asUInt =/= 0.U
val s2_takenOnBr = resp.s2.real_br_taken_mask.asUInt =/= 0.U
val s2_predicted_ghist = s2_ghist.update(s2_sawNTBr, s2_takenOnBr)
val s2_correct_s1_ghist = s1_ghist =/= s2_predicted_ghist
......@@ -375,9 +375,9 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst {
// s3
val s3_sawNTBr = Mux(resp.s3.preds.hit,
resp.s3.preds.is_br.zip(resp.s3.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
resp.s3.ftb_entry.brValids.zip(resp.s3.preds.taken_mask).map{ case (b, t) => b && !t }.reduce(_||_),
false.B)
val s3_takenOnBr = resp.s3.preds.real_br_taken_mask.asUInt =/= 0.U
val s3_takenOnBr = resp.s3.real_br_taken_mask.asUInt =/= 0.U
val s3_predicted_ghist = s3_ghist.update(s3_sawNTBr, s3_takenOnBr)
val s3_correct_s2_ghist = s2_ghist =/= s3_predicted_ghist
val s3_correct_s1_ghist = s1_ghist =/= s3_predicted_ghist
......
......@@ -85,7 +85,7 @@ class BIM(implicit p: Parameters) extends BasePredictor with BimParams with BPUU
))
val update_mask = LowerMask(PriorityEncoderOH(update.preds.taken_mask.asUInt))
val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.preds.is_br(i) && update_mask(i)))
val need_to_update = VecInit((0 until numBr).map(i => u_valid && update.ftb_entry.brValids(i) && update_mask(i)))
when (reset.asBool) { wrbypass_ctr_valids.foreach(_ := VecInit(Seq.fill(numBr)(false.B)))}
......@@ -129,7 +129,7 @@ class BIM(implicit p: Parameters) extends BasePredictor with BimParams with BPUU
XSDebug(latch_s0_fire, "last_cycle req %d: ctr=%b\n", i.U, s1_read(i))
}
XSDebug(u_valid, "update_pc=%x, update_idx=%d, is_br=%b\n", update.pc, u_idx, update.preds.is_br.asUInt)
XSDebug(u_valid, "update_pc=%x, update_idx=%d, is_br=%b\n", update.pc, u_idx, update.ftb_entry.brValids.asUInt)
XSDebug(u_valid, "newTakens=%b\n", newTakens.asUInt)
......
......@@ -192,11 +192,11 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU
val s1_latch_call_is_rvc = DontCare // TODO: modify when add RAS
io.out.resp.s2.preds.taken_mask := io.in.bits.resp_in(0).s2.preds.taken_mask
io.out.resp.s2.preds.is_br := ftb_entry.brValids
io.out.resp.s2.preds.is_jal := ftb_entry.jmpValid && !ftb_entry.isJalr
io.out.resp.s2.preds.is_jalr := ftb_entry.isJalr
io.out.resp.s2.preds.is_call := ftb_entry.isCall
io.out.resp.s2.preds.is_ret := ftb_entry.isRet
// io.out.resp.s2.preds.is_br := ftb_entry.brValids
// io.out.resp.s2.preds.is_jal := ftb_entry.jmpValid && !ftb_entry.isJalr
// io.out.resp.s2.preds.is_jalr := ftb_entry.isJalr
// io.out.resp.s2.preds.is_call := ftb_entry.isCall
// io.out.resp.s2.preds.is_ret := ftb_entry.isRet
io.out.resp.s2.preds.hit := s2_hit
io.out.resp.s2.preds.target := s2_target
......@@ -248,7 +248,7 @@ class FTB(implicit p: Parameters) extends BasePredictor with FTBParams with BPUU
XSDebug("req_v=%b, req_pc=%x, ready=%b (resp at next cycle)\n", io.s0_fire, s0_pc, ftbBank.io.read_pc.ready)
XSDebug("s2_hit=%b, hit_way=%b\n", s2_hit, writeWay.asUInt)
XSDebug("s2_taken_mask=%b, s2_real_taken_mask=%b\n",
io.in.bits.resp_in(0).s2.preds.taken_mask.asUInt, io.out.resp.s2.preds.real_taken_mask().asUInt)
io.in.bits.resp_in(0).s2.preds.taken_mask.asUInt, io.out.resp.s2.real_taken_mask().asUInt)
XSDebug("s2_target=%x\n", s2_target)
ftb_entry.display(true.B)
......
......@@ -85,29 +85,16 @@ class TableAddr(val idxBits: Int, val banks: Int)(implicit p: Parameters) extend
}
class BranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUConst {
val taken_mask = Vec(numBr, Bool())
val is_br = Vec(numBr, Bool())
val is_jal = Bool()
val is_jalr = Bool()
val is_call = Bool()
val is_ret = Bool()
val call_is_rvc = Bool()
// val is_br = Vec(numBr, Bool())
// val is_jal = Bool()
// val is_jalr = Bool()
// val is_call = Bool()
// val is_ret = Bool()
// val call_is_rvc = Bool()
val target = UInt(VAddrBits.W)
val hit = Bool()
def taken = taken_mask.reduce(_||_) // || (is_jal || is_jalr)
def real_taken_mask(): Vec[Bool] = {
Mux(hit,
VecInit(taken_mask.zip(is_br).map{ case(m, b) => m && b } :+ (is_jal || is_jalr)),
VecInit(Seq.fill(numBr+1)(false.B)))
}
def real_br_taken_mask(): Vec[Bool] = {
Mux(hit,
VecInit(taken_mask.zip(is_br).map{ case(m, b) => m && b }),
VecInit(Seq.fill(numBr)(false.B)))
}
def hit_taken_on_call = !VecInit(real_taken_mask.take(numBr)).asUInt.orR && hit && is_call
def hit_taken_on_ret = !VecInit(real_taken_mask.take(numBr)).asUInt.orR && hit && is_ret
// override def toPrintable: Printable = {
// p"-----------BranchPrediction----------- " +
......@@ -118,7 +105,7 @@ class BranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUConst
// }
def display(cond: Bool): Unit = {
XSDebug(cond, p"[taken_mask] ${Binary(taken_mask.asUInt)} [is_br] ${Binary(is_br.asUInt)} [is_jal] $is_jal [is_jalr] $is_jalr [is_call] $is_call [is_ret] $is_ret\n")
XSDebug(cond, p"[taken_mask] ${Binary(taken_mask.asUInt)}\n")
XSDebug(cond, p"[hit] $hit [target] ${Hexadecimal(target)}\n")
}
}
......@@ -136,6 +123,20 @@ class BranchPredictionBundle(implicit p: Parameters) extends XSBundle with HasBP
val ftb_entry = new FTBEntry() // TODO: Send this entry to ftq
def real_taken_mask(): Vec[Bool] = {
Mux(preds.hit,
VecInit(preds.taken_mask.zip(ftb_entry.brValids).map{ case(m, b) => m && b } :+ ftb_entry.jmpValid),
VecInit(Seq.fill(numBr+1)(false.B)))
}
def real_br_taken_mask(): Vec[Bool] = {
Mux(preds.hit,
VecInit(preds.taken_mask.zip(ftb_entry.brValids).map{ case(m, b) => m && b }),
VecInit(Seq.fill(numBr)(false.B)))
}
def hit_taken_on_call = !VecInit(real_taken_mask.take(numBr)).asUInt.orR && preds.hit && ftb_entry.isCall
def hit_taken_on_ret = !VecInit(real_taken_mask.take(numBr)).asUInt.orR && preds.hit && ftb_entry.isRet
// override def toPrintable: Printable = {
// p"-----------BranchPredictionBundle----------- " +
// p"[pc] ${Hexadecimal(pc)} " +
......@@ -159,13 +160,13 @@ class BranchPredictionResp(implicit p: Parameters) extends XSBundle with HasBPUC
val s3 = new BranchPredictionBundle()
}
class BPUToFtqBundle(implicit p: Parameters) extends BranchPredictionBundle with HasBPUConst {
class BpuToFtqBundle(implicit p: Parameters) extends BranchPredictionBundle with HasBPUConst {
val meta = UInt(MaxMetaLength.W)
}
object BPUToFtqBundle {
def apply(resp: BranchPredictionBundle)(implicit p: Parameters): BPUToFtqBundle = {
val e = Wire(new BPUToFtqBundle())
object BpuToFtqBundle {
def apply(resp: BranchPredictionBundle)(implicit p: Parameters): BpuToFtqBundle = {
val e = Wire(new BpuToFtqBundle())
e.pc := resp.pc
e.preds := resp.preds
e.ghist := resp.ghist
......
......@@ -444,10 +444,10 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe
entry_fetch_status(enqIdx) := f_to_send
commitStateQueue(enqIdx) := VecInit(Seq.fill(PredictWidth)(c_invalid))
entry_hit_status(enqIdx) := Mux(io.fromBpu.resp.bits.preds.hit, h_hit, h_not_hit) // pd may change it to h_false_hit
enq_cfiIndex.valid := preds.real_taken_mask.asUInt.orR
enq_cfiIndex.valid := io.fromBpu.resp.bits.real_taken_mask.asUInt.orR
// when no takens, set cfiIndex to PredictWidth-1
enq_cfiIndex.bits := ParallelPriorityMux(preds.real_taken_mask, ftb_entry.getOffsetVec) |
Fill(log2Ceil(PredictWidth), (!preds.real_taken_mask.asUInt.orR).asUInt)
enq_cfiIndex.bits := ParallelPriorityMux(io.fromBpu.resp.bits.real_taken_mask, ftb_entry.getOffsetVec) |
Fill(log2Ceil(PredictWidth), (!io.fromBpu.resp.bits.real_taken_mask.asUInt.orR).asUInt)
cfiIndex_vec(enqIdx) := enq_cfiIndex
mispredict_vec(enqIdx) := WireInit(VecInit(Seq.fill(PredictWidth)(false.B)))
update_target(enqIdx) := preds.target
......@@ -817,12 +817,12 @@ class Ftq(implicit p: Parameters) extends XSModule with HasCircularQueuePtrHelpe
}
val preds = update.preds
preds.is_br := update_ftb_entry.brValids
preds.is_jal := update_ftb_entry.jmpValid && !update_ftb_entry.isJalr
preds.is_jalr := update_ftb_entry.jmpValid && update_ftb_entry.isJalr
preds.is_call := update_ftb_entry.jmpValid && update_ftb_entry.isCall
preds.is_ret := update_ftb_entry.jmpValid && update_ftb_entry.isRet
preds.call_is_rvc := update_ftb_entry.jmpValid && update_ftb_entry.isCall && update_ftb_entry.last_is_rvc
// preds.is_br := update_ftb_entry.brValids
// preds.is_jal := update_ftb_entry.jmpValid && !update_ftb_entry.isJalr
// preds.is_jalr := update_ftb_entry.jmpValid && update_ftb_entry.isJalr
// preds.is_call := update_ftb_entry.jmpValid && update_ftb_entry.isCall
// preds.is_ret := update_ftb_entry.jmpValid && update_ftb_entry.isRet
// preds.call_is_rvc := update_ftb_entry.jmpValid && update_ftb_entry.isCall && update_ftb_entry.last_is_rvc
preds.target := commit_target
preds.taken_mask := ftbEntryGen.taken_mask
......
......@@ -158,8 +158,8 @@ class RAS(implicit p: Parameters) extends BasePredictor {
val spec_top_addr = spec_ras.top.retAddr
// confirm that the call/ret is the taken cfi
spec_push := io.s3_fire && io.in.bits.resp_in(0).s3.preds.hit_taken_on_call
spec_pop := io.s3_fire && io.in.bits.resp_in(0).s3.preds.hit_taken_on_ret
spec_push := io.s3_fire && io.in.bits.resp_in(0).s3.hit_taken_on_call
spec_pop := io.s3_fire && io.in.bits.resp_in(0).s3.hit_taken_on_ret
when (spec_pop) {
io.out.resp.s3.preds.target := spec_top_addr
......
......@@ -523,8 +523,8 @@ class Tage(implicit p: Parameters) extends BaseTage {
val fallThruAddr = getFallThroughAddr(s3_pc, ftb_entry.carry, ftb_entry.pftAddr)
when(ftb_hit) {
io.out.resp.s3.preds.target := Mux((resp_s3.preds.real_taken_mask.asUInt & ftb_entry.brValids.asUInt) =/= 0.U,
PriorityMux(resp_s3.preds.real_taken_mask.asUInt & ftb_entry.brValids.asUInt, ftb_entry.brTargets),
io.out.resp.s3.preds.target := Mux((resp_s3.real_taken_mask.asUInt & ftb_entry.brValids.asUInt) =/= 0.U,
PriorityMux(resp_s3.real_taken_mask.asUInt & ftb_entry.brValids.asUInt, ftb_entry.brTargets),
Mux(ftb_entry.jmpValid, ftb_entry.jmpTarget, fallThruAddr))
}
......
......@@ -229,13 +229,15 @@ class MicroBTB(implicit p: Parameters) extends BasePredictor
io.out.resp.s1.pc := s1_pc
io.out.resp.s1.preds.target := Mux(banks.read_hit, read_resps.target, s1_pc + (FetchWidth*4).U)
io.out.resp.s1.preds.taken_mask := read_resps.taken_mask
io.out.resp.s1.preds.is_br := read_resps.brValids
// io.out.resp.s1.preds.is_br := read_resps.brValids
io.out.resp.s1.preds.hit := banks.read_hit
// io.out.bits.resp.s1.preds.is_jal := read_resps.jmpValid && !(read_resps.isCall || read_resps.isRet || read_resps.isJalr)
// io.out.bits.resp.s1.preds.is_jalr := read_resps.jmpValid && read_resps.isJalr
// io.out.bits.resp.s1.preds.is_call := read_resps.jmpValid && read_resps.isCall
// io.out.bits.resp.s1.preds.is_ret := read_resps.jmpValid && read_resps.isRet
// io.out.bits.resp.s1.preds.call_is_rvc := read_resps.last_is_rvc
io.out.resp.s1.ftb_entry := DontCare
io.out.resp.s1.ftb_entry.brValids := read_resps.brValids
io.out.s3_meta := RegEnable(RegEnable(read_resps.asUInt, io.s1_fire), io.s2_fire)
// Update logic
......@@ -250,7 +252,7 @@ class MicroBTB(implicit p: Parameters) extends BasePredictor
// val u_target_lower = update.preds.target(lowerBitSize-1+instOffsetBits, instOffsetBits)
val data_write_valid = u_valid && u_taken
val meta_write_valid = u_valid && (u_taken || update.preds.is_br.reduce(_||_))
val meta_write_valid = u_valid && (u_taken || update.ftb_entry.brValids.reduce(_||_))
val update_write_datas = Wire(new MicroBTBData)
val update_write_metas = Wire(new MicroBTBMeta)
......@@ -260,8 +262,8 @@ class MicroBTB(implicit p: Parameters) extends BasePredictor
update_write_metas.valid := true.B
update_write_metas.tag := u_tag
// brOffset
update_write_metas.brValids := update.preds.is_br
update_write_metas.jmpValid := update.preds.is_jal || update.preds.is_jalr // || update.preds.is_call || update.preds.is_ret
update_write_metas.brValids := update.ftb_entry.brValids
update_write_metas.jmpValid := update.ftb_entry.jmpValid
// isJalr
// isCall
// isRet
......@@ -287,7 +289,7 @@ class MicroBTB(implicit p: Parameters) extends BasePredictor
XSDebug(u_valid, "Update from ftq\n")
XSDebug(u_valid, "update_pc=%x, tag=%x\n", u_pc, getTag(u_pc))
XSDebug(u_valid, "taken_mask=%b, brValids=%b, jmpValid=%b\n",
u_taken_mask.asUInt, update.preds.is_br.asUInt, update.preds.is_jal || update.preds.is_jalr)
u_taken_mask.asUInt, update.ftb_entry.brValids.asUInt, update.ftb_entry.jmpValid)
XSPerfAccumulate("ubtb_read_hits", RegNext(io.s1_fire) && banks.read_hit)
XSPerfAccumulate("ubtb_read_misses", RegNext(io.s1_fire) && !banks.read_hit)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册