提交 f5c046cd 编写于 作者: Z zhanglinjuan

bpu: fix history shifting logic in Stage3

上级 028970c4
......@@ -3,7 +3,6 @@ package xiangshan
import chisel3._
import chisel3.util._
import bus.simplebus._
import xiangshan.frontend.HasTageParameter
import xiangshan.backend.brq.BrqPtr
import xiangshan.backend.rename.FreeListPtr
......
......@@ -7,7 +7,7 @@ import noop.{Cache, CacheConfig, HasExceptionNO, TLB, TLBConfig}
import xiangshan.backend._
import xiangshan.backend.dispatch.DP1Config
import xiangshan.backend.exu.ExuConfig
import xiangshan.frontend.Frontend
import xiangshan.frontend.{Frontend, HasTageParameter}
import xiangshan.utils._
trait HasXSParameter {
......@@ -88,6 +88,7 @@ trait NeedImpl { this: Module =>
abstract class XSBundle extends Bundle
with HasXSParameter
with HasTageParameter
case class XSConfig
(
......
......@@ -4,6 +4,7 @@ import chisel3._
import chisel3.util._
import xiangshan._
import xiangshan.utils._
import xiangshan.backend.ALUOpType
import utils._
class TableAddr(val idxBits: Int, val banks: Int) extends XSBundle {
......@@ -24,7 +25,7 @@ class Stage1To2IO extends XSBundle {
val pc = Output(UInt(VAddrBits.W))
val btb = new Bundle {
val hits = Output(UInt(FetchWidth.W))
val targets = Output(Vec(FetchWidth, UInt(VAddrBits.B)))
val targets = Output(Vec(FetchWidth, UInt(VAddrBits.W)))
}
val jbtac = new Bundle {
val hitIdx = Output(UInt(FetchWidth.W))
......@@ -52,6 +53,13 @@ class BPUStage1 extends XSModule {
val out = Decoupled(new Stage1To2IO)
})
// TODO: delete this!!!
io.in.pc.ready := true.B
io.btbOut.valid := false.B
io.btbOut.bits := DontCare
io.out.valid := false.B
io.out.bits := DontCare
// flush Stage1 when io.flush
val flushS1 = BoolStopWatch(io.flush, io.in.pc.fire(), startHighPriority = true)
......@@ -115,7 +123,7 @@ class BPUStage3 extends XSModule {
val validLatch = RegInit(false.B)
when (io.in.fire()) { inLatch := io.in.bits }
when (io.in.fire()) {
validLatch := !io.in.flush
validLatch := !io.flush
}.elsewhen (io.out.valid) {
validLatch := false.B
}
......@@ -128,7 +136,7 @@ class BPUStage3 extends XSModule {
val retAddr = UInt(VAddrBits.W)
val ctr = UInt(8.W) // layer of nested call functions
}
val ras = RegInit(VecInit(RasSize, 0.U.asTypeOf(rasEntry())))
val ras = RegInit(VecInit(Seq.fill(RasSize)(0.U.asTypeOf(rasEntry()))))
val sp = Counter(RasSize)
val rasTop = ras(sp.value)
val rasTopAddr = rasTop.retAddr
......@@ -136,12 +144,12 @@ class BPUStage3 extends XSModule {
// get the first taken branch/jal/call/jalr/ret in a fetch line
// brTakenIdx/jalIdx/callIdx/jalrIdx/retIdx/jmpIdx is one-hot encoded.
// brNotTakenIdx indicates all the not-taken branches before the first jump instruction.
val brIdx = inLatch.btb.hits & io.predecode.bits.fuTypes.map { t => ALUOpType.isBranch(t) }.asUInt & io.predecode.bits.mask
val brIdx = inLatch.btb.hits & Cat(io.predecode.bits.fuTypes.map { t => ALUOpType.isBranch(t) }).asUInt & io.predecode.bits.mask
val brTakenIdx = LowestBit(brIdx & inLatch.tage.takens.asUInt, FetchWidth)
val jalIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.fuTypes.map { t => t === ALUOpType.jal }.asUInt & io.predecode.bits.mask, FetchWidth)
val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & io.predecode.bits.fuTypes.map { t => t === ALUOpType.call }.asUInt, FetchWidth)
val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & io.predecode.bits.fuTypes.map { t => t === ALUOpType.jalr }.asUInt, FetchWidth)
val retIdx = LowestBit(io.predecode.bits.mask & io.predecode.bits.fuTypes.map { t => t === ALUOpType.ret }.asUInt, FetchWidth)
val jalIdx = LowestBit(inLatch.btb.hits & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.jal }).asUInt & io.predecode.bits.mask, FetchWidth)
val callIdx = LowestBit(inLatch.btb.hits & io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.call }).asUInt, FetchWidth)
val jalrIdx = LowestBit(inLatch.jbtac.hitIdx & io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.jalr }).asUInt, FetchWidth)
val retIdx = LowestBit(io.predecode.bits.mask & Cat(io.predecode.bits.fuTypes.map { t => t === ALUOpType.ret }).asUInt, FetchWidth)
val jmpIdx = LowestBit(brTakenIdx | jalIdx | callIdx | jalrIdx | retIdx, FetchWidth)
val brNotTakenIdx = brIdx & ~inLatch.tage.takens.asUInt & LowerMask(jmpIdx, FetchWidth)
......@@ -156,13 +164,21 @@ class BPUStage3 extends XSModule {
//io.out.bits._type := Mux(jmpIdx === retIdx, BTBtype.R,
// Mux(jmpIdx === jalrIdx, BTBtype.I,
// Mux(jmpIdx === brTakenIdx, BTBtype.B, BTBtype.J)))
val firstHist = inLatch.btbPred.bits.hist
val firstHist = inLatch.btbPred.bits.hist(0)
// there may be several notTaken branches before the first jump instruction,
// so we need to calculate how many zeroes should each instruction shift in its global history.
// each history is exclusive of instruction's own jump direction.
val histShift = WireInit(VecInit(FetchWidth, 0.U(log2Up(FetchWidth).W)))
histShift := (0 until FetchWidth).map(i => Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), FetchWidth))).reduce(_+_)
(0 until FetchWidth).map(i => io.out.bits.hist(i) := firstHist << histShift)
val histShift = Wire(Vec(FetchWidth, UInt(log2Up(FetchWidth).W)))
val shift = Wire(Vec(FetchWidth, Vec(FetchWidth, UInt(1.W))))
(0 until FetchWidth).map(i => shift(i) := Mux(!brNotTakenIdx(i), 0.U, ~LowerMask(UIntToOH(i.U), FetchWidth)).asTypeOf(Vec(FetchWidth, UInt(1.W))))
for (j <- 0 until FetchWidth) {
var tmp = 0.U
for (i <- 0 until FetchWidth) {
tmp = tmp + shift(i)(j)
}
histShift(j) := tmp
}
(0 until FetchWidth).map(i => io.out.bits.hist(i) := firstHist << histShift(i))
// save ras checkpoint info
io.out.bits.rasSp := sp.value
io.out.bits.rasTopCtr := rasTop.ctr
......@@ -238,7 +254,7 @@ class BPU extends XSModule {
s3.io.redirectInfo <> io.redirectInfo
// TODO: delete this and put BTB and JBTAC into Stage1
/*
val flush = BoolStopWatch(io.redirect.valid, io.in.pc.valid, startHighPriority = true)
// BTB makes a quick prediction for branch and direct jump, which is
......
package xiangshan.frontend
import chisel3._
import chisel3.util._
import device.RAMHelper
import xiangshan._
import utils.{Debug, GTimer, XSDebug}
trait HasIFUConst { this: XSModule =>
val resetVector = 0x80000000L//TODO: set reset vec
// 4-byte align * FetchWidth-inst
val groupAlign = log2Up(FetchWidth * 4)
def groupPC(pc: UInt): UInt = Cat(pc(VAddrBits-1, groupAlign), 0.U(groupAlign.W))
}
class FakeCache extends XSModule with HasIFUConst {
val io = IO(new Bundle {
val addr = Input(UInt(VAddrBits.W))
val rdata = Output(Vec(FetchWidth, UInt(32.W)))
})
val memByte = 128 * 1024 * 1024
val ramHelpers = Array.fill(FetchWidth/2)(Module(new RAMHelper(memByte)).io)
ramHelpers.foreach(_.clk := clock)
val gpc = groupPC(io.addr)
val offsetBits = log2Up(memByte)
val offsetMask = (1 << offsetBits) - 1
def index(addr: UInt): UInt = ((addr & offsetMask.U) >> log2Ceil(DataBytes)).asUInt()
def inRange(idx: UInt): Bool = idx < (memByte / 8).U
for(i <- ramHelpers.indices) {
val rIdx = index(gpc) + i.U
ramHelpers(i).rIdx := rIdx
io.rdata(2*i) := ramHelpers(i).rdata.tail(32)
io.rdata(2*i+1) := ramHelpers(i).rdata.head(32)
Seq(
ramHelpers(i).wmask,
ramHelpers(i).wdata,
ramHelpers(i).wen,
ramHelpers(i).wIdx
).foreach(_ := 0.U)
}
}
class FakeIFU extends XSModule with HasIFUConst {
val io = IO(new Bundle() {
val fetchPacket = DecoupledIO(new FetchPacket)
val redirect = Flipped(ValidIO(new Redirect))
})
val pc = RegInit(resetVector.U(VAddrBits.W))
val pcUpdate = io.redirect.valid || io.fetchPacket.fire()
val gpc = groupPC(pc) // fetch group's pc
val snpc = Cat(pc(VAddrBits-1, groupAlign) + 1.U, 0.U(groupAlign.W)) // sequential next pc
// val bpu = Module(new BPU)
// val predRedirect = bpu.io.predMask.asUInt.orR
// val predTarget = PriorityMux(bpu.io.predMask, bpu.io.predTargets)
val npc = Mux(io.redirect.valid, io.redirect.bits.target, snpc) // next pc
// val npc = Mux(io.redirect.valid, io.redirect.bits.target, Mux(predRedirect, predTarget, snpc))
// bpu.io.redirect := io.redirect
// bpu.io.in.pc.valid := io.fetchPacket.fire()
// bpu.io.in.pc.bits := npc
when(pcUpdate){
pc := npc
}
val fakeCache = Module(new FakeCache)
fakeCache.io.addr := pc
io.fetchPacket.valid := !io.redirect.valid && (GTimer() > 500.U)
io.fetchPacket.bits.mask := Fill(FetchWidth*2, 1.U(1.W)) << pc(2+log2Up(FetchWidth)-1, 1)
io.fetchPacket.bits.pc := pc
io.fetchPacket.bits.instrs := fakeCache.io.rdata
// io.fetchPacket.bits.pnpc := bpu.io.predTargets
io.fetchPacket.bits.pnpc := DontCare
XSDebug(p"pc=${Hexadecimal(pc)}\n")
}
......@@ -14,8 +14,8 @@ trait HasIFUConst { this: XSModule =>
}
sealed abstract IFUBundle extends XSBundle with HasIFUConst
sealed abstract IFUModule extends XSModule with HasIFUConst with NeedImpl
sealed abstract class IFUBundle extends XSBundle with HasIFUConst
sealed abstract class IFUModule extends XSModule with HasIFUConst with NeedImpl
class IFUIO extends IFUBundle
{
......@@ -64,7 +64,7 @@ class IFU(implicit val p: XSConfig) extends IFUModule
when(if1_pcUpdate)
{
if1_pc = if1_npc
if1_pc := if1_npc
}
bpu.io.in.valid := if1_valid
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册