提交 58cc8bf7 编写于 作者: L Lingrui98

BPU: fixed all grammatic errors

上级 eb9c4583
......@@ -77,6 +77,14 @@ abstract class BasePredictor extends XSModule {
res := higher | lower
res
}
def circularShiftRight(source: UInt, len: Int, shamt: UInt): UInt = {
val res = Wire(UInt(len.W))
val higher = source << (len.U - shamt)
val lower = source >> shamt
res := higher | lower
res
}
}
class BPUStageIO extends XSBundle {
......
......@@ -49,10 +49,10 @@ class BIM extends BasePredictor with BimParams{
val baseBank = bimAddr.getBank(io.pc.bits)
val realMask = circularShiftLeft(io.inMask, BimBanks, baseBank)
val realMask = circularShiftRight(io.inMask, BimBanks, baseBank)
// those banks whose indexes are less than baseBank are in the next row
val isInNextRow = VecInit((0 until BtbBanks).map(b => ((BimBanks - baseBank) +& b.U)(log2Up(BimBanks))))
val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
val baseRow = bimAddr.getBankIdx(io.pc.bits)
......@@ -71,7 +71,7 @@ class BIM extends BasePredictor with BimParams{
val baseBankLatch = bimAddr.getBank(pcLatch)
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BimBanks).map(b => ((BimBanks - baseBankLatch) +& b.U)(log2Up(BimBanks)-1, 0)))
val bankIdxInOrder = VecInit((0 until BimBanks).map(b => (baseBankLatch +& b.U)(log2Up(BimBanks)-1, 0)))
for (b <- 0 until BimBanks) {
val ctr = bimRead(bankIdxInOrder(b))
......
......@@ -74,10 +74,10 @@ class BTB extends BasePredictor with BTBParams{
// BTB read requests
val baseBank = btbAddr.getBank(io.pc.bits)
val realMask = circularShiftLeft(io.inMask, BtbBanks, baseBank)
val realMask = circularShiftRight(io.inMask, BtbBanks, baseBank)
// those banks whose indexes are less than baseBank are in the next row
val isInNextRow = VecInit((0 until BtbBanks).map(b => ((BtbBanks - baseBank) +& b.U)(log2Up(BtbBanks))))
val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
val baseRow = btbAddr.getBankIdx(io.pc.bits)
......@@ -124,7 +124,7 @@ class BTB extends BasePredictor with BTBParams{
))
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => ((BtbBanks - baseBankLatch) +& b.U)(log2Up(BtbBanks)-1,0)))
val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch +& b.U)(log2Up(BtbBanks)-1,0)))
for (b <- 0 until BtbBanks) {
......
......@@ -98,6 +98,14 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
Mux(ctr === 7.U, 7.U, ctr + 1.U))
}
def circularShiftRight(source: UInt, len: Int, shamt: UInt): UInt = {
val res = Wire(UInt(len.W))
val higher = source << (len.U - shamt)
val lower = source >> shamt
res := higher | lower
res
}
val doing_reset = RegInit(true.B)
val reset_idx = RegInit(0.U(log2Ceil(nRows).W))
reset_idx := reset_idx + doing_reset
......@@ -111,30 +119,36 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val tageEntrySz = 1 + tagLen + 3
val unhashed_idxes = VecInit((0 until TageBanks).map((io.req.bits.pc >> 1.U) + _.U))
val bankIdxes = VecInit(unhashed_idxes.map(_(log2Up(TageBanks)-1, 0)))
// use real address to index
val unhashed_idxes = VecInit((0 until TageBanks).map(b => ((io.req.bits.pc >> 1.U) + b.U) >> log2Up(TageBanks).U))
val idxes_and_tags = (0 until TageBanks).map(b => compute_tag_and_hash(unhashed_idxes(b.U), io.req.bits.hist))
val idxes = VecInit(idxes_and_tags.map(_._1))
val tags = VecInit(idxes_and_tags.map(_._2))
val idxLatch = RegEnable(idxes, enable=io.req.valid)
val tagLatch = RegEnable(tags, enable=io.req.valid)
val hi_us = List.fill(TageBanks)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val lo_us = List.fill(TageBanks)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val table = List.fill(TageBanks)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false)))
val bankIdxesLatch = RegEnable(bankIdxes, enable=io.req.valid)
val idxLatch = RegEnable(VecInit(idxes_and_tags.map(_._1)), enable=io.req.valid)
val tagLatch = RegEnable(VecInit(idxes_and_tags.map(_._2)), enable=io.req.valid)
val hi_us_r = Wire(Vec(TageBanks, Bool()))
val lo_us_r = Wire(Vec(TageBanks, Bool()))
val table_r = Wire(Vec(TageBanks, new TageEntry))
val baseBank = unhashed_idxes(0)(log2Up(TageBanks)-1, 0)
val baseBank = io.req.bits.pc(log2Up(TageBanks), 1)
val bankIdxInOrder = VecInit((0 until TageBanks).map(b => ((TageBanks - baseBank) +& b.U)(log2Up(TageBanks)-1, 0)))
val bankIdxInOrderLatch= RegEnable(bankIdxInOrder, enable=io.req.valid)
// This is different from that in BTB and BIM
// We want to pass the correct index and tag into the TAGE table
// if baseBank == 9, then we want to pass idxes_and_tags(0) to bank 9,
// 0 1 8 9 10 15
// so the correct order is 7, 8, ..., 15, 0, 1, ..., 6
val iAndTIdxInOrder = VecInit((0 until TageBanks).map(b => ((TageBanks.U - baseBank) + b.U)(log2Up(TageBanks)-1, 0)))
val iAndTIdxInOrderLatch = RegEnable(iAndTIdxInOrder, enable=io.req.valid)
val realMask = circularShiftLeft(io.inMask, TageBanks, baseBank)
val realMask = circularShiftRight(io.req.bits.mask, TageBanks, baseBank)
(0 until TageBanks).map(
b => {
......@@ -144,14 +158,14 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
hi_us(b).io.r.req.valid := io.req.valid
lo_us(b).io.r.req.valid := io.req.valid
table(b).io.r.req.valid := io.req.valid
lo_us(b).io.r.req.bits.setIdx := idxes_and_tags(bankIdxInOrder(b))._1
hi_us(b).io.r.req.bits.setIdx := idxes_and_tags(bankIdxInOrder(b))._1
table(b).io.r.req.bits.setIdx := idxes_and_tags(bankIdxInOrder(b))._1
lo_us(b).io.r.req.bits.setIdx := idxes(iAndTIdxInOrder(b.U))
hi_us(b).io.r.req.bits.setIdx := idxes(iAndTIdxInOrder(b.U))
table(b).io.r.req.bits.setIdx := idxes(iAndTIdxInOrder(b.U))
// Reorder done
hi_us_r(bankIdxInOrderLatch(b)) := hi_us(b).io.r.resp.data(0)
lo_us_r(bankIdxInOrderLatch(b)) := lo_us(b).io.r.resp.data(0)
table_r(bankIdxInOrderLatch(b)) := table(b).io.r.resp.data(0)
hi_us_r(iAndTIdxInOrderLatch(b)) := hi_us(b).io.r.resp.data(0)
lo_us_r(iAndTIdxInOrderLatch(b)) := lo_us(b).io.r.resp.data(0)
table_r(iAndTIdxInOrderLatch(b)) := table(b).io.r.resp.data(0)
}
)
......@@ -172,7 +186,7 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U
val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod)
val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc >> 1.U, io.update.hist)
val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc >> (1.U + log2Up(TageBanks).U), io.update.hist)
val update_wdata = Wire(Vec(TageBanks, new TageEntry))
......@@ -237,25 +251,25 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
}
}
XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%b, base_idx=%d, base_tag=%x\n",
io.req.bits.pc, io.req.bits.hist, idxes_and_tags(0)._1, idxes_and_tags(0)._2)
io.req.bits.pc, io.req.bits.hist, idxes(0.U), tags(0.U))
for (i <- 0 until TageBanks) {
XSDebug(RegNext(io.req.valid), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, idxLatch(i), req_rhits(i), table_r(i).ctr, Cat(hi_us_r(i),lo_us_r(i)).asUInt)
}
}
class FakeTAGE extends TageModule {
class FakeTAGE extends BasePredictor with HasTageParameter {
class TAGEResp extends Resp {
val takens = Vec(PredictWidth, ValidUndirectioned(Bool()))
}
class TAGEMeta extends Meta with TageMeta{
class TAGEMeta extends Meta {
}
class FromBIM extends FromOthers {
val ctrs = Vec(PredictWidth, UInt(2.W))
}
class TageIO extends DefaultBasePredictorIO {
val resp = Output(new TAGEResp)
val meta = Output(Vec(PredictWidth, new TAGEMeta))
val meta = Output(Vec(PredictWidth, new TageMeta))
val bim = Input(new FromBIM)
val s3Fire = Input(Bool())
}
......@@ -267,18 +281,18 @@ class FakeTAGE extends TageModule {
}
class Tage extends BasePredictor with TageModule {
class Tage extends BasePredictor with HasTageParameter {
class TAGEResp extends Resp {
val takens = Vec(PredictWidth, ValidUndirectioned(Bool()))
}
class TAGEMeta extends Meta with TageMeta{
class TAGEMeta extends Meta{
}
class FromBIM extends FromOthers {
val ctrs = Vec(PredictWidth, UInt(2.W))
}
class TageIO extends DefaultBasePredictorIO {
val resp = Output(new TAGEResp)
val meta = Output(Vec(PredictWidth, new TAGEMeta))
val meta = Output(Vec(PredictWidth, new TageMeta))
val bim = Input(new FromBIM)
val s3Fire = Input(Bool())
}
......@@ -297,7 +311,7 @@ class Tage extends BasePredictor with TageModule {
}
// Keep the table responses to process in s3
val resps = VecInit(tables.map(RegEnable(_.io.resp, enable=io.s3Fire)))
val resps = VecInit(tables.map(t => RegEnable(t.io.resp, enable=io.s3Fire)))
val s2_bim = RegEnable(io.bim, enable=io.pc.valid) // actually it is s2Fire
val s3_bim = RegEnable(s2_bim, enable=io.s3Fire)
......@@ -305,8 +319,8 @@ class Tage extends BasePredictor with TageModule {
val debug_pc_s2 = RegEnable(io.pc.bits, enable=io.pc.valid)
val debug_pc_s3 = RegEnable(debug_pc_s2, enable=io.s3Fire)
val updateMeta = io.update.brInfo.tageMeta
val updateMisPred = io.update.isMisPred && io.update.pd.isBr
val updateMeta = io.update.bits.brInfo.tageMeta
val updateMisPred = io.update.bits.isMisPred && io.update.bits.pd.isBr
val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(TageBanks, Bool()))))
val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(TageBanks, Bool()))))
......@@ -319,7 +333,7 @@ class Tage extends BasePredictor with TageModule {
updateOldCtr := DontCare
updateU := DontCare
val updateBank = io.update.pc >> 1.U
val updateBank = io.update.bits.pc >> 1.U
// access tag tables and output meta info
for (w <- 0 until TageBanks) {
......@@ -331,7 +345,8 @@ class Tage extends BasePredictor with TageModule {
io.resp.takens(w).bits := s3_bim.ctrs(w)(1)
for (i <- 0 until TageNTables) {
io.resp.takens(w).valid = resps(i)(w).valid
val hit = resps(i)(w).valid
io.resp.takens(w).valid := hit
val ctr = resps(i)(w).bits.ctr
when (hit) {
io.resp.takens(w).bits := Mux(ctr === 3.U || ctr === 4.U, altPred, ctr(2)) // Use altpred on weak taken
......@@ -344,7 +359,7 @@ class Tage extends BasePredictor with TageModule {
io.resp.takens(w).valid := provided
io.meta(w).provider.valid := provided
io.meta(w).provider.bits := provider
io.meta(w).altDiffers := finalAltPred =/= io.out.takens(w)
io.meta(w).altDiffers := finalAltPred =/= io.resp.takens(w).bits
io.meta(w).providerU := resps(provider)(w).bits.u
io.meta(w).providerCtr := resps(provider)(w).bits.ctr
......@@ -361,8 +376,8 @@ class Tage extends BasePredictor with TageModule {
io.meta(w).allocate.bits := allocEntry
val isUpdateTaken = io.update.valid && updateBank === w.U &&
io.udpate.taken && io.update.pd.isBr
when (io.update.pd.isBr && io.update.valid && updateBank === w.U) {
io.update.bits.taken && io.update.bits.pd.isBr
when (io.update.bits.pd.isBr && io.update.valid && updateBank === w.U) {
when (updateMeta.provider.valid) {
val provider = updateMeta.provider.bits
......@@ -385,7 +400,7 @@ class Tage extends BasePredictor with TageModule {
val allocate = updateMeta.allocate
when (allocate.valid) {
updateMask(allocate.bits)(idx) := true.B
updateTaken(allocate.bits)(idx) := io.update.taken
updateTaken(allocate.bits)(idx) := io.update.bits.taken
updateAlloc(allocate.bits)(idx) := true.B
updateUMask(allocate.bits)(idx) := true.B
updateU(allocate.bits)(idx) := 0.U
......@@ -412,17 +427,17 @@ class Tage extends BasePredictor with TageModule {
tables(i).io.update.u(w) := updateU(i)(w)
}
// use fetch pc instead of instruction pc
tables(i).io.update.pc := io.update.pc
tables(i).io.update.hist := io.update.hist
tables(i).io.update.pc := io.update.bits.pc
tables(i).io.update.hist := io.update.bits.hist
}
val m = updateMeta
XSDebug(io.req.valid, "req: pc=0x%x, hist=%b\n", io.req.bits.pc, io.req.bits.hist)
XSDebug(io.pc.valid, "req: pc=0x%x, hist=%b\n", io.pc.bits, io.hist)
XSDebug(io.update.valid, "redirect: provider(%d):%d, altDiffers:%d, providerU:%d, providerCtr:%d, allocate(%d):%d\n",
m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits)
// This is not reversed
XSDebug(true.B, "s3Fire:%d, resp: pc=%x, hits=%b, takens=%b\n",
debug_pc_s3, Cat(io.resp.map(_.valid)).asUInt, Cat(io.resp.map(_.bits)).asUInt)
debug_pc_s3, Cat(io.resp.takens.map(_.valid)).asUInt, Cat(io.resp.takens.map(_.bits)).asUInt)
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册