提交 f506e33b 编写于 作者: Y Yinan Xu

dispatch queue: rewrite index update logic, support preg state replay

上级 7695ca79
......@@ -3,6 +3,7 @@ package xiangshan.backend.dispatch
import chisel3._
import chisel3.util._
import utils.{XSDebug, XSError, XSInfo}
import xiangshan.backend.decode.SrcType
import xiangshan.{MicroOp, Redirect, ReplayPregReq, RoqCommit, XSBundle, XSModule}
......@@ -50,50 +51,31 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
// enq: starting from tail ptr
val enqPtr = (0 until enqnum).map(i => tailPtr + i.U)
val enqIndex = enqPtr.map(ptr => ptr(indexWidth - 1, 0))
// walkDispatch: in case of redirect, walk backward
val walkDispatchPtr = (0 until RenameWidth).map(i => dispatchPtr - (i + 1).U)
val walkDispatchIndex = walkDispatchPtr.map(ptr => ptr(indexWidth - 1, 0))
// walkTail: in case of redirect, walk backward
val walkTailPtr = (0 until RenameWidth).map(i => tailPtr - (i + 1).U)
val walkTailIndex = walkTailPtr.map(ptr => ptr(indexWidth - 1, 0))
// debug: dump dispatch queue states
def greaterOrEqualThan(left: UInt, right: UInt) = {
Mux(
left(indexWidth) === right(indexWidth),
left(indexWidth - 1, 0) >= right(indexWidth - 1, 0),
left(indexWidth - 1, 0) <= right(indexWidth - 1, 0)
)
}
XSError(!greaterOrEqualThan(tailPtr, headPtr), p"assert greaterOrEqualThan(tailPtr: $tailPtr, headPtr: $headPtr) failed\n")
XSError(!greaterOrEqualThan(tailPtr, dispatchPtr), p"assert greaterOrEqualThan(tailPtr: $tailPtr, dispatchPtr: $dispatchPtr) failed\n")
XSError(!greaterOrEqualThan(dispatchPtr, headPtr), p"assert greaterOrEqualThan(dispatchPtr: $dispatchPtr, headPtr: $headPtr) failed\n")
XSDebug(p"head: $headPtr, tail: $tailPtr, dispatch: $dispatchPtr\n")
XSDebug(p"state: ")
stateEntries.reverse.foreach { s =>
XSDebug(false, s === s_invalid, "-")
XSDebug(false, s === s_valid, "v")
XSDebug(false, s === s_dispatched, "d")
def distanceBetween(left: UInt, right: UInt) = {
Mux(left(indexWidth) === right(indexWidth),
left(indexWidth - 1, 0) - right(indexWidth - 1, 0),
size.U + left(indexWidth - 1, 0) - right(indexWidth - 1, 0))
}
XSDebug(false, true.B, "\n")
XSDebug(p"ptr: ")
(0 until size).reverse.foreach { i =>
val isPtr = i.U === headIndex || i.U === tailIndex || i.U === dispatchIndex
XSDebug(false, isPtr, "^")
XSDebug(false, !isPtr, " ")
}
XSDebug(false, true.B, "\n")
val validEntries = Mux(headDirection === tailDirection, tailIndex - headIndex, size.U + tailIndex - headIndex)
val dispatchEntries = Mux(dispatchDirection === tailDirection, tailIndex - dispatchIndex, size.U + tailIndex - dispatchIndex)
XSError(validEntries < dispatchEntries, "validEntries should be less than dispatchEntries\n")
val validEntries = distanceBetween(tailPtr, headPtr)
val dispatchEntries = distanceBetween(tailPtr, dispatchPtr)
val commitEntries = validEntries - dispatchEntries
val emptyEntries = size.U - validEntries
/**
* Part 1: update states and uops when enqueue, dequeue, commit, redirect/replay
*
* uop only changes when a new instruction enqueues.
*
* state changes when
* (1) enqueue: from s_invalid to s_valid
* (2) dequeue: from s_valid to s_dispatched
* (3) commit: from s_dispatched to s_invalid
* (4) redirect (branch misprediction or exception): from any state to s_invalid (flushed)
* (5) redirect (replay): from s_dispatched to s_valid (re-dispatch)
*/
// enqueue: from s_invalid to s_valid
for (i <- 0 until enqnum) {
when (io.enq(i).fire()) {
uopEntries(enqIndex(i)) := io.enq(i).bits
......@@ -101,9 +83,11 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
}
}
// dequeue: from s_valid to s_dispatched
for (i <- 0 until deqnum) {
when (io.deq(i).fire()) {
stateEntries(deqIndex(i)) := s_dispatched
XSError(stateEntries(deqIndex(i)) =/= s_valid, "state of the dispatch entry is not s_valid\n")
}
}
......@@ -114,21 +98,24 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
for (i <- 0 until CommitWidth) {
when (commitBits(i)) {
stateEntries(commitIndex(i)) := s_invalid
XSError(stateEntries(commitIndex(i)) =/= s_dispatched, "state of the commit entry is not s_dispatched\n")
}
}
// redirect: cancel uops currently in the queue
val mispredictionValid = io.redirect.valid && io.redirect.bits.isMisPred
val exceptionValid = io.redirect.valid && io.redirect.bits.isException
val roqNeedFlush = Wire(Vec(size, Bool()))
val needCancel = Wire(Vec(size, Bool()))
for (i <- 0 until size) {
roqNeedFlush(i) := uopEntries(i.U).needFlush(io.redirect)
val needCancel = stateEntries(i) =/= s_invalid && ((roqNeedFlush(i) && io.redirect.bits.isMisPred) || exceptionValid)
when (needCancel) {
needCancel(i) := stateEntries(i) =/= s_invalid && ((roqNeedFlush(i) && mispredictionValid) || exceptionValid)
when (needCancel(i)) {
stateEntries(i) := s_invalid
}
XSInfo(needCancel, p"valid entry($i)(pc = ${Hexadecimal(uopEntries(i.U).cf.pc)})" +
XSInfo(needCancel(i), p"valid entry($i)(pc = ${Hexadecimal(uopEntries(i.U).cf.pc)})" +
p"roqIndex 0x${Hexadecimal(uopEntries(i.U).roqIdx)} " +
p"cancelled with redirect roqIndex 0x${Hexadecimal(io.redirect.bits.roqIdx)}\n")
}
......@@ -146,19 +133,74 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
}
/**
* Part 2: update indices
* Part 2: walk
*
* Instead of keeping the walking distances, we keep the walking target position for simplicity.
*
* (1) replay: move dispatchPtr to the first needReplay entry
* (2) redirect (branch misprediction): move dispatchPtr, tailPtr to the first cancelled entry
*
*/
// getFirstIndex: get the head index of consecutive ones
// note that it returns the position starting from either the leftmost or the rightmost
// 00000001 => 0
// 00111000 => 3
// 11000111 => 1
// 10000000 => 0
def getFirstMaskPosition(mask: Vec[Bool]) = {
Mux(mask(size - 1),
PriorityEncoder(mask.reverse.map(m => !m)),
PriorityEncoder(mask)
)
}
val cancelPosition = getFirstMaskPosition(needCancel)
val replayPosition = getFirstMaskPosition(needReplay)
assert(cancelPosition.getWidth == indexWidth)
assert(replayPosition.getWidth == indexWidth)
// If the highest bit is one, the direction flips.
// Otherwise, the direction keeps the same.
val tailCancelPtr = Cat(Mux(needCancel(size - 1), ~tailDirection, tailDirection), cancelPosition)
// In case of branch mis-prediction, the last dispatched instruction must be the mis-prediction instruction.
// Thus, we only need to reset dispatchPtr to tailPtr.
val dispatchCancelPtr = tailCancelPtr
// In case of replay, we need to walk back and recover preg states in the busy table.
// We keep track of the number of entries needed to be walked instead of target position to reduce overhead
val dispatchReplayCnt = Mux(needReplay(size - 1), dispatchIndex + replayPosition, dispatchIndex - replayPosition)
val inReplayWalk = dispatchReplayCnt =/= 0.U
val dispatchReplayCntReg = Reg(UInt(indexWidth.W))
val dispatchReplayStep = Mux(dispatchReplayCntReg > replayWidth.U, replayWidth.U, dispatchReplayCntReg)
when (io.redirect.valid && io.redirect.bits.isReplay) {
dispatchReplayCntReg := dispatchReplayCnt
}.otherwise {
dispatchReplayCntReg := dispatchReplayCntReg - dispatchReplayStep
}
val replayIndex = (0 until replayWidth).map(i => (dispatchPtr - i.U)(indexWidth - 1, 0))
for (i <- 0 until replayWidth) {
val replayValid = stateEntries(replayIndex(i)) === s_valid
io.replayPregReq(i).isInt := replayValid && uopEntries(replayIndex(i)).ctrl.src1Type === SrcType.reg
io.replayPregReq(i).isFp := replayValid && uopEntries(replayIndex(i)).ctrl.src1Type === SrcType.fp
io.replayPregReq(i).preg := uopEntries(replayIndex(i)).pdest
}
/**
* Part 3: update indices
*
* tail: (1) enqueue; (2) walk in case of redirect
* dispatch: (1) dequeue; (2) replay; (3) walk in case of redirect
* dispatch: (1) dequeue; (2) walk in case of replay; (3) walk in case of redirect
* head: commit
*/
// enqueue
val numEnqTry = Mux(emptyEntries > enqnum.U, enqnum.U, emptyEntries)
val numEnq = PriorityEncoder(io.enq.map(!_.fire()) :+ true.B)
val numWalkTailTry = PriorityEncoder(walkTailIndex.map(i => stateEntries(i) =/= s_invalid) :+ true.B)
val numWalkTail = Mux(numWalkTailTry > validEntries, validEntries, numWalkTailTry)
XSError(numEnq =/= 0.U && numWalkTail =/= 0.U, "should not enqueue when walk\n")
tailPtr := Mux(exceptionValid, 0.U, tailPtr + Mux(numEnq =/= 0.U, numEnq, -numWalkTail))
XSError(numEnq =/= 0.U && (mispredictionValid || exceptionValid), "should not enqueue when redirect\n")
tailPtr := Mux(exceptionValid,
0.U,
Mux(mispredictionValid,
tailCancelPtr,
tailPtr + numEnq)
)
// dequeue
val numDeqTry = Mux(dispatchEntries > deqnum.U, deqnum.U, dispatchEntries)
......@@ -169,34 +211,57 @@ class DispatchQueue(size: Int, enqnum: Int, deqnum: Int, replayWidth: Int) exten
!deq.fire() && (if (i == 0) true.B else stateEntries(deqIndex(i)) =/= s_dispatched)
} :+ true.B)
val numDeq = Mux(numDeqTry > numDeqFire, numDeqFire, numDeqTry)
// TODO: this is unaccptable since it needs to add 64 bits
val headMask = (1.U((size+1).W) << headIndex).asUInt() - 1.U
val dispatchMask = (1.U((size + 1).W) << dispatchIndex).asUInt() - 1.U
val mask = headMask ^ dispatchMask
val replayMask = Mux(headDirection === dispatchDirection, mask, ~mask)
val numReplay = PopCount((0 until size).map(i => needReplay(i) & replayMask(i)))
val numWalkDispatchTry = PriorityEncoder(walkDispatchPtr.map(i => stateEntries(i) =/= s_invalid) :+ true.B)
val numWalkDispatch = Mux(numWalkDispatchTry > commitEntries, commitEntries, numWalkDispatchTry)
val walkCntDispatch = numWalkDispatch + numReplay
// note that numDeq === 0.U entries after dispatch are all flushed
// so, numDeq and walkCntDispatch cannot be nonzero at the same time
XSError(numDeq =/= 0.U && walkCntDispatch =/= 0.U, "should not dequeue when walk\n")
dispatchPtr := Mux(exceptionValid, 0.U, dispatchPtr + Mux(numDeq =/= 0.U, numDeq, -walkCntDispatch))
dispatchPtr := Mux(exceptionValid,
0.U,
// TODO: misprediction when replay? need to compare ROB index
Mux(mispredictionValid,
dispatchCancelPtr,
dispatchPtr - dispatchReplayCntReg)
)
headPtr := Mux(exceptionValid, 0.U, headPtr + numCommit)
/**
* Part 3: set output and input
* Part 4: set output and input
*/
val enqReadyBits = (1.U << numEnqTry).asUInt() - 1.U
for (i <- 0 until enqnum) {
io.enq(i).ready := enqReadyBits(i).asBool()
io.enq(i).ready := enqReadyBits(i).asBool() && !inReplayWalk
}
for (i <- 0 until deqnum) {
io.deq(i).bits := uopEntries(deqIndex(i))
// do not dequeue when io.redirect valid because it may cause dispatchPtr work improperly
io.deq(i).valid := stateEntries(deqIndex(i)) === s_valid && !io.redirect.valid
io.deq(i).valid := stateEntries(deqIndex(i)) === s_valid && !io.redirect.valid && !inReplayWalk
}
// debug: dump dispatch queue states
def greaterOrEqualThan(left: UInt, right: UInt) = {
Mux(
left(indexWidth) === right(indexWidth),
left(indexWidth - 1, 0) >= right(indexWidth - 1, 0),
left(indexWidth - 1, 0) <= right(indexWidth - 1, 0)
)
}
XSError(!greaterOrEqualThan(tailPtr, headPtr), p"assert greaterOrEqualThan(tailPtr: $tailPtr, headPtr: $headPtr) failed\n")
XSError(!greaterOrEqualThan(tailPtr, dispatchPtr) && !inReplayWalk, p"assert greaterOrEqualThan(tailPtr: $tailPtr, dispatchPtr: $dispatchPtr) failed\n")
XSError(!greaterOrEqualThan(dispatchPtr, headPtr), p"assert greaterOrEqualThan(dispatchPtr: $dispatchPtr, headPtr: $headPtr) failed\n")
XSError(validEntries < dispatchEntries && !inReplayWalk, "validEntries should be less than dispatchEntries\n")
XSDebug(p"head: $headPtr, tail: $tailPtr, dispatch: $dispatchPtr\n")
XSDebug(p"state: ")
stateEntries.reverse.foreach { s =>
XSDebug(false, s === s_invalid, "-")
XSDebug(false, s === s_valid, "v")
XSDebug(false, s === s_dispatched, "d")
}
XSDebug(false, true.B, "\n")
XSDebug(p"ptr: ")
(0 until size).reverse.foreach { i =>
val isPtr = i.U === headIndex || i.U === tailIndex || i.U === dispatchIndex
XSDebug(false, isPtr, "^")
XSDebug(false, !isPtr, " ")
}
XSDebug(false, true.B, "\n")
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册