提交 8d22bbae 编写于 作者: Z zhanglinjuan

bpu: support prediction of RVC

上级 caa40861
......@@ -118,6 +118,7 @@ class BPUStage1 extends XSModule {
val btbValids = btb.io.out.hits
val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
val btbTypes = VecInit(btb.io.out.dEntries.map(_._type))
val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
val jbtac = Module(new JBTAC)
......@@ -158,6 +159,10 @@ class BPUStage1 extends XSModule {
updateGhr := io.flush || io.s1OutPred.bits.redirect || RegNext(io.in.pc.fire) && (btbNotTakens.asUInt & maskLatch).orR.asBool
val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
// if backend redirects, restore history from backend;
// if stage3 redirects, restore history from stage3;
// if stage1 redirects, speculatively update history;
// if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r._type === BTBtype.B && !r.taken),
Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
......@@ -169,8 +174,10 @@ class BPUStage1 extends XSModule {
io.s1OutPred.valid := io.out.valid
io.s1OutPred.bits.redirect := btbTaken || jbtacHit
io.s1OutPred.bits.instrValid := Mux(io.s1OutPred.bits.redirect, LowerMask(takenIdx, PredictWidth), maskLatch).asTypeOf(Vec(PredictWidth, Bool()))
io.s1OutPred.bits.target := Mux(brJumpIdx === takenIdx, btbTakenTarget, Mux(indirectIdx === takenIdx, jbtacTarget, pcLatch + PopCount(maskLatch) << 1.U))
io.s1OutPred.bits.instrValid := Mux(!io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump, maskLatch,
Mux(!btbIsRVCs(OHToUInt(takenIdx)), LowerMask(takenIdx << 1.U, PredictWidth),
LowerMask(takenIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
// io.s1OutPred.bits.btbVictimWay := btbWriteWay
io.s1OutPred.bits.predCtr := btbCtrs
......@@ -300,15 +307,20 @@ class BPUStage3 extends XSModule {
if(HasBPD) ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
else ~Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt))
io.out.bits.redirect := jmpIdx.orR.asBool
io.out.bits.target := Mux(jmpIdx === retIdx, rasTopAddr,
val lateJump = jmpIdx === HighestBit(io.predecode.bits.mask, PredictWidth) && !io.predecode.bits.isRVC(OHToUInt(jmpIdx))
io.out.bits.target := Mux(jmpIdx === 0.U, inLatch.pc + (PopCount(io.predecode.bits.mask) << 1.U),
Mux(jmpIdx === retIdx, rasTopAddr,
Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
Mux(jmpIdx === 0.U, inLatch.pc + 32.U, // TODO: RVC
PriorityMux(jmpIdx, inLatch.btb.targets))))
io.out.bits.instrValid := Mux(jmpIdx.orR, LowerMask(jmpIdx, FetchWidth), Fill(FetchWidth, 1.U(1.W))).asTypeOf(Vec(FetchWidth, Bool()))
PriorityMux(jmpIdx, inLatch.btb.targets)))) // TODO: jal and call's target can be calculated here
io.out.bits.instrValid := Mux(!jmpIdx.orR || lateJump, io.predecode.bits.mask,
Mux(!io.predecode.bits.isRVC(OHToUInt(jmpIdx)), LowerMask(jmpIdx << 1.U, PredictWidth),
LowerMask(jmpIdx, PredictWidth))).asTypeOf(Vec(PredictWidth, Bool()))
// io.out.bits.btbVictimWay := inLatch.btbPred.bits.btbVictimWay
io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
io.out.bits.btbHitWay := inLatch.btbPred.bits.btbHitWay
io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
//io.out.bits._type := Mux(jmpIdx === retIdx, BTBtype.R,
// Mux(jmpIdx === jalrIdx, BTBtype.I,
......@@ -317,33 +329,35 @@ class BPUStage3 extends XSModule {
// there may be several notTaken branches before the first jump instruction,
// so we need to calculate how many zeroes should each instruction shift in its global history.
// each history is exclusive of instruction's own jump direction.
val histShift = Wire(Vec(FetchWidth, UInt(log2Up(FetchWidth).W)))
val shift = Wire(Vec(FetchWidth, Vec(FetchWidth, UInt(1.W))))
(0 until FetchWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), FetchWidth)).asTypeOf(Vec(FetchWidth, UInt(1.W))))
for (j <- 0 until FetchWidth) {
val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
(0 until PredictWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
for (j <- 0 until PredictWidth) {
var tmp = 0.U
for (i <- 0 until FetchWidth) {
for (i <- 0 until PredictWidth) {
tmp = tmp + shift(i)(j)
}
histShift(j) := tmp
}
(0 until FetchWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
(0 until PredictWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
// save ras checkpoint info
io.out.bits.rasSp := sp.value
io.out.bits.rasTopCtr := rasTop.ctr
// flush BPU and redirect when target differs from the target predicted in Stage1
io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
else false.B)
// io.out.bits.redirect := (if(EnableBPD) (inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
// inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target)
// else false.B)
io.out.bits.redirect := inLatch.btbPred.bits.redirect ^ jmpIdx.orR.asBool ||
inLatch.btbPred.bits.redirect && jmpIdx.orR.asBool && io.out.bits.target =/= inLatch.btbPred.bits.target
io.flushBPU := io.out.bits.redirect && io.out.valid
// speculative update RAS
val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 2.U) + 4.U
rasWrite.retAddr := inLatch.pc + (OHToUInt(callIdx) << 1.U) + Mux(PriorityMux(callIdx, io.predecode.bits.isRVC), 2.U, 4.U)
val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
when (io.out.valid) {
when (io.out.valid && jmpIdx =/= 0.U) {
when (jmpIdx === callIdx) {
ras(Mux(allocNewEntry, sp.value + 1.U, sp.value)) := rasWrite
when (allocNewEntry) { sp.value := sp.value + 1.U }
......@@ -358,23 +372,23 @@ class BPUStage3 extends XSModule {
// use checkpoint to recover RAS
val recoverSp = io.redirectInfo.redirect.rasSp
val recoverCtr = io.redirectInfo.redirect.rasTopCtr
when (io.redirectInfo.valid && io.redirectInfo.misPred) {
when (io.redirectInfo.flush()) {
sp.value := recoverSp
ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
}
// roll back global history in S1 if S3 redirects
io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brIdx & ~inLatch.tage.takens.asUInt))
io.s1RollBackHist := Mux(io.s3Taken, PriorityMux(jmpIdx, io.out.bits.hist), io.out.bits.hist(0) << PopCount(brNotTakenIdx))
// whether Stage3 has a taken jump
io.s3Taken := jmpIdx.orR.asBool
// debug info
XSDebug(io.in.fire(), "[BPUS3]in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
XSDebug(io.out.valid, "[BPUS3]out:%d pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
XSDebug(io.out.valid, "out:%d pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
io.out.valid, inLatch.pc, io.out.bits.redirect, io.predecode.bits.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
XSDebug(true.B, "[BPUS3]flushS3=%d\n", flushS3)
XSDebug(true.B, "[BPUS3]validLatch=%d predecode.valid=%d\n", validLatch, io.predecode.valid)
XSDebug(true.B, "[BPUS3]brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
XSDebug("flushS3=%d\n", flushS3)
XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, io.predecode.valid)
XSDebug("brIdx=%b brTakenIdx=%b brNTakenIdx=%b jalIdx=%b jalrIdx=%b callIdx=%b retIdx=%b\n",
brIdx, brTakenIdx, brNotTakenIdx, jalIdx, jalrIdx, callIdx, retIdx)
// BPU's TEMP Perf Cnt
......
......@@ -54,7 +54,7 @@ class IFU extends XSModule with HasIFUConst
val if1_pc = RegInit(resetVector.U(VAddrBits.W))
//next
val if2_ready = WireInit(false.B)
val if2_snpc = snpc(if1_pc) //TODO: this is ugly
val if2_snpc = snpc(if1_pc) //TODO: calculate snpc according to mask of current fetch packet
val needflush = WireInit(false.B)
//pipe fire
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册