From 80d2974b083ab30ee12d60853063f1c370505af3 Mon Sep 17 00:00:00 2001 From: Lingrui98 Date: Wed, 29 Jul 2020 16:41:21 +0800 Subject: [PATCH] BPU: Initiate refactoring --- src/main/scala/xiangshan/Bundle.scala | 3 +- src/main/scala/xiangshan/frontend/BPU.scala | 560 ++++++++++++- src/main/scala/xiangshan/frontend/IFU.scala | 22 - src/main/scala/xiangshan/frontend/Tage.scala | 772 +++++++++--------- src/main/scala/xiangshan/frontend/btb.scala | 496 +++++------ src/main/scala/xiangshan/frontend/jbtac.scala | 302 +++---- 6 files changed, 1327 insertions(+), 828 deletions(-) diff --git a/src/main/scala/xiangshan/Bundle.scala b/src/main/scala/xiangshan/Bundle.scala index 8c8024848..8b3c8a4eb 100644 --- a/src/main/scala/xiangshan/Bundle.scala +++ b/src/main/scala/xiangshan/Bundle.scala @@ -60,7 +60,7 @@ class BranchInfo extends XSBundle { this.histPtr := histPtr this.tageMeta := tageMeta this.rasSp := rasSp - this.rasTopCtr + this.rasTopCtr := rasTopCtr this.asUInt } def size = 0.U.asTypeOf(this).getWidth @@ -142,7 +142,6 @@ class Redirect extends XSBundle with HasRoqIdx { val pc = UInt(VAddrBits.W) val target = UInt(VAddrBits.W) val brTag = new BrqPtr - val histPtr = UInt(log2Up(ExtHistoryLength).W) } class Dp1ToDp2IO extends XSBundle { diff --git a/src/main/scala/xiangshan/frontend/BPU.scala b/src/main/scala/xiangshan/frontend/BPU.scala index b77639cde..bb8b4e0af 100644 --- a/src/main/scala/xiangshan/frontend/BPU.scala +++ b/src/main/scala/xiangshan/frontend/BPU.scala @@ -7,21 +7,46 @@ import xiangshan._ import xiangshan.backend.ALUOpType import xiangshan.backend.JumpOpType -class BPUStage1To2IO extends XSBundle { - // TODO +class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle { + def tagBits = VAddrBits - idxBits - 1 + + val tag = UInt(tagBits.W) + val idx = UInt(idxBits.W) + val offset = UInt(1.W) + + def fromUInt(x: UInt) = x.asTypeOf(UInt(VAddrBits.W)).asTypeOf(this) + def getTag(x: UInt) = fromUInt(x).tag + def getIdx(x: UInt) = fromUInt(x).idx + def getBank(x: UInt) = getIdx(x)(log2Up(banks) - 1, 0) + def getBankIdx(x: UInt) = getIdx(x)(idxBits - 1, log2Up(banks)) } -class BPUStage2To3IO extends XSBundle { - // TODO +class BTBResponse extends XSBundle { + // the valid bits indicates whether a target is hit + val ubtb = new Bundle { + val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W))) + val takens = Vec(PredictWidth, Bool()) + } + // the valid bits indicates whether a target is hit + val btb = new Bundle { + val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W))) + val takens = Vec(PredictWidth, Bool()) + } } +class BPUStageIO extends XSBundle { + val pc = Output(UInt(VAddrBits.W)) + val btbResp = Output(new BTBResponse) + val brInfo = Output(Vec(PredictWidth, new BranchInfo)) +} + + class BPUStage1 extends XSModule { val io = IO(new Bundle() { val flush = Input(Bool()) val in = new Bundle { val pc = Flipped(ValidIO(UInt(VAddrBits.W))) } - val s1_out = Decoupled(new BranchPrediction) - val out = Decoupled(new BPUStage1To2IO) - val redirect = Flipped(ValidIO(new Redirect)) // used to fix ghr + val pred = Decoupled(new BranchPrediction) + val out = Decoupled(new BPUStageIO) val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) }) @@ -31,29 +56,27 @@ class BPUStage1 extends XSModule { class BPUStage2 extends XSModule { val io = IO(new Bundle() { val flush = Input(Bool()) - val in = Flipped(Decoupled(new BPUStage1To2IO)) - val s2_out = Decoupled(new BranchPrediction) - val out = Decoupled(new BPUStage2To3IO) + val in = Flipped(Decoupled(new BPUStageIO)) + val pred = Decoupled(new BranchPrediction) + val out = Decoupled(new BPUStageIO) val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) // delete this if useless val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) // delete this if useless }) - } class BPUStage3 extends XSModule { val io = IO(new Bundle() { val flush = Input(Bool()) - val in = Flipped(Decoupled(new BPUStage2To3IO)) - val s3_out = Decoupled(new BranchPrediction) + val in = Flipped(Decoupled(new BPUStageIO)) + val pred = Decoupled(new BranchPrediction) val predecode = Flipped(ValidIO(new Predecode)) }) } -class BPU extends XSModule { +class BaseBPU extends XSModule { val io = IO(new Bundle() { // from backend - val redirect = Flipped(ValidIO(new Redirect)) val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) // from ifu, frontend redirect @@ -65,8 +88,19 @@ class BPU extends XSModule { // from if4 val predecode = Flipped(ValidIO(new Predecode)) // to if4, some bpu info used for updating - val branchInfo = Decoupled(new BranchInfo) + val branchInfo = Decoupled(Vec(PredictWidth, new BranchInfo)) }) +} + +class FakeBPU extends BaseBPU { + io.out.foreach(i => { + i <> DontCare + i.redirect := false.B + }) + io.branchInfo <> DontCare +} + +class BPU extends BaseBPU { val s1 = Module(new BPUStage1) val s2 = Module(new BPUStage2) @@ -80,9 +114,9 @@ class BPU extends XSModule { s2.io.in <> s1.io.out s3.io.in <> s2.io.out - io.out(0) <> s1.io.s1_out - io.out(1) <> s2.io.s2_out - io.out(2) <> s3.io.s3_out + io.out(0) <> s1.io.pred + io.out(1) <> s2.io.pred + io.out(2) <> s3.io.pred s1.io.redirect <> io.redirect s1.io.outOfOrderBrInfo <> io.outOfOrderBrInfo @@ -92,3 +126,491 @@ class BPU extends XSModule { s3.io.predecode <> io.predecode } + + + + + + + + + + + +// class BPUStage1 extends XSModule { +// val io = IO(new Bundle() { +// val in = new Bundle { val pc = Flipped(Decoupled(UInt(VAddrBits.W))) } +// // from backend +// val redirectInfo = Input(new RedirectInfo) +// // from Stage3 +// val flush = Input(Bool()) +// val s3RollBackHist = Input(UInt(HistoryLength.W)) +// val s3Taken = Input(Bool()) +// // to ifu, quick prediction result +// val s1OutPred = ValidIO(new BranchPrediction) +// // to Stage2 +// val out = Decoupled(new Stage1To2IO) +// }) + +// io.in.pc.ready := true.B + +// // flush Stage1 when io.flush +// val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true) +// val s1OutPredLatch = RegEnable(io.s1OutPred.bits, RegNext(io.in.pc.fire())) +// val outLatch = RegEnable(io.out.bits, RegNext(io.in.pc.fire())) + +// val s1Valid = RegInit(false.B) +// when (io.flush) { +// s1Valid := true.B +// }.elsewhen (io.in.pc.fire()) { +// s1Valid := true.B +// }.elsewhen (io.out.fire()) { +// s1Valid := false.B +// } +// io.out.valid := s1Valid + + +// // global history register +// val ghr = RegInit(0.U(HistoryLength.W)) +// // modify updateGhr and newGhr when updating ghr +// val updateGhr = WireInit(false.B) +// val newGhr = WireInit(0.U(HistoryLength.W)) +// when (updateGhr) { ghr := newGhr } +// // use hist as global history!!! +// val hist = Mux(updateGhr, newGhr, ghr) + +// // Tage predictor +// val tage = if(EnableBPD) Module(new Tage) else Module(new FakeTAGE) +// tage.io.req.valid := io.in.pc.fire() +// tage.io.req.bits.pc := io.in.pc.bits +// tage.io.req.bits.hist := hist +// tage.io.redirectInfo <> io.redirectInfo +// io.s1OutPred.bits.tageMeta := tage.io.meta + +// // latch pc for 1 cycle latency when reading SRAM +// val pcLatch = RegEnable(io.in.pc.bits, io.in.pc.fire()) +// // TODO: pass real mask in +// // val maskLatch = RegEnable(btb.io.in.mask, io.in.pc.fire()) +// val maskLatch = Fill(PredictWidth, 1.U(1.W)) + +// val r = io.redirectInfo.redirect +// val updateFetchpc = r.pc - (r.fetchIdx << 1.U) +// // BTB +// val btb = Module(new BTB) +// btb.io.in.pc <> io.in.pc +// btb.io.in.pcLatch := pcLatch +// // TODO: pass real mask in +// btb.io.in.mask := Fill(PredictWidth, 1.U(1.W)) +// btb.io.redirectValid := io.redirectInfo.valid +// btb.io.flush := io.flush + +// // btb.io.update.fetchPC := updateFetchpc +// // btb.io.update.fetchIdx := r.fetchIdx +// btb.io.update.pc := r.pc +// btb.io.update.hit := r.btbHit +// btb.io.update.misPred := io.redirectInfo.misPred +// // btb.io.update.writeWay := r.btbVictimWay +// btb.io.update.oldCtr := r.btbPredCtr +// btb.io.update.taken := r.taken +// btb.io.update.target := r.brTarget +// btb.io.update.btbType := r.btbType +// // TODO: add RVC logic +// btb.io.update.isRVC := r.isRVC + +// // val btbHit = btb.io.out.hit +// val btbTaken = btb.io.out.taken +// val btbTakenIdx = btb.io.out.takenIdx +// val btbTakenTarget = btb.io.out.target +// // val btbWriteWay = btb.io.out.writeWay +// val btbNotTakens = btb.io.out.notTakens +// val btbCtrs = VecInit(btb.io.out.dEntries.map(_.pred)) +// val btbValids = btb.io.out.hits +// val btbTargets = VecInit(btb.io.out.dEntries.map(_.target)) +// val btbTypes = VecInit(btb.io.out.dEntries.map(_.btbType)) +// val btbIsRVCs = VecInit(btb.io.out.dEntries.map(_.isRVC)) + + +// val jbtac = Module(new JBTAC) +// jbtac.io.in.pc <> io.in.pc +// jbtac.io.in.pcLatch := pcLatch +// // TODO: pass real mask in +// jbtac.io.in.mask := Fill(PredictWidth, 1.U(1.W)) +// jbtac.io.in.hist := hist +// jbtac.io.redirectValid := io.redirectInfo.valid +// jbtac.io.flush := io.flush + +// jbtac.io.update.fetchPC := updateFetchpc +// jbtac.io.update.fetchIdx := r.fetchIdx +// jbtac.io.update.misPred := io.redirectInfo.misPred +// jbtac.io.update.btbType := r.btbType +// jbtac.io.update.target := r.target +// jbtac.io.update.hist := r.hist +// jbtac.io.update.isRVC := r.isRVC + +// val jbtacHit = jbtac.io.out.hit +// val jbtacTarget = jbtac.io.out.target +// val jbtacHitIdx = jbtac.io.out.hitIdx +// val jbtacIsRVC = jbtac.io.out.isRVC + +// // calculate global history of each instr +// val firstHist = RegNext(hist) +// val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) +// val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) +// (0 until PredictWidth).foreach(i => shift(i) := Mux(!btbNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) +// for (j <- 0 until PredictWidth) { +// var tmp = 0.U +// for (i <- 0 until PredictWidth) { +// tmp = tmp + shift(i)(j) +// } +// histShift(j) := tmp +// } + +// // update ghr +// updateGhr := io.s1OutPred.bits.redirect || +// RegNext(io.in.pc.fire) && ~io.s1OutPred.bits.redirect && (btbNotTakens.asUInt & maskLatch).orR || // TODO: use parallel or +// io.flush +// val brJumpIdx = Mux(!btbTaken, 0.U, UIntToOH(btbTakenIdx)) +// val indirectIdx = Mux(!jbtacHit, 0.U, UIntToOH(jbtacHitIdx)) +// // if backend redirects, restore history from backend; +// // if stage3 redirects, restore history from stage3; +// // if stage1 redirects, speculatively update history; +// // if none of above happens, check if stage1 has not-taken branches and shift zeroes accordingly +// newGhr := Mux(io.redirectInfo.flush(), (r.hist << 1.U) | !(r.btbType === BTBtype.B && !r.taken), +// Mux(io.flush, Mux(io.s3Taken, (io.s3RollBackHist << 1.U) | 1.U, io.s3RollBackHist), +// Mux(io.s1OutPred.bits.redirect, (PriorityMux(brJumpIdx | indirectIdx, io.s1OutPred.bits.hist) << 1.U | 1.U), +// io.s1OutPred.bits.hist(0) << PopCount(btbNotTakens.asUInt & maskLatch)))) + +// def getInstrValid(i: Int): UInt = { +// val vec = Wire(Vec(PredictWidth, UInt(1.W))) +// for (j <- 0 until PredictWidth) { +// if (j <= i) +// vec(j) := 1.U +// else +// vec(j) := 0.U +// } +// vec.asUInt +// } + +// // redirect based on BTB and JBTAC +// val takenIdx = LowestBit(brJumpIdx | indirectIdx, PredictWidth) + +// // io.out.valid := RegNext(io.in.pc.fire()) && !io.flush + +// // io.s1OutPred.valid := io.out.valid +// io.s1OutPred.valid := io.out.fire() +// when (RegNext(io.in.pc.fire())) { +// io.s1OutPred.bits.redirect := btbTaken || jbtacHit +// // io.s1OutPred.bits.instrValid := (maskLatch & Fill(PredictWidth, ~io.s1OutPred.bits.redirect || io.s1OutPred.bits.lateJump) | +// // PriorityMux(brJumpIdx | indirectIdx, (0 until PredictWidth).map(getInstrValid(_)))).asTypeOf(Vec(PredictWidth, Bool())) +// io.s1OutPred.bits.instrValid := (maskLatch & Fill(PredictWidth, ~io.s1OutPred.bits.redirect) | +// PriorityMux(brJumpIdx | indirectIdx, (0 until PredictWidth).map(getInstrValid(_)))).asTypeOf(Vec(PredictWidth, Bool())) +// for (i <- 0 until (PredictWidth - 1)) { +// when (!io.s1OutPred.bits.lateJump && (1.U << i) === takenIdx && (!btbIsRVCs(i) && btbValids(i) || !jbtacIsRVC && (1.U << i) === indirectIdx)) { +// io.s1OutPred.bits.instrValid(i+1) := maskLatch(i+1) +// } +// } +// io.s1OutPred.bits.target := Mux(takenIdx === 0.U, pcLatch + (PopCount(maskLatch) << 1.U), Mux(takenIdx === brJumpIdx, btbTakenTarget, jbtacTarget)) +// io.s1OutPred.bits.lateJump := btb.io.out.isRVILateJump || jbtac.io.out.isRVILateJump +// (0 until PredictWidth).map(i => io.s1OutPred.bits.hist(i) := firstHist << histShift(i)) +// // io.s1OutPred.bits.btbVictimWay := btbWriteWay +// io.s1OutPred.bits.predCtr := btbCtrs +// io.s1OutPred.bits.btbHit := btbValids +// io.s1OutPred.bits.tageMeta := tage.io.meta // TODO: enableBPD +// io.s1OutPred.bits.rasSp := DontCare +// io.s1OutPred.bits.rasTopCtr := DontCare +// }.otherwise { +// io.s1OutPred.bits := s1OutPredLatch +// } + +// when (RegNext(io.in.pc.fire())) { +// io.out.bits.pc := pcLatch +// io.out.bits.btb.hits := btbValids.asUInt +// (0 until PredictWidth).map(i => io.out.bits.btb.targets(i) := btbTargets(i)) +// io.out.bits.jbtac.hitIdx := Mux(jbtacHit, UIntToOH(jbtacHitIdx), 0.U) // UIntToOH(jbtacHitIdx) +// io.out.bits.jbtac.target := jbtacTarget +// io.out.bits.tage <> tage.io.out +// // TODO: we don't need this repeatedly! +// io.out.bits.hist := io.s1OutPred.bits.hist +// io.out.bits.btbPred := io.s1OutPred +// }.otherwise { +// io.out.bits := outLatch +// } + + +// // debug info +// XSDebug("in:(%d %d) pc=%x ghr=%b\n", io.in.pc.valid, io.in.pc.ready, io.in.pc.bits, hist) +// XSDebug("outPred:(%d) pc=0x%x, redirect=%d instrValid=%b tgt=%x\n", +// io.s1OutPred.valid, pcLatch, io.s1OutPred.bits.redirect, io.s1OutPred.bits.instrValid.asUInt, io.s1OutPred.bits.target) +// XSDebug(io.flush && io.redirectInfo.flush(), +// "flush from backend: pc=%x tgt=%x brTgt=%x btbType=%b taken=%d oldHist=%b fetchIdx=%d isExcpt=%d\n", +// r.pc, r.target, r.brTarget, r.btbType, r.taken, r.hist, r.fetchIdx, r.isException) +// XSDebug(io.flush && !io.redirectInfo.flush(), +// "flush from Stage3: s3Taken=%d s3RollBackHist=%b\n", io.s3Taken, io.s3RollBackHist) + +// } + +// class Stage2To3IO extends Stage1To2IO { +// } + +// class BPUStage2 extends XSModule { +// val io = IO(new Bundle() { +// // flush from Stage3 +// val flush = Input(Bool()) +// val in = Flipped(Decoupled(new Stage1To2IO)) +// val out = Decoupled(new Stage2To3IO) +// }) + +// // flush Stage2 when Stage3 or banckend redirects +// val flushS2 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) +// val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) +// when (io.in.fire()) { inLatch := io.in.bits } +// val validLatch = RegInit(false.B) +// when (io.flush) { +// validLatch := false.B +// }.elsewhen (io.in.fire()) { +// validLatch := true.B +// }.elsewhen (io.out.fire()) { +// validLatch := false.B +// } + +// io.out.valid := !io.flush && !flushS2 && validLatch +// io.in.ready := !validLatch || io.out.fire() + +// // do nothing +// io.out.bits := inLatch + +// // debug info +// XSDebug("in:(%d %d) pc=%x out:(%d %d) pc=%x\n", +// io.in.valid, io.in.ready, io.in.bits.pc, io.out.valid, io.out.ready, io.out.bits.pc) +// XSDebug("validLatch=%d pc=%x\n", validLatch, inLatch.pc) +// XSDebug(io.flush, "flush!!!\n") +// } + +// class BPUStage3 extends XSModule { +// val io = IO(new Bundle() { +// val flush = Input(Bool()) +// val in = Flipped(Decoupled(new Stage2To3IO)) +// val out = Decoupled(new BranchPrediction) +// // from icache +// val predecode = Flipped(ValidIO(new Predecode)) +// // from backend +// val redirectInfo = Input(new RedirectInfo) +// // to Stage1 and Stage2 +// val flushBPU = Output(Bool()) +// // to Stage1, restore ghr in stage1 when flushBPU is valid +// val s1RollBackHist = Output(UInt(HistoryLength.W)) +// val s3Taken = Output(Bool()) +// }) + +// val flushS3 = BoolStopWatch(io.flush, io.in.fire(), startHighPriority = true) +// val inLatch = RegInit(0.U.asTypeOf(io.in.bits)) +// val validLatch = RegInit(false.B) +// val predecodeLatch = RegInit(0.U.asTypeOf(io.predecode.bits)) +// val predecodeValidLatch = RegInit(false.B) +// when (io.in.fire()) { inLatch := io.in.bits } +// when (io.flush) { +// validLatch := false.B +// }.elsewhen (io.in.fire()) { +// validLatch := true.B +// }.elsewhen (io.out.fire()) { +// validLatch := false.B +// } + +// when (io.predecode.valid) { predecodeLatch := io.predecode.bits } +// when (io.flush || io.out.fire()) { +// predecodeValidLatch := false.B +// }.elsewhen (io.predecode.valid) { +// predecodeValidLatch := true.B +// } + +// val predecodeValid = io.predecode.valid || predecodeValidLatch +// val predecode = Mux(io.predecode.valid, io.predecode.bits, predecodeLatch) +// io.out.valid := validLatch && predecodeValid && !flushS3 && !io.flush +// io.in.ready := !validLatch || io.out.fire() + +// // RAS +// // TODO: split retAddr and ctr +// def rasEntry() = new Bundle { +// val retAddr = UInt(VAddrBits.W) +// val ctr = UInt(8.W) // layer of nested call functions +// } +// val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry())))) +// val sp = Counter(RasSize) +// val rasTop = ras(sp.value) +// val rasTopAddr = rasTop.retAddr + +// // get the first taken branch/jal/call/jalr/ret in a fetch line +// // brNotTakenIdx indicates all the not-taken branches before the first jump instruction + +// val tageHits = inLatch.tage.hits +// val tageTakens = inLatch.tage.takens +// val btbTakens = inLatch.btbPred.bits.predCtr + +// val brs = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => ALUOpType.isBranch(t) }).asUInt) & predecode.mask +// // val brTakens = brs & inLatch.tage.takens.asUInt +// val brTakens = if (EnableBPD) { +// // If tage hits, use tage takens, otherwise keep btbpreds +// // brs & Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt) +// XSDebug("tageHits=%b, tageTakens=%b\n", tageHits, tageTakens.asUInt) +// brs & Reverse(Cat((0 until PredictWidth).map(i => Mux(tageHits(i), tageTakens(i), btbTakens(i)(1))))) +// } else { +// brs & Reverse(Cat(inLatch.btbPred.bits.predCtr.map {c => c(1)}).asUInt) +// } +// val jals = inLatch.btb.hits & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jal }).asUInt) & predecode.mask +// val calls = inLatch.btb.hits & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.call }).asUInt) +// val jalrs = inLatch.jbtac.hitIdx & predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.jalr }).asUInt) +// val rets = predecode.mask & Reverse(Cat(predecode.fuOpTypes.map { t => t === JumpOpType.ret }).asUInt) + +// val brTakenIdx = PriorityMux(brTakens, (0 until PredictWidth).map(_.U)) +// val jalIdx = PriorityMux(jals, (0 until PredictWidth).map(_.U)) +// val callIdx = PriorityMux(calls, (0 until PredictWidth).map(_.U)) +// val jalrIdx = PriorityMux(jalrs, (0 until PredictWidth).map(_.U)) +// val retIdx = PriorityMux(rets, (0 until PredictWidth).map(_.U)) + +// val jmps = (if (EnableRAS) {brTakens | jals | calls | jalrs | rets} else {brTakens | jals | calls | jalrs}) +// val jmpIdx = MuxCase(0.U, (0 until PredictWidth).map(i => (jmps(i), i.U))) +// io.s3Taken := MuxCase(false.B, (0 until PredictWidth).map(i => (jmps(i), true.B))) + +// // val brNotTakens = VecInit((0 until PredictWidth).map(i => brs(i) && ~inLatch.tage.takens(i) && i.U <= jmpIdx && io.predecode.bits.mask(i))) +// val brNotTakens = if (EnableBPD) { +// VecInit((0 until PredictWidth).map(i => brs(i) && i.U <= jmpIdx && Mux(tageHits(i), ~tageTakens(i), ~btbTakens(i)(1)) && predecode.mask(i))) +// } else { +// VecInit((0 until PredictWidth).map(i => brs(i) && i.U <= jmpIdx && ~inLatch.btbPred.bits.predCtr(i)(1) && predecode.mask(i))) +// } + +// // TODO: what if if4 and if2 late jump to the same target? +// // val lateJump = io.s3Taken && PriorityMux(Reverse(predecode.mask), ((PredictWidth - 1) to 0).map(_.U)) === jmpIdx && !predecode.isRVC(jmpIdx) +// val lateJump = io.s3Taken && PriorityMux(Reverse(predecode.mask), (0 until PredictWidth).map {i => (PredictWidth - 1 - i).U}) === jmpIdx && !predecode.isRVC(jmpIdx) +// io.out.bits.lateJump := lateJump + +// io.out.bits.predCtr := inLatch.btbPred.bits.predCtr +// io.out.bits.btbHit := inLatch.btbPred.bits.btbHit +// io.out.bits.tageMeta := inLatch.btbPred.bits.tageMeta +// //io.out.bits.btbType := Mux(jmpIdx === retIdx, BTBtype.R, +// // Mux(jmpIdx === jalrIdx, BTBtype.I, +// // Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J))) +// val firstHist = inLatch.btbPred.bits.hist(0) +// // there may be several notTaken branches before the first jump instruction, +// // so we need to calculate how many zeroes should each instruction shift in its global history. +// // each history is exclusive of instruction's own jump direction. +// val histShift = Wire(Vec(PredictWidth, UInt(log2Up(PredictWidth).W))) +// val shift = Wire(Vec(PredictWidth, Vec(PredictWidth, UInt(1.W)))) +// (0 until PredictWidth).foreach(i => shift(i) := Mux(!brNotTakens(i), 0.U, ~LowerMask(UIntToOH(i.U), PredictWidth)).asTypeOf(Vec(PredictWidth, UInt(1.W)))) +// for (j <- 0 until PredictWidth) { +// var tmp = 0.U +// for (i <- 0 until PredictWidth) { +// tmp = tmp + shift(i)(j) +// } +// histShift(j) := tmp +// } +// (0 until PredictWidth).foreach(i => io.out.bits.hist(i) := firstHist << histShift(i)) +// // save ras checkpoint info +// io.out.bits.rasSp := sp.value +// io.out.bits.rasTopCtr := rasTop.ctr + +// // flush BPU and redirect when target differs from the target predicted in Stage1 +// val tToNt = inLatch.btbPred.bits.redirect && ~io.s3Taken +// val ntToT = ~inLatch.btbPred.bits.redirect && io.s3Taken +// val dirDiffers = tToNt || ntToT +// val tgtDiffers = inLatch.btbPred.bits.redirect && io.s3Taken && io.out.bits.target =/= inLatch.btbPred.bits.target +// // io.out.bits.redirect := (if (EnableBPD) {dirDiffers || tgtDiffers} else false.B) +// io.out.bits.redirect := dirDiffers || tgtDiffers +// io.out.bits.target := Mux(!io.s3Taken, inLatch.pc + (PopCount(predecode.mask) << 1.U), // TODO: RVC +// Mux(jmpIdx === retIdx, rasTopAddr, +// Mux(jmpIdx === jalrIdx, inLatch.jbtac.target, +// inLatch.btb.targets(jmpIdx)))) +// // for (i <- 0 until FetchWidth) { +// // io.out.bits.instrValid(i) := ((io.s3Taken && i.U <= jmpIdx) || ~io.s3Taken) && io.predecode.bits.mask(i) +// // } +// io.out.bits.instrValid := predecode.mask.asTypeOf(Vec(PredictWidth, Bool())) +// for (i <- PredictWidth - 1 to 0) { +// io.out.bits.instrValid(i) := (io.s3Taken && i.U <= jmpIdx || !io.s3Taken) && predecode.mask(i) +// if (i != (PredictWidth - 1)) { +// when (!lateJump && !predecode.isRVC(i) && io.s3Taken && i.U <= jmpIdx) { +// io.out.bits.instrValid(i+1) := predecode.mask(i+1) +// } +// } +// } +// io.flushBPU := io.out.bits.redirect && io.out.fire() + +// // speculative update RAS +// val rasWrite = WireInit(0.U.asTypeOf(rasEntry())) +// val retAddr = inLatch.pc + (callIdx << 1.U) + Mux(predecode.isRVC(callIdx), 2.U, 4.U) +// rasWrite.retAddr := retAddr +// val allocNewEntry = rasWrite.retAddr =/= rasTopAddr +// rasWrite.ctr := Mux(allocNewEntry, 1.U, rasTop.ctr + 1.U) +// val rasWritePosition = Mux(allocNewEntry, sp.value + 1.U, sp.value) +// when (io.out.fire() && io.s3Taken) { +// when (jmpIdx === callIdx) { +// ras(rasWritePosition) := rasWrite +// when (allocNewEntry) { sp.value := sp.value + 1.U } +// }.elsewhen (jmpIdx === retIdx) { +// when (rasTop.ctr === 1.U) { +// sp.value := Mux(sp.value === 0.U, 0.U, sp.value - 1.U) +// }.otherwise { +// ras(sp.value) := Cat(rasTop.ctr - 1.U, rasTopAddr).asTypeOf(rasEntry()) +// } +// } +// } +// // use checkpoint to recover RAS +// val recoverSp = io.redirectInfo.redirect.rasSp +// val recoverCtr = io.redirectInfo.redirect.rasTopCtr +// when (io.redirectInfo.flush()) { +// sp.value := recoverSp +// ras(recoverSp) := Cat(recoverCtr, ras(recoverSp).retAddr).asTypeOf(rasEntry()) +// } + +// // roll back global history in S1 if S3 redirects +// io.s1RollBackHist := Mux(io.s3Taken, io.out.bits.hist(jmpIdx), +// io.out.bits.hist(0) << PopCount(brs & predecode.mask & ~Reverse(Cat(inLatch.tage.takens.map {t => Fill(2, t.asUInt)}).asUInt))) + +// // debug info +// XSDebug(io.in.fire(), "in:(%d %d) pc=%x\n", io.in.valid, io.in.ready, io.in.bits.pc) +// XSDebug(io.out.fire(), "out:(%d %d) pc=%x redirect=%d predcdMask=%b instrValid=%b tgt=%x\n", +// io.out.valid, io.out.ready, inLatch.pc, io.out.bits.redirect, predecode.mask, io.out.bits.instrValid.asUInt, io.out.bits.target) +// XSDebug("flushS3=%d\n", flushS3) +// XSDebug("validLatch=%d predecode.valid=%d\n", validLatch, predecodeValid) +// XSDebug("brs=%b brTakens=%b brNTakens=%b jals=%b jalrs=%b calls=%b rets=%b\n", +// brs, brTakens, brNotTakens.asUInt, jals, jalrs, calls, rets) +// // ?????condition is wrong +// // XSDebug(io.in.fire() && callIdx.orR, "[RAS]:pc=0x%x, rasWritePosition=%d, rasWriteAddr=0x%x\n", +// // io.in.bits.pc, rasWritePosition, retAddr) +// } + +// class BPU extends XSModule { +// val io = IO(new Bundle() { +// // from backend +// // flush pipeline if misPred and update bpu based on redirect signals from brq +// val redirectInfo = Input(new RedirectInfo) + +// val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } + +// val btbOut = ValidIO(new BranchPrediction) +// val tageOut = Decoupled(new BranchPrediction) + +// // predecode info from icache +// // TODO: simplify this after implement predecode unit +// val predecode = Flipped(ValidIO(new Predecode)) +// }) + +// val s1 = Module(new BPUStage1) +// val s2 = Module(new BPUStage2) +// val s3 = Module(new BPUStage3) + +// s1.io.redirectInfo <> io.redirectInfo +// s1.io.flush := s3.io.flushBPU || io.redirectInfo.flush() +// s1.io.in.pc.valid := io.in.pc.valid +// s1.io.in.pc.bits <> io.in.pc.bits +// io.btbOut <> s1.io.s1OutPred +// s1.io.s3RollBackHist := s3.io.s1RollBackHist +// s1.io.s3Taken := s3.io.s3Taken + +// s1.io.out <> s2.io.in +// s2.io.flush := s3.io.flushBPU || io.redirectInfo.flush() + +// s2.io.out <> s3.io.in +// s3.io.flush := io.redirectInfo.flush() +// s3.io.predecode <> io.predecode +// io.tageOut <> s3.io.out +// s3.io.redirectInfo <> io.redirectInfo +// } diff --git a/src/main/scala/xiangshan/frontend/IFU.scala b/src/main/scala/xiangshan/frontend/IFU.scala index 4f8fcb2e7..9d2037387 100644 --- a/src/main/scala/xiangshan/frontend/IFU.scala +++ b/src/main/scala/xiangshan/frontend/IFU.scala @@ -24,28 +24,6 @@ class IFUIO extends XSBundle val icacheResp = Flipped(DecoupledIO(new FakeIcacheResp)) } -class BaseBPU extends XSModule { - val io = IO(new Bundle() { - val redirect = Flipped(ValidIO(new Redirect)) - val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) - val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo)) - val in = new Bundle { val pc = Flipped(Valid(UInt(VAddrBits.W))) } - val btbOut = ValidIO(new BranchPrediction) - val tageOut = Decoupled(new BranchPrediction) - val predecode = Flipped(ValidIO(new Predecode)) - }) -} - -class FakeBPU extends BaseBPU { - - io.btbOut.valid := false.B - io.btbOut.bits <> DontCare - io.btbOut.bits.redirect := false.B - io.btbOut.bits.target := DontCare - io.tageOut.valid := false.B - io.tageOut.bits <> DontCare -} - class IFU extends XSModule with HasIFUConst { diff --git a/src/main/scala/xiangshan/frontend/Tage.scala b/src/main/scala/xiangshan/frontend/Tage.scala index a3d9c608b..8b263cfaa 100644 --- a/src/main/scala/xiangshan/frontend/Tage.scala +++ b/src/main/scala/xiangshan/frontend/Tage.scala @@ -1,386 +1,386 @@ -//package xiangshan.frontend -// -//import chisel3._ -//import chisel3.util._ -//import xiangshan._ -//import utils._ -// -//import scala.math.min -// -//trait HasTageParameter { -// // Sets Hist Tag -// val TableInfo = Seq(( 128, 2, 7), -// ( 128, 4, 7), -// ( 256, 8, 8), -// ( 256, 16, 8), -// ( 128, 32, 9), -// ( 128, 64, 9)) -// val TageNTables = TableInfo.size -// val UBitPeriod = 2048 -// val BankWidth = 16 // FetchWidth -// -// val TotalBits = TableInfo.map { -// case (s, h, t) => { -// s * (1+t+3) * BankWidth -// } -// }.reduce(_+_) -//} -// -//abstract class TageBundle extends XSBundle with HasTageParameter -//abstract class TageModule extends XSModule with HasTageParameter -// -//class TageReq extends TageBundle { -// val pc = UInt(VAddrBits.W) -// val hist = UInt(HistoryLength.W) -//} -// -//class TageResp extends TageBundle { -// val ctr = UInt(3.W) -// val u = UInt(2.W) -//} -// -//class TageUpdate extends TageBundle { -// val pc = UInt(VAddrBits.W) -// val hist = UInt(HistoryLength.W) -// // update tag and ctr -// val mask = Vec(BankWidth, Bool()) -// val taken = Vec(BankWidth, Bool()) -// val alloc = Vec(BankWidth, Bool()) -// val oldCtr = Vec(BankWidth, UInt(3.W)) -// // update u -// val uMask = Vec(BankWidth, Bool()) -// val u = Vec(BankWidth, UInt(2.W)) -//} -// -//class FakeTageTable() extends TageModule { -// val io = IO(new Bundle() { -// val req = Input(Valid(new TageReq)) -// val resp = Output(Vec(BankWidth, Valid(new TageResp))) -// val update = Input(new TageUpdate) -// }) -// io.resp := DontCare -// -//} -// -//class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule { -// val io = IO(new Bundle() { -// val req = Input(Valid(new TageReq)) -// val resp = Output(Vec(BankWidth, Valid(new TageResp))) -// val update = Input(new TageUpdate) -// }) -// -// // bypass entries for tage update -// val wrBypassEntries = 8 -// -// def compute_folded_hist(hist: UInt, l: Int) = { -// val nChunks = (histLen + l - 1) / l -// val hist_chunks = (0 until nChunks) map {i => -// hist(min((i+1)*l, histLen)-1, i*l) -// } -// hist_chunks.reduce(_^_) -// } -// -// def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt) = { -// val idx_history = compute_folded_hist(hist, log2Ceil(nRows)) -// val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows)-1,0) -// val tag_history = compute_folded_hist(hist, tagLen) -// // Use another part of pc to make tags -// val tag = ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history)(tagLen-1,0) -// (idx, tag) -// } -// -// def inc_ctr(ctr: UInt, taken: Bool): UInt = { -// Mux(!taken, Mux(ctr === 0.U, 0.U, ctr - 1.U), -// Mux(ctr === 7.U, 7.U, ctr + 1.U)) -// } -// -// val doing_reset = RegInit(true.B) -// val reset_idx = RegInit(0.U(log2Ceil(nRows).W)) -// reset_idx := reset_idx + doing_reset -// when (reset_idx === (nRows-1).U) { doing_reset := false.B } -// -// class TageEntry() extends TageBundle { -// val valid = Bool() -// val tag = UInt(tagLen.W) -// val ctr = UInt(3.W) -// } -// -// val tageEntrySz = 1 + tagLen + 3 -// -// val (hashed_idx, tag) = compute_tag_and_hash(io.req.bits.pc, io.req.bits.hist) -// -// val hi_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false))) -// val lo_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false))) -// val table = List.fill(BankWidth)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false))) -// -// val hi_us_r = Wire(Vec(BankWidth, Bool())) -// val lo_us_r = Wire(Vec(BankWidth, Bool())) -// val table_r = Wire(Vec(BankWidth, new TageEntry)) -// -// (0 until BankWidth).map( -// b => { -// hi_us(b).reset := reset.asBool -// lo_us(b).reset := reset.asBool -// table(b).reset := reset.asBool -// hi_us(b).io.r.req.valid := io.req.valid -// lo_us(b).io.r.req.valid := io.req.valid -// table(b).io.r.req.valid := io.req.valid -// hi_us(b).io.r.req.bits.setIdx := hashed_idx -// lo_us(b).io.r.req.bits.setIdx := hashed_idx -// table(b).io.r.req.bits.setIdx := hashed_idx -// -// hi_us_r(b) := hi_us(b).io.r.resp.data(0) -// lo_us_r(b) := lo_us(b).io.r.resp.data(0) -// table_r(b) := table(b).io.r.resp.data(0) -// -// // io.resp(b).valid := table_r(b).valid && table_r(b).tag === tag // Missing reset logic -// // io.resp(b).bits.ctr := table_r(b).ctr -// // io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b)) -// } -// ) -// -// val req_rhits = VecInit(table_r.map(e => e.valid && e.tag === tag && !doing_reset)) -// -// (0 until BankWidth).map(b => { -// io.resp(b).valid := req_rhits(b) -// io.resp(b).bits.ctr := table_r(b).ctr -// io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b)) -// }) -// -// -// val clear_u_ctr = RegInit(0.U((log2Ceil(uBitPeriod) + log2Ceil(nRows) + 1).W)) -// when (doing_reset) { clear_u_ctr := 1.U } .otherwise { clear_u_ctr := clear_u_ctr + 1.U } -// -// val doing_clear_u = clear_u_ctr(log2Ceil(uBitPeriod)-1,0) === 0.U -// val doing_clear_u_hi = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 1.U -// val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U -// val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod) -// -// val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc, io.update.hist) -// -// val update_wdata = Wire(Vec(BankWidth, new TageEntry)) -// -// -// (0 until BankWidth).map(b => { -// table(b).io.w.req.valid := io.update.mask(b) || doing_reset -// table(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx) -// table(b).io.w.req.bits.data := Mux(doing_reset, 0.U.asTypeOf(new TageEntry), update_wdata(b)) -// }) -// -// val update_hi_wdata = Wire(Vec(BankWidth, Bool())) -// (0 until BankWidth).map(b => { -// hi_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_hi -// hi_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_hi, clear_u_idx, update_idx)) -// hi_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_hi, 0.U, update_hi_wdata(b)) -// }) -// -// val update_lo_wdata = Wire(Vec(BankWidth, Bool())) -// (0 until BankWidth).map(b => { -// lo_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_lo -// lo_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_lo, clear_u_idx, update_idx)) -// lo_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_lo, 0.U, update_lo_wdata(b)) -// }) -// -// val wrbypass_tags = Reg(Vec(wrBypassEntries, UInt(tagLen.W))) -// val wrbypass_idxs = Reg(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W))) -// val wrbypass = Reg(Vec(wrBypassEntries, Vec(BankWidth, UInt(3.W)))) -// val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W)) -// -// val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i => -// !doing_reset && -// wrbypass_tags(i) === update_tag && -// wrbypass_idxs(i) === update_idx -// }) -// val wrbypass_hit = wrbypass_hits.reduce(_||_) -// val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits) -// -// for (w <- 0 until BankWidth) { -// update_wdata(w).ctr := Mux(io.update.alloc(w), -// Mux(io.update.taken(w), 4.U, -// 3.U -// ), -// Mux(wrbypass_hit, inc_ctr(wrbypass(wrbypass_hit_idx)(w), io.update.taken(w)), -// inc_ctr(io.update.oldCtr(w), io.update.taken(w)) -// ) -// ) -// update_wdata(w).valid := true.B -// update_wdata(w).tag := update_tag -// -// update_hi_wdata(w) := io.update.u(w)(1) -// update_lo_wdata(w) := io.update.u(w)(0) -// } -// -// when (io.update.mask.reduce(_||_)) { -// when (wrbypass_hits.reduce(_||_)) { -// wrbypass(wrbypass_hit_idx) := VecInit(update_wdata.map(_.ctr)) -// } .otherwise { -// wrbypass (wrbypass_enq_idx) := VecInit(update_wdata.map(_.ctr)) -// wrbypass_tags(wrbypass_enq_idx) := update_tag -// wrbypass_idxs(wrbypass_enq_idx) := update_idx -// wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0) -// } -// } -// XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%b, idx=%d, tag=%x\n", io.req.bits.pc, io.req.bits.hist, hashed_idx, tag) -// for (i <- 0 until BankWidth) { -// XSDebug(RegNext(io.req.valid), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, RegNext(hashed_idx), req_rhits(i), table_r(i).ctr, Cat(hi_us_r(i),lo_us_r(i)).asUInt) -// } -// -//} -// -//class FakeTAGE extends TageModule { -// val io = IO(new Bundle() { -// val req = Input(Valid(new TageReq)) -// val out = new Bundle { -// val hits = Output(UInt(BankWidth.W)) -// val takens = Output(Vec(BankWidth, Bool())) -// } -// val meta = Output(Vec(BankWidth, (new TageMeta))) -// val redirectInfo = Input(new RedirectInfo) -// }) -// -// io.out.hits := 0.U(BankWidth.W) -// io.out.takens := DontCare -// io.meta := DontCare -//} -// -// -//class Tage extends TageModule { -// val io = IO(new Bundle() { -// val req = Input(Valid(new TageReq)) -// val out = new Bundle { -// val hits = Output(UInt(BankWidth.W)) -// val takens = Output(Vec(BankWidth, Bool())) -// } -// val meta = Output(Vec(BankWidth, (new TageMeta))) -// val redirectInfo = Input(new RedirectInfo) -// }) -// -// val tables = TableInfo.map { -// case (nRows, histLen, tagLen) => { -// val t = if(EnableBPD) Module(new TageTable(nRows, histLen, tagLen, UBitPeriod)) else Module(new FakeTageTable) -// t.io.req <> io.req -// t -// } -// } -// val resps = VecInit(tables.map(_.io.resp)) -// -// val updateMeta = io.redirectInfo.redirect.tageMeta -// //val updateMisPred = UIntToOH(io.redirectInfo.redirect.fetchIdx) & -// // Fill(BankWidth, (io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B).asUInt) -// val updateMisPred = io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B -// -// val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool())))) -// val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool())))) -// val updateTaken = Wire(Vec(TageNTables, Vec(BankWidth, Bool()))) -// val updateAlloc = Wire(Vec(TageNTables, Vec(BankWidth, Bool()))) -// val updateOldCtr = Wire(Vec(TageNTables, Vec(BankWidth, UInt(3.W)))) -// val updateU = Wire(Vec(TageNTables, Vec(BankWidth, UInt(2.W)))) -// updateTaken := DontCare -// updateAlloc := DontCare -// updateOldCtr := DontCare -// updateU := DontCare -// -// // access tag tables and output meta info -// val outHits = Wire(Vec(BankWidth, Bool())) -// for (w <- 0 until BankWidth) { -// var altPred = false.B -// val finalAltPred = WireInit(false.B) -// var provided = false.B -// var provider = 0.U -// outHits(w) := false.B -// io.out.takens(w) := false.B -// -// for (i <- 0 until TageNTables) { -// val hit = resps(i)(w).valid -// val ctr = resps(i)(w).bits.ctr -// when (hit) { -// io.out.takens(w) := Mux(ctr === 3.U || ctr === 4.U, altPred, ctr(2)) // Use altpred on weak taken -// finalAltPred := altPred -// } -// provided = provided || hit // Once hit then provide -// provider = Mux(hit, i.U, provider) // Use the last hit as provider -// altPred = Mux(hit, ctr(2), altPred) // Save current pred as potential altpred -// } -// outHits(w) := provided -// io.meta(w).provider.valid := provided -// io.meta(w).provider.bits := provider -// io.meta(w).altDiffers := finalAltPred =/= io.out.takens(w) -// io.meta(w).providerU := resps(provider)(w).bits.u -// io.meta(w).providerCtr := resps(provider)(w).bits.ctr -// -// // Create a mask fo tables which did not hit our query, and also contain useless entries -// // and also uses a longer history than the provider -// val allocatableSlots = (VecInit(resps.map(r => !r(w).valid && r(w).bits.u === 0.U)).asUInt & -// ~(LowerMask(UIntToOH(provider), TageNTables) & Fill(TageNTables, provided.asUInt)) -// ) -// val allocLFSR = LFSR64()(TageNTables - 1, 0) -// val firstEntry = PriorityEncoder(allocatableSlots) -// val maskedEntry = PriorityEncoder(allocatableSlots & allocLFSR) -// val allocEntry = Mux(allocatableSlots(maskedEntry), maskedEntry, firstEntry) -// io.meta(w).allocate.valid := allocatableSlots =/= 0.U -// io.meta(w).allocate.bits := allocEntry -// -// val isUpdateTaken = io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U && -// io.redirectInfo.redirect.taken && io.redirectInfo.redirect.btbType === BTBtype.B -// when (io.redirectInfo.redirect.btbType === BTBtype.B && io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U) { -// when (updateMeta.provider.valid) { -// val provider = updateMeta.provider.bits -// -// updateMask(provider)(w) := true.B -// updateUMask(provider)(w) := true.B -// -// updateU(provider)(w) := Mux(!updateMeta.altDiffers, updateMeta.providerU, -// Mux(updateMisPred, Mux(updateMeta.providerU === 0.U, 0.U, updateMeta.providerU - 1.U), -// Mux(updateMeta.providerU === 3.U, 3.U, updateMeta.providerU + 1.U)) -// ) -// updateTaken(provider)(w) := isUpdateTaken -// updateOldCtr(provider)(w) := updateMeta.providerCtr -// updateAlloc(provider)(w) := false.B -// } -// } -// } -// -// when (io.redirectInfo.valid && updateMisPred) { -// val idx = io.redirectInfo.redirect.fetchIdx -// val allocate = updateMeta.allocate -// when (allocate.valid) { -// updateMask(allocate.bits)(idx) := true.B -// updateTaken(allocate.bits)(idx) := io.redirectInfo.redirect.taken -// updateAlloc(allocate.bits)(idx) := true.B -// updateUMask(allocate.bits)(idx) := true.B -// updateU(allocate.bits)(idx) := 0.U -// }.otherwise { -// val provider = updateMeta.provider -// val decrMask = Mux(provider.valid, ~LowerMask(UIntToOH(provider.bits), TageNTables), 0.U) -// for (i <- 0 until TageNTables) { -// when (decrMask(i)) { -// updateUMask(i)(idx) := true.B -// updateU(i)(idx) := 0.U -// } -// } -// } -// } -// -// for (i <- 0 until TageNTables) { -// for (w <- 0 until BankWidth) { -// tables(i).io.update.mask(w) := updateMask(i)(w) -// tables(i).io.update.taken(w) := updateTaken(i)(w) -// tables(i).io.update.alloc(w) := updateAlloc(i)(w) -// tables(i).io.update.oldCtr(w) := updateOldCtr(i)(w) -// -// tables(i).io.update.uMask(w) := updateUMask(i)(w) -// tables(i).io.update.u(w) := updateU(i)(w) -// } -// // use fetch pc instead of instruction pc -// tables(i).io.update.pc := io.redirectInfo.redirect.pc - (io.redirectInfo.redirect.fetchIdx << 1.U) -// tables(i).io.update.hist := io.redirectInfo.redirect.hist -// } -// -// io.out.hits := outHits.asUInt -// -// -// val m = updateMeta -// XSDebug(io.req.valid, "req: pc=0x%x, hist=%b\n", io.req.bits.pc, io.req.bits.hist) -// XSDebug(io.redirectInfo.valid, "redirect: provider(%d):%d, altDiffers:%d, providerU:%d, providerCtr:%d, allocate(%d):%d\n", m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits) -// XSDebug(RegNext(io.req.valid), "resp: pc=%x, outHits=%b, takens=%b\n", RegNext(io.req.bits.pc), io.out.hits, io.out.takens.asUInt) -//} \ No newline at end of file +package xiangshan.frontend + +import chisel3._ +import chisel3.util._ +import xiangshan._ +import utils._ + +import scala.math.min + +trait HasTageParameter { + // Sets Hist Tag + val TableInfo = Seq(( 128, 2, 7), + ( 128, 4, 7), + ( 256, 8, 8), + ( 256, 16, 8), + ( 128, 32, 9), + ( 128, 64, 9)) + val TageNTables = TableInfo.size + val UBitPeriod = 2048 + val BankWidth = 16 // FetchWidth + + val TotalBits = TableInfo.map { + case (s, h, t) => { + s * (1+t+3) * BankWidth + } + }.reduce(_+_) +} + +abstract class TageBundle extends XSBundle with HasTageParameter +abstract class TageModule extends XSModule with HasTageParameter + +class TageReq extends TageBundle { + val pc = UInt(VAddrBits.W) + val hist = UInt(HistoryLength.W) +} + +class TageResp extends TageBundle { + val ctr = UInt(3.W) + val u = UInt(2.W) +} + +class TageUpdate extends TageBundle { + val pc = UInt(VAddrBits.W) + val hist = UInt(HistoryLength.W) + // update tag and ctr + val mask = Vec(BankWidth, Bool()) + val taken = Vec(BankWidth, Bool()) + val alloc = Vec(BankWidth, Bool()) + val oldCtr = Vec(BankWidth, UInt(3.W)) + // update u + val uMask = Vec(BankWidth, Bool()) + val u = Vec(BankWidth, UInt(2.W)) +} + +class FakeTageTable() extends TageModule { + val io = IO(new Bundle() { + val req = Input(Valid(new TageReq)) + val resp = Output(Vec(BankWidth, Valid(new TageResp))) + val update = Input(new TageUpdate) + }) + io.resp := DontCare + +} + +class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule { + val io = IO(new Bundle() { + val req = Input(Valid(new TageReq)) + val resp = Output(Vec(BankWidth, Valid(new TageResp))) + val update = Input(new TageUpdate) + }) + + // bypass entries for tage update + val wrBypassEntries = 8 + + def compute_folded_hist(hist: UInt, l: Int) = { + val nChunks = (histLen + l - 1) / l + val hist_chunks = (0 until nChunks) map {i => + hist(min((i+1)*l, histLen)-1, i*l) + } + hist_chunks.reduce(_^_) + } + + def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt) = { + val idx_history = compute_folded_hist(hist, log2Ceil(nRows)) + val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows)-1,0) + val tag_history = compute_folded_hist(hist, tagLen) + // Use another part of pc to make tags + val tag = ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history)(tagLen-1,0) + (idx, tag) + } + + def inc_ctr(ctr: UInt, taken: Bool): UInt = { + Mux(!taken, Mux(ctr === 0.U, 0.U, ctr - 1.U), + Mux(ctr === 7.U, 7.U, ctr + 1.U)) + } + + val doing_reset = RegInit(true.B) + val reset_idx = RegInit(0.U(log2Ceil(nRows).W)) + reset_idx := reset_idx + doing_reset + when (reset_idx === (nRows-1).U) { doing_reset := false.B } + + class TageEntry() extends TageBundle { + val valid = Bool() + val tag = UInt(tagLen.W) + val ctr = UInt(3.W) + } + + val tageEntrySz = 1 + tagLen + 3 + + val (hashed_idx, tag) = compute_tag_and_hash(io.req.bits.pc, io.req.bits.hist) + + val hi_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false))) + val lo_us = List.fill(BankWidth)(Module(new SRAMTemplate(Bool(), set=nRows, shouldReset=false, holdRead=true, singlePort=false))) + val table = List.fill(BankWidth)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false))) + + val hi_us_r = Wire(Vec(BankWidth, Bool())) + val lo_us_r = Wire(Vec(BankWidth, Bool())) + val table_r = Wire(Vec(BankWidth, new TageEntry)) + + (0 until BankWidth).map( + b => { + hi_us(b).reset := reset.asBool + lo_us(b).reset := reset.asBool + table(b).reset := reset.asBool + hi_us(b).io.r.req.valid := io.req.valid + lo_us(b).io.r.req.valid := io.req.valid + table(b).io.r.req.valid := io.req.valid + hi_us(b).io.r.req.bits.setIdx := hashed_idx + lo_us(b).io.r.req.bits.setIdx := hashed_idx + table(b).io.r.req.bits.setIdx := hashed_idx + + hi_us_r(b) := hi_us(b).io.r.resp.data(0) + lo_us_r(b) := lo_us(b).io.r.resp.data(0) + table_r(b) := table(b).io.r.resp.data(0) + + // io.resp(b).valid := table_r(b).valid && table_r(b).tag === tag // Missing reset logic + // io.resp(b).bits.ctr := table_r(b).ctr + // io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b)) + } + ) + + val req_rhits = VecInit(table_r.map(e => e.valid && e.tag === tag && !doing_reset)) + + (0 until BankWidth).map(b => { + io.resp(b).valid := req_rhits(b) + io.resp(b).bits.ctr := table_r(b).ctr + io.resp(b).bits.u := Cat(hi_us_r(b),lo_us_r(b)) + }) + + + val clear_u_ctr = RegInit(0.U((log2Ceil(uBitPeriod) + log2Ceil(nRows) + 1).W)) + when (doing_reset) { clear_u_ctr := 1.U } .otherwise { clear_u_ctr := clear_u_ctr + 1.U } + + val doing_clear_u = clear_u_ctr(log2Ceil(uBitPeriod)-1,0) === 0.U + val doing_clear_u_hi = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 1.U + val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U + val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod) + + val (update_idx, update_tag) = compute_tag_and_hash(io.update.pc, io.update.hist) + + val update_wdata = Wire(Vec(BankWidth, new TageEntry)) + + + (0 until BankWidth).map(b => { + table(b).io.w.req.valid := io.update.mask(b) || doing_reset + table(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx) + table(b).io.w.req.bits.data := Mux(doing_reset, 0.U.asTypeOf(new TageEntry), update_wdata(b)) + }) + + val update_hi_wdata = Wire(Vec(BankWidth, Bool())) + (0 until BankWidth).map(b => { + hi_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_hi + hi_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_hi, clear_u_idx, update_idx)) + hi_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_hi, 0.U, update_hi_wdata(b)) + }) + + val update_lo_wdata = Wire(Vec(BankWidth, Bool())) + (0 until BankWidth).map(b => { + lo_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_lo + lo_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_lo, clear_u_idx, update_idx)) + lo_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_lo, 0.U, update_lo_wdata(b)) + }) + + val wrbypass_tags = Reg(Vec(wrBypassEntries, UInt(tagLen.W))) + val wrbypass_idxs = Reg(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W))) + val wrbypass = Reg(Vec(wrBypassEntries, Vec(BankWidth, UInt(3.W)))) + val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W)) + + val wrbypass_hits = VecInit((0 until wrBypassEntries) map { i => + !doing_reset && + wrbypass_tags(i) === update_tag && + wrbypass_idxs(i) === update_idx + }) + val wrbypass_hit = wrbypass_hits.reduce(_||_) + val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits) + + for (w <- 0 until BankWidth) { + update_wdata(w).ctr := Mux(io.update.alloc(w), + Mux(io.update.taken(w), 4.U, + 3.U + ), + Mux(wrbypass_hit, inc_ctr(wrbypass(wrbypass_hit_idx)(w), io.update.taken(w)), + inc_ctr(io.update.oldCtr(w), io.update.taken(w)) + ) + ) + update_wdata(w).valid := true.B + update_wdata(w).tag := update_tag + + update_hi_wdata(w) := io.update.u(w)(1) + update_lo_wdata(w) := io.update.u(w)(0) + } + + when (io.update.mask.reduce(_||_)) { + when (wrbypass_hits.reduce(_||_)) { + wrbypass(wrbypass_hit_idx) := VecInit(update_wdata.map(_.ctr)) + } .otherwise { + wrbypass (wrbypass_enq_idx) := VecInit(update_wdata.map(_.ctr)) + wrbypass_tags(wrbypass_enq_idx) := update_tag + wrbypass_idxs(wrbypass_enq_idx) := update_idx + wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0) + } + } + XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%b, idx=%d, tag=%x\n", io.req.bits.pc, io.req.bits.hist, hashed_idx, tag) + for (i <- 0 until BankWidth) { + XSDebug(RegNext(io.req.valid), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, RegNext(hashed_idx), req_rhits(i), table_r(i).ctr, Cat(hi_us_r(i),lo_us_r(i)).asUInt) + } + +} + +class FakeTAGE extends TageModule { + val io = IO(new Bundle() { + val req = Input(Valid(new TageReq)) + val out = new Bundle { + val hits = Output(UInt(BankWidth.W)) + val takens = Output(Vec(BankWidth, Bool())) + } + val meta = Output(Vec(BankWidth, (new TageMeta))) + val redirectInfo = Input(new RedirectInfo) + }) + + io.out.hits := 0.U(BankWidth.W) + io.out.takens := DontCare + io.meta := DontCare +} + + +class Tage extends TageModule { + val io = IO(new Bundle() { + val req = Input(Valid(new TageReq)) + val out = new Bundle { + val hits = Output(UInt(BankWidth.W)) + val takens = Output(Vec(BankWidth, Bool())) + } + val meta = Output(Vec(BankWidth, (new TageMeta))) + val redirectInfo = Input(new RedirectInfo) + }) + + val tables = TableInfo.map { + case (nRows, histLen, tagLen) => { + val t = if(EnableBPD) Module(new TageTable(nRows, histLen, tagLen, UBitPeriod)) else Module(new FakeTageTable) + t.io.req <> io.req + t + } + } + val resps = VecInit(tables.map(_.io.resp)) + + val updateMeta = io.redirectInfo.redirect.tageMeta + //val updateMisPred = UIntToOH(io.redirectInfo.redirect.fetchIdx) & + // Fill(BankWidth, (io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B).asUInt) + val updateMisPred = io.redirectInfo.misPred && io.redirectInfo.redirect.btbType === BTBtype.B + + val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool())))) + val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(BankWidth, Bool())))) + val updateTaken = Wire(Vec(TageNTables, Vec(BankWidth, Bool()))) + val updateAlloc = Wire(Vec(TageNTables, Vec(BankWidth, Bool()))) + val updateOldCtr = Wire(Vec(TageNTables, Vec(BankWidth, UInt(3.W)))) + val updateU = Wire(Vec(TageNTables, Vec(BankWidth, UInt(2.W)))) + updateTaken := DontCare + updateAlloc := DontCare + updateOldCtr := DontCare + updateU := DontCare + + // access tag tables and output meta info + val outHits = Wire(Vec(BankWidth, Bool())) + for (w <- 0 until BankWidth) { + var altPred = false.B + val finalAltPred = WireInit(false.B) + var provided = false.B + var provider = 0.U + outHits(w) := false.B + io.out.takens(w) := false.B + + for (i <- 0 until TageNTables) { + val hit = resps(i)(w).valid + val ctr = resps(i)(w).bits.ctr + when (hit) { + io.out.takens(w) := Mux(ctr === 3.U || ctr === 4.U, altPred, ctr(2)) // Use altpred on weak taken + finalAltPred := altPred + } + provided = provided || hit // Once hit then provide + provider = Mux(hit, i.U, provider) // Use the last hit as provider + altPred = Mux(hit, ctr(2), altPred) // Save current pred as potential altpred + } + outHits(w) := provided + io.meta(w).provider.valid := provided + io.meta(w).provider.bits := provider + io.meta(w).altDiffers := finalAltPred =/= io.out.takens(w) + io.meta(w).providerU := resps(provider)(w).bits.u + io.meta(w).providerCtr := resps(provider)(w).bits.ctr + + // Create a mask fo tables which did not hit our query, and also contain useless entries + // and also uses a longer history than the provider + val allocatableSlots = (VecInit(resps.map(r => !r(w).valid && r(w).bits.u === 0.U)).asUInt & + ~(LowerMask(UIntToOH(provider), TageNTables) & Fill(TageNTables, provided.asUInt)) + ) + val allocLFSR = LFSR64()(TageNTables - 1, 0) + val firstEntry = PriorityEncoder(allocatableSlots) + val maskedEntry = PriorityEncoder(allocatableSlots & allocLFSR) + val allocEntry = Mux(allocatableSlots(maskedEntry), maskedEntry, firstEntry) + io.meta(w).allocate.valid := allocatableSlots =/= 0.U + io.meta(w).allocate.bits := allocEntry + + val isUpdateTaken = io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U && + io.redirectInfo.redirect.taken && io.redirectInfo.redirect.btbType === BTBtype.B + when (io.redirectInfo.redirect.btbType === BTBtype.B && io.redirectInfo.valid && io.redirectInfo.redirect.fetchIdx === w.U) { + when (updateMeta.provider.valid) { + val provider = updateMeta.provider.bits + + updateMask(provider)(w) := true.B + updateUMask(provider)(w) := true.B + + updateU(provider)(w) := Mux(!updateMeta.altDiffers, updateMeta.providerU, + Mux(updateMisPred, Mux(updateMeta.providerU === 0.U, 0.U, updateMeta.providerU - 1.U), + Mux(updateMeta.providerU === 3.U, 3.U, updateMeta.providerU + 1.U)) + ) + updateTaken(provider)(w) := isUpdateTaken + updateOldCtr(provider)(w) := updateMeta.providerCtr + updateAlloc(provider)(w) := false.B + } + } + } + + when (io.redirectInfo.valid && updateMisPred) { + val idx = io.redirectInfo.redirect.fetchIdx + val allocate = updateMeta.allocate + when (allocate.valid) { + updateMask(allocate.bits)(idx) := true.B + updateTaken(allocate.bits)(idx) := io.redirectInfo.redirect.taken + updateAlloc(allocate.bits)(idx) := true.B + updateUMask(allocate.bits)(idx) := true.B + updateU(allocate.bits)(idx) := 0.U + }.otherwise { + val provider = updateMeta.provider + val decrMask = Mux(provider.valid, ~LowerMask(UIntToOH(provider.bits), TageNTables), 0.U) + for (i <- 0 until TageNTables) { + when (decrMask(i)) { + updateUMask(i)(idx) := true.B + updateU(i)(idx) := 0.U + } + } + } + } + + for (i <- 0 until TageNTables) { + for (w <- 0 until BankWidth) { + tables(i).io.update.mask(w) := updateMask(i)(w) + tables(i).io.update.taken(w) := updateTaken(i)(w) + tables(i).io.update.alloc(w) := updateAlloc(i)(w) + tables(i).io.update.oldCtr(w) := updateOldCtr(i)(w) + + tables(i).io.update.uMask(w) := updateUMask(i)(w) + tables(i).io.update.u(w) := updateU(i)(w) + } + // use fetch pc instead of instruction pc + tables(i).io.update.pc := io.redirectInfo.redirect.pc - (io.redirectInfo.redirect.fetchIdx << 1.U) + tables(i).io.update.hist := io.redirectInfo.redirect.hist + } + + io.out.hits := outHits.asUInt + + + val m = updateMeta + XSDebug(io.req.valid, "req: pc=0x%x, hist=%b\n", io.req.bits.pc, io.req.bits.hist) + XSDebug(io.redirectInfo.valid, "redirect: provider(%d):%d, altDiffers:%d, providerU:%d, providerCtr:%d, allocate(%d):%d\n", m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits) + XSDebug(RegNext(io.req.valid), "resp: pc=%x, outHits=%b, takens=%b\n", RegNext(io.req.bits.pc), io.out.hits, io.out.takens.asUInt) +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/btb.scala b/src/main/scala/xiangshan/frontend/btb.scala index 9349b09ac..fe66fd5a1 100644 --- a/src/main/scala/xiangshan/frontend/btb.scala +++ b/src/main/scala/xiangshan/frontend/btb.scala @@ -1,248 +1,248 @@ -//package xiangshan.frontend -// -//import chisel3._ -//import chisel3.util._ -//import xiangshan._ -//import xiangshan.backend.ALUOpType -//import utils._ -//import chisel3.util.experimental.BoringUtils -//import xiangshan.backend.decode.XSTrap -// -//class BTBUpdateBundle extends XSBundle { -// val pc = UInt(VAddrBits.W) -// val hit = Bool() -// val misPred = Bool() -// val oldCtr = UInt(2.W) -// val taken = Bool() -// val target = UInt(VAddrBits.W) -// val btbType = UInt(2.W) -// val isRVC = Bool() -//} -// -//class BTBPred extends XSBundle { -// val taken = Bool() -// val takenIdx = UInt(log2Up(PredictWidth).W) -// val target = UInt(VAddrBits.W) -// -// val notTakens = Vec(PredictWidth, Bool()) -// val dEntries = Vec(PredictWidth, btbDataEntry()) -// val hits = Vec(PredictWidth, Bool()) -// -// // whether an RVI instruction crosses over two fetch packet -// val isRVILateJump = Bool() -//} -// -//case class btbDataEntry() extends XSBundle { -// val target = UInt(VAddrBits.W) -// val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor -// val btbType = UInt(2.W) -// val isRVC = Bool() -//} -// -//case class btbMetaEntry() extends XSBundle { -// val valid = Bool() -// // TODO: don't need full length of tag -// val tag = UInt((VAddrBits - log2Up(BtbSize) - 1).W) -//} -// -//class BTB extends XSModule { -// val io = IO(new Bundle() { -// // Input -// val in = new Bundle { -// val pc = Flipped(Decoupled(UInt(VAddrBits.W))) -// val pcLatch = Input(UInt(VAddrBits.W)) -// val mask = Input(UInt(PredictWidth.W)) -// } -// val redirectValid = Input(Bool()) -// val flush = Input(Bool()) -// val update = Input(new BTBUpdateBundle) -// // Output -// val out = Output(new BTBPred) -// }) -// -// io.in.pc.ready := true.B -// val fireLatch = RegNext(io.in.pc.fire()) -// val maskLatch = RegEnable(io.in.mask, io.in.pc.fire()) -// -// val btbAddr = new TableAddr(log2Up(BtbSize), BtbBanks) -// -// // SRAMs to store BTB meta & data -// val btbMeta = List.fill(BtbBanks)( -// Module(new SRAMTemplate(btbMetaEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true))) -// val btbData = List.fill(BtbBanks)( -// Module(new SRAMTemplate(btbDataEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true))) -// -// // BTB read requests -// val baseBank = btbAddr.getBank(io.in.pc.bits) -// // circular shifting -// def circularShiftLeft(source: UInt, len: Int, shamt: UInt): UInt = { -// val res = Wire(UInt(len.W)) -// val higher = source << shamt -// val lower = source >> (len.U - shamt) -// res := higher | lower -// res -// } -// val realMask = circularShiftLeft(io.in.mask, BtbBanks, baseBank) -// -// // those banks whose indexes are less than baseBank are in the next row -// val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank)) -// -// val baseRow = btbAddr.getBankIdx(io.in.pc.bits) -// // this row is the last row of a bank -// val nextRowStartsUp = baseRow.andR -// val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b.U), Mux(nextRowStartsUp, 0.U, baseRow+1.U), baseRow))) -// val realRowLatch = VecInit(realRow.map(RegNext(_))) -// -// for (b <- 0 until BtbBanks) { -// btbMeta(b).reset := reset.asBool -// btbMeta(b).io.r.req.valid := realMask(b) && io.in.pc.valid -// btbMeta(b).io.r.req.bits.setIdx := realRow(b) -// btbData(b).reset := reset.asBool -// btbData(b).io.r.req.valid := realMask(b) && io.in.pc.valid -// btbData(b).io.r.req.bits.setIdx := realRow(b) -// } -// -// -// // Entries read from SRAM -// val metaRead = Wire(Vec(BtbBanks, btbMetaEntry())) -// val dataRead = Wire(Vec(BtbBanks, btbDataEntry())) -// val readFire = Wire(Vec(BtbBanks, Bool())) -// for (b <- 0 until BtbBanks) { -// readFire(b) := btbMeta(b).io.r.req.fire() && btbData(b).io.r.req.fire() -// metaRead(b) := btbMeta(b).io.r.resp.data(0) -// dataRead(b) := btbData(b).io.r.resp.data(0) -// } -// -// val baseBankLatch = btbAddr.getBank(io.in.pcLatch) -// // val isAlignedLatch = baseBankLatch === 0.U -// val baseTag = btbAddr.getTag(io.in.pcLatch) -// // If the next row starts up, the tag needs to be incremented as well -// val tagIncremented = VecInit((0 until BtbBanks).map(b => RegEnable(isInNextRow(b.U) && nextRowStartsUp, io.in.pc.valid))) -// -// val bankHits = Wire(Vec(BtbBanks, Bool())) -// for (b <- 0 until BtbBanks) { -// bankHits(b) := metaRead(b).valid && -// (Mux(tagIncremented(b), baseTag+1.U, baseTag) === metaRead(b).tag) && !io.flush && RegNext(readFire(b), init = false.B) -// } -// -// // taken branches of jumps from a valid entry -// val predTakens = Wire(Vec(BtbBanks, Bool())) -// // not taken branches from a valid entry -// val notTakenBranches = Wire(Vec(BtbBanks, Bool())) -// for (b <- 0 until BtbBanks) { -// predTakens(b) := bankHits(b) && (dataRead(b).btbType === BTBtype.J || dataRead(b).btbType === BTBtype.B && dataRead(b).pred(1).asBool) -// notTakenBranches(b) := bankHits(b) && dataRead(b).btbType === BTBtype.B && !dataRead(b).pred(1).asBool -// } -// -// // e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4) -// val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch + b.U) % BtbBanks.U)) -// -// val isTaken = predTakens.reduce(_||_) -// // Priority mux which corresponds with inst orders -// // BTB only produce one single prediction -// val takenTarget = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).target))) -// val takenType = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).btbType))) -// // Record which inst is predicted taken -// val takenIdx = MuxCase(0.U, (0 until BtbBanks).map(b => (predTakens(bankIdxInOrder(b)), b.U))) -// -// // Update logic -// // 1 calculate new 2-bit saturated counter value -// def satUpdate(old: UInt, len: Int, taken: Bool): UInt = { -// val oldSatTaken = old === ((1 << len)-1).U -// val oldSatNotTaken = old === 0.U -// Mux(oldSatTaken && taken, ((1 << len)-1).U, -// Mux(oldSatNotTaken && !taken, 0.U, -// Mux(taken, old + 1.U, old - 1.U))) -// } -// -// val u = io.update -// val newCtr = Mux(!u.hit, "b10".U, satUpdate(u.oldCtr, 2, u.taken)) -// -// val updateOnSaturated = u.taken && u.oldCtr === "b11".U || !u.taken && u.oldCtr === "b00".U -// -// // 2 write btb -// val updateBankIdx = btbAddr.getBank(u.pc) -// val updateRow = btbAddr.getBankIdx(u.pc) -// val btbMetaWrite = Wire(btbMetaEntry()) -// btbMetaWrite.valid := true.B -// btbMetaWrite.tag := btbAddr.getTag(u.pc) -// val btbDataWrite = Wire(btbDataEntry()) -// btbDataWrite.target := u.target -// btbDataWrite.pred := newCtr -// btbDataWrite.btbType := u.btbType -// btbDataWrite.isRVC := u.isRVC -// -// val isBr = u.btbType === BTBtype.B -// val isJ = u.btbType === BTBtype.J -// val notBrOrJ = u.btbType =/= BTBtype.B && u.btbType =/= BTBtype.J -// -// // Do not update BTB on indirect or return, or correctly predicted J or saturated counters -// val noNeedToUpdate = (!u.misPred && (isBr && updateOnSaturated || isJ)) || notBrOrJ -// -// // do not update on saturated ctrs -// val btbWriteValid = io.redirectValid && !noNeedToUpdate -// -// for (b <- 0 until BtbBanks) { -// btbMeta(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx -// btbMeta(b).io.w.req.bits.setIdx := updateRow -// btbMeta(b).io.w.req.bits.data := btbMetaWrite -// btbData(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx -// btbData(b).io.w.req.bits.setIdx := updateRow -// btbData(b).io.w.req.bits.data := btbDataWrite -// } -// -// // io.out.hit := bankHits.reduce(_||_) -// io.out.taken := isTaken -// io.out.takenIdx := takenIdx -// io.out.target := takenTarget -// // io.out.writeWay := writeWay -// io.out.notTakens := VecInit((0 until BtbBanks).map(b => notTakenBranches(bankIdxInOrder(b)))) -// io.out.dEntries := VecInit((0 until BtbBanks).map(b => dataRead(bankIdxInOrder(b)))) -// io.out.hits := VecInit((0 until BtbBanks).map(b => bankHits(bankIdxInOrder(b)))) -// io.out.isRVILateJump := io.out.taken && takenIdx === OHToUInt(HighestBit(maskLatch, PredictWidth)) && !dataRead(bankIdxInOrder(takenIdx)).isRVC -// -// // read-after-write bypass -// val rawBypassHit = Wire(Vec(BtbBanks, Bool())) -// for (b <- 0 until BtbBanks) { -// when (b.U === updateBankIdx && realRow(b) === updateRow) { // read and write to the same address -// when (realMask(b) && io.in.pc.valid && btbWriteValid) { // both read and write valid -// rawBypassHit(b) := true.B -// btbMeta(b).io.r.req.valid := false.B -// btbData(b).io.r.req.valid := false.B -// // metaRead(b) := RegNext(btbMetaWrite) -// // dataRead(b) := RegNext(btbDataWrite) -// readFire(b) := true.B -// XSDebug("raw bypass hits: bank=%d, row=%d, meta: %d %x, data: tgt=%x pred=%b btbType=%b isRVC=%d\n", -// b.U, updateRow, -// btbMetaWrite.valid, btbMetaWrite.tag, -// btbDataWrite.target, btbDataWrite.pred, btbDataWrite.btbType, btbDataWrite.isRVC) -// }.otherwise { -// rawBypassHit(b) := false.B -// } -// }.otherwise { -// rawBypassHit(b) := false.B -// } -// -// when (RegNext(rawBypassHit(b))) { -// metaRead(b) := RegNext(btbMetaWrite) -// dataRead(b) := RegNext(btbDataWrite) -// } -// } -// -// XSDebug(io.in.pc.fire(), "read: pc=0x%x, baseBank=%d, realMask=%b\n", io.in.pc.bits, baseBank, realMask) -// XSDebug(fireLatch, "read_resp: pc=0x%x, readIdx=%d-------------------------------\n", -// io.in.pcLatch, btbAddr.getIdx(io.in.pcLatch)) -// for (i <- 0 until BtbBanks){ -// XSDebug(fireLatch, "read_resp[b=%d][r=%d]: valid=%d, tag=0x%x, target=0x%x, type=%d, ctr=%d\n", -// i.U, realRowLatch(i), metaRead(i).valid, metaRead(i).tag, dataRead(i).target, dataRead(i).btbType, dataRead(i).pred) -// } -// XSDebug("out: taken=%d takenIdx=%d tgt=%x notTakens=%b hits=%b isRVILateJump=%d\n", -// io.out.taken, io.out.takenIdx, io.out.target, io.out.notTakens.asUInt, io.out.hits.asUInt, io.out.isRVILateJump) -// XSDebug(fireLatch, "bankIdxInOrder:") -// for (i <- 0 until BtbBanks){ XSDebug(fireLatch, "%d ", bankIdxInOrder(i))} -// XSDebug(fireLatch, "\n") -// XSDebug(io.redirectValid, "update_req: pc=0x%x, hit=%d, misPred=%d, oldCtr=%d, taken=%d, target=0x%x, btbType=%d\n", -// u.pc, u.hit, u.misPred, u.oldCtr, u.taken, u.target, u.btbType) -// XSDebug(io.redirectValid, "update: noNeedToUpdate=%d, writeValid=%d, bank=%d, row=%d, newCtr=%d\n", -// noNeedToUpdate, btbWriteValid, updateBankIdx, updateRow, newCtr) -//} \ No newline at end of file +package xiangshan.frontend + +import chisel3._ +import chisel3.util._ +import xiangshan._ +import xiangshan.backend.ALUOpType +import utils._ +import chisel3.util.experimental.BoringUtils +import xiangshan.backend.decode.XSTrap + +class BTBUpdateBundle extends XSBundle { + val pc = UInt(VAddrBits.W) + val hit = Bool() + val misPred = Bool() + val oldCtr = UInt(2.W) + val taken = Bool() + val target = UInt(VAddrBits.W) + val btbType = UInt(2.W) + val isRVC = Bool() +} + +class BTBPred extends XSBundle { + val taken = Bool() + val takenIdx = UInt(log2Up(PredictWidth).W) + val target = UInt(VAddrBits.W) + + val notTakens = Vec(PredictWidth, Bool()) + val dEntries = Vec(PredictWidth, btbDataEntry()) + val hits = Vec(PredictWidth, Bool()) + + // whether an RVI instruction crosses over two fetch packet + val isRVILateJump = Bool() +} + +case class btbDataEntry() extends XSBundle { + val target = UInt(VAddrBits.W) + val pred = UInt(2.W) // 2-bit saturated counter as a quick predictor + val btbType = UInt(2.W) + val isRVC = Bool() +} + +case class btbMetaEntry() extends XSBundle { + val valid = Bool() + // TODO: don't need full length of tag + val tag = UInt((VAddrBits - log2Up(BtbSize) - 1).W) +} + +class BTB extends XSModule { + val io = IO(new Bundle() { + // Input + val in = new Bundle { + val pc = Flipped(Decoupled(UInt(VAddrBits.W))) + val pcLatch = Input(UInt(VAddrBits.W)) + val mask = Input(UInt(PredictWidth.W)) + } + val redirectValid = Input(Bool()) + val flush = Input(Bool()) + val update = Input(new BTBUpdateBundle) + // Output + val out = Output(new BTBPred) + }) + + io.in.pc.ready := true.B + val fireLatch = RegNext(io.in.pc.fire()) + val maskLatch = RegEnable(io.in.mask, io.in.pc.fire()) + + val btbAddr = new TableAddr(log2Up(BtbSize), BtbBanks) + + // SRAMs to store BTB meta & data + val btbMeta = List.fill(BtbBanks)( + Module(new SRAMTemplate(btbMetaEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true))) + val btbData = List.fill(BtbBanks)( + Module(new SRAMTemplate(btbDataEntry(), set = BtbSize / BtbBanks, shouldReset = true, holdRead = true))) + + // BTB read requests + val baseBank = btbAddr.getBank(io.in.pc.bits) + // circular shifting + def circularShiftLeft(source: UInt, len: Int, shamt: UInt): UInt = { + val res = Wire(UInt(len.W)) + val higher = source << shamt + val lower = source >> (len.U - shamt) + res := higher | lower + res + } + val realMask = circularShiftLeft(io.in.mask, BtbBanks, baseBank) + + // those banks whose indexes are less than baseBank are in the next row + val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank)) + + val baseRow = btbAddr.getBankIdx(io.in.pc.bits) + // this row is the last row of a bank + val nextRowStartsUp = baseRow.andR + val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b.U), Mux(nextRowStartsUp, 0.U, baseRow+1.U), baseRow))) + val realRowLatch = VecInit(realRow.map(RegNext(_))) + + for (b <- 0 until BtbBanks) { + btbMeta(b).reset := reset.asBool + btbMeta(b).io.r.req.valid := realMask(b) && io.in.pc.valid + btbMeta(b).io.r.req.bits.setIdx := realRow(b) + btbData(b).reset := reset.asBool + btbData(b).io.r.req.valid := realMask(b) && io.in.pc.valid + btbData(b).io.r.req.bits.setIdx := realRow(b) + } + + + // Entries read from SRAM + val metaRead = Wire(Vec(BtbBanks, btbMetaEntry())) + val dataRead = Wire(Vec(BtbBanks, btbDataEntry())) + val readFire = Wire(Vec(BtbBanks, Bool())) + for (b <- 0 until BtbBanks) { + readFire(b) := btbMeta(b).io.r.req.fire() && btbData(b).io.r.req.fire() + metaRead(b) := btbMeta(b).io.r.resp.data(0) + dataRead(b) := btbData(b).io.r.resp.data(0) + } + + val baseBankLatch = btbAddr.getBank(io.in.pcLatch) + // val isAlignedLatch = baseBankLatch === 0.U + val baseTag = btbAddr.getTag(io.in.pcLatch) + // If the next row starts up, the tag needs to be incremented as well + val tagIncremented = VecInit((0 until BtbBanks).map(b => RegEnable(isInNextRow(b.U) && nextRowStartsUp, io.in.pc.valid))) + + val bankHits = Wire(Vec(BtbBanks, Bool())) + for (b <- 0 until BtbBanks) { + bankHits(b) := metaRead(b).valid && + (Mux(tagIncremented(b), baseTag+1.U, baseTag) === metaRead(b).tag) && !io.flush && RegNext(readFire(b), init = false.B) + } + + // taken branches of jumps from a valid entry + val predTakens = Wire(Vec(BtbBanks, Bool())) + // not taken branches from a valid entry + val notTakenBranches = Wire(Vec(BtbBanks, Bool())) + for (b <- 0 until BtbBanks) { + predTakens(b) := bankHits(b) && (dataRead(b).btbType === BTBtype.J || dataRead(b).btbType === BTBtype.B && dataRead(b).pred(1).asBool) + notTakenBranches(b) := bankHits(b) && dataRead(b).btbType === BTBtype.B && !dataRead(b).pred(1).asBool + } + + // e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4) + val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch + b.U) % BtbBanks.U)) + + val isTaken = predTakens.reduce(_||_) + // Priority mux which corresponds with inst orders + // BTB only produce one single prediction + val takenTarget = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).target))) + val takenType = MuxCase(0.U, bankIdxInOrder.map(b => (predTakens(b), dataRead(b).btbType))) + // Record which inst is predicted taken + val takenIdx = MuxCase(0.U, (0 until BtbBanks).map(b => (predTakens(bankIdxInOrder(b)), b.U))) + + // Update logic + // 1 calculate new 2-bit saturated counter value + def satUpdate(old: UInt, len: Int, taken: Bool): UInt = { + val oldSatTaken = old === ((1 << len)-1).U + val oldSatNotTaken = old === 0.U + Mux(oldSatTaken && taken, ((1 << len)-1).U, + Mux(oldSatNotTaken && !taken, 0.U, + Mux(taken, old + 1.U, old - 1.U))) + } + + val u = io.update + val newCtr = Mux(!u.hit, "b10".U, satUpdate(u.oldCtr, 2, u.taken)) + + val updateOnSaturated = u.taken && u.oldCtr === "b11".U || !u.taken && u.oldCtr === "b00".U + + // 2 write btb + val updateBankIdx = btbAddr.getBank(u.pc) + val updateRow = btbAddr.getBankIdx(u.pc) + val btbMetaWrite = Wire(btbMetaEntry()) + btbMetaWrite.valid := true.B + btbMetaWrite.tag := btbAddr.getTag(u.pc) + val btbDataWrite = Wire(btbDataEntry()) + btbDataWrite.target := u.target + btbDataWrite.pred := newCtr + btbDataWrite.btbType := u.btbType + btbDataWrite.isRVC := u.isRVC + + val isBr = u.btbType === BTBtype.B + val isJ = u.btbType === BTBtype.J + val notBrOrJ = u.btbType =/= BTBtype.B && u.btbType =/= BTBtype.J + + // Do not update BTB on indirect or return, or correctly predicted J or saturated counters + val noNeedToUpdate = (!u.misPred && (isBr && updateOnSaturated || isJ)) || notBrOrJ + + // do not update on saturated ctrs + val btbWriteValid = io.redirectValid && !noNeedToUpdate + + for (b <- 0 until BtbBanks) { + btbMeta(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx + btbMeta(b).io.w.req.bits.setIdx := updateRow + btbMeta(b).io.w.req.bits.data := btbMetaWrite + btbData(b).io.w.req.valid := btbWriteValid && b.U === updateBankIdx + btbData(b).io.w.req.bits.setIdx := updateRow + btbData(b).io.w.req.bits.data := btbDataWrite + } + + // io.out.hit := bankHits.reduce(_||_) + io.out.taken := isTaken + io.out.takenIdx := takenIdx + io.out.target := takenTarget + // io.out.writeWay := writeWay + io.out.notTakens := VecInit((0 until BtbBanks).map(b => notTakenBranches(bankIdxInOrder(b)))) + io.out.dEntries := VecInit((0 until BtbBanks).map(b => dataRead(bankIdxInOrder(b)))) + io.out.hits := VecInit((0 until BtbBanks).map(b => bankHits(bankIdxInOrder(b)))) + io.out.isRVILateJump := io.out.taken && takenIdx === OHToUInt(HighestBit(maskLatch, PredictWidth)) && !dataRead(bankIdxInOrder(takenIdx)).isRVC + + // read-after-write bypass + val rawBypassHit = Wire(Vec(BtbBanks, Bool())) + for (b <- 0 until BtbBanks) { + when (b.U === updateBankIdx && realRow(b) === updateRow) { // read and write to the same address + when (realMask(b) && io.in.pc.valid && btbWriteValid) { // both read and write valid + rawBypassHit(b) := true.B + btbMeta(b).io.r.req.valid := false.B + btbData(b).io.r.req.valid := false.B + // metaRead(b) := RegNext(btbMetaWrite) + // dataRead(b) := RegNext(btbDataWrite) + readFire(b) := true.B + XSDebug("raw bypass hits: bank=%d, row=%d, meta: %d %x, data: tgt=%x pred=%b btbType=%b isRVC=%d\n", + b.U, updateRow, + btbMetaWrite.valid, btbMetaWrite.tag, + btbDataWrite.target, btbDataWrite.pred, btbDataWrite.btbType, btbDataWrite.isRVC) + }.otherwise { + rawBypassHit(b) := false.B + } + }.otherwise { + rawBypassHit(b) := false.B + } + + when (RegNext(rawBypassHit(b))) { + metaRead(b) := RegNext(btbMetaWrite) + dataRead(b) := RegNext(btbDataWrite) + } + } + + XSDebug(io.in.pc.fire(), "read: pc=0x%x, baseBank=%d, realMask=%b\n", io.in.pc.bits, baseBank, realMask) + XSDebug(fireLatch, "read_resp: pc=0x%x, readIdx=%d-------------------------------\n", + io.in.pcLatch, btbAddr.getIdx(io.in.pcLatch)) + for (i <- 0 until BtbBanks){ + XSDebug(fireLatch, "read_resp[b=%d][r=%d]: valid=%d, tag=0x%x, target=0x%x, type=%d, ctr=%d\n", + i.U, realRowLatch(i), metaRead(i).valid, metaRead(i).tag, dataRead(i).target, dataRead(i).btbType, dataRead(i).pred) + } + XSDebug("out: taken=%d takenIdx=%d tgt=%x notTakens=%b hits=%b isRVILateJump=%d\n", + io.out.taken, io.out.takenIdx, io.out.target, io.out.notTakens.asUInt, io.out.hits.asUInt, io.out.isRVILateJump) + XSDebug(fireLatch, "bankIdxInOrder:") + for (i <- 0 until BtbBanks){ XSDebug(fireLatch, "%d ", bankIdxInOrder(i))} + XSDebug(fireLatch, "\n") + XSDebug(io.redirectValid, "update_req: pc=0x%x, hit=%d, misPred=%d, oldCtr=%d, taken=%d, target=0x%x, btbType=%d\n", + u.pc, u.hit, u.misPred, u.oldCtr, u.taken, u.target, u.btbType) + XSDebug(io.redirectValid, "update: noNeedToUpdate=%d, writeValid=%d, bank=%d, row=%d, newCtr=%d\n", + noNeedToUpdate, btbWriteValid, updateBankIdx, updateRow, newCtr) +} \ No newline at end of file diff --git a/src/main/scala/xiangshan/frontend/jbtac.scala b/src/main/scala/xiangshan/frontend/jbtac.scala index bf73d6c9b..9e0735959 100644 --- a/src/main/scala/xiangshan/frontend/jbtac.scala +++ b/src/main/scala/xiangshan/frontend/jbtac.scala @@ -1,151 +1,151 @@ -//package xiangshan.frontend -// -//import chisel3._ -//import chisel3.util._ -//import xiangshan._ -//import utils._ -//import xiangshan.backend.ALUOpType -// -// -//class JBTACUpdateBundle extends XSBundle { -// val fetchPC = UInt(VAddrBits.W) -// val fetchIdx = UInt(log2Up(PredictWidth).W) -// val hist = UInt(HistoryLength.W) -// val target = UInt(VAddrBits.W) -// val btbType = UInt(2.W) -// val misPred = Bool() -// val isRVC = Bool() -//} -// -//class JBTACPred extends XSBundle { -// val hit = Bool() -// val target = UInt(VAddrBits.W) -// val hitIdx = UInt(log2Up(PredictWidth).W) -// val isRVILateJump = Bool() -// val isRVC = Bool() -//} -// -//class JBTAC extends XSModule { -// val io = IO(new Bundle { -// val in = new Bundle { -// val pc = Flipped(Decoupled(UInt(VAddrBits.W))) -// val pcLatch = Input(UInt(VAddrBits.W)) -// val mask = Input(UInt(PredictWidth.W)) -// val hist = Input(UInt(HistoryLength.W)) -// } -// val redirectValid = Input(Bool()) -// val flush = Input(Bool()) -// val update = Input(new JBTACUpdateBundle) -// -// val out = Output(new JBTACPred) -// }) -// -// io.in.pc.ready := true.B -// -// val fireLatch = RegNext(io.in.pc.fire()) -// -// // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret. -// val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks) -// def jbtacEntry() = new Bundle { -// val valid = Bool() -// // TODO: don't need full length of tag and target -// val tag = UInt(jbtacAddr.tagBits.W + jbtacAddr.idxBits.W) -// val target = UInt(VAddrBits.W) -// val offset = UInt(log2Up(PredictWidth).W) -// val isRVC = Bool() -// } -// -// val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false))) -// -// val readEntries = Wire(Vec(JbtacBanks, jbtacEntry())) -// -// val readFire = Reg(Vec(JbtacBanks, Bool())) -// // Only read one bank -// val histXORAddr = io.in.pc.bits ^ Cat(io.in.hist, 0.U(1.W))(VAddrBits - 1, 0) -// val histXORAddrLatch = RegEnable(histXORAddr, io.in.pc.valid) -// -// val readBank = jbtacAddr.getBank(histXORAddr) -// val readRow = jbtacAddr.getBankIdx(histXORAddr) -// readFire := 0.U.asTypeOf(Vec(JbtacBanks, Bool())) -// (0 until JbtacBanks).map( -// b => { -// jbtac(b).reset := reset.asBool -// jbtac(b).io.r.req.valid := io.in.pc.fire() && b.U === readBank -// jbtac(b).io.r.req.bits.setIdx := readRow -// readFire(b) := jbtac(b).io.r.req.fire() -// readEntries(b) := jbtac(b).io.r.resp.data(0) -// } -// ) -// -// val readBankLatch = jbtacAddr.getBank(histXORAddrLatch) -// val readRowLatch = jbtacAddr.getBankIdx(histXORAddrLatch) -// val readMaskLatch = RegEnable(io.in.mask, io.in.pc.fire()) -// -// val outHit = readEntries(readBankLatch).valid && -// readEntries(readBankLatch).tag === Cat(jbtacAddr.getTag(io.in.pcLatch), jbtacAddr.getIdx(io.in.pcLatch)) && -// !io.flush && RegNext(readFire(readBankLatch)) && readMaskLatch(readEntries(readBankLatch).offset).asBool -// -// io.out.hit := outHit -// io.out.hitIdx := readEntries(readBankLatch).offset -// io.out.target := readEntries(readBankLatch).target -// io.out.isRVILateJump := io.out.hit && io.out.hitIdx === OHToUInt(HighestBit(readMaskLatch, PredictWidth)) && !readEntries(readBankLatch).isRVC -// io.out.isRVC := readEntries(readBankLatch).isRVC -// -// // update jbtac -// val writeEntry = Wire(jbtacEntry()) -// // val updateHistXORAddr = updatefetchPC ^ Cat(r.hist, 0.U(2.W))(VAddrBits - 1, 0) -// val updateHistXORAddr = io.update.fetchPC ^ Cat(io.update.hist, 0.U(1.W))(VAddrBits - 1, 0) -// writeEntry.valid := true.B -// // writeEntry.tag := jbtacAddr.getTag(updatefetchPC) -// writeEntry.tag := Cat(jbtacAddr.getTag(io.update.fetchPC), jbtacAddr.getIdx(io.update.fetchPC)) -// writeEntry.target := io.update.target -// // writeEntry.offset := updateFetchIdx -// writeEntry.offset := io.update.fetchIdx -// writeEntry.isRVC := io.update.isRVC -// -// val writeBank = jbtacAddr.getBank(updateHistXORAddr) -// val writeRow = jbtacAddr.getBankIdx(updateHistXORAddr) -// val writeValid = io.redirectValid && io.update.misPred && io.update.btbType === BTBtype.I -// for (b <- 0 until JbtacBanks) { -// when (b.U === writeBank) { -// jbtac(b).io.w.req.valid := writeValid -// jbtac(b).io.w.req.bits.setIdx := writeRow -// jbtac(b).io.w.req.bits.data := writeEntry -// }.otherwise { -// jbtac(b).io.w.req.valid := false.B -// jbtac(b).io.w.req.bits.setIdx := DontCare -// jbtac(b).io.w.req.bits.data := DontCare -// } -// } -// -// // read-after-write bypass -// val rawBypassHit = Wire(Vec(JbtacBanks, Bool())) -// for (b <- 0 until JbtacBanks) { -// when (readBank === writeBank && readRow === writeRow && b.U === readBank) { -// when (io.in.pc.fire() && writeValid) { -// rawBypassHit(b) := true.B -// jbtac(b).io.r.req.valid := false.B -// // readEntries(b) := RegNext(writeEntry) -// readFire(b) := true.B -// -// XSDebug("raw bypass hits: bank=%d, row=%d, tag=%x, tgt=%x, offet=%d, isRVC=%d\n", -// b.U, readRow, writeEntry.tag, writeEntry.target, writeEntry.offset, writeEntry.isRVC) -// }.otherwise { -// rawBypassHit(b) := false.B -// } -// }.otherwise { -// rawBypassHit(b) := false.B -// } -// -// when (RegNext(rawBypassHit(b))) { readEntries(b) := RegNext(writeEntry) } -// } -// -// XSDebug(io.in.pc.fire(), "read: pc=0x%x, histXORAddr=0x%x, bank=%d, row=%d, hist=%b\n", -// io.in.pc.bits, histXORAddr, readBank, readRow, io.in.hist) -// XSDebug("out: hit=%d tgt=%x hitIdx=%d iRVILateJump=%d isRVC=%d\n", -// io.out.hit, io.out.target, io.out.hitIdx, io.out.isRVILateJump, io.out.isRVC) -// XSDebug(fireLatch, "read_resp: pc=0x%x, bank=%d, row=%d, target=0x%x, offset=%d, hit=%d\n", -// io.in.pcLatch, readBankLatch, readRowLatch, readEntries(readBankLatch).target, readEntries(readBankLatch).offset, outHit) -// XSDebug(io.redirectValid, "update_req: fetchPC=0x%x, writeValid=%d, hist=%b, bank=%d, row=%d, target=0x%x, offset=%d, type=0x%d\n", -// io.update.fetchPC, writeValid, io.update.hist, writeBank, writeRow, io.update.target, io.update.fetchIdx, io.update.btbType) -//} \ No newline at end of file +package xiangshan.frontend + +import chisel3._ +import chisel3.util._ +import xiangshan._ +import utils._ +import xiangshan.backend.ALUOpType + + +class JBTACUpdateBundle extends XSBundle { + val fetchPC = UInt(VAddrBits.W) + val fetchIdx = UInt(log2Up(PredictWidth).W) + val hist = UInt(HistoryLength.W) + val target = UInt(VAddrBits.W) + val btbType = UInt(2.W) + val misPred = Bool() + val isRVC = Bool() +} + +class JBTACPred extends XSBundle { + val hit = Bool() + val target = UInt(VAddrBits.W) + val hitIdx = UInt(log2Up(PredictWidth).W) + val isRVILateJump = Bool() + val isRVC = Bool() +} + +class JBTAC extends XSModule { + val io = IO(new Bundle { + val in = new Bundle { + val pc = Flipped(Decoupled(UInt(VAddrBits.W))) + val pcLatch = Input(UInt(VAddrBits.W)) + val mask = Input(UInt(PredictWidth.W)) + val hist = Input(UInt(HistoryLength.W)) + } + val redirectValid = Input(Bool()) + val flush = Input(Bool()) + val update = Input(new JBTACUpdateBundle) + + val out = Output(new JBTACPred) + }) + + io.in.pc.ready := true.B + + val fireLatch = RegNext(io.in.pc.fire()) + + // JBTAC, divided into 8 banks, makes prediction for indirect jump except ret. + val jbtacAddr = new TableAddr(log2Up(JbtacSize), JbtacBanks) + def jbtacEntry() = new Bundle { + val valid = Bool() + // TODO: don't need full length of tag and target + val tag = UInt(jbtacAddr.tagBits.W + jbtacAddr.idxBits.W) + val target = UInt(VAddrBits.W) + val offset = UInt(log2Up(PredictWidth).W) + val isRVC = Bool() + } + + val jbtac = List.fill(JbtacBanks)(Module(new SRAMTemplate(jbtacEntry(), set = JbtacSize / JbtacBanks, shouldReset = true, holdRead = true, singlePort = false))) + + val readEntries = Wire(Vec(JbtacBanks, jbtacEntry())) + + val readFire = Reg(Vec(JbtacBanks, Bool())) + // Only read one bank + val histXORAddr = io.in.pc.bits ^ Cat(io.in.hist, 0.U(1.W))(VAddrBits - 1, 0) + val histXORAddrLatch = RegEnable(histXORAddr, io.in.pc.valid) + + val readBank = jbtacAddr.getBank(histXORAddr) + val readRow = jbtacAddr.getBankIdx(histXORAddr) + readFire := 0.U.asTypeOf(Vec(JbtacBanks, Bool())) + (0 until JbtacBanks).map( + b => { + jbtac(b).reset := reset.asBool + jbtac(b).io.r.req.valid := io.in.pc.fire() && b.U === readBank + jbtac(b).io.r.req.bits.setIdx := readRow + readFire(b) := jbtac(b).io.r.req.fire() + readEntries(b) := jbtac(b).io.r.resp.data(0) + } + ) + + val readBankLatch = jbtacAddr.getBank(histXORAddrLatch) + val readRowLatch = jbtacAddr.getBankIdx(histXORAddrLatch) + val readMaskLatch = RegEnable(io.in.mask, io.in.pc.fire()) + + val outHit = readEntries(readBankLatch).valid && + readEntries(readBankLatch).tag === Cat(jbtacAddr.getTag(io.in.pcLatch), jbtacAddr.getIdx(io.in.pcLatch)) && + !io.flush && RegNext(readFire(readBankLatch)) && readMaskLatch(readEntries(readBankLatch).offset).asBool + + io.out.hit := outHit + io.out.hitIdx := readEntries(readBankLatch).offset + io.out.target := readEntries(readBankLatch).target + io.out.isRVILateJump := io.out.hit && io.out.hitIdx === OHToUInt(HighestBit(readMaskLatch, PredictWidth)) && !readEntries(readBankLatch).isRVC + io.out.isRVC := readEntries(readBankLatch).isRVC + + // update jbtac + val writeEntry = Wire(jbtacEntry()) + // val updateHistXORAddr = updatefetchPC ^ Cat(r.hist, 0.U(2.W))(VAddrBits - 1, 0) + val updateHistXORAddr = io.update.fetchPC ^ Cat(io.update.hist, 0.U(1.W))(VAddrBits - 1, 0) + writeEntry.valid := true.B + // writeEntry.tag := jbtacAddr.getTag(updatefetchPC) + writeEntry.tag := Cat(jbtacAddr.getTag(io.update.fetchPC), jbtacAddr.getIdx(io.update.fetchPC)) + writeEntry.target := io.update.target + // writeEntry.offset := updateFetchIdx + writeEntry.offset := io.update.fetchIdx + writeEntry.isRVC := io.update.isRVC + + val writeBank = jbtacAddr.getBank(updateHistXORAddr) + val writeRow = jbtacAddr.getBankIdx(updateHistXORAddr) + val writeValid = io.redirectValid && io.update.misPred && io.update.btbType === BTBtype.I + for (b <- 0 until JbtacBanks) { + when (b.U === writeBank) { + jbtac(b).io.w.req.valid := writeValid + jbtac(b).io.w.req.bits.setIdx := writeRow + jbtac(b).io.w.req.bits.data := writeEntry + }.otherwise { + jbtac(b).io.w.req.valid := false.B + jbtac(b).io.w.req.bits.setIdx := DontCare + jbtac(b).io.w.req.bits.data := DontCare + } + } + + // read-after-write bypass + val rawBypassHit = Wire(Vec(JbtacBanks, Bool())) + for (b <- 0 until JbtacBanks) { + when (readBank === writeBank && readRow === writeRow && b.U === readBank) { + when (io.in.pc.fire() && writeValid) { + rawBypassHit(b) := true.B + jbtac(b).io.r.req.valid := false.B + // readEntries(b) := RegNext(writeEntry) + readFire(b) := true.B + + XSDebug("raw bypass hits: bank=%d, row=%d, tag=%x, tgt=%x, offet=%d, isRVC=%d\n", + b.U, readRow, writeEntry.tag, writeEntry.target, writeEntry.offset, writeEntry.isRVC) + }.otherwise { + rawBypassHit(b) := false.B + } + }.otherwise { + rawBypassHit(b) := false.B + } + + when (RegNext(rawBypassHit(b))) { readEntries(b) := RegNext(writeEntry) } + } + + XSDebug(io.in.pc.fire(), "read: pc=0x%x, histXORAddr=0x%x, bank=%d, row=%d, hist=%b\n", + io.in.pc.bits, histXORAddr, readBank, readRow, io.in.hist) + XSDebug("out: hit=%d tgt=%x hitIdx=%d iRVILateJump=%d isRVC=%d\n", + io.out.hit, io.out.target, io.out.hitIdx, io.out.isRVILateJump, io.out.isRVC) + XSDebug(fireLatch, "read_resp: pc=0x%x, bank=%d, row=%d, target=0x%x, offset=%d, hit=%d\n", + io.in.pcLatch, readBankLatch, readRowLatch, readEntries(readBankLatch).target, readEntries(readBankLatch).offset, outHit) + XSDebug(io.redirectValid, "update_req: fetchPC=0x%x, writeValid=%d, hist=%b, bank=%d, row=%d, target=0x%x, offset=%d, type=0x%d\n", + io.update.fetchPC, writeValid, io.update.hist, writeBank, writeRow, io.update.target, io.update.fetchIdx, io.update.btbType) +} \ No newline at end of file -- GitLab