提交 80d2974b 编写于 作者: L Lingrui98

BPU: Initiate refactoring

上级 f226232f
......@@ -60,7 +60,7 @@ class BranchInfo extends XSBundle {
this.histPtr := histPtr
this.tageMeta := tageMeta
this.rasSp := rasSp
this.rasTopCtr
this.rasTopCtr := rasTopCtr
this.asUInt
}
def size = 0.U.asTypeOf(this).getWidth
......@@ -142,7 +142,6 @@ class Redirect extends XSBundle with HasRoqIdx {
val pc = UInt(VAddrBits.W)
val target = UInt(VAddrBits.W)
val brTag = new BrqPtr
val histPtr = UInt(log2Up(ExtHistoryLength).W)
}
class Dp1ToDp2IO extends XSBundle {
......
......@@ -7,21 +7,46 @@ import xiangshan._
import xiangshan.backend.ALUOpType
import xiangshan.backend.JumpOpType
class BPUStage1To2IO extends XSBundle {
// TODO
class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
def tagBits = VAddrBits - idxBits - 1
val tag = UInt(tagBits.W)
val idx = UInt(idxBits.W)
val offset = UInt(1.W)
def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this)
def getTag(x: UInt) = fromUInt(x).tag
def getIdx(x: UInt) = fromUInt(x).idx
def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0)
def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks))
}
class BPUStage2To3IO extends XSBundle {
// TODO
class BTBResponse extends XSBundle {
// the valid bits indicates whether a target is hit
val ubtb = new Bundle {
val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
val takens = Vec(PredictWidth, Bool())
}
// the valid bits indicates whether a target is hit
val btb = new Bundle {
val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
val takens = Vec(PredictWidth, Bool())
}
}
class BPUStageIO extends XSBundle {
val pc = Output(UInt(VAddrBits.W))
val btbResp = Output(new BTBResponse)
val brInfo = Output(Vec(PredictWidth, new BranchInfo))
}
class BPUStage1 extends XSModule {
val io = IO(new Bundle() {
val flush = Input(Bool())
val in = new Bundle { val pc = Flipped(ValidIO(UInt(VAddrBits.W))) }
val s1_out = Decoupled(new BranchPrediction)
val out = Decoupled(new BPUStage1To2IO)
val redirect = Flipped(ValidIO(new Redirect)) // used to fix ghr
val pred = Decoupled(new BranchPrediction)
val out = Decoupled(new BPUStageIO)
val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
})
......@@ -31,29 +56,27 @@ class BPUStage1 extends XSModule {
class BPUStage2 extends XSModule {
val io = IO(new Bundle() {
val flush = Input(Bool())
val in = Flipped(Decoupled(new BPUStage1To2IO))
val s2_out = Decoupled(new BranchPrediction)
val out = Decoupled(new BPUStage2To3IO)
val in = Flipped(Decoupled(new BPUStageIO))
val pred = Decoupled(new BranchPrediction)
val out = Decoupled(new BPUStageIO)
val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) // delete this if useless
val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) // delete this if useless
})
}
class BPUStage3 extends XSModule {
val io = IO(new Bundle() {
val flush = Input(Bool())
val in = Flipped(Decoupled(new BPUStage2To3IO))
val s3_out = Decoupled(new BranchPrediction)
val in = Flipped(Decoupled(new BPUStageIO))
val pred = Decoupled(new BranchPrediction)
val predecode = Flipped(ValidIO(new Predecode))
})
}
class BPU extends XSModule {
class BaseBPU extends XSModule {
val io = IO(new Bundle() {
// from backend
val redirect = Flipped(ValidIO(new Redirect))
val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
// from ifu, frontend redirect
......@@ -65,8 +88,19 @@ class BPU extends XSModule {
// from if4
val predecode = Flipped(ValidIO(new Predecode))
// to if4, some bpu info used for updating
val branchInfo = Decoupled(new BranchInfo)
val branchInfo = Decoupled(Vec(PredictWidth, new BranchInfo))
})
}
class FakeBPU extends BaseBPU {
io.out.foreach(i => {
i <> DontCare
i.redirect := false.B
})
io.branchInfo <> DontCare
}
class BPU extends BaseBPU {
val s1 = Module(new BPUStage1)
val s2 = Module(new BPUStage2)
......@@ -80,9 +114,9 @@ class BPU extends XSModule {
s2.io.in <> s1.io.out
s3.io.in <> s2.io.out
io.out(0) <> s1.io.s1_out
io.out(1) <> s2.io.s2_out
io.out(2) <> s3.io.s3_out
io.out(0) <> s1.io.pred
io.out(1) <> s2.io.pred
io.out(2) <> s3.io.pred
s1.io.redirect <> io.redirect
s1.io.outOfOrderBrInfo <> io.outOfOrderBrInfo
......@@ -92,3 +126,491 @@ class BPU extends XSModule {
s3.io.predecode <> io.predecode
}
// class BPUStage1 extends XSModule {
// val io = IO(new Bundle() {
// val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) }
// // from backend
// val redirectInfo = Input(new RedirectInfo)
// // from Stage3
// val flush = Input(Bool())
// val s3RollBackHist = Input(UInt(HistoryLength.W))
// val s3Taken = Input(Bool())
// // to ifu, quick prediction result
// val s1OutPred = ValidIO(new BranchPrediction)
// // to Stage2
// val out = Decoupled(new Stage1To2IO)
// })
// io.in.pc.ready := true.B
// // flush Stage1 when io.flush
// val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
// val s1OutPredLatch = RegEnable(io.s1OutPred.bits, RegNext(io.in.pc.fire()))
// val outLatch = RegEnable(io.out.bits, RegNext(io.in.pc.fire()))
// val s1Valid = RegInit(false.B)
// when (io.flush) {
// s1Valid := true.B
// }.elsewhen (io.in.pc.fire()) {
// s1Valid := true.B
// }.elsewhen (io.out.fire()) {
// s1Valid := false.B
// }
// io.out.valid := s1Valid
// // global history register
// val ghr = RegInit(0.U(HistoryLength.W))
// // modify updateGhr and newGhr when updating ghr
// val updateGhr = WireInit(false.B)
// val newGhr = WireInit(0.U(HistoryLength.W))
// when (updateGhr) { ghr := newGhr }
// // use hist as global history!!!
// val hist = Mux(updateGhr, newGhr, ghr)
// // Tage predictor
// val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE)
// tage.io.req.valid := io.in.pc.fire()
// tage.io.req.bits.pc := io.in.pc.bits
// tage.io.req.bits.hist := hist
// tage.io.redirectInfo <> io.redirectInfo
// io.s1OutPred.bits.tageMeta := tage.io.meta
// // latch pc for 1 cycle latency when reading SRAM
// val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire())
// // TODO: pass real mask in
// // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire())
// val maskLatch = Fill(PredictWidth, 1.U(1.W))
// val r = io.redirectInfo.redirect
// val updateFetchpc = r.pc - (r.fetchIdx << 1.U)
// // BTB
// val btb = Module(new BTB)
// btb.io.in.pc <> io.in.pc
// btb.io.in.pcLatch := pcLatch
// // TODO: pass real mask in
// btb.io.in.mask := Fill(PredictWidth, 1.U(1.W))
// btb.io.redirectValid := io.redirectInfo.valid
// btb.io.flush := io.flush
// // btb.io.update.fetchPC := updateFetchpc
// // btb.io.update.fetchIdx := r.fetchIdx
// btb.io.update.pc := r.pc
// btb.io.update.hit := r.btbHit
// btb.io.update.misPred := io.redirectInfo.misPred
// // btb.io.update.writeWay := r.btbVictimWay
// btb.io.update.oldCtr := r.btbPredCtr
// btb.io.update.taken := r.taken
// btb.io.update.target := r.brTarget
// btb.io.update.btbType := r.btbType
// // TODO: add RVC logic
// btb.io.update.isRVC := r.isRVC
// // val btbHit = btb.io.out.hit
// val btbTaken = btb.io.out.taken
// val btbTakenIdx = btb.io.out.takenIdx
// val btbTakenTarget = btb.io.out.target
// // val btbWriteWay = btb.io.out.writeWay
// val btbNotTakens = btb.io.out.notTakens
// val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred))
// val btbValids = btb.io.out.hits
// val btbTargets = VecInit(btb.io.out.dEntries.map(_.target))
// val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType))
// val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC))
// val jbtac = Module(new JBTAC)
// jbtac.io.in.pc <> io.in.pc
// jbtac.io.in.pcLatch := pcLatch
// // TODO: pass real mask in
// jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W))
// jbtac.io.in.hist := hist
// jbtac.io.redirectValid := io.redirectInfo.valid
// jbtac.io.flush := io.flush
// jbtac.io.update.fetchPC := updateFetchpc
// jbtac.io.update.fetchIdx := r.fetchIdx
// jbtac.io.update.misPred := io.redirectInfo.misPred
// jbtac.io.update.btbType := r.btbType
// jbtac.io.update.target := r.target
// jbtac.io.update.hist := r.hist
// jbtac.io.update.isRVC := r.isRVC
// val jbtacHit = jbtac.io.out.hit
// val jbtacTarget = jbtac.io.out.target
// val jbtacHitIdx = jbtac.io.out.hitIdx
// val jbtacIsRVC = jbtac.io.out.isRVC
// // calculate global history of each instr
// val firstHist = RegNext(hist)
// val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
// val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
// (0 until PredictWidth).foreach(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
// for (j <- 0 until PredictWidth) {
// var tmp = 0.U
// for (i <- 0 until PredictWidth) {
// tmp = tmp + shift(i)(j)
// }
// histShift(j) := tmp
// }
// // update ghr
// updateGhr := io.s1OutPred.bits.redirect ||
// RegNext(io.in.pc.fire) && ~io.s1OutPred.bits.redirect && (btbNotTakens.asUInt & maskLatch).orR || // TODO: use parallel or
// io.flush
// val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx))
// val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx))
// // if backend redirects, restore history from backend;
// // if stage3 redirects, restore history from stage3;
// // if stage1 redirects, speculatively update history;
// // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly
// newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken),
// Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist),
// Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U),
// io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch))))
// def getInstrValid(i: Int): UInt = {
// val vec = Wire(Vec(PredictWidth, UInt(1.W)))
// for (j <- 0 until PredictWidth) {
// if (j <= i)
// vec(j) := 1.U
// else
// vec(j) := 0.U
// }
// vec.asUInt
// }
// // redirect based on BTB and JBTAC
// val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth)
// // io.out.valid := RegNext(io.in.pc.fire()) && !io.flush
// // io.s1OutPred.valid := io.out.valid
// io.s1OutPred.valid := io.out.fire()
// when (RegNext(io.in.pc.fire())) {
// io.s1OutPred.bits.redirect := btbTaken || jbtacHit
// // io.s1OutPred.bits.instrValid := (maskLatch & Fill(PredictWidth, ~io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump) |
// // PriorityMux(brJumpIdx | indirectIdx, (0 until PredictWidth).map(getInstrValid(_)))).asTypeOf(Vec(PredictWidth, Bool()))
// io.s1OutPred.bits.instrValid := (maskLatch & Fill(PredictWidth, ~io.s1OutPred.bits.redirect) |
// PriorityMux(brJumpIdx | indirectIdx, (0 until PredictWidth).map(getInstrValid(_)))).asTypeOf(Vec(PredictWidth, Bool()))
// for (i <- 0 until (PredictWidth - 1)) {
// when (!io.s1OutPred.bits.lateJump && (1.U << i) === takenIdx && (!btbIsRVCs(i) && btbValids(i) || !jbtacIsRVC && (1.U << i) === indirectIdx)) {
// io.s1OutPred.bits.instrValid(i+1) := maskLatch(i+1)
// }
// }
// io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget))
// io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump
// (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i))
// // io.s1OutPred.bits.btbVictimWay := btbWriteWay
// io.s1OutPred.bits.predCtr := btbCtrs
// io.s1OutPred.bits.btbHit := btbValids
// io.s1OutPred.bits.tageMeta := tage.io.meta // TODO: enableBPD
// io.s1OutPred.bits.rasSp := DontCare
// io.s1OutPred.bits.rasTopCtr := DontCare
// }.otherwise {
// io.s1OutPred.bits := s1OutPredLatch
// }
// when (RegNext(io.in.pc.fire())) {
// io.out.bits.pc := pcLatch
// io.out.bits.btb.hits := btbValids.asUInt
// (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i))
// io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U) // UIntToOH(jbtacHitIdx)
// io.out.bits.jbtac.target := jbtacTarget
// io.out.bits.tage <> tage.io.out
// // TODO: we don't need this repeatedly!
// io.out.bits.hist := io.s1OutPred.bits.hist
// io.out.bits.btbPred := io.s1OutPred
// }.otherwise {
// io.out.bits := outLatch
// }
// // debug info
// XSDebug("in:(%d %d) pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist)
// XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n",
// io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target)
// XSDebug(io.flush && io.redirectInfo.flush(),
// "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n",
// r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException)
// XSDebug(io.flush && !io.redirectInfo.flush(),
// "flush from Stage3: s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist)
// }
// class Stage2To3IO extends Stage1To2IO {
// }
// class BPUStage2 extends XSModule {
// val io = IO(new Bundle() {
// // flush from Stage3
// val flush = Input(Bool())
// val in = Flipped(Decoupled(new Stage1To2IO))
// val out = Decoupled(new Stage2To3IO)
// })
// // flush Stage2 when Stage3 or banckend redirects
// val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
// val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
// when (io.in.fire()) { inLatch := io.in.bits }
// val validLatch = RegInit(false.B)
// when (io.flush) {
// validLatch := false.B
// }.elsewhen (io.in.fire()) {
// validLatch := true.B
// }.elsewhen (io.out.fire()) {
// validLatch := false.B
// }
// io.out.valid := !io.flush && !flushS2 && validLatch
// io.in.ready := !validLatch || io.out.fire()
// // do nothing
// io.out.bits := inLatch
// // debug info
// XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n",
// io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc)
// XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc)
// XSDebug(io.flush, "flush!!!\n")
// }
// class BPUStage3 extends XSModule {
// val io = IO(new Bundle() {
// val flush = Input(Bool())
// val in = Flipped(Decoupled(new Stage2To3IO))
// val out = Decoupled(new BranchPrediction)
// // from icache
// val predecode = Flipped(ValidIO(new Predecode))
// // from backend
// val redirectInfo = Input(new RedirectInfo)
// // to Stage1 and Stage2
// val flushBPU = Output(Bool())
// // to Stage1, restore ghr in stage1 when flushBPU is valid
// val s1RollBackHist = Output(UInt(HistoryLength.W))
// val s3Taken = Output(Bool())
// })
// val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true)
// val inLatch = RegInit(0.U.asTypeOf(io.in.bits))
// val validLatch = RegInit(false.B)
// val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits))
// val predecodeValidLatch = RegInit(false.B)
// when (io.in.fire()) { inLatch := io.in.bits }
// when (io.flush) {
// validLatch := false.B
// }.elsewhen (io.in.fire()) {
// validLatch := true.B
// }.elsewhen (io.out.fire()) {
// validLatch := false.B
// }
// when (io.predecode.valid) { predecodeLatch := io.predecode.bits }
// when (io.flush || io.out.fire()) {
// predecodeValidLatch := false.B
// }.elsewhen (io.predecode.valid) {
// predecodeValidLatch := true.B
// }
// val predecodeValid = io.predecode.valid || predecodeValidLatch
// val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch)
// io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush
// io.in.ready := !validLatch || io.out.fire()
// // RAS
// // TODO: split retAddr and ctr
// def rasEntry() = new Bundle {
// val retAddr = UInt(VAddrBits.W)
// val ctr = UInt(8.W) // layer of nested call functions
// }
// val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
// val sp = Counter(RasSize)
// val rasTop = ras(sp.value)
// val rasTopAddr = rasTop.retAddr
// // get the first taken branch/jal/call/jalr/ret in a fetch line
// // brNotTakenIdx indicates all the not-taken branches before the first jump instruction
// val tageHits = inLatch.tage.hits
// val tageTakens = inLatch.tage.takens
// val btbTakens = inLatch.btbPred.bits.predCtr
// val brs = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask
// // val brTakens = brs & inLatch.tage.takens.asUInt
// val brTakens = if (EnableBPD) {
// // If tage hits, use tage takens, otherwise keep btbpreds
// // brs & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)
// XSDebug("tageHits=%b, tageTakens=%b\n", tageHits, tageTakens.asUInt)
// brs & Reverse(Cat((0 until PredictWidth).map(i => Mux(tageHits(i), tageTakens(i), btbTakens(i)(1)))))
// } else {
// brs & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt)
// }
// val jals = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask
// val calls = inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt)
// val jalrs = inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt)
// val rets = predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt)
// val brTakenIdx = PriorityMux(brTakens, (0 until PredictWidth).map(_.U))
// val jalIdx = PriorityMux(jals, (0 until PredictWidth).map(_.U))
// val callIdx = PriorityMux(calls, (0 until PredictWidth).map(_.U))
// val jalrIdx = PriorityMux(jalrs, (0 until PredictWidth).map(_.U))
// val retIdx = PriorityMux(rets, (0 until PredictWidth).map(_.U))
// val jmps = (if (EnableRAS) {brTakens | jals | calls | jalrs | rets} else {brTakens | jals | calls | jalrs})
// val jmpIdx = MuxCase(0.U, (0 until PredictWidth).map(i => (jmps(i), i.U)))
// io.s3Taken := MuxCase(false.B, (0 until PredictWidth).map(i => (jmps(i), true.B)))
// // val brNotTakens = VecInit((0 until PredictWidth).map(i => brs(i) && ~inLatch.tage.takens(i) && i.U <= jmpIdx && io.predecode.bits.mask(i)))
// val brNotTakens = if (EnableBPD) {
// VecInit((0 until PredictWidth).map(i => brs(i) && i.U <= jmpIdx && Mux(tageHits(i), ~tageTakens(i), ~btbTakens(i)(1)) && predecode.mask(i)))
// } else {
// VecInit((0 until PredictWidth).map(i => brs(i) && i.U <= jmpIdx && ~inLatch.btbPred.bits.predCtr(i)(1) && predecode.mask(i)))
// }
// // TODO: what if if4 and if2 late jump to the same target?
// // val lateJump = io.s3Taken && PriorityMux(Reverse(predecode.mask), ((PredictWidth - 1) to 0).map(_.U)) === jmpIdx && !predecode.isRVC(jmpIdx)
// val lateJump = io.s3Taken && PriorityMux(Reverse(predecode.mask), (0 until PredictWidth).map {i => (PredictWidth - 1 - i).U}) === jmpIdx && !predecode.isRVC(jmpIdx)
// io.out.bits.lateJump := lateJump
// io.out.bits.predCtr := inLatch.btbPred.bits.predCtr
// io.out.bits.btbHit := inLatch.btbPred.bits.btbHit
// io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta
// //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R,
// // Mux(jmpIdx === jalrIdx, BTBtype.I,
// // Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
// val firstHist = inLatch.btbPred.bits.hist(0)
// // there may be several notTaken branches before the first jump instruction,
// // so we need to calculate how many zeroes should each instruction shift in its global history.
// // each history is exclusive of instruction's own jump direction.
// val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W)))
// val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W))))
// (0 until PredictWidth).foreach(i => shift(i) := Mux(!brNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W))))
// for (j <- 0 until PredictWidth) {
// var tmp = 0.U
// for (i <- 0 until PredictWidth) {
// tmp = tmp + shift(i)(j)
// }
// histShift(j) := tmp
// }
// (0 until PredictWidth).foreach(i => io.out.bits.hist(i) := firstHist << histShift(i))
// // save ras checkpoint info
// io.out.bits.rasSp := sp.value
// io.out.bits.rasTopCtr := rasTop.ctr
// // flush BPU and redirect when target differs from the target predicted in Stage1
// val tToNt = inLatch.btbPred.bits.redirect && ~io.s3Taken
// val ntToT = ~inLatch.btbPred.bits.redirect && io.s3Taken
// val dirDiffers = tToNt || ntToT
// val tgtDiffers = inLatch.btbPred.bits.redirect && io.s3Taken && io.out.bits.target =/= inLatch.btbPred.bits.target
// // io.out.bits.redirect := (if (EnableBPD) {dirDiffers || tgtDiffers} else false.B)
// io.out.bits.redirect := dirDiffers || tgtDiffers
// io.out.bits.target := Mux(!io.s3Taken, inLatch.pc + (PopCount(predecode.mask) << 1.U), // TODO: RVC
// Mux(jmpIdx === retIdx, rasTopAddr,
// Mux(jmpIdx === jalrIdx, inLatch.jbtac.target,
// inLatch.btb.targets(jmpIdx))))
// // for (i <- 0 until FetchWidth) {
// // io.out.bits.instrValid(i) := ((io.s3Taken && i.U <= jmpIdx) || ~io.s3Taken) && io.predecode.bits.mask(i)
// // }
// io.out.bits.instrValid := predecode.mask.asTypeOf(Vec(PredictWidth, Bool()))
// for (i <- PredictWidth - 1 to 0) {
// io.out.bits.instrValid(i) := (io.s3Taken && i.U <= jmpIdx || !io.s3Taken) && predecode.mask(i)
// if (i != (PredictWidth - 1)) {
// when (!lateJump && !predecode.isRVC(i) && io.s3Taken && i.U <= jmpIdx) {
// io.out.bits.instrValid(i+1) := predecode.mask(i+1)
// }
// }
// }
// io.flushBPU := io.out.bits.redirect && io.out.fire()
// // speculative update RAS
// val rasWrite = WireInit(0.U.asTypeOf(rasEntry()))
// val retAddr = inLatch.pc + (callIdx << 1.U) + Mux(predecode.isRVC(callIdx), 2.U, 4.U)
// rasWrite.retAddr := retAddr
// val allocNewEntry = rasWrite.retAddr =/= rasTopAddr
// rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U)
// val rasWritePosition = Mux(allocNewEntry, sp.value + 1.U, sp.value)
// when (io.out.fire() && io.s3Taken) {
// when (jmpIdx === callIdx) {
// ras(rasWritePosition) := rasWrite
// when (allocNewEntry) { sp.value := sp.value + 1.U }
// }.elsewhen (jmpIdx === retIdx) {
// when (rasTop.ctr === 1.U) {
// sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U)
// }.otherwise {
// ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry())
// }
// }
// }
// // use checkpoint to recover RAS
// val recoverSp = io.redirectInfo.redirect.rasSp
// val recoverCtr = io.redirectInfo.redirect.rasTopCtr
// when (io.redirectInfo.flush()) {
// sp.value := recoverSp
// ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry())
// }
// // roll back global history in S1 if S3 redirects
// io.s1RollBackHist := Mux(io.s3Taken, io.out.bits.hist(jmpIdx),
// io.out.bits.hist(0) << PopCount(brs & predecode.mask & ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt)))
// // debug info
// XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc)
// XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n",
// io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target)
// XSDebug("flushS3=%d\n", flushS3)
// XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, predecodeValid)
// XSDebug("brs=%b brTakens=%b brNTakens=%b jals=%b jalrs=%b calls=%b rets=%b\n",
// brs, brTakens, brNotTakens.asUInt, jals, jalrs, calls, rets)
// // ?????condition is wrong
// // XSDebug(io.in.fire() && callIdx.orR, "[RAS]:pc=0x%x, rasWritePosition=%d, rasWriteAddr=0x%x\n",
// // io.in.bits.pc, rasWritePosition, retAddr)
// }
// class BPU extends XSModule {
// val io = IO(new Bundle() {
// // from backend
// // flush pipeline if misPred and update bpu based on redirect signals from brq
// val redirectInfo = Input(new RedirectInfo)
// val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
// val btbOut = ValidIO(new BranchPrediction)
// val tageOut = Decoupled(new BranchPrediction)
// // predecode info from icache
// // TODO: simplify this after implement predecode unit
// val predecode = Flipped(ValidIO(new Predecode))
// })
// val s1 = Module(new BPUStage1)
// val s2 = Module(new BPUStage2)
// val s3 = Module(new BPUStage3)
// s1.io.redirectInfo <> io.redirectInfo
// s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
// s1.io.in.pc.valid := io.in.pc.valid
// s1.io.in.pc.bits <> io.in.pc.bits
// io.btbOut <> s1.io.s1OutPred
// s1.io.s3RollBackHist := s3.io.s1RollBackHist
// s1.io.s3Taken := s3.io.s3Taken
// s1.io.out <> s2.io.in
// s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush()
// s2.io.out <> s3.io.in
// s3.io.flush := io.redirectInfo.flush()
// s3.io.predecode <> io.predecode
// io.tageOut <> s3.io.out
// s3.io.redirectInfo <> io.redirectInfo
// }
......@@ -24,28 +24,6 @@ class IFUIO extends XSBundle
val icacheResp = Flipped(DecoupledIO(new FakeIcacheResp))
}
class BaseBPU extends XSModule {
val io = IO(new Bundle() {
val redirect = Flipped(ValidIO(new Redirect))
val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) }
val btbOut = ValidIO(new BranchPrediction)
val tageOut = Decoupled(new BranchPrediction)
val predecode = Flipped(ValidIO(new Predecode))
})
}
class FakeBPU extends BaseBPU {
io.btbOut.valid := false.B
io.btbOut.bits <> DontCare
io.btbOut.bits.redirect := false.B
io.btbOut.bits.target := DontCare
io.tageOut.valid := false.B
io.tageOut.bits <> DontCare
}
class IFU extends XSModule with HasIFUConst
{
......
//package xiangshan.frontend
//
//import chisel3._
//import chisel3.util._
//import xiangshan._
//import utils._
//
//import scala.math.min
//
//trait HasTageParameter {
// // Sets Hist Tag
// val TableInfo = Seq(( 128, 2, 7),
// ( 128, 4, 7),
// ( 256, 8, 8),
// ( 256, 16, 8),
// ( 128, 32, 9),
// ( 128, 64, 9))
// val TageNTables = TableInfo.size
// val UBitPeriod = 2048
// val BankWidth = 16 // FetchWidth
//
// val TotalBits = TableInfo.map {
// case (s, h, t) => {
// s * (1+t+3) * BankWidth
// }
// }.reduce(_+_)
//}
//
//abstract class TageBundle extends XSBundle with HasTageParameter
//abstract class TageModule extends XSModule with HasTageParameter
//
//class TageReq extends TageBundle {
// val pc = UInt(VAddrBits.W)
// val hist = UInt(HistoryLength.W)
//}
//
//class TageResp extends TageBundle {
// val ctr = UInt(3.W)
// val u = UInt(2.W)
//}
//
//class TageUpdate extends TageBundle {
// val pc = UInt(VAddrBits.W)
// val hist = UInt(HistoryLength.W)
// // update tag and ctr
// val mask = Vec(BankWidth, Bool())
// val taken = Vec(BankWidth, Bool())
// val alloc = Vec(BankWidth, Bool())
// val oldCtr = Vec(BankWidth, UInt(3.W))
// // update u
// val uMask = Vec(BankWidth, Bool())
// val u = Vec(BankWidth, UInt(2.W))
//}
//
//class FakeTageTable() extends TageModule {
// val io = IO(new Bundle() {
// val req = Input(Valid(new TageReq))
// val resp = Output(Vec(BankWidth, Valid(new TageResp)))
// val update = Input(new TageUpdate)
// })
// io.resp := DontCare
//
//}
//
//class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule {
// val io = IO(new Bundle() {
// val req = Input(Valid(new TageReq))
// val resp = Output(Vec(BankWidth, Valid(new TageResp)))
// val update = Input(new TageUpdate)
// })
//
// // bypass entries for tage update
// val wrBypassEntries = 8
//
// def compute_folded_hist(hist: UInt, l: Int) = {
// val nChunks = (histLen + l - 1) / l
// val hist_chunks = (0 until nChunks) map {i =>
// hist(min((i+1)*l, histLen)-1, i*l)
// }
// hist_chunks.reduce(_^_)
// }
//
// def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt) = {
// val idx_history = compute_folded_hist(hist, log2Ceil(nRows))
// val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows)-1,0)
// val tag_history = compute_folded_hist(hist, tagLen)
// // Use another part of pc to make tags
// val tag = ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history)(tagLen-1,0)
// (idx, tag)
// }
//
// def inc_ctr(ctr: UInt, taken: Bool): UInt = {
// Mux(!taken, Mux(ctr === 0.U, 0.U, ctr - 1.U),
// Mux(ctr === 7.U, 7.U, ctr + 1.U))
// }
//
// val doing_reset = RegInit(true.B)
// val reset_idx = RegInit(0.U(log2Ceil(nRows).W))
// reset_idx := reset_idx + doing_reset
// when (reset_idx === (nRows-1).U) { doing_reset := false.B }
//
// class TageEntry() extends TageBundle {
// val valid = Bool()
// val tag = UInt(tagLen.W)
// val ctr = UInt(3.W)
// }
//
// val tageEntrySz = 1 + tagLen + 3
//
// val (hashed_idx, tag) = compute_tag_and_hash(io.req.bits.pc, io.req.bits.hist)
//
// val hi_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
// val lo_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
// val table = List.fill(BankWidth)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
//
// val hi_us_r = Wire(Vec(BankWidth, Bool()))
// val lo_us_r = Wire(Vec(BankWidth, Bool()))
// val table_r = Wire(Vec(BankWidth, new TageEntry))
//
// (0 until BankWidth).map(
// b => {
// hi_us(b).reset := reset.asBool
// lo_us(b).reset := reset.asBool
// table(b).reset := reset.asBool
// hi_us(b).io.r.req.valid := io.req.valid
// lo_us(b).io.r.req.valid := io.req.valid
// table(b).io.r.req.valid := io.req.valid
// hi_us(b).io.r.req.bits.setIdx := hashed_idx
// lo_us(b).io.r.req.bits.setIdx := hashed_idx
// table(b).io.r.req.bits.setIdx := hashed_idx
//
// hi_us_r(b) := hi_us(b).io.r.resp.data(0)
// lo_us_r(b) := lo_us(b).io.r.resp.data(0)
// table_r(b) := table(b).io.r.resp.data(0)
//
// // io.resp(b).valid := table_r(b).valid && table_r(b).tag === tag // Missing reset logic
// // io.resp(b).bits.ctr := table_r(b).ctr
// // io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b))
// }
// )
//
// val req_rhits = VecInit(table_r.map(e => e.valid && e.tag === tag && !doing_reset))
//
// (0 until BankWidth).map(b => {
// io.resp(b).valid := req_rhits(b)
// io.resp(b).bits.ctr := table_r(b).ctr
// io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b))
// })
//
//
// val clear_u_ctr = RegInit(0.U((log2Ceil(uBitPeriod) + log2Ceil(nRows) + 1).W))
// when (doing_reset) { clear_u_ctr := 1.U } .otherwise { clear_u_ctr := clear_u_ctr + 1.U }
//
// val doing_clear_u = clear_u_ctr(log2Ceil(uBitPeriod)-1,0) === 0.U
// val doing_clear_u_hi = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 1.U
// val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U
// val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod)
//
// val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc, io.update.hist)
//
// val update_wdata = Wire(Vec(BankWidth, new TageEntry))
//
//
// (0 until BankWidth).map(b => {
// table(b).io.w.req.valid := io.update.mask(b) || doing_reset
// table(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx)
// table(b).io.w.req.bits.data := Mux(doing_reset, 0.U.asTypeOf(new TageEntry), update_wdata(b))
// })
//
// val update_hi_wdata = Wire(Vec(BankWidth, Bool()))
// (0 until BankWidth).map(b => {
// hi_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_hi
// hi_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_hi, clear_u_idx, update_idx))
// hi_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_hi, 0.U, update_hi_wdata(b))
// })
//
// val update_lo_wdata = Wire(Vec(BankWidth, Bool()))
// (0 until BankWidth).map(b => {
// lo_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_lo
// lo_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_lo, clear_u_idx, update_idx))
// lo_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_lo, 0.U, update_lo_wdata(b))
// })
//
// val wrbypass_tags = Reg(Vec(wrBypassEntries, UInt(tagLen.W)))
// val wrbypass_idxs = Reg(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))
// val wrbypass = Reg(Vec(wrBypassEntries, Vec(BankWidth, UInt(3.W))))
// val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W))
//
// val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i =>
// !doing_reset &&
// wrbypass_tags(i) === update_tag &&
// wrbypass_idxs(i) === update_idx
// })
// val wrbypass_hit = wrbypass_hits.reduce(_||_)
// val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits)
//
// for (w <- 0 until BankWidth) {
// update_wdata(w).ctr := Mux(io.update.alloc(w),
// Mux(io.update.taken(w), 4.U,
// 3.U
// ),
// Mux(wrbypass_hit, inc_ctr(wrbypass(wrbypass_hit_idx)(w), io.update.taken(w)),
// inc_ctr(io.update.oldCtr(w), io.update.taken(w))
// )
// )
// update_wdata(w).valid := true.B
// update_wdata(w).tag := update_tag
//
// update_hi_wdata(w) := io.update.u(w)(1)
// update_lo_wdata(w) := io.update.u(w)(0)
// }
//
// when (io.update.mask.reduce(_||_)) {
// when (wrbypass_hits.reduce(_||_)) {
// wrbypass(wrbypass_hit_idx) := VecInit(update_wdata.map(_.ctr))
// } .otherwise {
// wrbypass (wrbypass_enq_idx) := VecInit(update_wdata.map(_.ctr))
// wrbypass_tags(wrbypass_enq_idx) := update_tag
// wrbypass_idxs(wrbypass_enq_idx) := update_idx
// wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0)
// }
// }
// XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%b, idx=%d, tag=%x\n", io.req.bits.pc, io.req.bits.hist, hashed_idx, tag)
// for (i <- 0 until BankWidth) {
// XSDebug(RegNext(io.req.valid), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, RegNext(hashed_idx), req_rhits(i), table_r(i).ctr, Cat(hi_us_r(i),lo_us_r(i)).asUInt)
// }
//
//}
//
//class FakeTAGE extends TageModule {
// val io = IO(new Bundle() {
// val req = Input(Valid(new TageReq))
// val out = new Bundle {
// val hits = Output(UInt(BankWidth.W))
// val takens = Output(Vec(BankWidth, Bool()))
// }
// val meta = Output(Vec(BankWidth, (new TageMeta)))
// val redirectInfo = Input(new RedirectInfo)
// })
//
// io.out.hits := 0.U(BankWidth.W)
// io.out.takens := DontCare
// io.meta := DontCare
//}
//
//
//class Tage extends TageModule {
// val io = IO(new Bundle() {
// val req = Input(Valid(new TageReq))
// val out = new Bundle {
// val hits = Output(UInt(BankWidth.W))
// val takens = Output(Vec(BankWidth, Bool()))
// }
// val meta = Output(Vec(BankWidth, (new TageMeta)))
// val redirectInfo = Input(new RedirectInfo)
// })
//
// val tables = TableInfo.map {
// case (nRows, histLen, tagLen) => {
// val t = if(EnableBPD) Module(new TageTable(nRows, histLen, tagLen, UBitPeriod)) else Module(new FakeTageTable)
// t.io.req <> io.req
// t
// }
// }
// val resps = VecInit(tables.map(_.io.resp))
//
// val updateMeta = io.redirectInfo.redirect.tageMeta
// //val updateMisPred = UIntToOH(io.redirectInfo.redirect.fetchIdx) &
// // Fill(BankWidth, (io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B).asUInt)
// val updateMisPred = io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B
//
// val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool()))))
// val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool()))))
// val updateTaken = Wire(Vec(TageNTables, Vec(BankWidth, Bool())))
// val updateAlloc = Wire(Vec(TageNTables, Vec(BankWidth, Bool())))
// val updateOldCtr = Wire(Vec(TageNTables, Vec(BankWidth, UInt(3.W))))
// val updateU = Wire(Vec(TageNTables, Vec(BankWidth, UInt(2.W))))
// updateTaken := DontCare
// updateAlloc := DontCare
// updateOldCtr := DontCare
// updateU := DontCare
//
// // access tag tables and output meta info
// val outHits = Wire(Vec(BankWidth, Bool()))
// for (w <- 0 until BankWidth) {
// var altPred = false.B
// val finalAltPred = WireInit(false.B)
// var provided = false.B
// var provider = 0.U
// outHits(w) := false.B
// io.out.takens(w) := false.B
//
// for (i <- 0 until TageNTables) {
// val hit = resps(i)(w).valid
// val ctr = resps(i)(w).bits.ctr
// when (hit) {
// io.out.takens(w) := Mux(ctr === 3.U || ctr === 4.U, altPred, ctr(2)) // Use altpred on weak taken
// finalAltPred := altPred
// }
// provided = provided || hit // Once hit then provide
// provider = Mux(hit, i.U, provider) // Use the last hit as provider
// altPred = Mux(hit, ctr(2), altPred) // Save current pred as potential altpred
// }
// outHits(w) := provided
// io.meta(w).provider.valid := provided
// io.meta(w).provider.bits := provider
// io.meta(w).altDiffers := finalAltPred =/= io.out.takens(w)
// io.meta(w).providerU := resps(provider)(w).bits.u
// io.meta(w).providerCtr := resps(provider)(w).bits.ctr
//
// // Create a mask fo tables which did not hit our query, and also contain useless entries
// // and also uses a longer history than the provider
// val allocatableSlots = (VecInit(resps.map(r => !r(w).valid && r(w).bits.u === 0.U)).asUInt &
// ~(LowerMask(UIntToOH(provider), TageNTables) & Fill(TageNTables, provided.asUInt))
// )
// val allocLFSR = LFSR64()(TageNTables - 1, 0)
// val firstEntry = PriorityEncoder(allocatableSlots)
// val maskedEntry = PriorityEncoder(allocatableSlots & allocLFSR)
// val allocEntry = Mux(allocatableSlots(maskedEntry), maskedEntry, firstEntry)
// io.meta(w).allocate.valid := allocatableSlots =/= 0.U
// io.meta(w).allocate.bits := allocEntry
//
// val isUpdateTaken = io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U &&
// io.redirectInfo.redirect.taken && io.redirectInfo.redirect.btbType === BTBtype.B
// when (io.redirectInfo.redirect.btbType === BTBtype.B && io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U) {
// when (updateMeta.provider.valid) {
// val provider = updateMeta.provider.bits
//
// updateMask(provider)(w) := true.B
// updateUMask(provider)(w) := true.B
//
// updateU(provider)(w) := Mux(!updateMeta.altDiffers, updateMeta.providerU,
// Mux(updateMisPred, Mux(updateMeta.providerU === 0.U, 0.U, updateMeta.providerU - 1.U),
// Mux(updateMeta.providerU === 3.U, 3.U, updateMeta.providerU + 1.U))
// )
// updateTaken(provider)(w) := isUpdateTaken
// updateOldCtr(provider)(w) := updateMeta.providerCtr
// updateAlloc(provider)(w) := false.B
// }
// }
// }
//
// when (io.redirectInfo.valid && updateMisPred) {
// val idx = io.redirectInfo.redirect.fetchIdx
// val allocate = updateMeta.allocate
// when (allocate.valid) {
// updateMask(allocate.bits)(idx) := true.B
// updateTaken(allocate.bits)(idx) := io.redirectInfo.redirect.taken
// updateAlloc(allocate.bits)(idx) := true.B
// updateUMask(allocate.bits)(idx) := true.B
// updateU(allocate.bits)(idx) := 0.U
// }.otherwise {
// val provider = updateMeta.provider
// val decrMask = Mux(provider.valid, ~LowerMask(UIntToOH(provider.bits), TageNTables), 0.U)
// for (i <- 0 until TageNTables) {
// when (decrMask(i)) {
// updateUMask(i)(idx) := true.B
// updateU(i)(idx) := 0.U
// }
// }
// }
// }
//
// for (i <- 0 until TageNTables) {
// for (w <- 0 until BankWidth) {
// tables(i).io.update.mask(w) := updateMask(i)(w)
// tables(i).io.update.taken(w) := updateTaken(i)(w)
// tables(i).io.update.alloc(w) := updateAlloc(i)(w)
// tables(i).io.update.oldCtr(w) := updateOldCtr(i)(w)
//
// tables(i).io.update.uMask(w) := updateUMask(i)(w)
// tables(i).io.update.u(w) := updateU(i)(w)
// }
// // use fetch pc instead of instruction pc
// tables(i).io.update.pc := io.redirectInfo.redirect.pc - (io.redirectInfo.redirect.fetchIdx << 1.U)
// tables(i).io.update.hist := io.redirectInfo.redirect.hist
// }
//
// io.out.hits := outHits.asUInt
//
//
// val m = updateMeta
// XSDebug(io.req.valid, "req: pc=0x%x, hist=%b\n", io.req.bits.pc, io.req.bits.hist)
// XSDebug(io.redirectInfo.valid, "redirect: provider(%d):%d, altDiffers:%d, providerU:%d, providerCtr:%d, allocate(%d):%d\n", m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits)
// XSDebug(RegNext(io.req.valid), "resp: pc=%x, outHits=%b, takens=%b\n", RegNext(io.req.bits.pc), io.out.hits, io.out.takens.asUInt)
//}
\ No newline at end of file
package xiangshan.frontend
import chisel3._
import chisel3.util._
import xiangshan._
import utils._
import scala.math.min
trait HasTageParameter {
// Sets Hist Tag
val TableInfo = Seq(( 128, 2, 7),
( 128, 4, 7),
( 256, 8, 8),
( 256, 16, 8),
( 128, 32, 9),
( 128, 64, 9))
val TageNTables = TableInfo.size
val UBitPeriod = 2048
val BankWidth = 16 // FetchWidth
val TotalBits = TableInfo.map {
case (s, h, t) => {
s * (1+t+3) * BankWidth
}
}.reduce(_+_)
}
abstract class TageBundle extends XSBundle with HasTageParameter
abstract class TageModule extends XSModule with HasTageParameter
class TageReq extends TageBundle {
val pc = UInt(VAddrBits.W)
val hist = UInt(HistoryLength.W)
}
class TageResp extends TageBundle {
val ctr = UInt(3.W)
val u = UInt(2.W)
}
class TageUpdate extends TageBundle {
val pc = UInt(VAddrBits.W)
val hist = UInt(HistoryLength.W)
// update tag and ctr
val mask = Vec(BankWidth, Bool())
val taken = Vec(BankWidth, Bool())
val alloc = Vec(BankWidth, Bool())
val oldCtr = Vec(BankWidth, UInt(3.W))
// update u
val uMask = Vec(BankWidth, Bool())
val u = Vec(BankWidth, UInt(2.W))
}
class FakeTageTable() extends TageModule {
val io = IO(new Bundle() {
val req = Input(Valid(new TageReq))
val resp = Output(Vec(BankWidth, Valid(new TageResp)))
val update = Input(new TageUpdate)
})
io.resp := DontCare
}
class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule {
val io = IO(new Bundle() {
val req = Input(Valid(new TageReq))
val resp = Output(Vec(BankWidth, Valid(new TageResp)))
val update = Input(new TageUpdate)
})
// bypass entries for tage update
val wrBypassEntries = 8
def compute_folded_hist(hist: UInt, l: Int) = {
val nChunks = (histLen + l - 1) / l
val hist_chunks = (0 until nChunks) map {i =>
hist(min((i+1)*l, histLen)-1, i*l)
}
hist_chunks.reduce(_^_)
}
def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt) = {
val idx_history = compute_folded_hist(hist, log2Ceil(nRows))
val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows)-1,0)
val tag_history = compute_folded_hist(hist, tagLen)
// Use another part of pc to make tags
val tag = ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history)(tagLen-1,0)
(idx, tag)
}
def inc_ctr(ctr: UInt, taken: Bool): UInt = {
Mux(!taken, Mux(ctr === 0.U, 0.U, ctr - 1.U),
Mux(ctr === 7.U, 7.U, ctr + 1.U))
}
val doing_reset = RegInit(true.B)
val reset_idx = RegInit(0.U(log2Ceil(nRows).W))
reset_idx := reset_idx + doing_reset
when (reset_idx === (nRows-1).U) { doing_reset := false.B }
class TageEntry() extends TageBundle {
val valid = Bool()
val tag = UInt(tagLen.W)
val ctr = UInt(3.W)
}
val tageEntrySz = 1 + tagLen + 3
val (hashed_idx, tag) = compute_tag_and_hash(io.req.bits.pc, io.req.bits.hist)
val hi_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val lo_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val table = List.fill(BankWidth)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val hi_us_r = Wire(Vec(BankWidth, Bool()))
val lo_us_r = Wire(Vec(BankWidth, Bool()))
val table_r = Wire(Vec(BankWidth, new TageEntry))
(0 until BankWidth).map(
b => {
hi_us(b).reset := reset.asBool
lo_us(b).reset := reset.asBool
table(b).reset := reset.asBool
hi_us(b).io.r.req.valid := io.req.valid
lo_us(b).io.r.req.valid := io.req.valid
table(b).io.r.req.valid := io.req.valid
hi_us(b).io.r.req.bits.setIdx := hashed_idx
lo_us(b).io.r.req.bits.setIdx := hashed_idx
table(b).io.r.req.bits.setIdx := hashed_idx
hi_us_r(b) := hi_us(b).io.r.resp.data(0)
lo_us_r(b) := lo_us(b).io.r.resp.data(0)
table_r(b) := table(b).io.r.resp.data(0)
// io.resp(b).valid := table_r(b).valid && table_r(b).tag === tag // Missing reset logic
// io.resp(b).bits.ctr := table_r(b).ctr
// io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b))
}
)
val req_rhits = VecInit(table_r.map(e => e.valid && e.tag === tag && !doing_reset))
(0 until BankWidth).map(b => {
io.resp(b).valid := req_rhits(b)
io.resp(b).bits.ctr := table_r(b).ctr
io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b))
})
val clear_u_ctr = RegInit(0.U((log2Ceil(uBitPeriod) + log2Ceil(nRows) + 1).W))
when (doing_reset) { clear_u_ctr := 1.U } .otherwise { clear_u_ctr := clear_u_ctr + 1.U }
val doing_clear_u = clear_u_ctr(log2Ceil(uBitPeriod)-1,0) === 0.U
val doing_clear_u_hi = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 1.U
val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U
val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod)
val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc, io.update.hist)
val update_wdata = Wire(Vec(BankWidth, new TageEntry))
(0 until BankWidth).map(b => {
table(b).io.w.req.valid := io.update.mask(b) || doing_reset
table(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx)
table(b).io.w.req.bits.data := Mux(doing_reset, 0.U.asTypeOf(new TageEntry), update_wdata(b))
})
val update_hi_wdata = Wire(Vec(BankWidth, Bool()))
(0 until BankWidth).map(b => {
hi_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_hi
hi_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_hi, clear_u_idx, update_idx))
hi_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_hi, 0.U, update_hi_wdata(b))
})
val update_lo_wdata = Wire(Vec(BankWidth, Bool()))
(0 until BankWidth).map(b => {
lo_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_lo
lo_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_lo, clear_u_idx, update_idx))
lo_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_lo, 0.U, update_lo_wdata(b))
})
val wrbypass_tags = Reg(Vec(wrBypassEntries, UInt(tagLen.W)))
val wrbypass_idxs = Reg(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))
val wrbypass = Reg(Vec(wrBypassEntries, Vec(BankWidth, UInt(3.W))))
val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W))
val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i =>
!doing_reset &&
wrbypass_tags(i) === update_tag &&
wrbypass_idxs(i) === update_idx
})
val wrbypass_hit = wrbypass_hits.reduce(_||_)
val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits)
for (w <- 0 until BankWidth) {
update_wdata(w).ctr := Mux(io.update.alloc(w),
Mux(io.update.taken(w), 4.U,
3.U
),
Mux(wrbypass_hit, inc_ctr(wrbypass(wrbypass_hit_idx)(w), io.update.taken(w)),
inc_ctr(io.update.oldCtr(w), io.update.taken(w))
)
)
update_wdata(w).valid := true.B
update_wdata(w).tag := update_tag
update_hi_wdata(w) := io.update.u(w)(1)
update_lo_wdata(w) := io.update.u(w)(0)
}
when (io.update.mask.reduce(_||_)) {
when (wrbypass_hits.reduce(_||_)) {
wrbypass(wrbypass_hit_idx) := VecInit(update_wdata.map(_.ctr))
} .otherwise {
wrbypass (wrbypass_enq_idx) := VecInit(update_wdata.map(_.ctr))
wrbypass_tags(wrbypass_enq_idx) := update_tag
wrbypass_idxs(wrbypass_enq_idx) := update_idx
wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0)
}
}
XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%b, idx=%d, tag=%x\n", io.req.bits.pc, io.req.bits.hist, hashed_idx, tag)
for (i <- 0 until BankWidth) {
XSDebug(RegNext(io.req.valid), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, RegNext(hashed_idx), req_rhits(i), table_r(i).ctr, Cat(hi_us_r(i),lo_us_r(i)).asUInt)
}
}
class FakeTAGE extends TageModule {
val io = IO(new Bundle() {
val req = Input(Valid(new TageReq))
val out = new Bundle {
val hits = Output(UInt(BankWidth.W))
val takens = Output(Vec(BankWidth, Bool()))
}
val meta = Output(Vec(BankWidth, (new TageMeta)))
val redirectInfo = Input(new RedirectInfo)
})
io.out.hits := 0.U(BankWidth.W)
io.out.takens := DontCare
io.meta := DontCare
}
class Tage extends TageModule {
val io = IO(new Bundle() {
val req = Input(Valid(new TageReq))
val out = new Bundle {
val hits = Output(UInt(BankWidth.W))
val takens = Output(Vec(BankWidth, Bool()))
}
val meta = Output(Vec(BankWidth, (new TageMeta)))
val redirectInfo = Input(new RedirectInfo)
})
val tables = TableInfo.map {
case (nRows, histLen, tagLen) => {
val t = if(EnableBPD) Module(new TageTable(nRows, histLen, tagLen, UBitPeriod)) else Module(new FakeTageTable)
t.io.req <> io.req
t
}
}
val resps = VecInit(tables.map(_.io.resp))
val updateMeta = io.redirectInfo.redirect.tageMeta
//val updateMisPred = UIntToOH(io.redirectInfo.redirect.fetchIdx) &
// Fill(BankWidth, (io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B).asUInt)
val updateMisPred = io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B
val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool()))))
val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool()))))
val updateTaken = Wire(Vec(TageNTables, Vec(BankWidth, Bool())))
val updateAlloc = Wire(Vec(TageNTables, Vec(BankWidth, Bool())))
val updateOldCtr = Wire(Vec(TageNTables, Vec(BankWidth, UInt(3.W))))
val updateU = Wire(Vec(TageNTables, Vec(BankWidth, UInt(2.W))))
updateTaken := DontCare
updateAlloc := DontCare
updateOldCtr := DontCare
updateU := DontCare
// access tag tables and output meta info
val outHits = Wire(Vec(BankWidth, Bool()))
for (w <- 0 until BankWidth) {
var altPred = false.B
val finalAltPred = WireInit(false.B)
var provided = false.B
var provider = 0.U
outHits(w) := false.B
io.out.takens(w) := false.B
for (i <- 0 until TageNTables) {
val hit = resps(i)(w).valid
val ctr = resps(i)(w).bits.ctr
when (hit) {
io.out.takens(w) := Mux(ctr === 3.U || ctr === 4.U, altPred, ctr(2)) // Use altpred on weak taken
finalAltPred := altPred
}
provided = provided || hit // Once hit then provide
provider = Mux(hit, i.U, provider) // Use the last hit as provider
altPred = Mux(hit, ctr(2), altPred) // Save current pred as potential altpred
}
outHits(w) := provided
io.meta(w).provider.valid := provided
io.meta(w).provider.bits := provider
io.meta(w).altDiffers := finalAltPred =/= io.out.takens(w)
io.meta(w).providerU := resps(provider)(w).bits.u
io.meta(w).providerCtr := resps(provider)(w).bits.ctr
// Create a mask fo tables which did not hit our query, and also contain useless entries
// and also uses a longer history than the provider
val allocatableSlots = (VecInit(resps.map(r => !r(w).valid && r(w).bits.u === 0.U)).asUInt &
~(LowerMask(UIntToOH(provider), TageNTables) & Fill(TageNTables, provided.asUInt))
)
val allocLFSR = LFSR64()(TageNTables - 1, 0)
val firstEntry = PriorityEncoder(allocatableSlots)
val maskedEntry = PriorityEncoder(allocatableSlots & allocLFSR)
val allocEntry = Mux(allocatableSlots(maskedEntry), maskedEntry, firstEntry)
io.meta(w).allocate.valid := allocatableSlots =/= 0.U
io.meta(w).allocate.bits := allocEntry
val isUpdateTaken = io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U &&
io.redirectInfo.redirect.taken && io.redirectInfo.redirect.btbType === BTBtype.B
when (io.redirectInfo.redirect.btbType === BTBtype.B && io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U) {
when (updateMeta.provider.valid) {
val provider = updateMeta.provider.bits
updateMask(provider)(w) := true.B
updateUMask(provider)(w) := true.B
updateU(provider)(w) := Mux(!updateMeta.altDiffers, updateMeta.providerU,
Mux(updateMisPred, Mux(updateMeta.providerU === 0.U, 0.U, updateMeta.providerU - 1.U),
Mux(updateMeta.providerU === 3.U, 3.U, updateMeta.providerU + 1.U))
)
updateTaken(provider)(w) := isUpdateTaken
updateOldCtr(provider)(w) := updateMeta.providerCtr
updateAlloc(provider)(w) := false.B
}
}
}
when (io.redirectInfo.valid && updateMisPred) {
val idx = io.redirectInfo.redirect.fetchIdx
val allocate = updateMeta.allocate
when (allocate.valid) {
updateMask(allocate.bits)(idx) := true.B
updateTaken(allocate.bits)(idx) := io.redirectInfo.redirect.taken
updateAlloc(allocate.bits)(idx) := true.B
updateUMask(allocate.bits)(idx) := true.B
updateU(allocate.bits)(idx) := 0.U
}.otherwise {
val provider = updateMeta.provider
val decrMask = Mux(provider.valid, ~LowerMask(UIntToOH(provider.bits), TageNTables), 0.U)
for (i <- 0 until TageNTables) {
when (decrMask(i)) {
updateUMask(i)(idx) := true.B
updateU(i)(idx) := 0.U
}
}
}
}
for (i <- 0 until TageNTables) {
for (w <- 0 until BankWidth) {
tables(i).io.update.mask(w) := updateMask(i)(w)
tables(i).io.update.taken(w) := updateTaken(i)(w)
tables(i).io.update.alloc(w) := updateAlloc(i)(w)
tables(i).io.update.oldCtr(w) := updateOldCtr(i)(w)
tables(i).io.update.uMask(w) := updateUMask(i)(w)
tables(i).io.update.u(w) := updateU(i)(w)
}
// use fetch pc instead of instruction pc
tables(i).io.update.pc := io.redirectInfo.redirect.pc - (io.redirectInfo.redirect.fetchIdx << 1.U)
tables(i).io.update.hist := io.redirectInfo.redirect.hist
}
io.out.hits := outHits.asUInt
val m = updateMeta
XSDebug(io.req.valid, "req: pc=0x%x, hist=%b\n", io.req.bits.pc, io.req.bits.hist)
XSDebug(io.redirectInfo.valid, "redirect: provider(%d):%d, altDiffers:%d, providerU:%d, providerCtr:%d, allocate(%d):%d\n", m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits)
XSDebug(RegNext(io.req.valid), "resp: pc=%x, outHits=%b, takens=%b\n", RegNext(io.req.bits.pc), io.out.hits, io.out.takens.asUInt)
}
\ No newline at end of file
//package xiangshan.frontend
//
//import chisel3._
//import chisel3.util._
//import xiangshan._
//import xiangshan.backend.ALUOpType
//import utils._
//import chisel3.util.experimental.BoringUtils
//import xiangshan.backend.decode.XSTrap
//
//class BTBUpdateBundle extends XSBundle {
// val pc = UInt(VAddrBits.W)
// val hit = Bool()
// val misPred = Bool()
// val oldCtr = UInt(2.W)
// val taken = Bool()
// val target = UInt(VAddrBits.W)
// val btbType = UInt(2.W)
// val isRVC = Bool()
//}
//
//class BTBPred extends XSBundle {
// val taken = Bool()
// val takenIdx = UInt(log2Up(PredictWidth).W)
// val target = UInt(VAddrBits.W)
//
// val notTakens = Vec(PredictWidth, Bool())
// val dEntries = Vec(PredictWidth, btbDataEntry())
// val hits = Vec(PredictWidth, Bool())
//
// // whether an RVI instruction crosses over two fetch packet
// val isRVILateJump = Bool()
//}
//
//case class btbDataEntry() extends XSBundle {
// val target = UInt(VAddrBits.W)
// val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
// val btbType = UInt(2.W)
// val isRVC = Bool()
//}
//
//case class btbMetaEntry() extends XSBundle {
// val valid = Bool()
// // TODO: don't need full length of tag
// val tag = UInt((VAddrBits - log2Up(BtbSize) - 1).W)
//}
//
//class BTB extends XSModule {
// val io = IO(new Bundle() {
// // Input
// val in = new Bundle {
// val pc = Flipped(Decoupled(UInt(VAddrBits.W)))
// val pcLatch = Input(UInt(VAddrBits.W))
// val mask = Input(UInt(PredictWidth.W))
// }
// val redirectValid = Input(Bool())
// val flush = Input(Bool())
// val update = Input(new BTBUpdateBundle)
// // Output
// val out = Output(new BTBPred)
// })
//
// io.in.pc.ready := true.B
// val fireLatch = RegNext(io.in.pc.fire())
// val maskLatch = RegEnable(io.in.mask, io.in.pc.fire())
//
// val btbAddr = new TableAddr(log2Up(BtbSize), BtbBanks)
//
// // SRAMs to store BTB meta & data
// val btbMeta = List.fill(BtbBanks)(
// Module(new SRAMTemplate(btbMetaEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true)))
// val btbData = List.fill(BtbBanks)(
// Module(new SRAMTemplate(btbDataEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true)))
//
// // BTB read requests
// val baseBank = btbAddr.getBank(io.in.pc.bits)
// // circular shifting
// def circularShiftLeft(source: UInt, len: Int, shamt: UInt): UInt = {
// val res = Wire(UInt(len.W))
// val higher = source << shamt
// val lower = source >> (len.U - shamt)
// res := higher | lower
// res
// }
// val realMask = circularShiftLeft(io.in.mask, BtbBanks, baseBank)
//
// // those banks whose indexes are less than baseBank are in the next row
// val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
//
// val baseRow = btbAddr.getBankIdx(io.in.pc.bits)
// // this row is the last row of a bank
// val nextRowStartsUp = baseRow.andR
// val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b.U), Mux(nextRowStartsUp, 0.U, baseRow+1.U), baseRow)))
// val realRowLatch = VecInit(realRow.map(RegNext(_)))
//
// for (b <- 0 until BtbBanks) {
// btbMeta(b).reset := reset.asBool
// btbMeta(b).io.r.req.valid := realMask(b) && io.in.pc.valid
// btbMeta(b).io.r.req.bits.setIdx := realRow(b)
// btbData(b).reset := reset.asBool
// btbData(b).io.r.req.valid := realMask(b) && io.in.pc.valid
// btbData(b).io.r.req.bits.setIdx := realRow(b)
// }
//
//
// // Entries read from SRAM
// val metaRead = Wire(Vec(BtbBanks, btbMetaEntry()))
// val dataRead = Wire(Vec(BtbBanks, btbDataEntry()))
// val readFire = Wire(Vec(BtbBanks, Bool()))
// for (b <- 0 until BtbBanks) {
// readFire(b) := btbMeta(b).io.r.req.fire() && btbData(b).io.r.req.fire()
// metaRead(b) := btbMeta(b).io.r.resp.data(0)
// dataRead(b) := btbData(b).io.r.resp.data(0)
// }
//
// val baseBankLatch = btbAddr.getBank(io.in.pcLatch)
// // val isAlignedLatch = baseBankLatch === 0.U
// val baseTag = btbAddr.getTag(io.in.pcLatch)
// // If the next row starts up, the tag needs to be incremented as well
// val tagIncremented = VecInit((0 until BtbBanks).map(b => RegEnable(isInNextRow(b.U) && nextRowStartsUp, io.in.pc.valid)))
//
// val bankHits = Wire(Vec(BtbBanks, Bool()))
// for (b <- 0 until BtbBanks) {
// bankHits(b) := metaRead(b).valid &&
// (Mux(tagIncremented(b), baseTag+1.U, baseTag) === metaRead(b).tag) && !io.flush && RegNext(readFire(b), init = false.B)
// }
//
// // taken branches of jumps from a valid entry
// val predTakens = Wire(Vec(BtbBanks, Bool()))
// // not taken branches from a valid entry
// val notTakenBranches = Wire(Vec(BtbBanks, Bool()))
// for (b <- 0 until BtbBanks) {
// predTakens(b) := bankHits(b) && (dataRead(b).btbType === BTBtype.J || dataRead(b).btbType === BTBtype.B && dataRead(b).pred(1).asBool)
// notTakenBranches(b) := bankHits(b) && dataRead(b).btbType === BTBtype.B && !dataRead(b).pred(1).asBool
// }
//
// // e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
// val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch + b.U) % BtbBanks.U))
//
// val isTaken = predTakens.reduce(_||_)
// // Priority mux which corresponds with inst orders
// // BTB only produce one single prediction
// val takenTarget = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).target)))
// val takenType = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).btbType)))
// // Record which inst is predicted taken
// val takenIdx = MuxCase(0.U, (0 until BtbBanks).map(b => (predTakens(bankIdxInOrder(b)), b.U)))
//
// // Update logic
// // 1 calculate new 2-bit saturated counter value
// def satUpdate(old: UInt, len: Int, taken: Bool): UInt = {
// val oldSatTaken = old === ((1 << len)-1).U
// val oldSatNotTaken = old === 0.U
// Mux(oldSatTaken && taken, ((1 << len)-1).U,
// Mux(oldSatNotTaken && !taken, 0.U,
// Mux(taken, old + 1.U, old - 1.U)))
// }
//
// val u = io.update
// val newCtr = Mux(!u.hit, "b10".U, satUpdate(u.oldCtr, 2, u.taken))
//
// val updateOnSaturated = u.taken && u.oldCtr === "b11".U || !u.taken && u.oldCtr === "b00".U
//
// // 2 write btb
// val updateBankIdx = btbAddr.getBank(u.pc)
// val updateRow = btbAddr.getBankIdx(u.pc)
// val btbMetaWrite = Wire(btbMetaEntry())
// btbMetaWrite.valid := true.B
// btbMetaWrite.tag := btbAddr.getTag(u.pc)
// val btbDataWrite = Wire(btbDataEntry())
// btbDataWrite.target := u.target
// btbDataWrite.pred := newCtr
// btbDataWrite.btbType := u.btbType
// btbDataWrite.isRVC := u.isRVC
//
// val isBr = u.btbType === BTBtype.B
// val isJ = u.btbType === BTBtype.J
// val notBrOrJ = u.btbType =/= BTBtype.B && u.btbType =/= BTBtype.J
//
// // Do not update BTB on indirect or return, or correctly predicted J or saturated counters
// val noNeedToUpdate = (!u.misPred && (isBr && updateOnSaturated || isJ)) || notBrOrJ
//
// // do not update on saturated ctrs
// val btbWriteValid = io.redirectValid && !noNeedToUpdate
//
// for (b <- 0 until BtbBanks) {
// btbMeta(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx
// btbMeta(b).io.w.req.bits.setIdx := updateRow
// btbMeta(b).io.w.req.bits.data := btbMetaWrite
// btbData(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx
// btbData(b).io.w.req.bits.setIdx := updateRow
// btbData(b).io.w.req.bits.data := btbDataWrite
// }
//
// // io.out.hit := bankHits.reduce(_||_)
// io.out.taken := isTaken
// io.out.takenIdx := takenIdx
// io.out.target := takenTarget
// // io.out.writeWay := writeWay
// io.out.notTakens := VecInit((0 until BtbBanks).map(b => notTakenBranches(bankIdxInOrder(b))))
// io.out.dEntries := VecInit((0 until BtbBanks).map(b => dataRead(bankIdxInOrder(b))))
// io.out.hits := VecInit((0 until BtbBanks).map(b => bankHits(bankIdxInOrder(b))))
// io.out.isRVILateJump := io.out.taken && takenIdx === OHToUInt(HighestBit(maskLatch, PredictWidth)) && !dataRead(bankIdxInOrder(takenIdx)).isRVC
//
// // read-after-write bypass
// val rawBypassHit = Wire(Vec(BtbBanks, Bool()))
// for (b <- 0 until BtbBanks) {
// when (b.U === updateBankIdx && realRow(b) === updateRow) { // read and write to the same address
// when (realMask(b) && io.in.pc.valid && btbWriteValid) { // both read and write valid
// rawBypassHit(b) := true.B
// btbMeta(b).io.r.req.valid := false.B
// btbData(b).io.r.req.valid := false.B
// // metaRead(b) := RegNext(btbMetaWrite)
// // dataRead(b) := RegNext(btbDataWrite)
// readFire(b) := true.B
// XSDebug("raw bypass hits: bank=%d, row=%d, meta: %d %x, data: tgt=%x pred=%b btbType=%b isRVC=%d\n",
// b.U, updateRow,
// btbMetaWrite.valid, btbMetaWrite.tag,
// btbDataWrite.target, btbDataWrite.pred, btbDataWrite.btbType, btbDataWrite.isRVC)
// }.otherwise {
// rawBypassHit(b) := false.B
// }
// }.otherwise {
// rawBypassHit(b) := false.B
// }
//
// when (RegNext(rawBypassHit(b))) {
// metaRead(b) := RegNext(btbMetaWrite)
// dataRead(b) := RegNext(btbDataWrite)
// }
// }
//
// XSDebug(io.in.pc.fire(), "read: pc=0x%x, baseBank=%d, realMask=%b\n", io.in.pc.bits, baseBank, realMask)
// XSDebug(fireLatch, "read_resp: pc=0x%x, readIdx=%d-------------------------------\n",
// io.in.pcLatch, btbAddr.getIdx(io.in.pcLatch))
// for (i <- 0 until BtbBanks){
// XSDebug(fireLatch, "read_resp[b=%d][r=%d]: valid=%d, tag=0x%x, target=0x%x, type=%d, ctr=%d\n",
// i.U, realRowLatch(i), metaRead(i).valid, metaRead(i).tag, dataRead(i).target, dataRead(i).btbType, dataRead(i).pred)
// }
// XSDebug("out: taken=%d takenIdx=%d tgt=%x notTakens=%b hits=%b isRVILateJump=%d\n",
// io.out.taken, io.out.takenIdx, io.out.target, io.out.notTakens.asUInt, io.out.hits.asUInt, io.out.isRVILateJump)
// XSDebug(fireLatch, "bankIdxInOrder:")
// for (i <- 0 until BtbBanks){ XSDebug(fireLatch, "%d ", bankIdxInOrder(i))}
// XSDebug(fireLatch, "\n")
// XSDebug(io.redirectValid, "update_req: pc=0x%x, hit=%d, misPred=%d, oldCtr=%d, taken=%d, target=0x%x, btbType=%d\n",
// u.pc, u.hit, u.misPred, u.oldCtr, u.taken, u.target, u.btbType)
// XSDebug(io.redirectValid, "update: noNeedToUpdate=%d, writeValid=%d, bank=%d, row=%d, newCtr=%d\n",
// noNeedToUpdate, btbWriteValid, updateBankIdx, updateRow, newCtr)
//}
\ No newline at end of file
package xiangshan.frontend
import chisel3._
import chisel3.util._
import xiangshan._
import xiangshan.backend.ALUOpType
import utils._
import chisel3.util.experimental.BoringUtils
import xiangshan.backend.decode.XSTrap
class BTBUpdateBundle extends XSBundle {
val pc = UInt(VAddrBits.W)
val hit = Bool()
val misPred = Bool()
val oldCtr = UInt(2.W)
val taken = Bool()
val target = UInt(VAddrBits.W)
val btbType = UInt(2.W)
val isRVC = Bool()
}
class BTBPred extends XSBundle {
val taken = Bool()
val takenIdx = UInt(log2Up(PredictWidth).W)
val target = UInt(VAddrBits.W)
val notTakens = Vec(PredictWidth, Bool())
val dEntries = Vec(PredictWidth, btbDataEntry())
val hits = Vec(PredictWidth, Bool())
// whether an RVI instruction crosses over two fetch packet
val isRVILateJump = Bool()
}
case class btbDataEntry() extends XSBundle {
val target = UInt(VAddrBits.W)
val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor
val btbType = UInt(2.W)
val isRVC = Bool()
}
case class btbMetaEntry() extends XSBundle {
val valid = Bool()
// TODO: don't need full length of tag
val tag = UInt((VAddrBits - log2Up(BtbSize) - 1).W)
}
class BTB extends XSModule {
val io = IO(new Bundle() {
// Input
val in = new Bundle {
val pc = Flipped(Decoupled(UInt(VAddrBits.W)))
val pcLatch = Input(UInt(VAddrBits.W))
val mask = Input(UInt(PredictWidth.W))
}
val redirectValid = Input(Bool())
val flush = Input(Bool())
val update = Input(new BTBUpdateBundle)
// Output
val out = Output(new BTBPred)
})
io.in.pc.ready := true.B
val fireLatch = RegNext(io.in.pc.fire())
val maskLatch = RegEnable(io.in.mask, io.in.pc.fire())
val btbAddr = new TableAddr(log2Up(BtbSize), BtbBanks)
// SRAMs to store BTB meta & data
val btbMeta = List.fill(BtbBanks)(
Module(new SRAMTemplate(btbMetaEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true)))
val btbData = List.fill(BtbBanks)(
Module(new SRAMTemplate(btbDataEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true)))
// BTB read requests
val baseBank = btbAddr.getBank(io.in.pc.bits)
// circular shifting
def circularShiftLeft(source: UInt, len: Int, shamt: UInt): UInt = {
val res = Wire(UInt(len.W))
val higher = source << shamt
val lower = source >> (len.U - shamt)
res := higher | lower
res
}
val realMask = circularShiftLeft(io.in.mask, BtbBanks, baseBank)
// those banks whose indexes are less than baseBank are in the next row
val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
val baseRow = btbAddr.getBankIdx(io.in.pc.bits)
// this row is the last row of a bank
val nextRowStartsUp = baseRow.andR
val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b.U), Mux(nextRowStartsUp, 0.U, baseRow+1.U), baseRow)))
val realRowLatch = VecInit(realRow.map(RegNext(_)))
for (b <- 0 until BtbBanks) {
btbMeta(b).reset := reset.asBool
btbMeta(b).io.r.req.valid := realMask(b) && io.in.pc.valid
btbMeta(b).io.r.req.bits.setIdx := realRow(b)
btbData(b).reset := reset.asBool
btbData(b).io.r.req.valid := realMask(b) && io.in.pc.valid
btbData(b).io.r.req.bits.setIdx := realRow(b)
}
// Entries read from SRAM
val metaRead = Wire(Vec(BtbBanks, btbMetaEntry()))
val dataRead = Wire(Vec(BtbBanks, btbDataEntry()))
val readFire = Wire(Vec(BtbBanks, Bool()))
for (b <- 0 until BtbBanks) {
readFire(b) := btbMeta(b).io.r.req.fire() && btbData(b).io.r.req.fire()
metaRead(b) := btbMeta(b).io.r.resp.data(0)
dataRead(b) := btbData(b).io.r.resp.data(0)
}
val baseBankLatch = btbAddr.getBank(io.in.pcLatch)
// val isAlignedLatch = baseBankLatch === 0.U
val baseTag = btbAddr.getTag(io.in.pcLatch)
// If the next row starts up, the tag needs to be incremented as well
val tagIncremented = VecInit((0 until BtbBanks).map(b => RegEnable(isInNextRow(b.U) && nextRowStartsUp, io.in.pc.valid)))
val bankHits = Wire(Vec(BtbBanks, Bool()))
for (b <- 0 until BtbBanks) {
bankHits(b) := metaRead(b).valid &&
(Mux(tagIncremented(b), baseTag+1.U, baseTag) === metaRead(b).tag) && !io.flush && RegNext(readFire(b), init = false.B)
}
// taken branches of jumps from a valid entry
val predTakens = Wire(Vec(BtbBanks, Bool()))
// not taken branches from a valid entry
val notTakenBranches = Wire(Vec(BtbBanks, Bool()))
for (b <- 0 until BtbBanks) {
predTakens(b) := bankHits(b) && (dataRead(b).btbType === BTBtype.J || dataRead(b).btbType === BTBtype.B && dataRead(b).pred(1).asBool)
notTakenBranches(b) := bankHits(b) && dataRead(b).btbType === BTBtype.B && !dataRead(b).pred(1).asBool
}
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch + b.U) % BtbBanks.U))
val isTaken = predTakens.reduce(_||_)
// Priority mux which corresponds with inst orders
// BTB only produce one single prediction
val takenTarget = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).target)))
val takenType = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).btbType)))
// Record which inst is predicted taken
val takenIdx = MuxCase(0.U, (0 until BtbBanks).map(b => (predTakens(bankIdxInOrder(b)), b.U)))
// Update logic
// 1 calculate new 2-bit saturated counter value
def satUpdate(old: UInt, len: Int, taken: Bool): UInt = {
val oldSatTaken = old === ((1 << len)-1).U
val oldSatNotTaken = old === 0.U
Mux(oldSatTaken && taken, ((1 << len)-1).U,
Mux(oldSatNotTaken && !taken, 0.U,
Mux(taken, old + 1.U, old - 1.U)))
}
val u = io.update
val newCtr = Mux(!u.hit, "b10".U, satUpdate(u.oldCtr, 2, u.taken))
val updateOnSaturated = u.taken && u.oldCtr === "b11".U || !u.taken && u.oldCtr === "b00".U
// 2 write btb
val updateBankIdx = btbAddr.getBank(u.pc)
val updateRow = btbAddr.getBankIdx(u.pc)
val btbMetaWrite = Wire(btbMetaEntry())
btbMetaWrite.valid := true.B
btbMetaWrite.tag := btbAddr.getTag(u.pc)
val btbDataWrite = Wire(btbDataEntry())
btbDataWrite.target := u.target
btbDataWrite.pred := newCtr
btbDataWrite.btbType := u.btbType
btbDataWrite.isRVC := u.isRVC
val isBr = u.btbType === BTBtype.B
val isJ = u.btbType === BTBtype.J
val notBrOrJ = u.btbType =/= BTBtype.B && u.btbType =/= BTBtype.J
// Do not update BTB on indirect or return, or correctly predicted J or saturated counters
val noNeedToUpdate = (!u.misPred && (isBr && updateOnSaturated || isJ)) || notBrOrJ
// do not update on saturated ctrs
val btbWriteValid = io.redirectValid && !noNeedToUpdate
for (b <- 0 until BtbBanks) {
btbMeta(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx
btbMeta(b).io.w.req.bits.setIdx := updateRow
btbMeta(b).io.w.req.bits.data := btbMetaWrite
btbData(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx
btbData(b).io.w.req.bits.setIdx := updateRow
btbData(b).io.w.req.bits.data := btbDataWrite
}
// io.out.hit := bankHits.reduce(_||_)
io.out.taken := isTaken
io.out.takenIdx := takenIdx
io.out.target := takenTarget
// io.out.writeWay := writeWay
io.out.notTakens := VecInit((0 until BtbBanks).map(b => notTakenBranches(bankIdxInOrder(b))))
io.out.dEntries := VecInit((0 until BtbBanks).map(b => dataRead(bankIdxInOrder(b))))
io.out.hits := VecInit((0 until BtbBanks).map(b => bankHits(bankIdxInOrder(b))))
io.out.isRVILateJump := io.out.taken && takenIdx === OHToUInt(HighestBit(maskLatch, PredictWidth)) && !dataRead(bankIdxInOrder(takenIdx)).isRVC
// read-after-write bypass
val rawBypassHit = Wire(Vec(BtbBanks, Bool()))
for (b <- 0 until BtbBanks) {
when (b.U === updateBankIdx && realRow(b) === updateRow) { // read and write to the same address
when (realMask(b) && io.in.pc.valid && btbWriteValid) { // both read and write valid
rawBypassHit(b) := true.B
btbMeta(b).io.r.req.valid := false.B
btbData(b).io.r.req.valid := false.B
// metaRead(b) := RegNext(btbMetaWrite)
// dataRead(b) := RegNext(btbDataWrite)
readFire(b) := true.B
XSDebug("raw bypass hits: bank=%d, row=%d, meta: %d %x, data: tgt=%x pred=%b btbType=%b isRVC=%d\n",
b.U, updateRow,
btbMetaWrite.valid, btbMetaWrite.tag,
btbDataWrite.target, btbDataWrite.pred, btbDataWrite.btbType, btbDataWrite.isRVC)
}.otherwise {
rawBypassHit(b) := false.B
}
}.otherwise {
rawBypassHit(b) := false.B
}
when (RegNext(rawBypassHit(b))) {
metaRead(b) := RegNext(btbMetaWrite)
dataRead(b) := RegNext(btbDataWrite)
}
}
XSDebug(io.in.pc.fire(), "read: pc=0x%x, baseBank=%d, realMask=%b\n", io.in.pc.bits, baseBank, realMask)
XSDebug(fireLatch, "read_resp: pc=0x%x, readIdx=%d-------------------------------\n",
io.in.pcLatch, btbAddr.getIdx(io.in.pcLatch))
for (i <- 0 until BtbBanks){
XSDebug(fireLatch, "read_resp[b=%d][r=%d]: valid=%d, tag=0x%x, target=0x%x, type=%d, ctr=%d\n",
i.U, realRowLatch(i), metaRead(i).valid, metaRead(i).tag, dataRead(i).target, dataRead(i).btbType, dataRead(i).pred)
}
XSDebug("out: taken=%d takenIdx=%d tgt=%x notTakens=%b hits=%b isRVILateJump=%d\n",
io.out.taken, io.out.takenIdx, io.out.target, io.out.notTakens.asUInt, io.out.hits.asUInt, io.out.isRVILateJump)
XSDebug(fireLatch, "bankIdxInOrder:")
for (i <- 0 until BtbBanks){ XSDebug(fireLatch, "%d ", bankIdxInOrder(i))}
XSDebug(fireLatch, "\n")
XSDebug(io.redirectValid, "update_req: pc=0x%x, hit=%d, misPred=%d, oldCtr=%d, taken=%d, target=0x%x, btbType=%d\n",
u.pc, u.hit, u.misPred, u.oldCtr, u.taken, u.target, u.btbType)
XSDebug(io.redirectValid, "update: noNeedToUpdate=%d, writeValid=%d, bank=%d, row=%d, newCtr=%d\n",
noNeedToUpdate, btbWriteValid, updateBankIdx, updateRow, newCtr)
}
\ No newline at end of file
//package xiangshan.frontend
//
//import chisel3._
//import chisel3.util._
//import xiangshan._
//import utils._
//import xiangshan.backend.ALUOpType
//
//
//class JBTACUpdateBundle extends XSBundle {
// val fetchPC = UInt(VAddrBits.W)
// val fetchIdx = UInt(log2Up(PredictWidth).W)
// val hist = UInt(HistoryLength.W)
// val target = UInt(VAddrBits.W)
// val btbType = UInt(2.W)
// val misPred = Bool()
// val isRVC = Bool()
//}
//
//class JBTACPred extends XSBundle {
// val hit = Bool()
// val target = UInt(VAddrBits.W)
// val hitIdx = UInt(log2Up(PredictWidth).W)
// val isRVILateJump = Bool()
// val isRVC = Bool()
//}
//
//class JBTAC extends XSModule {
// val io = IO(new Bundle {
// val in = new Bundle {
// val pc = Flipped(Decoupled(UInt(VAddrBits.W)))
// val pcLatch = Input(UInt(VAddrBits.W))
// val mask = Input(UInt(PredictWidth.W))
// val hist = Input(UInt(HistoryLength.W))
// }
// val redirectValid = Input(Bool())
// val flush = Input(Bool())
// val update = Input(new JBTACUpdateBundle)
//
// val out = Output(new JBTACPred)
// })
//
// io.in.pc.ready := true.B
//
// val fireLatch = RegNext(io.in.pc.fire())
//
// // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
// val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
// def jbtacEntry() = new Bundle {
// val valid = Bool()
// // TODO: don't need full length of tag and target
// val tag = UInt(jbtacAddr.tagBits.W + jbtacAddr.idxBits.W)
// val target = UInt(VAddrBits.W)
// val offset = UInt(log2Up(PredictWidth).W)
// val isRVC = Bool()
// }
//
// val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false)))
//
// val readEntries = Wire(Vec(JbtacBanks, jbtacEntry()))
//
// val readFire = Reg(Vec(JbtacBanks, Bool()))
// // Only read one bank
// val histXORAddr = io.in.pc.bits ^ Cat(io.in.hist, 0.U(1.W))(VAddrBits - 1, 0)
// val histXORAddrLatch = RegEnable(histXORAddr, io.in.pc.valid)
//
// val readBank = jbtacAddr.getBank(histXORAddr)
// val readRow = jbtacAddr.getBankIdx(histXORAddr)
// readFire := 0.U.asTypeOf(Vec(JbtacBanks, Bool()))
// (0 until JbtacBanks).map(
// b => {
// jbtac(b).reset := reset.asBool
// jbtac(b).io.r.req.valid := io.in.pc.fire() && b.U === readBank
// jbtac(b).io.r.req.bits.setIdx := readRow
// readFire(b) := jbtac(b).io.r.req.fire()
// readEntries(b) := jbtac(b).io.r.resp.data(0)
// }
// )
//
// val readBankLatch = jbtacAddr.getBank(histXORAddrLatch)
// val readRowLatch = jbtacAddr.getBankIdx(histXORAddrLatch)
// val readMaskLatch = RegEnable(io.in.mask, io.in.pc.fire())
//
// val outHit = readEntries(readBankLatch).valid &&
// readEntries(readBankLatch).tag === Cat(jbtacAddr.getTag(io.in.pcLatch), jbtacAddr.getIdx(io.in.pcLatch)) &&
// !io.flush && RegNext(readFire(readBankLatch)) && readMaskLatch(readEntries(readBankLatch).offset).asBool
//
// io.out.hit := outHit
// io.out.hitIdx := readEntries(readBankLatch).offset
// io.out.target := readEntries(readBankLatch).target
// io.out.isRVILateJump := io.out.hit && io.out.hitIdx === OHToUInt(HighestBit(readMaskLatch, PredictWidth)) && !readEntries(readBankLatch).isRVC
// io.out.isRVC := readEntries(readBankLatch).isRVC
//
// // update jbtac
// val writeEntry = Wire(jbtacEntry())
// // val updateHistXORAddr = updatefetchPC ^ Cat(r.hist, 0.U(2.W))(VAddrBits - 1, 0)
// val updateHistXORAddr = io.update.fetchPC ^ Cat(io.update.hist, 0.U(1.W))(VAddrBits - 1, 0)
// writeEntry.valid := true.B
// // writeEntry.tag := jbtacAddr.getTag(updatefetchPC)
// writeEntry.tag := Cat(jbtacAddr.getTag(io.update.fetchPC), jbtacAddr.getIdx(io.update.fetchPC))
// writeEntry.target := io.update.target
// // writeEntry.offset := updateFetchIdx
// writeEntry.offset := io.update.fetchIdx
// writeEntry.isRVC := io.update.isRVC
//
// val writeBank = jbtacAddr.getBank(updateHistXORAddr)
// val writeRow = jbtacAddr.getBankIdx(updateHistXORAddr)
// val writeValid = io.redirectValid && io.update.misPred && io.update.btbType === BTBtype.I
// for (b <- 0 until JbtacBanks) {
// when (b.U === writeBank) {
// jbtac(b).io.w.req.valid := writeValid
// jbtac(b).io.w.req.bits.setIdx := writeRow
// jbtac(b).io.w.req.bits.data := writeEntry
// }.otherwise {
// jbtac(b).io.w.req.valid := false.B
// jbtac(b).io.w.req.bits.setIdx := DontCare
// jbtac(b).io.w.req.bits.data := DontCare
// }
// }
//
// // read-after-write bypass
// val rawBypassHit = Wire(Vec(JbtacBanks, Bool()))
// for (b <- 0 until JbtacBanks) {
// when (readBank === writeBank && readRow === writeRow && b.U === readBank) {
// when (io.in.pc.fire() && writeValid) {
// rawBypassHit(b) := true.B
// jbtac(b).io.r.req.valid := false.B
// // readEntries(b) := RegNext(writeEntry)
// readFire(b) := true.B
//
// XSDebug("raw bypass hits: bank=%d, row=%d, tag=%x, tgt=%x, offet=%d, isRVC=%d\n",
// b.U, readRow, writeEntry.tag, writeEntry.target, writeEntry.offset, writeEntry.isRVC)
// }.otherwise {
// rawBypassHit(b) := false.B
// }
// }.otherwise {
// rawBypassHit(b) := false.B
// }
//
// when (RegNext(rawBypassHit(b))) { readEntries(b) := RegNext(writeEntry) }
// }
//
// XSDebug(io.in.pc.fire(), "read: pc=0x%x, histXORAddr=0x%x, bank=%d, row=%d, hist=%b\n",
// io.in.pc.bits, histXORAddr, readBank, readRow, io.in.hist)
// XSDebug("out: hit=%d tgt=%x hitIdx=%d iRVILateJump=%d isRVC=%d\n",
// io.out.hit, io.out.target, io.out.hitIdx, io.out.isRVILateJump, io.out.isRVC)
// XSDebug(fireLatch, "read_resp: pc=0x%x, bank=%d, row=%d, target=0x%x, offset=%d, hit=%d\n",
// io.in.pcLatch, readBankLatch, readRowLatch, readEntries(readBankLatch).target, readEntries(readBankLatch).offset, outHit)
// XSDebug(io.redirectValid, "update_req: fetchPC=0x%x, writeValid=%d, hist=%b, bank=%d, row=%d, target=0x%x, offset=%d, type=0x%d\n",
// io.update.fetchPC, writeValid, io.update.hist, writeBank, writeRow, io.update.target, io.update.fetchIdx, io.update.btbType)
//}
\ No newline at end of file
package xiangshan.frontend
import chisel3._
import chisel3.util._
import xiangshan._
import utils._
import xiangshan.backend.ALUOpType
class JBTACUpdateBundle extends XSBundle {
val fetchPC = UInt(VAddrBits.W)
val fetchIdx = UInt(log2Up(PredictWidth).W)
val hist = UInt(HistoryLength.W)
val target = UInt(VAddrBits.W)
val btbType = UInt(2.W)
val misPred = Bool()
val isRVC = Bool()
}
class JBTACPred extends XSBundle {
val hit = Bool()
val target = UInt(VAddrBits.W)
val hitIdx = UInt(log2Up(PredictWidth).W)
val isRVILateJump = Bool()
val isRVC = Bool()
}
class JBTAC extends XSModule {
val io = IO(new Bundle {
val in = new Bundle {
val pc = Flipped(Decoupled(UInt(VAddrBits.W)))
val pcLatch = Input(UInt(VAddrBits.W))
val mask = Input(UInt(PredictWidth.W))
val hist = Input(UInt(HistoryLength.W))
}
val redirectValid = Input(Bool())
val flush = Input(Bool())
val update = Input(new JBTACUpdateBundle)
val out = Output(new JBTACPred)
})
io.in.pc.ready := true.B
val fireLatch = RegNext(io.in.pc.fire())
// JBTAC, divided into 8 banks, makes prediction for indirect jump except ret.
val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks)
def jbtacEntry() = new Bundle {
val valid = Bool()
// TODO: don't need full length of tag and target
val tag = UInt(jbtacAddr.tagBits.W + jbtacAddr.idxBits.W)
val target = UInt(VAddrBits.W)
val offset = UInt(log2Up(PredictWidth).W)
val isRVC = Bool()
}
val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false)))
val readEntries = Wire(Vec(JbtacBanks, jbtacEntry()))
val readFire = Reg(Vec(JbtacBanks, Bool()))
// Only read one bank
val histXORAddr = io.in.pc.bits ^ Cat(io.in.hist, 0.U(1.W))(VAddrBits - 1, 0)
val histXORAddrLatch = RegEnable(histXORAddr, io.in.pc.valid)
val readBank = jbtacAddr.getBank(histXORAddr)
val readRow = jbtacAddr.getBankIdx(histXORAddr)
readFire := 0.U.asTypeOf(Vec(JbtacBanks, Bool()))
(0 until JbtacBanks).map(
b => {
jbtac(b).reset := reset.asBool
jbtac(b).io.r.req.valid := io.in.pc.fire() && b.U === readBank
jbtac(b).io.r.req.bits.setIdx := readRow
readFire(b) := jbtac(b).io.r.req.fire()
readEntries(b) := jbtac(b).io.r.resp.data(0)
}
)
val readBankLatch = jbtacAddr.getBank(histXORAddrLatch)
val readRowLatch = jbtacAddr.getBankIdx(histXORAddrLatch)
val readMaskLatch = RegEnable(io.in.mask, io.in.pc.fire())
val outHit = readEntries(readBankLatch).valid &&
readEntries(readBankLatch).tag === Cat(jbtacAddr.getTag(io.in.pcLatch), jbtacAddr.getIdx(io.in.pcLatch)) &&
!io.flush && RegNext(readFire(readBankLatch)) && readMaskLatch(readEntries(readBankLatch).offset).asBool
io.out.hit := outHit
io.out.hitIdx := readEntries(readBankLatch).offset
io.out.target := readEntries(readBankLatch).target
io.out.isRVILateJump := io.out.hit && io.out.hitIdx === OHToUInt(HighestBit(readMaskLatch, PredictWidth)) && !readEntries(readBankLatch).isRVC
io.out.isRVC := readEntries(readBankLatch).isRVC
// update jbtac
val writeEntry = Wire(jbtacEntry())
// val updateHistXORAddr = updatefetchPC ^ Cat(r.hist, 0.U(2.W))(VAddrBits - 1, 0)
val updateHistXORAddr = io.update.fetchPC ^ Cat(io.update.hist, 0.U(1.W))(VAddrBits - 1, 0)
writeEntry.valid := true.B
// writeEntry.tag := jbtacAddr.getTag(updatefetchPC)
writeEntry.tag := Cat(jbtacAddr.getTag(io.update.fetchPC), jbtacAddr.getIdx(io.update.fetchPC))
writeEntry.target := io.update.target
// writeEntry.offset := updateFetchIdx
writeEntry.offset := io.update.fetchIdx
writeEntry.isRVC := io.update.isRVC
val writeBank = jbtacAddr.getBank(updateHistXORAddr)
val writeRow = jbtacAddr.getBankIdx(updateHistXORAddr)
val writeValid = io.redirectValid && io.update.misPred && io.update.btbType === BTBtype.I
for (b <- 0 until JbtacBanks) {
when (b.U === writeBank) {
jbtac(b).io.w.req.valid := writeValid
jbtac(b).io.w.req.bits.setIdx := writeRow
jbtac(b).io.w.req.bits.data := writeEntry
}.otherwise {
jbtac(b).io.w.req.valid := false.B
jbtac(b).io.w.req.bits.setIdx := DontCare
jbtac(b).io.w.req.bits.data := DontCare
}
}
// read-after-write bypass
val rawBypassHit = Wire(Vec(JbtacBanks, Bool()))
for (b <- 0 until JbtacBanks) {
when (readBank === writeBank && readRow === writeRow && b.U === readBank) {
when (io.in.pc.fire() && writeValid) {
rawBypassHit(b) := true.B
jbtac(b).io.r.req.valid := false.B
// readEntries(b) := RegNext(writeEntry)
readFire(b) := true.B
XSDebug("raw bypass hits: bank=%d, row=%d, tag=%x, tgt=%x, offet=%d, isRVC=%d\n",
b.U, readRow, writeEntry.tag, writeEntry.target, writeEntry.offset, writeEntry.isRVC)
}.otherwise {
rawBypassHit(b) := false.B
}
}.otherwise {
rawBypassHit(b) := false.B
}
when (RegNext(rawBypassHit(b))) { readEntries(b) := RegNext(writeEntry) }
}
XSDebug(io.in.pc.fire(), "read: pc=0x%x, histXORAddr=0x%x, bank=%d, row=%d, hist=%b\n",
io.in.pc.bits, histXORAddr, readBank, readRow, io.in.hist)
XSDebug("out: hit=%d tgt=%x hitIdx=%d iRVILateJump=%d isRVC=%d\n",
io.out.hit, io.out.target, io.out.hitIdx, io.out.isRVILateJump, io.out.isRVC)
XSDebug(fireLatch, "read_resp: pc=0x%x, bank=%d, row=%d, target=0x%x, offset=%d, hit=%d\n",
io.in.pcLatch, readBankLatch, readRowLatch, readEntries(readBankLatch).target, readEntries(readBankLatch).offset, outHit)
XSDebug(io.redirectValid, "update_req: fetchPC=0x%x, writeValid=%d, hist=%b, bank=%d, row=%d, target=0x%x, offset=%d, type=0x%d\n",
io.update.fetchPC, writeValid, io.update.hist, writeBank, writeRow, io.update.target, io.update.fetchIdx, io.update.btbType)
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册