提交 a63ad672 编写于 作者: Y YikeZhou

Merge branch 'master' into decode-alt

[submodule "rocket-chip"]
path = rocket-chip
url = https://github.com/chipsalliance/rocket-chip.git
branch = d6bd3c61993637c3f10544c59e861fae8af29f39
url = https://github.com/RISCVERS/rocket-chip.git
branch = 147bdcc4a26c74e5d7a47e3d667d456699d6d11f
[submodule "block-inclusivecache-sifive"]
path = block-inclusivecache-sifive
url = https://github.com/RISCVERS/block-inclusivecache-sifive.git
......
......@@ -27,7 +27,7 @@ help:
$(TOP_V): $(SCALA_FILE)
mkdir -p $(@D)
mill XiangShan.test.runMain $(SIMTOP) -X verilog -td $(@D) --full-stacktrace --output-file $(@F) --disable-all --fpga-platform $(SIM_ARGS)
mill XiangShan.test.runMain $(SIMTOP) -X verilog -td $(@D) --full-stacktrace --output-file $(@F) --disable-all --fpga-platform --remove-assert $(SIM_ARGS)
# mill XiangShan.runMain top.$(TOP) -X verilog -td $(@D) --output-file $(@F) --infer-rw $(FPGATOP) --repl-seq-mem -c:$(FPGATOP):-o:$(@D)/$(@F).conf
# $(MEM_GEN) $(@D)/$(@F).conf >> $@
# sed -i -e 's/_\(aw\|ar\|w\|r\|b\)_\(\|bits_\)/_\1/g' $@
......@@ -139,6 +139,8 @@ endif
SEED ?= $(shell shuf -i 1-10000 -n 1)
VME_SOURCE ?= $(shell pwd)
VME_MODULE ?=
# log will only be printed when (B<=GTimer<=E) && (L < loglevel)
# use 'emu -h' to see more details
......@@ -165,6 +167,23 @@ emu: $(EMU)
ls build
$(EMU) -i $(IMAGE) $(EMU_FLAGS)
# extract verilog module from sim_top.v
# usage: make vme VME_MODULE=Roq
vme: $(SIM_TOP_V)
mill XiangShan.runMain utils.ExtractVerilogModules -m $(VME_MODULE)
# usage: make phy_evaluate VME_MODULE=Roq REMOTE=100
phy_evaluate: vme
scp -r ./build/extracted/* $(REMOTE):~/phy_evaluation/remote_run/rtl
ssh -tt $(REMOTE) 'cd ~/phy_evaluation/remote_run && $(MAKE) evaluate DESIGN_NAME=$(VME_MODULE)'
scp -r $(REMOTE):~/phy_evaluation/remote_run/rpts ./build
# usage: make phy_evaluate_atc VME_MODULE=Roq REMOTE=100
phy_evaluate_atc: vme
scp -r ./build/extracted/* $(REMOTE):~/phy_evaluation/remote_run/rtl
ssh -tt $(REMOTE) 'cd ~/phy_evaluation/remote_run && $(MAKE) evaluate_atc DESIGN_NAME=$(VME_MODULE)'
scp -r $(REMOTE):~/phy_evaluation/remote_run/rpts ./build
cache:
$(MAKE) emu IMAGE=Makefile
......
......@@ -54,7 +54,7 @@ class AXI4RAM
val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))}
mems.zipWithIndex map { case (mem, i) =>
mem.io.clk := clock
mem.io.en := !reset.asBool() && (state === s_rdata)
mem.io.en := !reset.asBool() && ((state === s_rdata) || (state === s_wdata))
mem.io.rIdx := (rIdx << log2Up(split)) + i.U
mem.io.wIdx := (wIdx << log2Up(split)) + i.U
mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64)
......
......@@ -111,4 +111,8 @@ object GenMask {
def apply(pos: Int) = {
(1.U << pos).asUInt()
}
}
\ No newline at end of file
}
object UIntToMask {
def apply(ptr: UInt) = UIntToOH(ptr) - 1.U
}
package utils
/*
https://github.com/Lingrui98/scalaTage/blob/vme/src/main/scala/getVerilogModules.scala
*/
import scala.io.Source
import java.io._
import scala.language.postfixOps
import sys.process._
import sys._
class VerilogModuleExtractor {
// name
val modulePattern = "module ([\\w]+)\\(".r.unanchored
// type name
val subMoudlePattern = "([\\w]+) ([\\w]+) \\((?: //.*)*\\Z".r.unanchored
val endMoudleIOPattern = "\\);".r.unanchored
val endMoudlePattern = "endmodule".r.unanchored
// (submoudle type, submoudle name)
type SubMoudleRecord = Tuple2[String, String]
// (content, submodules)
type ModuleRecord = Tuple2[List[String], List[SubMoudleRecord]]
// name
type ModuleMap = Map[String, ModuleRecord]
def getLines(s: scala.io.BufferedSource): Iterator[String] = s.getLines()
def makeRecord(s: Iterator[String]): ModuleMap = {
val m: ModuleMap = Map()
// called before we see the first line of a module
def processModule(firstLine: String, it: Iterator[String]): ModuleRecord = {
val content: List[String] = List(firstLine)
val submodules: List[SubMoudleRecord] = List()
def iter(cont: List[String], subm: List[SubMoudleRecord]): ModuleRecord =
it.next() match {
case l: String => l match {
case endMoudlePattern() => (l :: cont, subm)
case subMoudlePattern(ty, name) =>
// println(s"submoudle $ty $name")
iter(l :: cont, (ty, name) :: subm)
case _ => iter(l :: cont, subm)
}
case _ => println("Should not reach here"); (cont, subm)
}
val temp = iter(content, submodules)
(temp._1.reverse, temp._2)
}
def traverse(m: ModuleMap, it: Iterator[String]): ModuleMap =
if (it.hasNext) {
it.next() match {
case l: String =>
// println(f"traversing $l")
l match {
case modulePattern(name) =>
// println(f"get Module of name $name")
traverse(m ++ Map(name -> processModule(l, it)), it)
case _ =>
println(f"line $l is not a module definition")
traverse(m, it)
}
case _ => traverse(m, it)
}
}
else m
traverse(m, s)
}
def makeRecordFromFile(file: String): ModuleMap = {
val bufSrc = Source.fromFile(file)
makeRecord(bufSrc.getLines())
}
def writeModuleToFile(name: String, record: ModuleRecord, dir: String) = {
val path = dir+name+".v"
val writer = new PrintWriter(new File(path))
println(f"Writing module $name%20s to $path")
record._1.foreach(r => {
writer.write(f"$r\n")
})
writer.close()
}
// get moudle definition of specified name
def getModule(name: String, m: ModuleMap): ModuleRecord = {
m(name)
}
def showModuleRecord(r: ModuleRecord) = {
val (content, submodules) = r
submodules.foreach {
case (t, n) => println(f"submoudle type: $t, submodule name: $n")
}
println("\nprinting module contents...")
content.foreach(println(_))
}
// We first get records of all the modules and its submodule record
// Then we choose a module as the root node to traverse its submodule
def processFromModule(name: String, map: ModuleMap, outPath: String, doneSet: Set[String] = Set(), top: Tuple2[String, Boolean]): Unit = {
def printSRAMs(sub: List[SubMoudleRecord]) = {
sub map {
case (ty, subn) if (ty contains "SRAM") => println(s"top module $name, sub module type $ty, name $subn")
case _ =>
}
}
val (topName, isTop) = top
if (!map.contains(name)) {
println(s"${if (isTop) "chosen top" else s"submodule of ${topName},"} module $name does not exist!")
return
}
if (isTop) println(s"\nProcessing top module $name")
val r = map(name)
new File(outPath).mkdirs() // ensure the path exists
writeModuleToFile(name, r, outPath)
val submodules = r._2
// printSRAMs(submodules)
// DFS
val subTypesSet = submodules map (m => m._1) toSet
val nowMap = map - name
val nowSet = doneSet ++ subTypesSet
subTypesSet.foreach { s => if (!doneSet.contains(s)) processFromModule(s, nowMap, outPath, nowSet, (if (isTop) name else topName, false)) }
}
def getDate: String = {
val d = java.time.LocalDate.now
d.toString.toCharArray.filterNot(_ == '-').mkString
}
def makePath(topModule: String, outDir: String , user: String = "glr"): String = {
(if (outDir.last == '/')
outDir
else
outDir+"/") + getDate + "-" + user + "-" + topModule + "/"
}
def extract(src: String, topModule: String, outDir: String, user: String, mapp: Option[ModuleMap]): Unit = {
val useMap = mapp.getOrElse(makeRecordFromFile(src))
val path = makePath(topModule, outDir, user)
processFromModule(topModule, useMap, path, top=(topModule, true))
}
def extract(src: String, topModules: List[String], outDir: String, user: String): Unit = {
// avoid repeat
val mapp = makeRecordFromFile(src)
topModules.foreach(n => extract(src, n, outDir, user, Some(mapp)))
}
}
trait VMEArgParser {
type OptionMap = Map[String, Option[Any]]
val usage = """
Usage: sbt "run [OPTION...]"
-s, --source the verilog file generated by chisel, all in one file
default: $NOOP_HOME/build/XSSimTop.v
-h, --help print this help info
-o, --output the place you want to store your extracted verilog
default: $NOOP_HOME/build/extracted
-u, --usr your name, will be used to name the output folder
default: current user
-m, --modules the top modules you would like to extract verilog from
should always be the last argument
default: IFU
"""
def parse(args: List[String]) = {
def nextOption(map: OptionMap, l: List[String]): OptionMap = {
def isSwitch(s : String)= (s(0) == '-')
l match {
case Nil => map
case ("--help" | "-h") :: tail => {
println(usage)
sys.exit()
map
}
case ("--source" | "-s") :: file :: tail =>
nextOption(map ++ Map("source" -> Some(file)), tail)
case ("--output" | "-o") :: path :: tail =>
nextOption(map ++ Map("output" -> Some(path)), tail)
case ("--usr" | "-u") :: name :: tail =>
nextOption(map ++ Map("usr" -> Some(name)), tail)
// this should always be the last argument, since it is length variable
case ("--modules" | "-m") :: m :: tail =>
map ++ Map("modules" -> Some(m :: tail))
case s :: tail => {
if (isSwitch(s)) println(s"unexpected argument $s")
nextOption(map, tail)
}
}
}
nextOption(Map("source" -> None, "output" -> None, "usr" -> None, "modules" -> None), args)
}
def wrapParams(args: Array[String]): (String, List[String], String, String) = {
val argL = args.toList
val paramMap = parse(argL)
(paramMap("source").map(_.asInstanceOf[String]).getOrElse(env("NOOP_HOME")+"/build/XSSimTop.v"),
paramMap("modules").map(_.asInstanceOf[List[String]]).getOrElse(List("IFU")),
paramMap("output").map(_.asInstanceOf[String]).getOrElse(env("NOOP_HOME")+"/build/extracted/"),
paramMap("usr").map(_.asInstanceOf[String]).getOrElse("whoami".!!.init))
}
}
object ExtractVerilogModules extends VMEArgParser {
def main(args: Array[String]): Unit = {
val vme = new VerilogModuleExtractor()
val (sourceFile, topModules, outTopDir, usr) = wrapParams(args)
vme.extract(sourceFile, topModules, outTopDir, usr)
}
}
......@@ -12,6 +12,8 @@ import xiangshan.mem.{LqPtr, SqPtr}
import xiangshan.frontend.PreDecodeInfo
import xiangshan.frontend.HasBPUParameter
import xiangshan.frontend.HasTageParameter
import xiangshan.frontend.HasIFUConst
import utils._
import scala.math.max
// Fetch FetchWidth x 32-bit insts from Icache
......@@ -62,14 +64,53 @@ class TageMeta extends XSBundle with HasTageParameter {
val scMeta = new SCMeta(EnableSC)
}
class BranchPrediction extends XSBundle {
val redirect = Bool()
val taken = Bool()
val jmpIdx = UInt(log2Up(PredictWidth).W)
val hasNotTakenBrs = Bool()
val target = UInt(VAddrBits.W)
val saveHalfRVI = Bool()
val takenOnBr = Bool()
class BranchPrediction extends XSBundle with HasIFUConst {
// val redirect = Bool()
val takens = UInt(PredictWidth.W)
// val jmpIdx = UInt(log2Up(PredictWidth).W)
val brMask = UInt(PredictWidth.W)
val jalMask = UInt(PredictWidth.W)
val targets = Vec(PredictWidth, UInt(VAddrBits.W))
// marks the last 2 bytes of this fetch packet
// val endsAtTheEndOfFirstBank = Bool()
// val endsAtTheEndOfLastBank = Bool()
// half RVI could only start at the end of a bank
val firstBankHasHalfRVI = Bool()
val lastBankHasHalfRVI = Bool()
def lastHalfRVIMask = Mux(firstBankHasHalfRVI, UIntToOH((bankWidth-1).U),
Mux(lastBankHasHalfRVI, UIntToOH((PredictWidth-1).U),
0.U(PredictWidth.W)
)
)
def lastHalfRVIClearMask = ~lastHalfRVIMask
// is taken from half RVI
def lastHalfRVITaken = (takens & lastHalfRVIMask).orR
def lastHalfRVIIdx = Mux(firstBankHasHalfRVI, (bankWidth-1).U, (PredictWidth-1).U)
// should not be used if not lastHalfRVITaken
def lastHalfRVITarget = Mux(firstBankHasHalfRVI, targets(bankWidth-1), targets(PredictWidth-1))
def realTakens = takens & lastHalfRVIClearMask
def realBrMask = brMask & lastHalfRVIClearMask
def realJalMask = jalMask & lastHalfRVIClearMask
def brNotTakens = ~realTakens & realBrMask
def sawNotTakenBr = VecInit((0 until PredictWidth).map(i =>
(if (i == 0) false.B else brNotTakens(i-1,0).orR)))
def hasNotTakenBrs = (brNotTakens & LowerMaskFromLowest(realTakens)).orR
def unmaskedJmpIdx = PriorityEncoder(takens)
def saveHalfRVI = (firstBankHasHalfRVI && (unmaskedJmpIdx === (bankWidth-1).U || !(takens.orR))) ||
(lastBankHasHalfRVI && unmaskedJmpIdx === (PredictWidth-1).U)
// could get PredictWidth-1 when only the first bank is valid
def jmpIdx = PriorityEncoder(realTakens)
// only used when taken
def target = targets(jmpIdx)
def taken = realTakens.orR
def takenOnBr = taken && realBrMask(jmpIdx)
}
class BranchInfo extends XSBundle with HasBPUParameter {
......@@ -103,9 +144,10 @@ class BranchInfo extends XSBundle with HasBPUParameter {
def fromUInt(x: UInt) = x.asTypeOf(this)
}
class Predecode extends XSBundle {
val isFetchpcEqualFirstpc = Bool()
class Predecode extends XSBundle with HasIFUConst {
val hasLastHalfRVI = Bool()
val mask = UInt((FetchWidth*2).W)
val lastHalf = UInt(nBanksInPacket.W)
val pd = Vec(FetchWidth*2, (new PreDecodeInfo))
}
......@@ -270,7 +312,7 @@ class FrontendToBackendIO extends XSBundle {
// to backend end
val cfVec = Vec(DecodeWidth, DecoupledIO(new CtrlFlow))
// from backend
val redirect = Flipped(ValidIO(new Redirect))
val redirect = Flipped(ValidIO(UInt(VAddrBits.W)))
val outOfOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
val inOrderBrInfo = Flipped(ValidIO(new BranchUpdateInfo))
}
......
......@@ -36,8 +36,8 @@ case class XSCoreParameters
EnableBPU: Boolean = true,
EnableBPD: Boolean = true,
EnableRAS: Boolean = true,
EnableLB: Boolean = true,
EnableLoop: Boolean = true,
EnableLB: Boolean = false,
EnableLoop: Boolean = false,
EnableSC: Boolean = false,
HistoryLength: Int = 64,
BtbSize: Int = 2048,
......@@ -68,10 +68,7 @@ case class XSCoreParameters
LsDqSize = 96,
IntDqDeqWidth = 4,
FpDqDeqWidth = 4,
LsDqDeqWidth = 4,
IntDqReplayWidth = 4,
FpDqReplayWidth = 4,
LsDqReplayWidth = 4
LsDqDeqWidth = 4
),
exuParameters: ExuParameters = ExuParameters(
JmpCnt = 1,
......@@ -148,7 +145,6 @@ trait HasXSParameter {
val LoadQueueSize = core.LoadQueueSize
val StoreQueueSize = core.StoreQueueSize
val dpParams = core.dpParams
val ReplayWidth = dpParams.IntDqReplayWidth + dpParams.FpDqReplayWidth + dpParams.LsDqReplayWidth
val exuParameters = core.exuParameters
val NRIntReadPorts = core.NRIntReadPorts
val NRIntWritePorts = core.NRIntWritePorts
......@@ -388,7 +384,6 @@ class XSCoreImp(outer: XSCore) extends LazyModuleImp(outer)
memBlock.io.lsqio.commits <> ctrlBlock.io.roqio.commits
memBlock.io.lsqio.roqDeqPtr <> ctrlBlock.io.roqio.roqDeqPtr
memBlock.io.lsqio.oldestStore <> ctrlBlock.io.oldestStore
memBlock.io.lsqio.exceptionAddr.lsIdx.lqIdx := ctrlBlock.io.roqio.exception.bits.lqIdx
memBlock.io.lsqio.exceptionAddr.lsIdx.sqIdx := ctrlBlock.io.roqio.exception.bits.sqIdx
memBlock.io.lsqio.exceptionAddr.isStore := CommitType.lsInstIsStore(ctrlBlock.io.roqio.exception.bits.ctrl.commitType)
......
......@@ -2,6 +2,7 @@ package xiangshan.backend
import chisel3._
import chisel3.util._
import utils._
import xiangshan._
import xiangshan.backend.decode.{DecodeBuffer, DecodeStage}
import xiangshan.backend.rename.{Rename, BusyTable}
......@@ -37,7 +38,7 @@ class CtrlToLsBlockIO extends XSBundle {
val redirect = ValidIO(new Redirect)
}
class CtrlBlock extends XSModule {
class CtrlBlock extends XSModule with HasCircularQueuePtrHelper {
val io = IO(new Bundle {
val frontend = Flipped(new FrontendToBackendIO)
val fromIntBlock = Flipped(new IntBlockToCtrlIO)
......@@ -55,7 +56,6 @@ class CtrlBlock extends XSModule {
val commits = Vec(CommitWidth, ValidIO(new RoqCommit))
val roqDeqPtr = Output(new RoqPtr)
}
val oldestStore = Input(Valid(new RoqPtr))
})
val decode = Module(new DecodeStage)
......@@ -70,18 +70,17 @@ class CtrlBlock extends XSModule {
val roq = Module(new Roq(roqWbSize))
val redirect = Mux(
roq.io.redirect.valid,
roq.io.redirect,
Mux(
brq.io.redirect.valid,
brq.io.redirect,
io.fromLsBlock.replay
)
)
io.frontend.redirect := redirect
io.frontend.redirect.valid := redirect.valid && !redirect.bits.isReplay
// When replay and mis-prediction have the same roqIdx,
// mis-prediction should have higher priority, since mis-prediction flushes the load instruction.
// Thus, only when mis-prediction roqIdx is after replay roqIdx, replay should be valid.
val brqIsAfterLsq = isAfter(brq.io.redirect.bits.roqIdx, io.fromLsBlock.replay.bits.roqIdx)
val redirectArb = Mux(io.fromLsBlock.replay.valid && (!brq.io.redirect.valid || brqIsAfterLsq),
io.fromLsBlock.replay.bits, brq.io.redirect.bits)
val redirectValid = roq.io.redirect.valid || brq.io.redirect.valid || io.fromLsBlock.replay.valid
val redirect = Mux(roq.io.redirect.valid, roq.io.redirect.bits, redirectArb)
io.frontend.redirect.valid := redirectValid
io.frontend.redirect.bits := Mux(roq.io.redirect.valid, roq.io.redirect.bits.target, redirectArb.target)
io.frontend.outOfOrderBrInfo <> brq.io.outOfOrderBrInfo
io.frontend.inOrderBrInfo <> brq.io.inOrderBrInfo
......@@ -91,28 +90,27 @@ class CtrlBlock extends XSModule {
decode.io.out <> decBuf.io.in
brq.io.roqRedirect <> roq.io.redirect
brq.io.memRedirect <> io.fromLsBlock.replay
brq.io.memRedirect.valid := brq.io.redirect.valid || io.fromLsBlock.replay.valid
brq.io.memRedirect.bits <> redirectArb
brq.io.bcommit <> roq.io.bcommit
brq.io.enqReqs <> decode.io.toBrq
brq.io.exuRedirect <> io.fromIntBlock.exuRedirect
decBuf.io.isWalking := roq.io.commits(0).valid && roq.io.commits(0).bits.isWalk
decBuf.io.redirect <> redirect
decBuf.io.redirect.valid <> redirectValid
decBuf.io.redirect.bits <> redirect
decBuf.io.out <> rename.io.in
rename.io.redirect <> redirect
rename.io.redirect.valid <> redirectValid
rename.io.redirect.bits <> redirect
rename.io.roqCommits <> roq.io.commits
rename.io.out <> dispatch.io.fromRename
rename.io.renameBypass <> dispatch.io.renameBypass
dispatch.io.redirect <> redirect
dispatch.io.redirect.valid <> redirectValid
dispatch.io.redirect.bits <> redirect
dispatch.io.enqRoq <> roq.io.enq
dispatch.io.enqLsq <> io.toLsBlock.enqLsq
dispatch.io.dequeueRoqIndex.valid := roq.io.commitRoqIndex.valid || io.oldestStore.valid
dispatch.io.dequeueRoqIndex.bits := Mux(io.oldestStore.valid,
io.oldestStore.bits,
roq.io.commitRoqIndex.bits
)
dispatch.io.readIntRf <> io.toIntBlock.readRf
dispatch.io.readFpRf <> io.toFpBlock.readRf
dispatch.io.allocPregs.zipWithIndex.foreach { case (preg, i) =>
......@@ -126,7 +124,7 @@ class CtrlBlock extends XSModule {
dispatch.io.enqIQData <> io.toIntBlock.enqIqData ++ io.toFpBlock.enqIqData ++ io.toLsBlock.enqIqData
val flush = redirect.valid && (redirect.bits.isException || redirect.bits.isFlushPipe)
val flush = redirectValid && (redirect.isException || redirect.isFlushPipe)
fpBusyTable.io.flush := flush
intBusyTable.io.flush := flush
for((wb, setPhyRegRdy) <- io.fromIntBlock.wbRegs.zip(intBusyTable.io.wbPregs)){
......@@ -141,15 +139,11 @@ class CtrlBlock extends XSModule {
intBusyTable.io.pregRdy <> dispatch.io.intPregRdy
fpBusyTable.io.rfReadAddr <> dispatch.io.readFpRf.map(_.addr)
fpBusyTable.io.pregRdy <> dispatch.io.fpPregRdy
for(i <- 0 until ReplayWidth){
intBusyTable.io.replayPregs(i).valid := dispatch.io.replayPregReq(i).isInt
fpBusyTable.io.replayPregs(i).valid := dispatch.io.replayPregReq(i).isFp
intBusyTable.io.replayPregs(i).bits := dispatch.io.replayPregReq(i).preg
fpBusyTable.io.replayPregs(i).bits := dispatch.io.replayPregReq(i).preg
}
roq.io.memRedirect <> io.fromLsBlock.replay
roq.io.brqRedirect <> brq.io.redirect
roq.io.memRedirect := DontCare
roq.io.memRedirect.valid := false.B
roq.io.brqRedirect.valid := brq.io.redirect.valid || io.fromLsBlock.replay.valid
roq.io.brqRedirect.bits <> redirectArb
roq.io.exeWbResults.take(roqWbSize-1).zip(
io.fromIntBlock.wbRegs ++ io.fromFpBlock.wbRegs ++ io.fromLsBlock.stOut
).foreach{
......@@ -159,9 +153,12 @@ class CtrlBlock extends XSModule {
}
roq.io.exeWbResults.last := brq.io.out
io.toIntBlock.redirect := redirect
io.toFpBlock.redirect := redirect
io.toLsBlock.redirect := redirect
io.toIntBlock.redirect.valid := redirectValid
io.toIntBlock.redirect.bits := redirect
io.toFpBlock.redirect.valid := redirectValid
io.toFpBlock.redirect.bits := redirect
io.toLsBlock.redirect.valid := redirectValid
io.toLsBlock.redirect.bits := redirect
// roq to int block
io.roqio.toCSR <> roq.io.csr
......
......@@ -55,7 +55,6 @@ class MemBlock
val exceptionAddr = new ExceptionAddrIO // to csr
val commits = Flipped(Vec(CommitWidth, Valid(new RoqCommit))) // to lsq
val roqDeqPtr = Input(new RoqPtr) // to lsq
val oldestStore = Output(Valid(new RoqPtr)) // to dispatch
}
})
......@@ -209,7 +208,6 @@ class MemBlock
// Lsq
lsq.io.commits <> io.lsqio.commits
lsq.io.enq <> io.fromCtrlBlock.enqLsq
lsq.io.oldestStore <> io.lsqio.oldestStore
lsq.io.brqRedirect := io.fromCtrlBlock.redirect
lsq.io.roqDeqPtr := io.lsqio.roqDeqPtr
io.toCtrlBlock.replay <> lsq.io.rollback
......
......@@ -99,12 +99,8 @@ class Brq extends XSModule with HasCircularQueuePtrHelper {
commitIdx = 6
*/
val headIdxOH = UIntToOH(headIdx)
val headIdxMaskHiVec = Wire(Vec(BrqSize, Bool()))
for(i <- headIdxMaskHiVec.indices){
headIdxMaskHiVec(i) := { if(i==0) headIdxOH(i) else headIdxMaskHiVec(i-1) || headIdxOH(i) }
}
val headIdxMaskHi = headIdxMaskHiVec.asUInt()
val headIdxMaskLo = (~headIdxMaskHi).asUInt()
val headIdxMaskLo = headIdxOH - 1.U
val headIdxMaskHi = ~headIdxMaskLo
val commitIdxHi = PriorityEncoder((~skipMask).asUInt() & headIdxMaskHi)
val (commitIdxLo, findLo) = PriorityEncoderWithFlag((~skipMask).asUInt() & headIdxMaskLo)
......@@ -163,9 +159,9 @@ class Brq extends XSModule with HasCircularQueuePtrHelper {
headPtr := headPtrNext
io.redirect.valid := commitValid &&
commitIsMisPred &&
!io.roqRedirect.valid &&
!io.redirect.bits.roqIdx.needFlush(io.memRedirect)
commitIsMisPred //&&
// !io.roqRedirect.valid &&
// !io.redirect.bits.roqIdx.needFlush(io.memRedirect)
io.redirect.bits := commitEntry.exuOut.redirect
io.out.valid := commitValid
......@@ -182,11 +178,12 @@ class Brq extends XSModule with HasCircularQueuePtrHelper {
)
// branch insts enq
val validEntries = distanceBetween(tailPtr, headPtr)
for(i <- 0 until DecodeWidth){
val offset = if(i == 0) 0.U else PopCount(io.enqReqs.take(i).map(_.valid))
val brTag = tailPtr + offset
val idx = brTag.value
io.enqReqs(i).ready := stateQueue(idx).isInvalid
io.enqReqs(i).ready := validEntries <= (BrqSize - (i + 1)).U
io.brTags(i) := brTag
when(io.enqReqs(i).fire()){
brQueue(idx).npc := io.enqReqs(i).bits.cf.brUpdate.pnpc
......@@ -220,20 +217,20 @@ class Brq extends XSModule with HasCircularQueuePtrHelper {
headPtr := BrqPtr(false.B, 0.U)
tailPtr := BrqPtr(false.B, 0.U)
brCommitCnt := 0.U
}.elsewhen(io.redirect.valid || io.memRedirect.valid){
}.elsewhen(io.memRedirect.valid){
// misprediction or replay
stateQueue.zipWithIndex.foreach({case(s, i) =>
// replay should flush brTag
val ptr = BrqPtr(brQueue(i).ptrFlag, i.U)
when(s.isWb && brQueue(i).exuOut.uop.roqIdx.needFlush(io.memRedirect)){
s := s_idle
}
when(io.redirect.valid && ptr.needBrFlush(io.redirect.bits.brTag)){
val replayMatch = io.memRedirect.bits.isReplay && ptr === io.memRedirect.bits.brTag
when(io.memRedirect.valid && (ptr.needBrFlush(io.memRedirect.bits.brTag) || replayMatch)){
s := s_invalid
}
})
when(io.redirect.valid){ // Only Br Mispred reset tailPtr, replay does not
tailPtr := io.redirect.bits.brTag + true.B
when(io.memRedirect.valid){
tailPtr := io.memRedirect.bits.brTag + Mux(io.memRedirect.bits.isReplay, 0.U, 1.U)
}
}
......
......@@ -24,7 +24,7 @@ class DecodeBuffer extends XSModule {
})
)
val flush = io.redirect.valid && !io.redirect.bits.isReplay
val flush = io.redirect.valid// && !io.redirect.bits.isReplay
for( i <- 0 until RenameWidth){
when(io.out(i).fire()){
......
......@@ -41,14 +41,12 @@ class DecodeStage extends XSModule {
decoderToDecBuffer(i).brTag := io.brTags(i)
io.out(i).bits := decoderToDecBuffer(i)
val thisReady = io.out(i).ready && io.toBrq(i).ready
val isMret = decoders(i).io.deq.cf_ctrl.cf.instr === BitPat("b001100000010_00000_000_00000_1110011")
val isSret = decoders(i).io.deq.cf_ctrl.cf.instr === BitPat("b000100000010_00000_000_00000_1110011")
val thisBrqValid = io.in(i).valid && (!decoders(i).io.deq.cf_ctrl.cf.brUpdate.pd.notCFI || isMret || isSret) && io.out(i).ready
val thisOutValid = io.in(i).valid && io.toBrq(i).ready
io.in(i).ready := { if (i == 0) thisReady else io.in(i-1).ready && thisReady }
io.out(i).valid := { if (i == 0) thisOutValid else io.in(i-1).ready && thisOutValid }
io.toBrq(i).valid := { if (i == 0) thisBrqValid else io.in(i-1).ready && thisBrqValid }
val thisBrqValid = !decoders(i).io.deq.cf_ctrl.cf.brUpdate.pd.notCFI || isMret || isSret
io.in(i).ready := io.out(i).ready && io.toBrq(i).ready
io.out(i).valid := io.in(i).valid && io.toBrq(i).ready
io.toBrq(i).valid := io.in(i).valid && thisBrqValid && io.out(i).ready
XSDebug(io.in(i).valid || io.out(i).valid || io.toBrq(i).valid, "i:%d In(%d %d) Out(%d %d) ToBrq(%d %d) pc:%x instr:%x\n", i.U, io.in(i).valid, io.in(i).ready, io.out(i).valid, io.out(i).ready, io.toBrq(i).valid, io.toBrq(i).ready, io.in(i).bits.pc, io.in(i).bits.instr)
}
......
......@@ -17,10 +17,7 @@ case class DispatchParameters
LsDqSize: Int,
IntDqDeqWidth: Int,
FpDqDeqWidth: Int,
LsDqDeqWidth: Int,
IntDqReplayWidth: Int,
FpDqReplayWidth: Int,
LsDqReplayWidth: Int
LsDqDeqWidth: Int
)
class Dispatch extends XSModule {
......@@ -30,6 +27,8 @@ class Dispatch extends XSModule {
// from rename
val fromRename = Vec(RenameWidth, Flipped(DecoupledIO(new MicroOp)))
val renameBypass = Input(new RenameBypassInfo)
// to busytable: set pdest to busy (not ready) when they are dispatched
val allocPregs = Vec(RenameWidth, Output(new ReplayPregReq))
// enq Roq
val enqRoq = new Bundle {
val canAccept = Input(Bool())
......@@ -44,16 +43,12 @@ class Dispatch extends XSModule {
val req = Vec(RenameWidth, ValidIO(new MicroOp))
val resp = Vec(RenameWidth, Input(new LSIdx))
}
val dequeueRoqIndex = Input(Valid(new RoqPtr))
// read regfile
val readIntRf = Vec(NRIntReadPorts, Flipped(new RfReadPort))
val readFpRf = Vec(NRFpReadPorts, Flipped(new RfReadPort))
// read reg status (busy/ready)
val intPregRdy = Vec(NRIntReadPorts, Input(Bool()))
val fpPregRdy = Vec(NRFpReadPorts, Input(Bool()))
// replay: set preg status to not ready
val replayPregReq = Output(Vec(ReplayWidth, new ReplayPregReq))
val allocPregs = Vec(RenameWidth, Output(new ReplayPregReq))
// to reservation stations
val numExist = Input(Vec(exuParameters.ExuCnt, UInt(log2Ceil(IssQueSize).W)))
val enqIQCtrl = Vec(exuParameters.ExuCnt, DecoupledIO(new MicroOp))
......@@ -61,13 +56,13 @@ class Dispatch extends XSModule {
})
val dispatch1 = Module(new Dispatch1)
val intDq = Module(new DispatchQueue(dpParams.IntDqSize, dpParams.DqEnqWidth, dpParams.IntDqDeqWidth, dpParams.IntDqReplayWidth))
val fpDq = Module(new DispatchQueue(dpParams.FpDqSize, dpParams.DqEnqWidth, dpParams.FpDqDeqWidth, dpParams.FpDqReplayWidth))
val lsDq = Module(new DispatchQueue(dpParams.LsDqSize, dpParams.DqEnqWidth, dpParams.LsDqDeqWidth, dpParams.LsDqReplayWidth))
val intDq = Module(new DispatchQueue(dpParams.IntDqSize, dpParams.DqEnqWidth, dpParams.IntDqDeqWidth))
val fpDq = Module(new DispatchQueue(dpParams.FpDqSize, dpParams.DqEnqWidth, dpParams.FpDqDeqWidth))
val lsDq = Module(new DispatchQueue(dpParams.LsDqSize, dpParams.DqEnqWidth, dpParams.LsDqDeqWidth))
// pipeline between rename and dispatch
// accepts all at once
val redirectValid = io.redirect.valid && !io.redirect.bits.isReplay
val redirectValid = io.redirect.valid// && !io.redirect.bits.isReplay
for (i <- 0 until RenameWidth) {
PipelineConnect(io.fromRename(i), dispatch1.io.fromRename(i), dispatch1.io.recv(i), redirectValid)
}
......@@ -88,30 +83,8 @@ class Dispatch extends XSModule {
// dispatch queue: queue uops and dispatch them to different reservation stations or issue queues
// it may cancel the uops
intDq.io.redirect <> io.redirect
intDq.io.dequeueRoqIndex <> io.dequeueRoqIndex
intDq.io.replayPregReq.zipWithIndex.map { case(replay, i) =>
io.replayPregReq(i) <> replay
}
intDq.io.otherWalkDone := !fpDq.io.inReplayWalk && !lsDq.io.inReplayWalk
fpDq.io.redirect <> io.redirect
fpDq.io.dequeueRoqIndex <> io.dequeueRoqIndex
fpDq.io.replayPregReq.zipWithIndex.map { case(replay, i) =>
io.replayPregReq(i + dpParams.IntDqReplayWidth) <> replay
}
fpDq.io.otherWalkDone := !intDq.io.inReplayWalk && !lsDq.io.inReplayWalk
lsDq.io.redirect <> io.redirect
lsDq.io.dequeueRoqIndex <> io.dequeueRoqIndex
lsDq.io.replayPregReq.zipWithIndex.map { case(replay, i) =>
io.replayPregReq(i + dpParams.IntDqReplayWidth + dpParams.FpDqReplayWidth) <> replay
}
lsDq.io.otherWalkDone := !intDq.io.inReplayWalk && !fpDq.io.inReplayWalk
if (!env.FPGAPlatform) {
val inWalk = intDq.io.inReplayWalk || fpDq.io.inReplayWalk || lsDq.io.inReplayWalk
ExcitingUtils.addSource(inWalk, "perfCntCondDpqReplay", Perf)
}
// Int dispatch queue to Int reservation stations
val intDispatch = Module(new Dispatch2Int)
......
......@@ -122,7 +122,7 @@ class Dispatch1 extends XSModule {
* acquire ROQ (all), LSQ (load/store only) and dispatch queue slots
* only set valid when all of them provides enough entries
*/
val redirectValid = io.redirect.valid && !io.redirect.bits.isReplay
val redirectValid = io.redirect.valid// && !io.redirect.bits.isReplay
val allResourceReady = io.enqLsq.canAccept && io.enqRoq.canAccept && io.toIntDqReady && io.toFpDqReady && io.toLsDqReady
// Instructions should enter dispatch queues in order.
......
......@@ -3,66 +3,46 @@ package xiangshan.backend.dispatch
import chisel3._
import chisel3.util._
import utils._
import xiangshan.backend.decode.SrcType
import xiangshan._
import xiangshan.backend.roq.RoqPtr
class DispatchQueueIO(enqnum: Int, deqnum: Int, replayWidth: Int) extends XSBundle {
class DispatchQueueIO(enqnum: Int, deqnum: Int) extends XSBundle {
val enq = Vec(enqnum, Flipped(ValidIO(new MicroOp)))
val enqReady = Output(Bool())
val deq = Vec(deqnum, DecoupledIO(new MicroOp))
val dequeueRoqIndex = Input(Valid(new RoqPtr))
val redirect = Flipped(ValidIO(new Redirect))
val replayPregReq = Output(Vec(replayWidth, new ReplayPregReq))
val inReplayWalk = Output(Bool())
val otherWalkDone = Input(Bool())
override def cloneType: DispatchQueueIO.this.type =
new DispatchQueueIO(enqnum, deqnum, replayWidth).asInstanceOf[this.type]
new DispatchQueueIO(enqnum, deqnum).asInstanceOf[this.type]
}
// dispatch queue: accepts at most enqnum uops from dispatch1 and dispatches deqnum uops at every clock cycle
class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) extends XSModule with HasCircularQueuePtrHelper {
val io = IO(new DispatchQueueIO(enqnum, deqnum, replayWidth))
class DispatchQueue(size: Int, enqnum: Int, deqnum: Int) extends XSModule with HasCircularQueuePtrHelper {
val io = IO(new DispatchQueueIO(enqnum, deqnum))
val indexWidth = log2Ceil(size)
val s_invalid :: s_valid :: s_dispatched :: Nil = Enum(3)
val s_invalid :: s_valid:: Nil = Enum(2)
// queue data array
val uopEntries = Mem(size, new MicroOp)
val stateEntries = RegInit(VecInit(Seq.fill(size)(s_invalid)))
// head: first valid entry (dispatched entry)
val headPtr = RegInit(0.U.asTypeOf(new CircularQueuePtr(size)))
// dispatch: first entry that has not been dispatched
val dispatchPtr = RegInit(0.U.asTypeOf(new CircularQueuePtr(size)))
val headPtrMask = UIntToMask(headPtr.value)
// tail: first invalid entry (free entry)
val tailPtr = RegInit(0.U.asTypeOf(new CircularQueuePtr(size)))
val tailPtrMask = UIntToMask(tailPtr.value)
// TODO: make ptr a vector to reduce latency?
// commit: starting from head ptr
val commitIndex = (0 until CommitWidth).map(i => headPtr + i.U).map(_.value)
// deq: starting from dispatch ptr
val deqIndex = (0 until deqnum).map(i => dispatchPtr + i.U).map(_.value)
// deq: starting from head ptr
val deqIndex = (0 until deqnum).map(i => headPtr + i.U).map(_.value)
// enq: starting from tail ptr
val enqIndex = (0 until enqnum).map(i => tailPtr + i.U).map(_.value)
val validEntries = distanceBetween(tailPtr, headPtr)
val dispatchEntries = distanceBetween(tailPtr, dispatchPtr)
val commitEntries = validEntries - dispatchEntries
val emptyEntries = size.U - validEntries
def rangeMask(start: CircularQueuePtr, end: CircularQueuePtr): UInt = {
val startMask = (1.U((size + 1).W) << start.value).asUInt - 1.U
val endMask = (1.U((size + 1).W) << end.value).asUInt - 1.U
val xorMask = startMask(size - 1, 0) ^ endMask(size - 1, 0)
Mux(start.flag === end.flag, xorMask, ~xorMask)
}
val dispatchedMask = rangeMask(headPtr, dispatchPtr)
val allWalkDone = !io.inReplayWalk && io.otherWalkDone
val canEnqueue = validEntries <= (size - enqnum).U && allWalkDone
val canActualEnqueue = canEnqueue && !(io.redirect.valid && !io.redirect.bits.isReplay)
val isTrueEmpty = ~Cat((0 until size).map(i => stateEntries(i) === s_valid)).orR
val canEnqueue = validEntries <= (size - enqnum).U
val canActualEnqueue = canEnqueue && !(io.redirect.valid /*&& !io.redirect.bits.isReplay*/)
/**
* Part 1: update states and uops when enqueue, dequeue, commit, redirect/replay
......@@ -87,36 +67,22 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
// dequeue: from s_valid to s_dispatched
for (i <- 0 until deqnum) {
when (io.deq(i).fire()) {
stateEntries(deqIndex(i)) := s_dispatched
when (io.deq(i).fire() && !io.redirect.valid) {
stateEntries(deqIndex(i)) := s_invalid
XSError(stateEntries(deqIndex(i)) =/= s_valid, "state of the dispatch entry is not s_valid\n")
}
}
// commit: from s_dispatched to s_invalid
val needDequeue = Wire(Vec(size, Bool()))
val deqRoqIdx = io.dequeueRoqIndex.bits
for (i <- 0 until size) {
needDequeue(i) := stateEntries(i) === s_dispatched && io.dequeueRoqIndex.valid && !isAfter(uopEntries(i).roqIdx, deqRoqIdx) && dispatchedMask(i)
when (needDequeue(i)) {
stateEntries(i) := s_invalid
}
XSInfo(needDequeue(i), p"dispatched entry($i)(pc = ${Hexadecimal(uopEntries(i).cf.pc)}) " +
p"roqIndex 0x${Hexadecimal(uopEntries(i).roqIdx.asUInt)} " +
p"left dispatch queue with deqRoqIndex 0x${Hexadecimal(io.dequeueRoqIndex.bits.asUInt)}\n")
}
// redirect: cancel uops currently in the queue
val mispredictionValid = io.redirect.valid && io.redirect.bits.isMisPred
val mispredictionValid = io.redirect.valid //&& io.redirect.bits.isMisPred
val exceptionValid = io.redirect.valid && io.redirect.bits.isException
val flushPipeValid = io.redirect.valid && io.redirect.bits.isFlushPipe
val roqNeedFlush = Wire(Vec(size, Bool()))
val needCancel = Wire(Vec(size, Bool()))
for (i <- 0 until size) {
roqNeedFlush(i) := uopEntries(i.U).roqIdx.needFlush(io.redirect)
needCancel(i) := stateEntries(i) =/= s_invalid && ((roqNeedFlush(i) && mispredictionValid) || exceptionValid || flushPipeValid) && !needDequeue(i)
needCancel(i) := stateEntries(i) =/= s_invalid && ((roqNeedFlush(i) && mispredictionValid) || exceptionValid || flushPipeValid)
when (needCancel(i)) {
stateEntries(i) := s_invalid
......@@ -127,182 +93,76 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
p"cancelled with redirect roqIndex 0x${Hexadecimal(io.redirect.bits.roqIdx.asUInt)}\n")
}
// replay: from s_dispatched to s_valid
val replayValid = io.redirect.valid && io.redirect.bits.isReplay
val needReplay = Wire(Vec(size, Bool()))
for (i <- 0 until size) {
needReplay(i) := roqNeedFlush(i) && stateEntries(i) === s_dispatched && replayValid
when (needReplay(i)) {
stateEntries(i) := s_valid
}
XSInfo(needReplay(i), p"dispatched entry($i)(pc = ${Hexadecimal(uopEntries(i.U).cf.pc)}) " +
p"replayed with roqIndex ${io.redirect.bits.roqIdx}\n")
}
/**
* Part 2: walk
*
* Instead of keeping the walking distances, we keep the walking target position for simplicity.
*
* (1) replay: move dispatchPtr to the first needReplay entry
* (2) redirect (branch misprediction): move dispatchPtr, tailPtr to the first cancelled entry
* Part 2: update indices
*
* tail: (1) enqueue; (2) redirect
* head: dequeue
*/
// getFirstIndex: get the head index of consecutive ones
// note that it returns the position starting from either the leftmost or the rightmost
// 00000001 => 0
// 00111000 => 3
// 11000111 => 2
// 10000000 => 1
// 00000000 => 7
// 11111111 => 7
def getFirstMaskPosition(mask: Seq[Bool]) = {
Mux(mask(size - 1),
PriorityEncoder(mask.reverse.map(m => !m)),
PriorityEncoder(mask)
)
}
val maskedNeedReplay = Cat(needReplay.reverse) & dispatchedMask
val allCancel = Cat(needCancel).andR
val someReplay = Cat(maskedNeedReplay).orR
val allReplay = Cat(maskedNeedReplay).andR
XSDebug(replayValid, p"needReplay: ${Binary(Cat(needReplay))}\n")
XSDebug(replayValid, p"dispatchedMask: ${Binary(dispatchedMask)}\n")
XSDebug(replayValid, p"maskedNeedReplay: ${Binary(maskedNeedReplay)}\n")
// when nothing or everything is cancelled or replayed, the pointers remain unchanged
// if any uop is cancelled or replayed, the pointer should go to the first zero before all ones
// position: target index
// (1) if leftmost bits are ones, count continuous ones from leftmost (target position is the last one)
// (2) if leftmost bit is zero, count rightmost zero btis (target position is the first one)
// if all bits are one, we need to keep the index unchanged
// 00000000, 11111111: unchanged
// otherwise: firstMaskPosition
val cancelPosition = Mux(!Cat(needCancel).orR || allCancel, tailPtr.value, getFirstMaskPosition(needCancel))
val replayPosition = Mux(!someReplay || allReplay, dispatchPtr.value, getFirstMaskPosition(maskedNeedReplay.asBools))
XSDebug(replayValid, p"getFirstMaskPosition: ${getFirstMaskPosition(maskedNeedReplay.asBools)}\n")
assert(cancelPosition.getWidth == indexWidth)
assert(replayPosition.getWidth == indexWidth)
// If the highest bit is one, the direction flips.
// Otherwise, the direction keeps the same.
val tailCancelPtr = Wire(new CircularQueuePtr(size))
tailCancelPtr.flag := Mux(needCancel(size - 1), ~tailPtr.flag, tailPtr.flag)
tailCancelPtr.value := Mux(needCancel(size - 1) && !allCancel, size.U - cancelPosition, cancelPosition)
// In case of branch mis-prediction:
// If mis-prediction happens after dispatchPtr, the pointer keeps the same as before.
// If dispatchPtr needs to be cancelled, reset dispatchPtr to tailPtr.
val dispatchCancelPtr = Mux(needCancel(dispatchPtr.value) || dispatchEntries === 0.U, tailCancelPtr, dispatchPtr)
// In case of replay, we need to walk back and recover preg states in the busy table.
// We keep track of the number of entries needed to be walked instead of target position to reduce overhead
// for 11111111, replayPosition is unuseful. We naively set Cnt to size.U
val dispatchReplayCnt = Mux(
allReplay, size.U,
Mux(maskedNeedReplay(size - 1),
// replay makes flag flipped
dispatchPtr.value + replayPosition,
// the new replay does not change the flag
Mux(dispatchPtr.value <= replayPosition,
// but we are currently in a replay that changes the flag
dispatchPtr.value + (size.U - replayPosition),
dispatchPtr.value - replayPosition)))
val dispatchReplayCntReg = RegInit(0.U)
// actually, if deqIndex points to head uops and they are replayed, there's no need for extraWalk
// however, to simplify logic, we simply let it do extra walk now
val needExtraReplayWalk = Cat((0 until deqnum).map(i => needReplay(deqIndex(i)))).orR
val needExtraReplayWalkReg = RegNext(needExtraReplayWalk && replayValid, false.B)
val inReplayWalk = dispatchReplayCntReg =/= 0.U || needExtraReplayWalkReg
val dispatchReplayStep = Mux(needExtraReplayWalkReg, 0.U, Mux(dispatchReplayCntReg > replayWidth.U, replayWidth.U, dispatchReplayCntReg))
when (exceptionValid) {
dispatchReplayCntReg := 0.U
}.elsewhen (inReplayWalk && mispredictionValid && needCancel((dispatchPtr - 1.U).value)) {
val distance = distanceBetween(dispatchPtr, tailCancelPtr)
dispatchReplayCntReg := Mux(dispatchReplayCntReg > distance, dispatchReplayCntReg - distance, 0.U)
}.elsewhen (replayValid && someReplay) {
dispatchReplayCntReg := dispatchReplayCnt - dispatchReplayStep
}.elsewhen (!needExtraReplayWalkReg) {
dispatchReplayCntReg := dispatchReplayCntReg - dispatchReplayStep
}
io.inReplayWalk := inReplayWalk
val replayIndex = (0 until replayWidth).map(i => (dispatchPtr - (i + 1).U).value)
for (i <- 0 until replayWidth) {
val index = Mux(needExtraReplayWalkReg, (if (i < deqnum) deqIndex(i) else 0.U), replayIndex(i))
val shouldResetDest = inReplayWalk && stateEntries(index) === s_valid
io.replayPregReq(i).isInt := shouldResetDest && uopEntries(index).ctrl.rfWen && uopEntries(index).ctrl.ldest =/= 0.U
io.replayPregReq(i).isFp := shouldResetDest && uopEntries(index).ctrl.fpWen
io.replayPregReq(i).preg := uopEntries(index).pdest
XSDebug(shouldResetDest, p"replay $i: " +
p"type (${uopEntries(index).ctrl.rfWen}, ${uopEntries(index).ctrl.fpWen}) " +
p"pdest ${uopEntries(index).pdest} ldest ${uopEntries(index).ctrl.ldest}\n")
}
/**
* Part 3: update indices
*
* tail: (1) enqueue; (2) walk in case of redirect
* dispatch: (1) dequeue; (2) walk in case of replay; (3) walk in case of redirect
* head: commit
*/
// enqueue
val numEnq = Mux(canActualEnqueue, PriorityEncoder(io.enq.map(!_.valid) :+ true.B), 0.U)
XSError(numEnq =/= 0.U && (mispredictionValid || exceptionValid), "should not enqueue when redirect\n")
tailPtr := Mux(exceptionValid,
0.U.asTypeOf(new CircularQueuePtr(size)),
Mux(mispredictionValid,
tailCancelPtr,
tailPtr + numEnq)
)
// dequeue
val numDeqTry = Mux(dispatchEntries > deqnum.U, deqnum.U, dispatchEntries)
val numDeqTry = Mux(validEntries > deqnum.U, deqnum.U, validEntries)
val numDeqFire = PriorityEncoder(io.deq.zipWithIndex.map{case (deq, i) =>
// For dequeue, the first entry should never be s_invalid
// Otherwise, there should be a redirect and tail walks back
// in this case, we set numDeq to 0
!deq.fire() && (if (i == 0) true.B else stateEntries(deqIndex(i)) =/= s_dispatched)
!deq.fire() && (if (i == 0) true.B else stateEntries(deqIndex(i)) =/= s_invalid)
} :+ true.B)
val numDeq = Mux(numDeqTry > numDeqFire, numDeqFire, numDeqTry)
dispatchPtr := Mux(exceptionValid,
// agreement with reservation station: don't dequeue when redirect.valid
val headPtrNext = Mux(mispredictionValid, headPtr, headPtr + numDeq)
headPtr := Mux(exceptionValid, 0.U.asTypeOf(new CircularQueuePtr(size)), headPtrNext)
// For branch mis-prediction or memory violation replay,
// we delay updating the indices for one clock cycle.
// For now, we simply use PopCount to count #instr cancelled.
val lastCycleMisprediction = RegNext(io.redirect.valid && !(io.redirect.bits.isException || io.redirect.bits.isFlushPipe))
// find the last one's position, starting from headPtr and searching backwards
val validBitVec = VecInit((0 until size).map(i => stateEntries(i) === s_valid))
val loValidBitVec = Cat((0 until size).map(i => validBitVec(i) && headPtrMask(i)))
val hiValidBitVec = Cat((0 until size).map(i => validBitVec(i) && ~headPtrMask(i)))
val flippedFlag = loValidBitVec.orR
val lastOneIndex = size.U - PriorityEncoder(Mux(loValidBitVec.orR, loValidBitVec, hiValidBitVec))
val walkedTailPtr = Wire(new CircularQueuePtr(size))
walkedTailPtr.flag := flippedFlag ^ headPtr.flag
walkedTailPtr.value := lastOneIndex
// enqueue
val numEnq = Mux(canActualEnqueue, PriorityEncoder(io.enq.map(!_.valid) :+ true.B), 0.U)
XSError(numEnq =/= 0.U && (mispredictionValid || exceptionValid), "should not enqueue when redirect\n")
tailPtr := Mux(exceptionValid,
0.U.asTypeOf(new CircularQueuePtr(size)),
Mux(mispredictionValid && (!inReplayWalk || needCancel((dispatchPtr - 1.U).value)),
dispatchCancelPtr,
Mux(inReplayWalk, dispatchPtr - dispatchReplayStep, dispatchPtr + numDeq))
Mux(lastCycleMisprediction,
Mux(isTrueEmpty, headPtr, walkedTailPtr),
tailPtr + numEnq)
)
headPtr := Mux(exceptionValid, 0.U.asTypeOf(new CircularQueuePtr(size)), headPtr + PopCount(needDequeue))
/**
* Part 4: set output and input
* Part 3: set output and input
*/
// TODO: remove this when replay moves to roq
for (i <- 0 until deqnum) {
io.deq(i).bits := uopEntries(deqIndex(i))
// do not dequeue when io.redirect valid because it may cause dispatchPtr work improperly
io.deq(i).valid := stateEntries(deqIndex(i)) === s_valid && !io.redirect.valid && allWalkDone
io.deq(i).valid := stateEntries(deqIndex(i)) === s_valid && !lastCycleMisprediction// && !io.redirect.valid
}
// debug: dump dispatch queue states
XSDebug(p"head: $headPtr, tail: $tailPtr, dispatch: $dispatchPtr, " +
p"replayCnt: $dispatchReplayCntReg, needExtraReplayWalkReg: $needExtraReplayWalkReg\n")
XSDebug(p"head: $headPtr, tail: $tailPtr\n")
XSDebug(p"state: ")
stateEntries.reverse.foreach { s =>
XSDebug(false, s === s_invalid, "-")
XSDebug(false, s === s_valid, "v")
XSDebug(false, s === s_dispatched, "d")
}
XSDebug(false, true.B, "\n")
XSDebug(p"ptr: ")
(0 until size).reverse.foreach { i =>
val isPtr = i.U === headPtr.value || i.U === tailPtr.value || i.U === dispatchPtr.value
val isPtr = i.U === headPtr.value || i.U === tailPtr.value
XSDebug(false, isPtr, "^")
XSDebug(false, !isPtr, " ")
}
XSDebug(false, true.B, "\n")
XSError(isAfter(headPtr, tailPtr), p"assert greaterOrEqualThan(tailPtr: $tailPtr, headPtr: $headPtr) failed\n")
XSError(isAfter(dispatchPtr, tailPtr) && !inReplayWalk, p"assert greaterOrEqualThan(tailPtr: $tailPtr, dispatchPtr: $dispatchPtr) failed\n")
XSError(isAfter(headPtr, dispatchPtr), p"assert greaterOrEqualThan(dispatchPtr: $dispatchPtr, headPtr: $headPtr) failed\n")
XSError(validEntries < dispatchEntries && !inReplayWalk, "validEntries should be less than dispatchEntries\n")
}
......@@ -823,7 +823,6 @@ class CSR extends FunctionUnit with HasCSRConst
"MbpIWrong" -> (0xb0b, "perfCntCondMbpIWrong" ),
"MbpRRight" -> (0xb0c, "perfCntCondMbpRRight" ),
"MbpRWrong" -> (0xb0d, "perfCntCondMbpRWrong" ),
"DpqReplay" -> (0xb0e, "perfCntCondDpqReplay" ),
"RoqWalk" -> (0xb0f, "perfCntCondRoqWalk" ),
"RoqWaitInt" -> (0xb10, "perfCntCondRoqWaitInt" ),
"RoqWaitFp" -> (0xb11, "perfCntCondRoqWaitFp" ),
......
......@@ -5,8 +5,6 @@ import chisel3.util._
import xiangshan._
import utils._
import xiangshan.backend.exu.{Exu, ExuConfig}
import java.rmi.registry.Registry
import java.{util => ju}
class BypassQueue(number: Int) extends XSModule {
val io = IO(new Bundle {
......@@ -206,9 +204,11 @@ class ReservationStationCtrl
// enq
val tailAfterRealDeq = tailPtr - (issFire && !needFeedback|| bubReg)
val isFull = tailAfterRealDeq.flag // tailPtr===qsize.U
tailPtr := tailAfterRealDeq + io.enqCtrl.fire()
// agreement with dispatch: don't fire when io.redirect.valid
val enqFire = io.enqCtrl.fire() && !io.redirect.valid
tailPtr := tailAfterRealDeq + enqFire
io.enqCtrl.ready := !isFull && !io.redirect.valid // TODO: check this redirect && need more optimization
io.enqCtrl.ready := !isFull
val enqUop = io.enqCtrl.bits
val srcSeq = Seq(enqUop.psrc1, enqUop.psrc2, enqUop.psrc3)
val srcTypeSeq = Seq(enqUop.ctrl.src1Type, enqUop.ctrl.src2Type, enqUop.ctrl.src3Type)
......@@ -222,7 +222,7 @@ class ReservationStationCtrl
(srcType === SrcType.reg && src === 0.U)
}
when (io.enqCtrl.fire()) {
when (enqFire) {
stateQueue(enqIdx_ctrl) := s_valid
srcQueue(enqIdx_ctrl).zipWithIndex.map{ case (s, i) =>
s := Mux(enqBpVec(i) || stateCheck(srcSeq(i), srcTypeSeq(i)), true.B,
......@@ -249,7 +249,7 @@ class ReservationStationCtrl
io.data.enqPtr := idxQueue(Mux(tailPtr.flag, deqIdx, tailPtr.value))
io.data.deqPtr.valid := selValid
io.data.deqPtr.bits := idxQueue(selectedIdxWire)
io.data.enqCtrl.valid := io.enqCtrl.fire
io.data.enqCtrl.valid := enqFire
io.data.enqCtrl.bits := io.enqCtrl.bits
// other io
......@@ -335,8 +335,8 @@ class ReservationStationData
// enq
val enqPtr = enq(log2Up(IssQueSize)-1,0)
val enqPtrReg = RegEnable(enqPtr, enqCtrl.fire())
val enqEn = enqCtrl.fire()
val enqPtrReg = RegEnable(enqPtr, enqCtrl.valid)
val enqEn = enqCtrl.valid
val enqEnReg = RegNext(enqEn)
when (enqEn) {
uop(enqPtr) := enqUop
......@@ -407,7 +407,7 @@ class ReservationStationData
val srcSeq = Seq(enqUop.psrc1, enqUop.psrc2, enqUop.psrc3)
val srcTypeSeq = Seq(enqUop.ctrl.src1Type, enqUop.ctrl.src2Type, enqUop.ctrl.src3Type)
io.ctrl.srcUpdate(IssQueSize).zipWithIndex.map{ case (h, i) =>
val (bpHit, bpHitReg, bpData)= bypass(srcSeq(i), srcTypeSeq(i), enqCtrl.fire())
val (bpHit, bpHitReg, bpData)= bypass(srcSeq(i), srcTypeSeq(i), enqCtrl.valid)
when (bpHitReg) { data(enqPtrReg)(i) := bpData }
h := bpHit
// NOTE: enq bp is done here
......
......@@ -12,8 +12,6 @@ class BusyTable(numReadPorts: Int, numWritePorts: Int) extends XSModule {
val allocPregs = Vec(RenameWidth, Flipped(ValidIO(UInt(PhyRegIdxWidth.W))))
// set preg state to ready (write back regfile + roq walk)
val wbPregs = Vec(numWritePorts, Flipped(ValidIO(UInt(PhyRegIdxWidth.W))))
// set preg state to busy when replay
val replayPregs = Vec(ReplayWidth, Flipped(ValidIO(UInt(PhyRegIdxWidth.W))))
// read preg state
val rfReadAddr = Vec(numReadPorts, Input(UInt(PhyRegIdxWidth.W)))
val pregRdy = Vec(numReadPorts, Output(Bool()))
......@@ -27,17 +25,15 @@ class BusyTable(numReadPorts: Int, numWritePorts: Int) extends XSModule {
val wbMask = reqVecToMask(io.wbPregs)
val allocMask = reqVecToMask(io.allocPregs)
val replayMask = reqVecToMask(io.replayPregs)
val tableAfterWb = table & (~wbMask).asUInt
val tableAfterAlloc = tableAfterWb | allocMask
val tableAfterReplay = tableAfterAlloc | replayMask
for((raddr, rdy) <- io.rfReadAddr.zip(io.pregRdy)){
rdy := !tableAfterWb(raddr)
}
table := tableAfterReplay
table := tableAfterAlloc
// for((alloc, i) <- io.allocPregs.zipWithIndex){
// when(alloc.valid){
......
......@@ -43,6 +43,7 @@ class FreeList extends XSModule with HasFreeListConsts with HasCircularQueuePtrH
// do checkpoints
val cpReqs = Vec(RenameWidth, Flipped(ValidIO(new BrqPtr)))
val walk = Flipped(ValidIO(UInt(log2Up(RenameWidth).W)))
// dealloc phy regs
val deallocReqs = Input(Vec(CommitWidth, Bool()))
......@@ -96,15 +97,11 @@ class FreeList extends XSModule with HasFreeListConsts with HasCircularQueuePtrH
val headPtrNext = Mux(hasEnoughRegs, newHeadPtrs.last, headPtr)
freeRegs := distanceBetween(tailPtr, headPtrNext)
headPtr := Mux(io.redirect.valid, // mispredict or exception happen
Mux(io.redirect.bits.isException || io.redirect.bits.isFlushPipe, // TODO: need check by JiaWei
FreeListPtr(!tailPtrNext.flag, tailPtrNext.value),
Mux(io.redirect.bits.isMisPred,
checkPoints(io.redirect.bits.brTag.value),
headPtrNext // replay
)
),
headPtrNext
// when mispredict or exception happens, reset headPtr to tailPtr (freelist is full).
val resetHeadPtr = io.redirect.valid && (io.redirect.bits.isException || io.redirect.bits.isFlushPipe)
headPtr := Mux(resetHeadPtr,
FreeListPtr(!tailPtrNext.flag, tailPtrNext.value),
Mux(io.walk.valid, headPtr - io.walk.bits, headPtrNext)
)
XSDebug(p"head:$headPtr tail:$tailPtr\n")
......
......@@ -54,6 +54,11 @@ class Rename extends XSModule {
def needDestReg[T <: CfCtrl](fp: Boolean, x: T): Bool = {
{if(fp) x.ctrl.fpWen else x.ctrl.rfWen && (x.ctrl.ldest =/= 0.U)}
}
val walkValid = Cat(io.roqCommits.map(_.valid)).orR && io.roqCommits(0).bits.isWalk
fpFreeList.walk.valid := walkValid
intFreeList.walk.valid := walkValid
fpFreeList.walk.bits := PopCount(io.roqCommits.map(c => c.valid && needDestReg(true, c.bits.uop)))
intFreeList.walk.bits := PopCount(io.roqCommits.map(c => c.valid && needDestReg(false, c.bits.uop)))
val uops = Wire(Vec(RenameWidth, new MicroOp))
......
......@@ -55,7 +55,6 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
val exeWbResults = Vec(numWbPorts, Flipped(ValidIO(new ExuOutput)))
val commits = Vec(CommitWidth, Valid(new RoqCommit))
val bcommit = Output(UInt(BrTagWidth.W))
val commitRoqIndex = Output(Valid(new RoqPtr))
val roqDeqPtr = Output(new RoqPtr)
val csr = new RoqCSRIO
})
......@@ -164,7 +163,7 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
XSDebug(p"(ready, valid): ${io.enq.canAccept}, ${Binary(firedDispatch)}\n")
val dispatchCnt = PopCount(firedDispatch)
enqPtrExt := enqPtrExt + PopCount(firedDispatch)
enqPtrExt := enqPtrExt + dispatchCnt
when (firedDispatch.orR) {
XSInfo("dispatched %d insts\n", dispatchCnt)
}
......@@ -220,8 +219,8 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
walkPtrVec(i) := walkPtrExt - i.U
shouldWalkVec(i) := i.U < walkCounter
}
val walkFinished = walkCounter <= CommitWidth.U && // walk finish in this cycle
!io.brqRedirect.valid // no new redirect comes and update walkptr
val walkFinished = walkCounter <= CommitWidth.U //&& // walk finish in this cycle
//!io.brqRedirect.valid // no new redirect comes and update walkptr
// extra space is used weh roq has no enough space, but mispredict recovery needs such info to walk regmap
val needExtraSpaceForMPR = WireInit(VecInit(
......@@ -336,9 +335,6 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
}
val retireCounter = Mux(state === s_idle, commitCnt, 0.U)
XSInfo(retireCounter > 0.U, "retired %d insts\n", retireCounter)
val commitOffset = PriorityEncoder((validCommit :+ false.B).map(!_))
io.commitRoqIndex.valid := state === s_idle
io.commitRoqIndex.bits := deqPtrExt + commitOffset
// commit branch to brq
io.bcommit := PopCount(cfiCommitVec)
......@@ -346,12 +342,13 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
// when redirect, walk back roq entries
when(io.brqRedirect.valid){ // TODO: need check if consider exception redirect?
state := s_walk
walkPtrExt := Mux(state === s_walk && !walkFinished, walkPtrExt - CommitWidth.U, Mux(state === s_extrawalk, walkPtrExt, enqPtrExt - 1.U + dispatchCnt))
val nextEnqPtr = (enqPtrExt - 1.U) + dispatchCnt
walkPtrExt := Mux(state === s_walk,
walkPtrExt - Mux(walkFinished, walkCounter, CommitWidth.U),
Mux(state === s_extrawalk, walkPtrExt, nextEnqPtr))
// walkTgtExt := io.brqRedirect.bits.roqIdx
walkCounter := Mux(state === s_walk,
distanceBetween(walkPtrExt, io.brqRedirect.bits.roqIdx) - commitCnt,
distanceBetween(enqPtrExt, io.brqRedirect.bits.roqIdx) + dispatchCnt -1.U,
)
val currentWalkPtr = Mux(state === s_walk || state === s_extrawalk, walkPtrExt, nextEnqPtr)
walkCounter := distanceBetween(currentWalkPtr, io.brqRedirect.bits.roqIdx) - Mux(state === s_walk, commitCnt, 0.U)
enqPtrExt := io.brqRedirect.bits.roqIdx + 1.U
}
......@@ -425,14 +422,14 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
}
// rollback: write all
// when rollback, reset writebacked entry to valid
when(io.memRedirect.valid) { // TODO: opt timing
for (i <- 0 until RoqSize) {
val recRoqIdx = RoqPtr(flagBkup(i), i.U)
when (valid(i) && isAfter(recRoqIdx, io.memRedirect.bits.roqIdx)) {
writebacked(i) := false.B
}
}
}
// when(io.memRedirect.valid) { // TODO: opt timing
// for (i <- 0 until RoqSize) {
// val recRoqIdx = RoqPtr(flagBkup(i), i.U)
// when (valid(i) && isAfter(recRoqIdx, io.memRedirect.bits.roqIdx)) {
// writebacked(i) := false.B
// }
// }
// }
// read
// deqPtrWritebacked
......
......@@ -244,9 +244,10 @@ class ICache extends ICacheModule
sourceVec_16bit(i*4 + j) := sourceVec(i)(j*16+15, j*16)
}
}
val cutPacket = WireInit(VecInit(Seq.fill(blockWords * 2){0.U(RVCInsLen.W)}))
(0 until blockWords * 2).foreach{ i =>
cutPacket(i) := Mux(mask(i).asBool,sourceVec_16bit(startPtr + i.U),0.U)
val cutPacket = WireInit(VecInit(Seq.fill(PredictWidth){0.U(RVCInsLen.W)}))
val start = Cat(startPtr(4,3),0.U(3.W))
(0 until PredictWidth ).foreach{ i =>
cutPacket(i) := Mux(mask(i).asBool,sourceVec_16bit(start + i.U),0.U)
}
cutPacket.asUInt
}
......
......@@ -6,11 +6,12 @@ import utils._
import xiangshan._
import xiangshan.backend.ALUOpType
import xiangshan.backend.JumpOpType
import chisel3.experimental.chiselName
trait HasBPUParameter extends HasXSParameter {
val BPUDebug = false
val BPUDebug = true
val EnableCFICommitLog = true
val EnbaleCFIPredLog = true
val EnbaleCFIPredLog = false
val EnableBPUTimeRecord = EnableCFICommitLog || EnbaleCFIPredLog
}
......@@ -98,7 +99,8 @@ trait PredictorUtils {
Mux(taken, old + 1.S, old - 1.S)))
}
}
abstract class BasePredictor extends XSModule with HasBPUParameter with PredictorUtils {
abstract class BasePredictor extends XSModule
with HasBPUParameter with HasIFUConst with PredictorUtils {
val metaLen = 0
// An implementation MUST extend the IO bundle with a response
......@@ -126,23 +128,21 @@ class BPUStageIO extends XSBundle {
val pc = UInt(VAddrBits.W)
val mask = UInt(PredictWidth.W)
val resp = new PredictorResponse
val target = UInt(VAddrBits.W)
// val target = UInt(VAddrBits.W)
val brInfo = Vec(PredictWidth, new BranchInfo)
val saveHalfRVI = Bool()
// val saveHalfRVI = Bool()
}
abstract class BPUStage extends XSModule with HasBPUParameter{
abstract class BPUStage extends XSModule with HasBPUParameter with HasIFUConst {
class DefaultIO extends XSBundle {
val flush = Input(Bool())
val in = Input(new BPUStageIO)
val inFire = Input(Bool())
val stageValid = Input(Bool())
val pred = Output(new BranchPrediction) // to ifu
val out = Output(new BPUStageIO) // to the next stage
val outFire = Input(Bool())
val predecode = Input(new Predecode)
val recover = Flipped(ValidIO(new BranchUpdateInfo))
val debug_hist = Input(UInt((if (BPUDebug) (HistoryLength) else 0).W))
val debug_histPtr = Input(UInt((if (BPUDebug) (ExtHistoryLength) else 0).W))
}
......@@ -156,68 +156,45 @@ abstract class BPUStage extends XSModule with HasBPUParameter{
// takens, notTakens and target
val takens = Wire(Vec(PredictWidth, Bool()))
val notTakens = Wire(Vec(PredictWidth, Bool()))
// val notTakens = Wire(Vec(PredictWidth, Bool()))
val brMask = Wire(Vec(PredictWidth, Bool()))
val jmpIdx = PriorityEncoder(takens)
val hasNTBr = (0 until PredictWidth).map(i => i.U <= jmpIdx && notTakens(i) && brMask(i)).reduce(_||_)
val taken = takens.reduce(_||_)
// get the last valid inst
val lastValidPos = WireInit(PriorityMux(Reverse(inLatch.mask), (PredictWidth-1 to 0 by -1).map(i => i.U)))
val lastHit = Wire(Bool())
val lastIsRVC = Wire(Bool())
val saveHalfRVI = ((lastValidPos === jmpIdx && taken) || !taken ) && !lastIsRVC && lastHit
val targetSrc = Wire(Vec(PredictWidth, UInt(VAddrBits.W)))
val target = Mux(taken, targetSrc(jmpIdx), npc(inLatch.pc, PopCount(inLatch.mask)))
val jalMask = Wire(Vec(PredictWidth, Bool()))
val targets = Wire(Vec(PredictWidth, UInt(VAddrBits.W)))
val firstBankHasHalfRVI = Wire(Bool())
val lastBankHasHalfRVI = Wire(Bool())
val lastBankHasInst = WireInit(inLatch.mask(PredictWidth-1, bankWidth).orR)
io.pred <> DontCare
io.pred.redirect := target =/= inLatch.target || inLatch.saveHalfRVI && !saveHalfRVI
io.pred.taken := taken
io.pred.jmpIdx := jmpIdx
io.pred.hasNotTakenBrs := hasNTBr
io.pred.target := target
io.pred.saveHalfRVI := saveHalfRVI
io.pred.takenOnBr := taken && brMask(jmpIdx)
io.pred.takens := takens.asUInt
io.pred.brMask := brMask.asUInt
io.pred.jalMask := jalMask.asUInt
io.pred.targets := targets
io.pred.firstBankHasHalfRVI := firstBankHasHalfRVI
io.pred.lastBankHasHalfRVI := lastBankHasHalfRVI
io.out <> DontCare
io.out.pc := inLatch.pc
io.out.mask := inLatch.mask
io.out.target := target
io.out.resp <> inLatch.resp
io.out.brInfo := inLatch.brInfo
io.out.saveHalfRVI := saveHalfRVI
(0 until PredictWidth).map(i =>
io.out.brInfo(i).sawNotTakenBranch := (if (i == 0) false.B else (brMask.asUInt & notTakens.asUInt)(i-1,0).orR))
// Default logic
// pred.ready not taken into consideration
// could be broken
// when (io.flush) { predValid := false.B }
// .elsewhen (inFire) { predValid := true.B }
// .elsewhen (outFire) { predValid := false.B }
// .otherwise { predValid := predValid }
(0 until PredictWidth).map(i => io.out.brInfo(i).sawNotTakenBranch := io.pred.sawNotTakenBr(i))
if (BPUDebug) {
XSDebug(io.inFire, "in: pc=%x, mask=%b, target=%x\n", io.in.pc, io.in.mask, io.in.target)
XSDebug(io.outFire, "out: pc=%x, mask=%b, target=%x\n", io.out.pc, io.out.mask, io.out.target)
val jmpIdx = io.pred.jmpIdx
val taken = io.pred.taken
val target = Mux(taken, io.pred.targets(jmpIdx), snpc(inLatch.pc))
XSDebug("in(%d): pc=%x, mask=%b\n", io.inFire, io.in.pc, io.in.mask)
XSDebug("inLatch: pc=%x, mask=%b\n", inLatch.pc, inLatch.mask)
XSDebug("out(%d): pc=%x, mask=%b, taken=%d, jmpIdx=%d, target=%x, firstHasHalfRVI=%d, lastHasHalfRVI=%d\n",
io.outFire, io.out.pc, io.out.mask, taken, jmpIdx, target, firstBankHasHalfRVI, lastBankHasHalfRVI)
XSDebug("flush=%d\n", io.flush)
XSDebug("taken=%d, takens=%b, notTakens=%b, jmpIdx=%d, hasNTBr=%d, lastValidPos=%d, target=%x\n",
taken, takens.asUInt, notTakens.asUInt, jmpIdx, hasNTBr, lastValidPos, target)
val p = io.pred
XSDebug(io.outFire, "outPred: redirect=%d, taken=%d, jmpIdx=%d, hasNTBrs=%d, target=%x, saveHalfRVI=%d\n",
p.redirect, p.taken, p.jmpIdx, p.hasNotTakenBrs, p.target, p.saveHalfRVI)
XSDebug(io.outFire && p.taken, "outPredTaken: fetchPC:%x, jmpPC:%x\n",
inLatch.pc, inLatch.pc + (jmpIdx << 1.U))
XSDebug(io.outFire && p.redirect, "outPred: previous target:%x redirected to %x \n",
inLatch.target, p.target)
XSDebug(io.outFire, "outPred targetSrc: ")
for (i <- 0 until PredictWidth) {
XSDebug(false, io.outFire, "(%d):%x ", i.U, targetSrc(i))
}
XSDebug(false, io.outFire, "\n")
}
}
@chiselName
class BPUStage1 extends BPUStage {
// ubtb is accessed with inLatch pc in s1,
......@@ -225,21 +202,19 @@ class BPUStage1 extends BPUStage {
val ubtbResp = io.in.resp.ubtb
// the read operation is already masked, so we do not need to mask here
takens := VecInit((0 until PredictWidth).map(i => ubtbResp.hits(i) && ubtbResp.takens(i)))
notTakens := VecInit((0 until PredictWidth).map(i => ubtbResp.hits(i) && !ubtbResp.takens(i) && ubtbResp.brMask(i)))
targetSrc := ubtbResp.targets
// notTakens := VecInit((0 until PredictWidth).map(i => ubtbResp.hits(i) && !ubtbResp.takens(i) && ubtbResp.brMask(i)))
brMask := ubtbResp.brMask
jalMask := DontCare
targets := ubtbResp.targets
lastIsRVC := ubtbResp.is_RVC(lastValidPos)
lastHit := ubtbResp.hits(lastValidPos)
firstBankHasHalfRVI := Mux(lastBankHasInst, false.B, ubtbResp.hits(bankWidth-1) && !ubtbResp.is_RVC(bankWidth-1) && inLatch.mask(bankWidth-1))
lastBankHasHalfRVI := ubtbResp.hits(PredictWidth-1) && !ubtbResp.is_RVC(PredictWidth-1) && inLatch.mask(PredictWidth-1)
// resp and brInfo are from the components,
// so it does not need to be latched
io.out.resp <> io.in.resp
io.out.brInfo := io.in.brInfo
// we do not need to compare target in stage1
io.pred.redirect := taken
if (BPUDebug) {
XSDebug(io.outFire, "outPred using ubtb resp: hits:%b, takens:%b, notTakens:%b, isRVC:%b\n",
ubtbResp.hits.asUInt, ubtbResp.takens.asUInt, ~ubtbResp.takens.asUInt & brMask.asUInt, ubtbResp.is_RVC.asUInt)
......@@ -248,19 +223,18 @@ class BPUStage1 extends BPUStage {
io.out.brInfo.map(_.debug_ubtb_cycle := GTimer())
}
}
@chiselName
class BPUStage2 extends BPUStage {
// Use latched response from s1
val btbResp = inLatch.resp.btb
val bimResp = inLatch.resp.bim
takens := VecInit((0 until PredictWidth).map(i => btbResp.hits(i) && (btbResp.types(i) === BTBtype.B && bimResp.ctrs(i)(1) || btbResp.types(i) =/= BTBtype.B)))
notTakens := VecInit((0 until PredictWidth).map(i => btbResp.hits(i) && btbResp.types(i) === BTBtype.B && !bimResp.ctrs(i)(1)))
targetSrc := btbResp.targets
brMask := VecInit(btbResp.types.map(_ === BTBtype.B))
lastIsRVC := btbResp.isRVC(lastValidPos)
lastHit := btbResp.hits(lastValidPos)
targets := btbResp.targets
brMask := VecInit(btbResp.types.map(_ === BTBtype.B))
jalMask := DontCare
firstBankHasHalfRVI := Mux(lastBankHasInst, false.B, btbResp.hits(bankWidth-1) && !btbResp.isRVC(bankWidth-1) && inLatch.mask(bankWidth-1))
lastBankHasHalfRVI := btbResp.hits(PredictWidth-1) && !btbResp.isRVC(PredictWidth-1) && inLatch.mask(PredictWidth-1)
if (BPUDebug) {
XSDebug(io.outFire, "outPred using btb&bim resp: hits:%b, ctrTakens:%b\n",
......@@ -270,22 +244,31 @@ class BPUStage2 extends BPUStage {
io.out.brInfo.map(_.debug_btb_cycle := GTimer())
}
}
@chiselName
class BPUStage3 extends BPUStage {
class S3IO extends XSBundle {
val predecode = Input(new Predecode)
val realMask = Input(UInt(PredictWidth.W))
val prevHalf = Input(new PrevHalfInstr)
val recover = Flipped(ValidIO(new BranchUpdateInfo))
}
val s3IO = IO(new S3IO)
// TAGE has its own pipelines and the
// response comes directly from s3,
// so we do not use those from inLatch
val tageResp = io.in.resp.tage
val tageTakens = tageResp.takens
val tageHits = tageResp.hits
val tageValidTakens = VecInit((tageTakens zip tageHits).map{case (t, h) => t && h})
val loopResp = io.in.resp.loop.exit
val pdMask = io.predecode.mask
val pds = io.predecode.pd
// realMask is in it
val pdMask = s3IO.predecode.mask
val pdLastHalf = s3IO.predecode.lastHalf
val pds = s3IO.predecode.pd
val btbHits = inLatch.resp.btb.hits.asUInt
val btbResp = inLatch.resp.btb
val btbHits = btbResp.hits.asUInt
val bimTakens = VecInit(inLatch.resp.bim.ctrs.map(_(1)))
val brs = pdMask & Reverse(Cat(pds.map(_.isBr)))
......@@ -295,42 +278,41 @@ class BPUStage3 extends BPUStage {
val rets = pdMask & Reverse(Cat(pds.map(_.isRet)))
val RVCs = pdMask & Reverse(Cat(pds.map(_.isRVC)))
val callIdx = PriorityEncoder(calls)
val retIdx = PriorityEncoder(rets)
val callIdx = PriorityEncoder(calls)
val retIdx = PriorityEncoder(rets)
// Use bim results for those who tage does not have an entry for
val brTakens = brs &
(if (EnableBPD) Reverse(Cat((0 until PredictWidth).map(i => tageValidTakens(i) || !tageHits(i) && bimTakens(i)))) else Reverse(Cat((0 until PredictWidth).map(i => bimTakens(i))))) &
(if (EnableLoop) ~loopResp.asUInt else Fill(PredictWidth, 1.U(1.W)))
// if (EnableBPD) {
// brs & Reverse(Cat((0 until PredictWidth).map(i => tageValidTakens(i))))
// } else {
// brs & Reverse(Cat((0 until PredictWidth).map(i => bimTakens(i))))
// }
val brPred = (if(EnableBPD) tageTakens else bimTakens).asUInt
val loopRes = (if (EnableLoop) loopResp else VecInit(Fill(PredictWidth, 0.U(1.W)))).asUInt
val prevHalfTaken = s3IO.prevHalf.valid && s3IO.prevHalf.taken
val prevHalfTakenMask = prevHalfTaken.asUInt
val brTakens = ((brs & brPred | prevHalfTakenMask) & ~loopRes)
// VecInit((0 until PredictWidth).map(i => brs(i) && (brPred(i) || (if (i == 0) prevHalfTaken else false.B)) && !loopRes(i)))
// predict taken only if btb has a target, jal targets will be provided by IFU
takens := VecInit((0 until PredictWidth).map(i => (brTakens(i) || jalrs(i)) && btbHits(i) || jals(i)))
// Whether should we count in branches that are not recorded in btb?
// PS: Currently counted in. Whenever tage does not provide a valid
// taken prediction, the branch is counted as a not taken branch
notTakens := ((VecInit((0 until PredictWidth).map(i => brs(i) && !takens(i)))).asUInt |
(if (EnableLoop) { VecInit((0 until PredictWidth).map(i => brs(i) && loopResp(i)))}
else { WireInit(0.U.asTypeOf(UInt(PredictWidth.W))) }).asUInt).asTypeOf(Vec(PredictWidth, Bool()))
targetSrc := inLatch.resp.btb.targets
brMask := WireInit(brs.asTypeOf(Vec(PredictWidth, Bool())))
targets := inLatch.resp.btb.targets
brMask := WireInit(brs.asTypeOf(Vec(PredictWidth, Bool())))
jalMask := WireInit(jals.asTypeOf(Vec(PredictWidth, Bool())))
lastBankHasInst := s3IO.realMask(PredictWidth-1, bankWidth).orR
firstBankHasHalfRVI := Mux(lastBankHasInst, false.B, pdLastHalf(0))
lastBankHasHalfRVI := pdLastHalf(1)
//RAS
if(EnableRAS){
val ras = Module(new RAS)
ras.io <> DontCare
ras.io.pc.bits := inLatch.pc
ras.io.pc.bits := bankAligned(inLatch.pc)
ras.io.pc.valid := io.outFire//predValid
ras.io.is_ret := rets.orR && (retIdx === jmpIdx) && io.stageValid
ras.io.callIdx.valid := calls.orR && (callIdx === jmpIdx) && io.stageValid
ras.io.is_ret := rets.orR && (retIdx === io.pred.jmpIdx)
ras.io.callIdx.valid := calls.orR && (callIdx === io.pred.jmpIdx)
ras.io.callIdx.bits := callIdx
ras.io.isRVC := (calls & RVCs).orR //TODO: this is ugly
ras.io.isLastHalfRVI := !io.predecode.isFetchpcEqualFirstpc
ras.io.recover := io.recover
ras.io.isLastHalfRVI := s3IO.predecode.hasLastHalfRVI
ras.io.recover := s3IO.recover
for(i <- 0 until PredictWidth){
io.out.brInfo(i).rasSp := ras.io.branchInfo.rasSp
......@@ -340,30 +322,34 @@ class BPUStage3 extends BPUStage {
takens := VecInit((0 until PredictWidth).map(i => {
((brTakens(i) || jalrs(i)) && btbHits(i)) ||
jals(i) ||
(!ras.io.out.bits.specEmpty && rets(i)) ||
(ras.io.out.bits.specEmpty && btbHits(i))
(ras.io.out.valid && rets(i)) ||
(!ras.io.out.valid && rets(i) && btbHits(i))
}
))
when(ras.io.is_ret && ras.io.out.valid){
targetSrc(retIdx) := ras.io.out.bits.target
for (i <- 0 until PredictWidth) {
when(rets(i) && ras.io.out.valid){
targets(i) := ras.io.out.bits.target
}
}
}
lastIsRVC := pds(lastValidPos).isRVC
when (lastValidPos === 1.U) {
lastHit := pdMask(1) |
!pdMask(0) & !pdMask(1) |
pdMask(0) & !pdMask(1) & (pds(0).isRVC | !io.predecode.isFetchpcEqualFirstpc)
}.elsewhen (lastValidPos > 0.U) {
lastHit := pdMask(lastValidPos) |
!pdMask(lastValidPos - 1.U) & !pdMask(lastValidPos) |
pdMask(lastValidPos - 1.U) & !pdMask(lastValidPos) & pds(lastValidPos - 1.U).isRVC
}.otherwise {
lastHit := pdMask(0) | !pdMask(0) & !pds(0).isRVC
}
// we should provide the prediction for the first half RVI of the end of a fetch packet
// branch taken information would be lost in the prediction of the next packet,
// so we preserve this information here
when (firstBankHasHalfRVI && btbResp.types(bankWidth-1) === BTBtype.B && btbHits(bankWidth-1)) {
takens(bankWidth-1) := brPred(bankWidth-1) && !loopRes(bankWidth-1)
}
when (lastBankHasHalfRVI && btbResp.types(PredictWidth-1) === BTBtype.B && btbHits(PredictWidth-1)) {
takens(PredictWidth-1) := brPred(PredictWidth-1) && !loopRes(PredictWidth-1)
}
io.pred.saveHalfRVI := ((lastValidPos === jmpIdx && taken && !(jmpIdx === 0.U && !io.predecode.isFetchpcEqualFirstpc)) || !taken ) && !lastIsRVC && lastHit
// targets would be lost as well, since it is from btb
// unless it is a ret, which target is from ras
when (prevHalfTaken && !rets(0)) {
targets(0) := s3IO.prevHalf.target
}
// Wrap tage resp and tage meta in
// This is ugly
......@@ -375,10 +361,10 @@ class BPUStage3 extends BPUStage {
}
if (BPUDebug) {
XSDebug(io.inFire, "predecode: pc:%x, mask:%b\n", inLatch.pc, io.predecode.mask)
XSDebug(io.inFire, "predecode: pc:%x, mask:%b\n", inLatch.pc, s3IO.predecode.mask)
for (i <- 0 until PredictWidth) {
val p = io.predecode.pd(i)
XSDebug(io.inFire && io.predecode.mask(i), "predecode(%d): brType:%d, br:%d, jal:%d, jalr:%d, call:%d, ret:%d, RVC:%d, excType:%d\n",
val p = s3IO.predecode.pd(i)
XSDebug(io.inFire && s3IO.predecode.mask(i), "predecode(%d): brType:%d, br:%d, jal:%d, jalr:%d, call:%d, ret:%d, RVC:%d, excType:%d\n",
i.U, p.brType, p.isBr, p.isJal, p.isJalr, p.isCall, p.isRet, p.isRVC, p.excType)
}
}
......@@ -435,11 +421,12 @@ abstract class BaseBPU extends XSModule with BranchPredictorComponents with HasB
// from if1
val in = Input(new BPUReq)
val inFire = Input(Vec(4, Bool()))
val stageValid = Input(Vec(3, Bool()))
// to if2/if3/if4
val out = Vec(3, Output(new BranchPrediction))
// from if4
val predecode = Input(new Predecode)
val realMask = Input(UInt(PredictWidth.W))
val prevHalf = Input(new PrevHalfInstr)
// to if4, some bpu info used for updating
val branchInfo = Output(Vec(PredictWidth, new BranchInfo))
})
......@@ -474,24 +461,11 @@ abstract class BaseBPU extends XSModule with BranchPredictorComponents with HasB
s2.io.outFire := s3_fire
s3.io.outFire := s4_fire
s1.io.stageValid := io.stageValid(0)
s2.io.stageValid := io.stageValid(1)
s3.io.stageValid := io.stageValid(2)
io.out(0) <> s1.io.pred
io.out(1) <> s2.io.pred
io.out(2) <> s3.io.pred
s1.io.predecode <> DontCare
s2.io.predecode <> DontCare
s3.io.predecode <> io.predecode
io.branchInfo := s3.io.out.brInfo
s1.io.recover <> DontCare
s2.io.recover <> DontCare
s3.io.recover.valid <> io.inOrderBrInfo.valid
s3.io.recover.bits <> io.inOrderBrInfo.bits.ui
if (BPUDebug) {
XSDebug(io.inFire(3), "branchInfo sent!\n")
......@@ -512,11 +486,11 @@ class FakeBPU extends BaseBPU {
io.out.foreach(i => {
// Provide not takens
i <> DontCare
i.redirect := false.B
i.takens := 0.U
})
io.branchInfo <> DontCare
}
@chiselName
class BPU extends BaseBPU {
//**********************Stage 1****************************//
......@@ -574,10 +548,8 @@ class BPU extends BaseBPU {
s1.io.inFire := s1_fire
s1.io.in.pc := io.in.pc
s1.io.in.mask := io.in.inMask
s1.io.in.target := DontCare
s1.io.in.resp <> s1_resp_in
s1.io.in.brInfo <> s1_brInfo_in
s1.io.in.saveHalfRVI := false.B
val s1_hist = RegEnable(io.in.hist, enable=s1_fire)
val s2_hist = RegEnable(s1_hist, enable=s2_fire)
......@@ -624,6 +596,15 @@ class BPU extends BaseBPU {
s3.io.in.brInfo(i).specCnt := loop.io.meta.specCnts(i)
}
s3.s3IO.predecode <> io.predecode
s3.s3IO.realMask := io.realMask
s3.s3IO.prevHalf := io.prevHalf
s3.s3IO.recover.valid <> io.inOrderBrInfo.valid
s3.s3IO.recover.bits <> io.inOrderBrInfo.bits.ui
if (BPUDebug) {
if (debug_verbose) {
val uo = ubtb.io.out
......
......@@ -6,6 +6,7 @@ import xiangshan._
import xiangshan.backend.ALUOpType
import utils._
import xiangshan.backend.decode.XSTrap
import chisel3.experimental.chiselName
trait BimParams extends HasXSParameter {
val BimBanks = PredictWidth
......@@ -14,7 +15,8 @@ trait BimParams extends HasXSParameter {
val bypassEntries = 4
}
class BIM extends BasePredictor with BimParams{
@chiselName
class BIM extends BasePredictor with BimParams {
class BIMResp extends Resp {
val ctrs = Vec(PredictWidth, UInt(2.W))
}
......@@ -29,10 +31,12 @@ class BIM extends BasePredictor with BimParams{
}
override val io = IO(new BIMIO)
override val debug = true
val bimAddr = new TableAddr(log2Up(BimSize), BimBanks)
val pcLatch = RegEnable(io.pc.bits, io.pc.valid)
val bankAlignedPC = bankAligned(io.pc.bits)
val pcLatch = RegEnable(bankAlignedPC, io.pc.valid)
val bim = List.fill(BimBanks) {
Module(new SRAMTemplate(UInt(2.W), set = nRows, shouldReset = false, holdRead = true))
......@@ -43,34 +47,35 @@ class BIM extends BasePredictor with BimParams{
resetRow := resetRow + doing_reset
when (resetRow === (nRows-1).U) { doing_reset := false.B }
val baseBank = bimAddr.getBank(io.pc.bits)
// this bank means cache bank
val startsAtOddBank = bankInGroup(bankAlignedPC)(0)
val realMask = Mux(startsAtOddBank,
Cat(io.inMask(bankWidth-1,0), io.inMask(PredictWidth-1, bankWidth)),
io.inMask)
val realMask = circularShiftRight(io.inMask, BimBanks, baseBank)
// those banks whose indexes are less than baseBank are in the next row
val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
val isInNextRow = VecInit((0 until BimBanks).map(i => Mux(startsAtOddBank, (i < bankWidth).B, false.B)))
val baseRow = bimAddr.getBankIdx(io.pc.bits)
val baseRow = bimAddr.getBankIdx(bankAlignedPC)
val realRow = VecInit((0 until BimBanks).map(b => Mux(isInNextRow(b.U), (baseRow+1.U)(log2Up(nRows)-1, 0), baseRow)))
val realRow = VecInit((0 until BimBanks).map(b => Mux(isInNextRow(b), (baseRow+1.U)(log2Up(nRows)-1, 0), baseRow)))
val realRowLatch = VecInit(realRow.map(RegEnable(_, enable=io.pc.valid)))
for (b <- 0 until BimBanks) {
bim(b).reset := reset.asBool
bim(b).io.r.req.valid := realMask(b) && io.pc.valid
bim(b).io.r.req.bits.setIdx := realRow(b)
}
val bimRead = VecInit(bim.map(_.io.r.resp.data(0)))
val baseBankLatch = bimAddr.getBank(pcLatch)
val startsAtOddBankLatch = bankInGroup(pcLatch)(0)
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BimBanks).map(b => (baseBankLatch +& b.U)(log2Up(BimBanks)-1, 0)))
for (b <- 0 until BimBanks) {
val ctr = bimRead(bankIdxInOrder(b))
val realBank = (if (b < bankWidth) Mux(startsAtOddBankLatch, (b+bankWidth).U, b.U)
else Mux(startsAtOddBankLatch, (b-bankWidth).U, b.U))
val ctr = bimRead(realBank)
io.resp.ctrs(b) := ctr
io.meta.ctrs(b) := ctr
}
......
......@@ -7,6 +7,8 @@ import xiangshan._
import xiangshan.backend.ALUOpType
import utils._
import xiangshan.backend.decode.XSTrap
import chisel3.experimental.chiselName
import scala.math.min
......@@ -70,7 +72,9 @@ class BTB extends BasePredictor with BTBParams{
override val io = IO(new BTBIO)
val btbAddr = new TableAddr(log2Up(BtbSize/BtbWays), BtbBanks)
val pcLatch = RegEnable(io.pc.bits, io.pc.valid)
val bankAlignedPC = bankAligned(io.pc.bits)
val pcLatch = RegEnable(bankAlignedPC, io.pc.valid)
val data = List.fill(BtbWays) {
List.fill(BtbBanks) {
......@@ -82,47 +86,53 @@ class BTB extends BasePredictor with BTBParams{
Module(new SRAMTemplate(new BtbMetaEntry, set = nRows, shouldReset = true, holdRead = true))
}
}
val edata = Module(new SRAMTemplate(UInt(VAddrBits.W), set = extendedNRows, shouldReset = true, holdRead = true))
val edata = List.fill(2)(Module(new SRAMTemplate(UInt(VAddrBits.W), set = extendedNRows/2, shouldReset = true, holdRead = true)))
// BTB read requests
val baseBank = btbAddr.getBank(io.pc.bits)
val realMask = circularShiftLeft(io.inMask, BtbBanks, baseBank)
// this bank means cache bank
val startsAtOddBank = bankInGroup(bankAlignedPC)(0)
val realMaskLatch = RegEnable(realMask, io.pc.valid)
val baseBank = btbAddr.getBank(bankAlignedPC)
// those banks whose indexes are less than baseBank are in the next row
val isInNextRow = VecInit((0 until BtbBanks).map(_.U < baseBank))
val realMask = Mux(startsAtOddBank,
Cat(io.inMask(bankWidth-1,0), io.inMask(PredictWidth-1, bankWidth)),
io.inMask)
val realMaskLatch = RegEnable(realMask, io.pc.valid)
val baseRow = btbAddr.getBankIdx(io.pc.bits)
val isInNextRow = VecInit((0 until BtbBanks).map(i => Mux(startsAtOddBank, (i < bankWidth).B, false.B)))
val baseRow = btbAddr.getBankIdx(bankAlignedPC)
val nextRowStartsUp = baseRow.andR
val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b.U), (baseRow+1.U)(log2Up(nRows)-1, 0), baseRow)))
val realRow = VecInit((0 until BtbBanks).map(b => Mux(isInNextRow(b), (baseRow+1.U)(log2Up(nRows)-1, 0), baseRow)))
val realRowLatch = VecInit(realRow.map(RegEnable(_, enable=io.pc.valid)))
for (w <- 0 until BtbWays) {
for (b <- 0 until BtbBanks) {
meta(w)(b).reset := reset.asBool
meta(w)(b).io.r.req.valid := realMask(b) && io.pc.valid
meta(w)(b).io.r.req.bits.setIdx := realRow(b)
data(w)(b).reset := reset.asBool
data(w)(b).io.r.req.valid := realMask(b) && io.pc.valid
data(w)(b).io.r.req.bits.setIdx := realRow(b)
}
}
edata.reset := reset.asBool
edata.io.r.req.valid := io.pc.valid
edata.io.r.req.bits.setIdx := realRow(0) // Use the baseRow
for (b <- 0 to 1) {
edata(b).io.r.req.valid := io.pc.valid
val row = if (b == 0) { Mux(startsAtOddBank, realRow(bankWidth), realRow(0)) }
else { Mux(startsAtOddBank, realRow(0), realRow(bankWidth))}
edata(b).io.r.req.bits.setIdx := row
}
// Entries read from SRAM
val metaRead = VecInit((0 until BtbWays).map(w => VecInit((0 until BtbBanks).map( b => meta(w)(b).io.r.resp.data(0)))))
val dataRead = VecInit((0 until BtbWays).map(w => VecInit((0 until BtbBanks).map( b => data(w)(b).io.r.resp.data(0)))))
val edataRead = edata.io.r.resp.data(0)
val edataRead = VecInit((0 to 1).map(i => edata(i).io.r.resp.data(0)))
val baseBankLatch = btbAddr.getBank(pcLatch)
val startsAtOddBankLatch = bankInGroup(pcLatch)(0)
val baseTag = btbAddr.getTag(pcLatch)
val tagIncremented = VecInit((0 until BtbBanks).map(b => RegEnable(isInNextRow(b.U) && nextRowStartsUp, io.pc.valid)))
......@@ -165,20 +175,22 @@ class BTB extends BasePredictor with BTBParams{
b => Mux(bankHits(b), bankHitWays(b), allocWays(b))
))
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch +& b.U)(log2Up(BtbBanks)-1,0)))
for (b <- 0 until BtbBanks) {
val meta_entry = metaRead(bankHitWays(bankIdxInOrder(b)))(bankIdxInOrder(b))
val data_entry = dataRead(bankHitWays(bankIdxInOrder(b)))(bankIdxInOrder(b))
val realBank = (if (b < bankWidth) Mux(startsAtOddBankLatch, (b+bankWidth).U, b.U)
else Mux(startsAtOddBankLatch, (b-bankWidth).U, b.U))
val meta_entry = metaRead(bankHitWays(realBank))(realBank)
val data_entry = dataRead(bankHitWays(realBank))(realBank)
val edataBank = (if (b < bankWidth) Mux(startsAtOddBankLatch, 1.U, 0.U)
else Mux(startsAtOddBankLatch, 0.U, 1.U))
// Use real pc to calculate the target
io.resp.targets(b) := Mux(data_entry.extended, edataRead, (pcLatch.asSInt + (b << 1).S + data_entry.offset).asUInt)
io.resp.hits(b) := bankHits(bankIdxInOrder(b))
io.resp.targets(b) := Mux(data_entry.extended, edataRead(edataBank), (pcLatch.asSInt + (b << 1).S + data_entry.offset).asUInt)
io.resp.hits(b) := bankHits(realBank)
io.resp.types(b) := meta_entry.btbType
io.resp.isRVC(b) := meta_entry.isRVC
io.meta.writeWay(b) := writeWay(bankIdxInOrder(b))
io.meta.hitJal(b) := bankHits(bankIdxInOrder(b)) && meta_entry.btbType === BTBtype.J
io.meta.writeWay(b) := writeWay(realBank)
io.meta.hitJal(b) := bankHits(realBank) && meta_entry.btbType === BTBtype.J
}
def pdInfoToBTBtype(pd: PreDecodeInfo) = {
......@@ -200,13 +212,14 @@ class BTB extends BasePredictor with BTBParams{
val updateWay = u.brInfo.btbWriteWay
val updateBankIdx = btbAddr.getBank(u.pc)
val updateEBank = updateBankIdx(log2Ceil(BtbBanks)-1) // highest bit of bank idx
val updateRow = btbAddr.getBankIdx(u.pc)
val updateType = pdInfoToBTBtype(u.pd)
val metaWrite = BtbMetaEntry(btbAddr.getTag(u.pc), updateType, u.pd.isRVC)
val dataWrite = BtbDataEntry(new_offset, new_extended)
val jalFirstEncountered = !u.isMisPred && !u.brInfo.btbHitJal && updateType === BTBtype.J
val updateValid = io.update.valid && (u.isMisPred || jalFirstEncountered || !u.isMisPred && u.pd.isBr)
val updateValid = io.update.valid && (u.isMisPred || jalFirstEncountered)
// Update btb
for (w <- 0 until BtbWays) {
for (b <- 0 until BtbBanks) {
......@@ -218,10 +231,12 @@ class BTB extends BasePredictor with BTBParams{
data(w)(b).io.w.req.bits.data := dataWrite
}
}
edata.io.w.req.valid := updateValid && new_extended
edata.io.w.req.bits.setIdx := updateRow
edata.io.w.req.bits.data := u.target
for (b <- 0 to 1) {
edata(b).io.w.req.valid := updateValid && new_extended && b.U === updateEBank
edata(b).io.w.req.bits.setIdx := updateRow
edata(b).io.w.req.bits.data := u.target
}
if (BPUDebug && debug) {
......@@ -234,7 +249,7 @@ class BTB extends BasePredictor with BTBParams{
})
val validLatch = RegNext(io.pc.valid)
XSDebug(io.pc.valid, "read: pc=0x%x, baseBank=%d, realMask=%b\n", io.pc.bits, baseBank, realMask)
XSDebug(io.pc.valid, "read: pc=0x%x, baseBank=%d, realMask=%b\n", bankAlignedPC, baseBank, realMask)
XSDebug(validLatch, "read_resp: pc=0x%x, readIdx=%d-------------------------------\n",
pcLatch, btbAddr.getIdx(pcLatch))
if (debug_verbose) {
......@@ -245,6 +260,9 @@ class BTB extends BasePredictor with BTBParams{
}
}
}
// e.g: baseBank == 5 => (5, 6,..., 15, 0, 1, 2, 3, 4)
val bankIdxInOrder = VecInit((0 until BtbBanks).map(b => (baseBankLatch +& b.U)(log2Up(BtbBanks)-1,0)))
for (i <- 0 until BtbBanks) {
val idx = bankIdxInOrder(i)
XSDebug(validLatch && bankHits(bankIdxInOrder(i)), "resp(%d): bank(%d) hits, tgt=%x, isRVC=%d, type=%d\n",
......
......@@ -29,7 +29,7 @@ class FakeLoopBuffer extends XSModule {
io.loopBufPar.LBredirect.valid := false.B
}
class LoopBuffer extends XSModule {
class LoopBuffer extends XSModule with HasIFUConst{
val io = IO(new LoopBufferIO)
// FSM state define
......@@ -118,9 +118,10 @@ class LoopBuffer extends XSModule {
// Provide ICacheResp to IFU
when(LBstate === s_active) {
val offsetInBankWire = offsetInBank(io.loopBufPar.fetchReq)
io.out.bits.pc := io.loopBufPar.fetchReq
io.out.bits.data := Cat((31 to 0 by -1).map(i => buffer(io.loopBufPar.fetchReq(7,1) + i.U).inst))
io.out.bits.mask := Cat((31 to 0 by -1).map(i => bufferValid(io.loopBufPar.fetchReq(7,1) + i.U)))
io.out.bits.data := Cat((15 to 0 by -1).map(i => buffer(io.loopBufPar.fetchReq(7,1) + i.U).inst)) >> Cat(offsetInBankWire, 0.U(4.W))
io.out.bits.mask := Cat((15 to 0 by -1).map(i => bufferValid(io.loopBufPar.fetchReq(7,1) + i.U))) >> offsetInBankWire
io.out.bits.ipf := false.B
}
......
......@@ -5,6 +5,7 @@ import chisel3.util._
import xiangshan._
import utils._
import xiangshan.backend.brq.BrqPtr
import chisel3.experimental.chiselName
trait LTBParams extends HasXSParameter with HasBPUParameter {
// +-----------+---------+--------------+-----------+
......@@ -64,6 +65,7 @@ class LTBColumnUpdate extends LTBBundle {
}
// each column/bank of Loop Termination Buffer
@chiselName
class LTBColumn extends LTBModule {
val io = IO(new Bundle() {
// if3 send req
......@@ -251,6 +253,7 @@ class LTBColumn extends LTBModule {
}
@chiselName
class LoopPredictor extends BasePredictor with LTBParams {
class LoopResp extends Resp {
val exit = Vec(PredictWidth, Bool())
......
......@@ -2,7 +2,7 @@ package xiangshan.frontend
import chisel3._
import chisel3.util._
import utils.XSDebug
import utils._
import xiangshan._
import xiangshan.backend.decode.isa.predecode.PreDecodeInst
import xiangshan.cache._
......@@ -45,14 +45,16 @@ class PreDecodeInfo extends XSBundle { // 8 bit
def notCFI = brType === BrType.notBr
}
class PreDecodeResp extends XSBundle {
class PreDecodeResp extends XSBundle with HasIFUConst {
val instrs = Vec(PredictWidth, UInt(32.W))
val pc = Vec(PredictWidth, UInt(VAddrBits.W))
val mask = UInt(PredictWidth.W)
// one for the first bank
val lastHalf = UInt(nBanksInPacket.W)
val pd = Vec(PredictWidth, (new PreDecodeInfo))
}
class PreDecode extends XSModule with HasPdconst{
class PreDecode extends XSModule with HasPdconst with HasIFUConst {
val io = IO(new Bundle() {
val in = Input(new ICacheResp)
val prev = Flipped(ValidIO(UInt(16.W)))
......@@ -61,38 +63,53 @@ class PreDecode extends XSModule with HasPdconst{
val data = io.in.data
val mask = io.in.mask
val validCount = PopCount(mask)
val bankAlignedPC = bankAligned(io.in.pc)
val bankOffset = offsetInBank(io.in.pc)
val isAligned = bankOffset === 0.U
val firstValidIdx = bankOffset // io.prev.valid should only occur with firstValidIdx = 0
XSError(firstValidIdx =/= 0.U && io.prev.valid, p"pc:${io.in.pc}, mask:${io.in.mask}, prevhalfInst valid occurs on unaligned fetch packet\n")
// val lastHalfInstrIdx = Mux(isInLastBank(pc), (bankWidth-1).U, (bankWidth*2-1).U)
// in case loop buffer gives a packet ending at an unaligned position
val lastHalfInstrIdx = PriorityMux(Reverse(mask), (PredictWidth-1 to 0 by -1).map(i => i.U))
val insts = Wire(Vec(PredictWidth, UInt(32.W)))
val instsMask = Wire(Vec(PredictWidth, Bool()))
val instsEndMask = Wire(Vec(PredictWidth, Bool()))
val instsRVC = Wire(Vec(PredictWidth,Bool()))
val instsPC = Wire(Vec(PredictWidth, UInt(VAddrBits.W)))
val rawInsts = VecInit((0 until PredictWidth).map(i => if (i == PredictWidth-1) Cat(0.U(16.W), data(i*16+15, i*16))
else data(i*16+31, i*16)))
// val nextHalf = Wire(UInt(16.W))
val lastHalfInstrIdx = PopCount(mask) - 1.U
val lastHalf = Wire(Vec(nBanksInPacket, UInt(1.W)))
for (i <- 0 until PredictWidth) {
val inst = Wire(UInt(32.W))
val valid = Wire(Bool())
val pc = io.in.pc + (i << 1).U - Mux(io.prev.valid && (i.U === 0.U), 2.U, 0.U)
if (i==0) {
inst := Mux(io.prev.valid, Cat(data(15,0), io.prev.bits), data(31,0))
// valid := Mux(lastHalfInstrIdx === 0.U, isRVC(inst), true.B)
valid := Mux(lastHalfInstrIdx === 0.U, Mux(!io.prev.valid, isRVC(inst), true.B), true.B)
} else if (i==1) {
inst := data(47,16)
valid := (io.prev.valid || !(instsMask(0) && !isRVC(insts(0)))) && Mux(lastHalfInstrIdx === 1.U, isRVC(inst), true.B)
} else if (i==PredictWidth-1) {
inst := Cat(0.U(16.W), data(i*16+15, i*16))
valid := !(instsMask(i-1) && !isRVC(insts(i-1)) || !isRVC(inst))
} else {
inst := data(i*16+31, i*16)
valid := !(instsMask(i-1) && !isRVC(insts(i-1))) && Mux(i.U === lastHalfInstrIdx, isRVC(inst), true.B)
}
val inst = WireInit(rawInsts(i))
val validStart = Wire(Bool()) // is the beginning of a valid inst
val validEnd = Wire(Bool()) // is the end of a valid inst
val pc = bankAlignedPC + (i << 1).U - Mux(io.prev.valid && (i.U === firstValidIdx), 2.U, 0.U)
val isFirstInPacket = i.U === firstValidIdx
val isLastInPacket = i.U === lastHalfInstrIdx
val currentRVC = isRVC(insts(i))
val lastIsValidEnd = if (i == 0) { !io.prev.valid } else { instsEndMask(i-1) }
inst := Mux(io.prev.valid && i.U === 0.U, Cat(rawInsts(i)(15,0), io.prev.bits), rawInsts(i))
validStart := lastIsValidEnd && !(isLastInPacket && !currentRVC)
validEnd := validStart && currentRVC || !validStart && !(isLastInPacket && !currentRVC)
val currentLastHalf = lastIsValidEnd && (isLastInPacket && !currentRVC)
insts(i) := inst
instsRVC(i) := isRVC(inst)
instsMask(i) := mask(i) && valid
instsMask(i) := (if (i == 0) Mux(io.prev.valid, validEnd, validStart) else validStart)
instsEndMask(i) := validEnd
instsPC(i) := pc
val brType::isCall::isRet::Nil = brInfo(inst)
......@@ -103,14 +120,18 @@ class PreDecode extends XSModule with HasPdconst{
io.out.pd(i).excType := ExcType.notExc
io.out.instrs(i) := insts(i)
io.out.pc(i) := instsPC(i)
if (i == bankWidth-1) { lastHalf(0) := currentLastHalf }
if (i == PredictWidth-1) { lastHalf(1) := currentLastHalf }
}
io.out.mask := instsMask.asUInt
io.out.mask := instsMask.asUInt & mask
io.out.lastHalf := lastHalf.asUInt
for (i <- 0 until PredictWidth) {
XSDebug(true.B,
p"instr ${Hexadecimal(io.out.instrs(i))}, " +
p"mask ${Binary(instsMask(i))}, " +
p"endMask ${Binary(instsEndMask(i))}, " +
p"pc ${Hexadecimal(io.out.pc(i))}, " +
p"isRVC ${Binary(io.out.pd(i).isRVC)}, " +
p"brType ${Binary(io.out.pd(i).brType)}, " +
......
......@@ -5,13 +5,14 @@ import chisel3.util._
import xiangshan._
import xiangshan.backend.ALUOpType
import utils._
import chisel3.experimental.chiselName
@chiselName
class RAS extends BasePredictor
{
class RASResp extends Resp
{
val target =UInt(VAddrBits.W)
val specEmpty = Bool()
}
class RASBranchInfo extends Meta
......@@ -50,6 +51,7 @@ class RAS extends BasePredictor
override val io = IO(new RASIO)
@chiselName
class RASStack(val rasSize: Int) extends XSModule {
val io = IO(new Bundle {
val push_valid = Input(Bool())
......@@ -64,7 +66,7 @@ class RAS extends BasePredictor
val copy_out_mem = Output(Vec(rasSize, rasEntry()))
val copy_out_sp = Output(UInt(log2Up(rasSize).W))
})
@chiselName
class Stack(val size: Int) extends XSModule {
val io = IO(new Bundle {
val rIdx = Input(UInt(log2Up(size).W))
......@@ -140,7 +142,7 @@ class RAS extends BasePredictor
val spec_push = WireInit(false.B)
val spec_pop = WireInit(false.B)
val spec_new_addr = WireInit(io.pc.bits + (io.callIdx.bits << 1.U) + Mux(io.isRVC,2.U,Mux(io.isLastHalfRVI, 2.U, 4.U)))
val spec_new_addr = WireInit(bankAligned(io.pc.bits) + (io.callIdx.bits << 1.U) + Mux(io.isRVC,2.U,Mux(io.isLastHalfRVI, 2.U, 4.U)))
spec_ras.push_valid := spec_push
spec_ras.pop_valid := spec_pop
spec_ras.new_addr := spec_new_addr
......@@ -167,9 +169,8 @@ class RAS extends BasePredictor
commit_pop := !commit_is_empty && io.recover.valid && io.recover.bits.pd.isRet
io.out.valid := !spec_is_empty && io.is_ret
io.out.valid := !spec_is_empty
io.out.bits.target := spec_top_addr
io.out.bits.specEmpty := spec_is_empty
// TODO: back-up stack for ras
// use checkpoint to recover RAS
......
......@@ -4,6 +4,7 @@ import chisel3._
import chisel3.util._
import xiangshan._
import utils._
import chisel3.experimental.chiselName
import scala.math.min
......@@ -38,6 +39,7 @@ class FakeSCTable extends BaseSCTable {
io.resp := 0.U.asTypeOf(Vec(TageBanks, new SCResp))
}
@chiselName
class SCTable(val nRows: Int, val ctrBits: Int, val histLen: Int) extends BaseSCTable(nRows, ctrBits, histLen) {
val table = List.fill(TageBanks) {
......
......@@ -4,6 +4,7 @@ import chisel3._
import chisel3.util._
import xiangshan._
import utils._
import chisel3.experimental.chiselName
import scala.math.min
......@@ -38,7 +39,7 @@ trait HasTageParameter extends HasXSParameter with HasBPUParameter{
}
abstract class TageBundle extends XSBundle with HasTageParameter with PredictorUtils
abstract class TageModule extends XSModule with HasTageParameter with PredictorUtils { val debug = false }
abstract class TageModule extends XSModule with HasTageParameter with PredictorUtils { val debug = true }
......@@ -77,8 +78,8 @@ class FakeTageTable() extends TageModule {
io.resp := DontCare
}
class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule {
@chiselName
class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule with HasIFUConst {
val io = IO(new Bundle() {
val req = Input(Valid(new TageReq))
val resp = Output(Vec(TageBanks, Valid(new TageResp)))
......@@ -86,7 +87,7 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
})
// override val debug = true
// bypass entries for tage update
val wrBypassEntries = 8
val wrBypassEntries = 4
def compute_folded_hist(hist: UInt, l: Int) = {
val nChunks = (histLen + l - 1) / l
......@@ -120,17 +121,29 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val tageEntrySz = 1 + tagLen + TageCtrBits
val bankAlignedPC = bankAligned(io.req.bits.pc)
// this bank means cache bank
val startsAtOddBank = bankInGroup(bankAlignedPC)(0)
// use real address to index
// val unhashed_idxes = VecInit((0 until TageBanks).map(b => ((io.req.bits.pc >> 1.U) + b.U) >> log2Up(TageBanks).U))
val unhashed_idx = io.req.bits.pc >> 1.U
val unhashed_idx = Wire(Vec(2, UInt((log2Ceil(nRows)+tagLen).W)))
// the first bank idx always correspond with pc
unhashed_idx(0) := io.req.bits.pc >> (1+log2Ceil(TageBanks))
// when pc is at odd bank, the second bank is at the next idx
unhashed_idx(1) := unhashed_idx(0) + startsAtOddBank
// val idxes_and_tags = (0 until TageBanks).map(b => compute_tag_and_hash(unhashed_idxes(b.U), io.req.bits.hist))
val (idx, tag) = compute_tag_and_hash(unhashed_idx, io.req.bits.hist)
// val (idx, tag) = compute_tag_and_hash(unhashed_idx, io.req.bits.hist)
val idxes_and_tags = unhashed_idx.map(compute_tag_and_hash(_, io.req.bits.hist))
// val idxes = VecInit(idxes_and_tags.map(_._1))
// val tags = VecInit(idxes_and_tags.map(_._2))
val idxLatch = RegEnable(idx, enable=io.req.valid)
val tagLatch = RegEnable(tag, enable=io.req.valid)
val idxes_latch = RegEnable(VecInit(idxes_and_tags.map(_._1)), io.req.valid)
val tags_latch = RegEnable(VecInit(idxes_and_tags.map(_._2)), io.req.valid)
// and_tags_latch = RegEnable(idxes_and_tags, enable=io.req.valid)
// val idxLatch = RegEnable(idx, enable=io.req.valid)
// val tagLatch = RegEnable(tag, enable=io.req.valid)
class HL_Bank (val nRows: Int = nRows) extends TageModule {
val io = IO(new Bundle {
......@@ -171,13 +184,18 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val bankIdxInOrder = VecInit((0 until TageBanks).map(b => (baseBankLatch +& b.U)(log2Up(TageBanks)-1, 0)))
val realMask = circularShiftLeft(io.req.bits.mask, TageBanks, baseBank)
val maskLatch = RegEnable(io.req.bits.mask, enable=io.req.valid)
val realMask = Mux(startsAtOddBank,
Cat(io.req.bits.mask(bankWidth-1,0), io.req.bits.mask(PredictWidth-1, bankWidth)),
io.req.bits.mask)
val maskLatch = RegEnable(realMask, enable=io.req.valid)
(0 until TageBanks).map(
b => {
val idxes = VecInit(idxes_and_tags.map(_._1))
val idx = (if (b < bankWidth) Mux(startsAtOddBank, idxes(1), idxes(0))
else Mux(startsAtOddBank, idxes(0), idxes(1)))
hi_us(b).io.r.req.valid := io.req.valid && realMask(b)
hi_us(b).io.r.req.bits.setIdx := idx
......@@ -194,12 +212,22 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
}
)
val req_rhits = VecInit((0 until TageBanks).map(b => table_r(bankIdxInOrder(b)).valid && table_r(bankIdxInOrder(b)).tag === tagLatch))
val startsAtOddBankLatch = RegEnable(startsAtOddBank, io.req.valid)
val req_rhits = VecInit((0 until TageBanks).map(b => {
val tag = (if (b < bankWidth) Mux(startsAtOddBank, tags_latch(1), tags_latch(0))
else Mux(startsAtOddBank, tags_latch(0), tags_latch(1)))
val bank = (if (b < bankWidth) Mux(startsAtOddBankLatch, (b+bankWidth).U, b.U)
else Mux(startsAtOddBankLatch, (b-bankWidth).U, b.U))
table_r(bank).valid && table_r(bank).tag === tag
}))
(0 until TageBanks).map(b => {
val bank = (if (b < bankWidth) Mux(startsAtOddBankLatch, (b+bankWidth).U, b.U)
else Mux(startsAtOddBankLatch, (b-bankWidth).U, b.U))
io.resp(b).valid := req_rhits(b) && maskLatch(b)
io.resp(b).bits.ctr := table_r(bankIdxInOrder(b)).ctr
io.resp(b).bits.u := Cat(hi_us_r(bankIdxInOrder(b)),lo_us_r(bankIdxInOrder(b)))
io.resp(b).bits.ctr := table_r(bank).ctr
io.resp(b).bits.u := Cat(hi_us_r(bank),lo_us_r(bank))
})
......@@ -212,7 +240,7 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod)
// Use fetchpc to compute hash
val (update_idx, update_tag) = compute_tag_and_hash((io.update.pc >> 1.U) - io.update.fetchIdx, io.update.hist)
val (update_idx, update_tag) = compute_tag_and_hash((io.update.pc >> (1 + log2Ceil(TageBanks))), io.update.hist)
val update_wdata = Wire(Vec(TageBanks, new TageEntry))
......@@ -251,28 +279,23 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
wrbypass_idxs(i) === update_idx
})
val wrbypass_rhits = VecInit((0 until wrBypassEntries) map { i =>
io.req.valid &&
wrbypass_tags(i) === tag &&
wrbypass_idxs(i) === idx
})
val wrbypass_hit = wrbypass_hits.reduce(_||_)
val wrbypass_rhit = wrbypass_rhits.reduce(_||_)
// val wrbypass_rhit = wrbypass_rhits.reduce(_||_)
val wrbypass_hit_idx = PriorityEncoder(wrbypass_hits)
val wrbypass_rhit_idx = PriorityEncoder(wrbypass_rhits)
// val wrbypass_rhit_idx = PriorityEncoder(wrbypass_rhits)
val wrbypass_rctr_hits = VecInit((0 until TageBanks).map( b => wrbypass_ctr_valids(wrbypass_rhit_idx)(b)))
// val wrbypass_rctr_hits = VecInit((0 until TageBanks).map( b => wrbypass_ctr_valids(wrbypass_rhit_idx)(b)))
val rhit_ctrs = RegEnable(wrbypass_ctrs(wrbypass_rhit_idx), wrbypass_rhit)
// val rhit_ctrs = RegEnable(wrbypass_ctrs(wrbypass_rhit_idx), wrbypass_rhit)
when (RegNext(wrbypass_rhit)) {
for (b <- 0 until TageBanks) {
when (RegNext(wrbypass_rctr_hits(b.U + baseBank))) {
io.resp(b).bits.ctr := rhit_ctrs(bankIdxInOrder(b))
}
}
}
// when (RegNext(wrbypass_rhit)) {
// for (b <- 0 until TageBanks) {
// when (RegNext(wrbypass_rctr_hits(b.U + baseBank))) {
// io.resp(b).bits.ctr := rhit_ctrs(bankIdxInOrder(b))
// }
// }
// }
val updateBank = PriorityEncoder(io.update.mask)
......@@ -312,10 +335,13 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
val u = io.update
val b = PriorityEncoder(u.mask)
val ub = PriorityEncoder(u.uMask)
XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%x, idx=%d, tag=%x, baseBank=%d, mask=%b, realMask=%b\n",
io.req.bits.pc, io.req.bits.hist, idx, tag, baseBank, io.req.bits.mask, realMask)
val idx = idxes_and_tags.map(_._1)
val tag = idxes_and_tags.map(_._2)
XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%x, idx=(%d,%d), tag=(%x,%x), baseBank=%d, mask=%b, realMask=%b\n",
io.req.bits.pc, io.req.bits.hist, idx(0), idx(1), tag(0), tag(1), baseBank, io.req.bits.mask, realMask)
for (i <- 0 until TageBanks) {
XSDebug(RegNext(io.req.valid) && req_rhits(i), "TageTableResp[%d]: idx=%d, hit:%d, ctr:%d, u:%d\n", i.U, idxLatch, req_rhits(i), io.resp(i).bits.ctr, io.resp(i).bits.u)
XSDebug(RegNext(io.req.valid) && req_rhits(i), "TageTableResp[%d]: idx=(%d,%d), hit:%d, ctr:%d, u:%d\n",
i.U, idxes_latch(0), idxes_latch(1), req_rhits(i), io.resp(i).bits.ctr, io.resp(i).bits.u)
}
XSDebug(RegNext(io.req.valid), "TageTableResp: hits:%b, maskLatch is %b\n", req_rhits.asUInt, maskLatch)
......@@ -333,13 +359,13 @@ class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPerio
"wrbypass hits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
wrbypass_hit_idx, update_tag, update_idx, wrbypass_ctrs(wrbypass_hit_idx)(updateBank), updateBank)
when (wrbypass_rhit && wrbypass_ctr_valids(wrbypass_rhit_idx).reduce(_||_)) {
for (b <- 0 until TageBanks) {
XSDebug(wrbypass_ctr_valids(wrbypass_rhit_idx)(b),
"wrbypass rhits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
wrbypass_rhit_idx, tag, idx, wrbypass_ctrs(wrbypass_rhit_idx)(b), b.U)
}
}
// when (wrbypass_rhit && wrbypass_ctr_valids(wrbypass_rhit_idx).reduce(_||_)) {
// for (b <- 0 until TageBanks) {
// XSDebug(wrbypass_ctr_valids(wrbypass_rhit_idx)(b),
// "wrbypass rhits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
// wrbypass_rhit_idx, tag, idx, wrbypass_ctrs(wrbypass_rhit_idx)(b), b.U)
// }
// }
// ------------------------------Debug-------------------------------------
val valids = Reg(Vec(TageBanks, Vec(nRows, Bool())))
......@@ -376,7 +402,7 @@ class FakeTage extends BaseTage {
io.meta <> DontCare
}
@chiselName
class Tage extends BaseTage {
val tables = TableInfo.map {
......@@ -406,7 +432,7 @@ class Tage extends BaseTage {
val useThreshold = WireInit(scThreshold.thres)
val updateThreshold = WireInit((useThreshold << 3) + 21.U)
// override val debug = true
override val debug = true
// Keep the table responses to process in s3
val resps = VecInit(tables.map(t => RegEnable(t.io.resp, enable=io.s3Fire)))
......
......@@ -4,6 +4,7 @@ import chisel3._
import chisel3.util._
import utils._
import xiangshan._
import chisel3.experimental.chiselName
import scala.math.min
......@@ -15,11 +16,12 @@ trait MicroBTBPatameter{
val extended_stat = false
}
@chiselName
class MicroBTB extends BasePredictor
with MicroBTBPatameter
{
// val tagSize = VAddrBits - log2Ceil(PredictWidth) - 1
val untaggedBits = PredictWidth + 1
val untaggedBits = log2Up(PredictWidth) + 1
class MicroBTBResp extends Resp
{
......@@ -98,6 +100,7 @@ class MicroBTB extends BasePredictor
val pred = UInt(2.W)
}
@chiselName
class UBTBMetaBank(nWays: Int) extends XSModule {
val io = IO(new Bundle {
val wen = Input(Bool())
......@@ -106,6 +109,7 @@ class MicroBTB extends BasePredictor
val rtag = Input(UInt(tagSize.W))
val rdata = Output(new MetaOutput)
val hit_ohs = Output(Vec(nWays, Bool()))
val hit_way = Output(UInt(log2Up(nWays).W))
val allocatable_way = Valid(UInt(log2Up(nWays).W))
val rWay = Input(UInt(log2Up(nWays).W))
val rpred = Output(UInt(2.W))
......@@ -116,6 +120,7 @@ class MicroBTB extends BasePredictor
val hit_way = PriorityEncoder(hit_ohs)
val hit_entry = rentries(hit_way)
io.hit_ohs := hit_ohs
io.hit_way := hit_way
io.rdata.is_Br := hit_entry.is_Br
io.rdata.is_RVC := hit_entry.is_RVC
io.rdata.pred := hit_entry.pred
......@@ -129,6 +134,7 @@ class MicroBTB extends BasePredictor
}
}
@chiselName
class UBTBDataBank(nWays: Int) extends XSModule {
val io = IO(new Bundle {
val wen = Input(Bool())
......@@ -160,9 +166,14 @@ class MicroBTB extends BasePredictor
//uBTB read
//tag is bank align
val bankAlignedPC = bankAligned(io.pc.bits)
val startsAtOddBank = bankInGroup(bankAlignedPC)(0).asBool
val read_valid = io.pc.valid
val read_req_tag = getTag(io.pc.bits)
val read_req_basebank = getBank(io.pc.bits)
val read_req_tag = getTag(bankAlignedPC)
val next_tag = read_req_tag + 1.U
// val read_mask = circularShiftLeft(io.inMask, PredictWidth, read_req_basebank)
......@@ -175,22 +186,21 @@ class MicroBTB extends BasePredictor
val is_Br = Bool()
}
val read_resp = Wire(Vec(PredictWidth,new ReadRespEntry))
val read_bank_inOrder = VecInit((0 until PredictWidth).map(b => (read_req_basebank + b.U)(log2Up(PredictWidth)-1,0) ))
//val read_bank_inOrder = VecInit((0 until PredictWidth).map(b => (read_req_basebank + b.U)(log2Up(PredictWidth)-1,0) ))
// val isInNextRow = VecInit((0 until PredictWidth).map(_.U < read_req_basebank))
(0 until PredictWidth).map{ b => metas(b).rtag := read_req_tag }
val read_hit_ohs = read_bank_inOrder.map{ b => metas(b).hit_ohs }
(0 until PredictWidth).map{ b => metas(b).rtag := Mux(startsAtOddBank && (b > PredictWidth).B,next_tag,read_req_tag) }
val read_hit_ohs = (0 until PredictWidth).map{ b => metas(b).hit_ohs }
val read_hit_vec = VecInit(read_hit_ohs.map{oh => ParallelOR(oh).asBool})
val read_hit_ways = VecInit(read_hit_ohs.map{oh => PriorityEncoder(oh)})
val read_hit_ways = (0 until PredictWidth).map{ b => metas(b).hit_way }
// val read_hit = ParallelOR(read_hit_vec).asBool
// val read_hit_way = PriorityEncoder(ParallelOR(read_hit_ohs.map(_.asUInt)))
(0 until PredictWidth).map(b => datas(b).rWay := read_hit_ways((b.U + PredictWidth.U - read_req_basebank)(log2Up(PredictWidth)-1, 0)))
(0 until PredictWidth).map(b => datas(b).rWay := read_hit_ways(b))
val uBTBMeta_resp = VecInit((0 until PredictWidth).map(b => metas(read_bank_inOrder(b)).rdata))
val btb_resp = VecInit((0 until PredictWidth).map(b => datas(read_bank_inOrder(b)).rdata))
val uBTBMeta_resp = VecInit((0 until PredictWidth).map(b => metas(b).rdata))
val btb_resp = VecInit((0 until PredictWidth).map(b => datas(b).rdata))
for(i <- 0 until PredictWidth){
// do not need to decide whether to produce results\
......@@ -224,7 +234,7 @@ class MicroBTB extends BasePredictor
// }
val alloc_ways = read_bank_inOrder.map{ b =>
val alloc_ways = (0 until PredictWidth).map{ b =>
Mux(metas(b).allocatable_way.valid, metas(b).allocatable_way.bits, LFSR64()(log2Ceil(nWays)-1,0))}
(0 until PredictWidth).map(i => out_ubtb_br_info.writeWay(i) := Mux(read_hit_vec(i).asBool,read_hit_ways(i),alloc_ways(i)))
......@@ -259,8 +269,8 @@ class MicroBTB extends BasePredictor
val jalFirstEncountered = !u.isMisPred && !u.brInfo.btbHitJal && (u.pd.brType === BrType.jal)
val entry_write_valid = io.update.valid && (u.isMisPred || !u.isMisPred && u.pd.isBr || jalFirstEncountered)//io.update.valid //&& update_is_BR_or_JAL
val meta_write_valid = io.update.valid && (u.isMisPred || !u.isMisPred && u.pd.isBr || jalFirstEncountered)//io.update.valid //&& update_is_BR_or_JAL
val entry_write_valid = io.update.valid && (u.isMisPred || jalFirstEncountered)//io.update.valid //&& update_is_BR_or_JAL
val meta_write_valid = io.update.valid && (u.isMisPred || jalFirstEncountered)//io.update.valid //&& update_is_BR_or_JAL
//write btb target when miss prediction
// when(entry_write_valid)
// {
......@@ -293,7 +303,7 @@ class MicroBTB extends BasePredictor
}
if (BPUDebug && debug) {
XSDebug(read_valid,"uBTB read req: pc:0x%x, tag:%x basebank:%d\n",io.pc.bits,read_req_tag,read_req_basebank)
XSDebug(read_valid,"uBTB read req: pc:0x%x, tag:%x startAtOdd:%d\n",io.pc.bits,read_req_tag,startsAtOddBank)
XSDebug(read_valid,"uBTB read resp: read_hit_vec:%b, \n",read_hit_vec.asUInt)
for(i <- 0 until PredictWidth) {
XSDebug(read_valid,"bank(%d) hit:%d way:%d valid:%d is_RVC:%d taken:%d isBr:%d target:0x%x alloc_way:%d\n",
......
......@@ -28,6 +28,11 @@ class LsqEntry extends XSBundle {
val fwdData = Vec(8, UInt(8.W))
}
class FwdEntry extends XSBundle {
val mask = Vec(8, Bool())
val data = Vec(8, UInt(8.W))
}
class LSQueueData(size: Int, nchannel: Int) extends XSModule with HasDCacheParameters with HasCircularQueuePtrHelper {
val io = IO(new Bundle() {
......@@ -124,6 +129,8 @@ class LSQueueData(size: Int, nchannel: Int) extends XSModule with HasDCacheParam
// i.e. forward1 is the target entries with the same flag bits and forward2 otherwise
// entry with larger index should have higher priority since it's data is younger
// FIXME: old fwd logic for assertion, remove when rtl freeze
(0 until nchannel).map(i => {
val forwardMask1 = WireInit(VecInit(Seq.fill(8)(false.B)))
......@@ -152,10 +159,63 @@ class LSQueueData(size: Int, nchannel: Int) extends XSModule with HasDCacheParam
// merge forward lookup results
// forward2 is younger than forward1 and should have higher priority
val oldFwdResult = Wire(new FwdEntry)
(0 until XLEN / 8).map(k => {
io.forward(i).forwardMask(k) := forwardMask1(k) || forwardMask2(k)
io.forward(i).forwardData(k) := Mux(forwardMask2(k), forwardData2(k), forwardData1(k))
oldFwdResult.mask(k) := RegNext(forwardMask1(k) || forwardMask2(k))
oldFwdResult.data(k) := RegNext(Mux(forwardMask2(k), forwardData2(k), forwardData1(k)))
})
// parallel fwd logic
val paddrMatch = Wire(Vec(size, Bool()))
val matchResultVec = Wire(Vec(size * 2, new FwdEntry))
def parallelFwd(xs: Seq[Data]): Data = {
ParallelOperation(xs, (a: Data, b: Data) => {
val l = a.asTypeOf(new FwdEntry)
val r = b.asTypeOf(new FwdEntry)
val res = Wire(new FwdEntry)
(0 until 8).map(p => {
res.mask(p) := l.mask(p) || r.mask(p)
res.data(p) := Mux(r.mask(p), r.data(p), l.data(p))
})
res
})
}
for (j <- 0 until size) {
paddrMatch(j) := io.forward(i).paddr(PAddrBits - 1, 3) === data(j).paddr(PAddrBits - 1, 3)
}
for (j <- 0 until size) {
val needCheck0 = RegNext(paddrMatch(j) && io.needForward(i)(0)(j))
val needCheck1 = RegNext(paddrMatch(j) && io.needForward(i)(1)(j))
(0 until XLEN / 8).foreach(k => {
matchResultVec(j).mask(k) := needCheck0 && data(j).mask(k)
matchResultVec(j).data(k) := data(j).data(8 * (k + 1) - 1, 8 * k)
matchResultVec(size + j).mask(k) := needCheck1 && data(j).mask(k)
matchResultVec(size + j).data(k) := data(j).data(8 * (k + 1) - 1, 8 * k)
})
}
val parallelFwdResult = parallelFwd(matchResultVec).asTypeOf(new FwdEntry)
io.forward(i).forwardMask := parallelFwdResult.mask
io.forward(i).forwardData := parallelFwdResult.data
when(
oldFwdResult.mask.asUInt =/= parallelFwdResult.mask.asUInt
){
printf("%d: mask error: right: %b false %b\n", GTimer(), oldFwdResult.mask.asUInt, parallelFwdResult.mask.asUInt)
}
for (p <- 0 until 8) {
when(
oldFwdResult.data(p) =/= parallelFwdResult.data(p) && oldFwdResult.mask(p)
){
printf("%d: data "+p+" error: right: %x false %x\n", GTimer(), oldFwdResult.data(p), parallelFwdResult.data(p))
}
}
})
// data read
......@@ -189,7 +249,6 @@ class LsqWrappper extends XSModule with HasDCacheParameters {
val dcache = new DCacheLineIO
val uncache = new DCacheWordIO
val roqDeqPtr = Input(new RoqPtr)
val oldestStore = Output(Valid(new RoqPtr))
val exceptionAddr = new ExceptionAddrIO
})
......@@ -232,7 +291,6 @@ class LsqWrappper extends XSModule with HasDCacheParameters {
storeQueue.io.mmioStout <> io.mmioStout
storeQueue.io.commits <> io.commits
storeQueue.io.roqDeqPtr <> io.roqDeqPtr
storeQueue.io.oldestStore <> io.oldestStore
storeQueue.io.exceptionAddr.lsIdx := io.exceptionAddr.lsIdx
storeQueue.io.exceptionAddr.isStore := DontCare
......
......@@ -51,7 +51,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
val dataModule = Module(new LSQueueData(LoadQueueSize, LoadPipelineWidth))
dataModule.io := DontCare
val allocated = RegInit(VecInit(List.fill(LoadQueueSize)(false.B))) // lq entry has been allocated
val valid = RegInit(VecInit(List.fill(LoadQueueSize)(false.B))) // data is valid
val datavalid = RegInit(VecInit(List.fill(LoadQueueSize)(false.B))) // data is valid
val writebacked = RegInit(VecInit(List.fill(LoadQueueSize)(false.B))) // inst has been writebacked to CDB
val commited = Reg(Vec(LoadQueueSize, Bool())) // inst has been writebacked to CDB
val miss = Reg(Vec(LoadQueueSize, Bool())) // load inst missed, waiting for miss queue to accept miss request
......@@ -87,7 +87,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
when(io.enq.req(i).valid) {
uop(index) := io.enq.req(i).bits
allocated(index) := true.B
valid(index) := false.B
datavalid(index) := false.B
writebacked(index) := false.B
commited(index) := false.B
miss(index) := false.B
......@@ -138,7 +138,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
)
}
val loadWbIndex = io.loadIn(i).bits.uop.lqIdx.value
valid(loadWbIndex) := !io.loadIn(i).bits.miss && !io.loadIn(i).bits.mmio
datavalid(loadWbIndex) := !io.loadIn(i).bits.miss && !io.loadIn(i).bits.mmio
writebacked(loadWbIndex) := !io.loadIn(i).bits.miss && !io.loadIn(i).bits.mmio
allocated(loadWbIndex) := !io.loadIn(i).bits.uop.cf.exceptionVec.asUInt.orR
......@@ -237,7 +237,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
dataModule.io.refill.wen(i) := false.B
when(allocated(i) && listening(i) && blockMatch && io.dcache.resp.fire()) {
dataModule.io.refill.wen(i) := true.B
valid(i) := true.B
datavalid(i) := true.B
listening(i) := false.B
}
})
......@@ -245,7 +245,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
// writeback up to 2 missed load insts to CDB
// just randomly pick 2 missed load (data refilled), write them back to cdb
val loadWbSelVec = VecInit((0 until LoadQueueSize).map(i => {
allocated(i) && valid(i) && !writebacked(i)
allocated(i) && datavalid(i) && !writebacked(i)
})).asUInt() // use uint instead vec to reduce verilog lines
val loadWbSel = Wire(Vec(StorePipelineWidth, UInt(log2Up(LoadQueueSize).W)))
val loadWbSelV= Wire(Vec(StorePipelineWidth, Bool()))
......@@ -387,7 +387,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
val lqViolationVec = VecInit((0 until LoadQueueSize).map(j => {
val addrMatch = allocated(j) &&
io.storeIn(i).bits.paddr(PAddrBits - 1, 3) === dataModule.io.rdata(j).paddr(PAddrBits - 1, 3)
val entryNeedCheck = toEnqPtrMask(j) && addrMatch && (valid(j) || listening(j) || miss(j))
val entryNeedCheck = toEnqPtrMask(j) && addrMatch && (datavalid(j) || listening(j) || miss(j))
// TODO: update refilled data
val violationVec = (0 until 8).map(k => dataModule.io.rdata(j).mask(k) && io.storeIn(i).bits.mask(k))
Cat(violationVec).orR() && entryNeedCheck
......@@ -433,6 +433,8 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
rollback(i).bits.isMisPred := false.B
rollback(i).bits.isException := false.B
rollback(i).bits.isFlushPipe := false.B
rollback(i).bits.target := rollbackUop.cf.pc
rollback(i).bits.brTag := rollbackUop.brTag
XSDebug(
l1Violation,
......@@ -500,7 +502,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
dataModule.io.uncache.wen := false.B
when(io.uncache.resp.fire()){
valid(deqPtr) := true.B
datavalid(deqPtr) := true.B
dataModule.io.uncacheWrite(deqPtr, io.uncache.resp.bits.data(XLEN-1, 0))
dataModule.io.uncache.wen := true.B
// TODO: write back exception info
......@@ -529,15 +531,15 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
for (i <- 0 until LoadQueueSize) {
needCancel(i) := uop(i).roqIdx.needFlush(io.brqRedirect) && allocated(i) && !commited(i)
when(needCancel(i)) {
when(io.brqRedirect.bits.isReplay){
valid(i) := false.B
writebacked(i) := false.B
listening(i) := false.B
miss(i) := false.B
pending(i) := false.B
}.otherwise{
// when(io.brqRedirect.bits.isReplay){
// valid(i) := false.B
// writebacked(i) := false.B
// listening(i) := false.B
// miss(i) := false.B
// pending(i) := false.B
// }.otherwise{
allocated(i) := false.B
}
// }
}
}
when (io.brqRedirect.valid && io.brqRedirect.bits.isMisPred) {
......@@ -564,7 +566,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
if (i % 4 == 0) XSDebug("")
XSDebug(false, true.B, "%x [%x] ", uop(i).cf.pc, dataModule.io.rdata(i).paddr)
PrintFlag(allocated(i), "a")
PrintFlag(allocated(i) && valid(i), "v")
PrintFlag(allocated(i) && datavalid(i), "v")
PrintFlag(allocated(i) && writebacked(i), "w")
PrintFlag(allocated(i) && commited(i), "c")
PrintFlag(allocated(i) && miss(i), "m")
......
......@@ -38,7 +38,6 @@ class StoreQueue extends XSModule with HasDCacheParameters with HasCircularQueue
val uncache = new DCacheWordIO
val roqDeqPtr = Input(new RoqPtr)
// val refill = Flipped(Valid(new DCacheLineReq ))
val oldestStore = Output(Valid(new RoqPtr))
val exceptionAddr = new ExceptionAddrIO
})
......@@ -178,13 +177,6 @@ class StoreQueue extends XSModule with HasDCacheParameters with HasCircularQueue
(selValid, selVec)
}
// select the last writebacked instruction
val validStoreVec = VecInit((0 until StoreQueueSize).map(i => !(allocated(i) && datavalid(i))))
val storeNotValid = SqPtr(false.B, getFirstOne(validStoreVec, tailMask))
val storeValidIndex = (storeNotValid - 1.U).value
io.oldestStore.valid := allocated(deqPtrExt.value) && datavalid(deqPtrExt.value) && !commited(storeValidIndex)
io.oldestStore.bits := uop(storeValidIndex).roqIdx
// writeback finished mmio store
io.mmioStout.bits.uop := uop(deqPtr)
io.mmioStout.bits.uop.sqIdx := deqPtrExt
......@@ -340,13 +332,13 @@ class StoreQueue extends XSModule with HasDCacheParameters with HasCircularQueue
for (i <- 0 until StoreQueueSize) {
needCancel(i) := uop(i).roqIdx.needFlush(io.brqRedirect) && allocated(i) && !commited(i)
when(needCancel(i)) {
when(io.brqRedirect.bits.isReplay){
datavalid(i) := false.B
writebacked(i) := false.B
pending(i) := false.B
}.otherwise{
// when(io.brqRedirect.bits.isReplay){
// datavalid(i) := false.B
// writebacked(i) := false.B
// pending(i) := false.B
// }.otherwise{
allocated(i) := false.B
}
// }
}
}
when (io.brqRedirect.valid && io.brqRedirect.bits.isMisPred) {
......
......@@ -129,19 +129,6 @@ class LoadUnit_S1 extends XSModule {
io.out.bits.forwardMask := io.sbuffer.forwardMask
io.out.bits.forwardData := io.sbuffer.forwardData
// generate XLEN/8 Muxs
for (i <- 0 until XLEN / 8) {
when(io.lsq.forwardMask(i)) {
io.out.bits.forwardMask(i) := true.B
io.out.bits.forwardData(i) := io.lsq.forwardData(i)
}
}
XSDebug(io.out.fire(), "[FWD LOAD RESP] pc %x fwd %x(%b) + %x(%b)\n",
s1_uop.cf.pc,
io.lsq.forwardData.asUInt, io.lsq.forwardMask.asUInt,
io.sbuffer.forwardData.asUInt, io.sbuffer.forwardMask.asUInt
)
io.out.valid := io.in.valid && !s1_tlb_miss && !s1_uop.roqIdx.needFlush(io.redirect)
io.out.bits.paddr := s1_paddr
......@@ -161,6 +148,7 @@ class LoadUnit_S2 extends XSModule {
val out = Decoupled(new LsPipelineBundle)
val redirect = Flipped(ValidIO(new Redirect))
val dcacheResp = Flipped(DecoupledIO(new DCacheWordResp))
val lsq = new LoadForwardQueryIO
})
val s2_uop = io.in.bits.uop
......@@ -173,10 +161,16 @@ class LoadUnit_S2 extends XSModule {
io.dcacheResp.ready := true.B
assert(!(io.in.valid && !io.dcacheResp.valid), "DCache response got lost")
val forwardMask = io.in.bits.forwardMask
val forwardData = io.in.bits.forwardData
val forwardMask = io.out.bits.forwardMask
val forwardData = io.out.bits.forwardData
val fullForward = (~forwardMask.asUInt & s2_mask) === 0.U
XSDebug(io.out.fire(), "[FWD LOAD RESP] pc %x fwd %x(%b) + %x(%b)\n",
s2_uop.cf.pc,
io.lsq.forwardData.asUInt, io.lsq.forwardMask.asUInt,
io.in.bits.forwardData.asUInt, io.in.bits.forwardMask.asUInt
)
// data merge
val rdata = VecInit((0 until XLEN / 8).map(j =>
Mux(forwardMask(j), forwardData(j), io.dcacheResp.bits.data(8*(j+1)-1, 8*j)))).asUInt
......@@ -213,9 +207,19 @@ class LoadUnit_S2 extends XSModule {
io.in.ready := io.out.ready || !io.in.valid
// merge forward result
io.lsq := DontCare
// generate XLEN/8 Muxs
for (i <- 0 until XLEN / 8) {
when(io.lsq.forwardMask(i)) {
io.out.bits.forwardMask(i) := true.B
io.out.bits.forwardData(i) := io.lsq.forwardData(i)
}
}
XSDebug(io.out.fire(), "[DCACHE LOAD RESP] pc %x rdata %x <- D$ %x + fwd %x(%b)\n",
s2_uop.cf.pc, rdataPartialLoad, io.dcacheResp.bits.data,
io.in.bits.forwardData.asUInt, io.in.bits.forwardMask.asUInt
io.out.bits.forwardData.asUInt, io.out.bits.forwardMask.asUInt
)
}
......@@ -268,6 +272,9 @@ class LoadUnit extends XSModule {
load_s2.io.redirect <> io.redirect
load_s2.io.dcacheResp <> io.dcache.resp
load_s2.io.lsq := DontCare
load_s2.io.lsq.forwardData <> io.lsq.forward.forwardData
load_s2.io.lsq.forwardMask <> io.lsq.forward.forwardMask
// PipelineConnect(load_s2.io.fp_out, load_s3.io.in, true.B, false.B)
// load_s3.io.redirect <> io.redirect
......
......@@ -4,7 +4,7 @@
#include "common.h"
#include "ram.h"
#define RAMSIZE (64 * 1024 * 1024 * 1024UL)
#define RAMSIZE (256 * 1024 * 1024UL)
#ifdef WITH_DRAMSIM3
#include "cosimulation.h"
......@@ -215,8 +215,9 @@ void ram_finish() {
extern "C" uint64_t ram_read_helper(uint8_t en, uint64_t rIdx) {
if (en && rIdx >= RAMSIZE / sizeof(uint64_t)) {
printf("ERROR: ram rIdx = 0x%lx out of bound!\n", rIdx);
assert(rIdx < RAMSIZE / sizeof(uint64_t));
printf("WARN: ram rIdx = 0x%lx out of bound!\n", rIdx);
// assert(rIdx < RAMSIZE / sizeof(uint64_t));
return 0x12345678deadbeafULL;
}
return (en) ? ram[rIdx] : 0;
}
......
......@@ -5,7 +5,7 @@
#include "VXSSimSoC.h"
#include <verilated_save.h>
class VerilatedSaveMem : public VerilatedSave {
class VerilatedSaveMem : public VerilatedSerialize {
const static long buf_size = 1024 * 1024 * 1024;
uint8_t *buf;
long size;
......
......@@ -41,7 +41,7 @@ class IFUTest extends AnyFlatSpec with ChiselScalatestTester with Matchers {
// Cycle 5
//-----------------
c.io.redirect.valid.poke(true.B)
c.io.redirect.bits.target.poke("h80002800".U)
c.io.redirect.bits.poke("h80002800".U)
c.clock.step()
//-----------------
// Cycle 6
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册