LoopPredictor.scala 13.8 KB
Newer Older
1 2 3 4 5 6
package xiangshan.frontend

import chisel3._
import chisel3.util._
import xiangshan._
import utils._
7
import xiangshan.backend.brq.BrqPtr
L
Lingrui98 已提交
8
import chisel3.experimental.chiselName
9

G
GouLingrui 已提交
10
trait LTBParams extends HasXSParameter with HasBPUParameter {
11 12 13
  //  +-----------+---------+--------------+-----------+
  //  |    tag    |   idx   |    4 bits    | 0 (1 bit) |
  //  +-----------+---------+--------------+-----------+
Z
zhanglinjuan 已提交
14
  val tagLen = 24
15 16 17 18 19 20
  val nRows = 16
  val idxLen = log2Up(nRows)
  val cntBits = 10
}

abstract class LTBBundle extends XSBundle with LTBParams
G
GouLingrui 已提交
21
abstract class LTBModule extends XSModule with LTBParams { val debug = false }
22

23 24
// class LoopMeta extends LTBBundle {
// }
25

26
class LoopEntry extends LTBBundle {
27 28 29 30 31 32 33
  val tag = UInt(tagLen.W)
  // how many times has the same loop trip count been seen in a row?
  val conf = UInt(3.W)
  // usefulness count, an entry can be replaced only if age counter is null
  val age = UInt(3.W) // TODO: delete this
  // loop trip count, the number of taken loop-branch before the last not-taken
  val tripCnt = UInt(cntBits.W)
34 35
  // the number of times loop-branch has been taken speculatively in a row
  val specCnt = UInt(cntBits.W)
36 37
  // the number of times loop-branch has been taken un-speculatively in a row
  val nSpecCnt = UInt(cntBits.W)
38 39
  // brTag of the latest not-taken/loop-exit branch
  val brTag = new BrqPtr
40
  val unusable = Bool()
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

  def isLearned = conf === 7.U
  def isConf = conf =/= 0.U
  def isUnconf = conf === 0.U
}

class LTBColumnReq extends LTBBundle {
  val pc = UInt(VAddrBits.W) // only for debug!!!
  val idx = UInt(idxLen.W)
  val tag = UInt(tagLen.W)
}

class LTBColumnResp extends LTBBundle {
  // exit the loop
  val exit = Bool()
56
  val meta = UInt(cntBits.W)
57 58 59 60 61
}

class LTBColumnUpdate extends LTBBundle {
  val misPred = Bool()
  val pc = UInt(VAddrBits.W)
62
  val meta = UInt(cntBits.W)
63
  val taken = Bool()
64
  val brTag = new BrqPtr
65 66 67
}

// each column/bank of Loop Termination Buffer
L
Lingrui98 已提交
68
@chiselName
69 70 71
class LTBColumn extends LTBModule {
  val io = IO(new Bundle() {
    // if3 send req
72 73 74 75
    val req = Input(new LTBColumnReq)
    val if3_fire = Input(Bool())
    val if4_fire = Input(Bool())
    val outMask = Input(Bool())
76 77 78 79 80 81
    // send out resp to if4
    val resp = Output(new LTBColumnResp)
    val update = Input(Valid(new LTBColumnUpdate))
    val repair = Input(Bool()) // roll back specCnts in the other 15 LTBs
  })

L
Lingrui98 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95
  class LTBMem extends LTBModule {
    val io = IO(new Bundle {
      val rIdx = Input(UInt(idxLen.W))
      val rdata = Output(new LoopEntry)
      val urIdx = Input(UInt(idxLen.W))
      val urdata = Output(new LoopEntry)
      val wen = Input(Bool())
      val wIdx = Input(UInt(idxLen.W))
      val wdata = Input(new LoopEntry)
      val swen = Input(Bool())
      val swIdx = Input(UInt(idxLen.W))
      val swdata = Input(new LoopEntry)
      val copyCnt = Input(Vec(nRows, Bool()))
    })
96 97 98
    
    // val mem = RegInit(0.U.asTypeOf(Vec(nRows, new LoopEntry)))
    val mem = Mem(nRows, new LoopEntry)
L
Lingrui98 已提交
99 100
    io.rdata  := mem(io.rIdx)
    io.urdata := mem(io.urIdx)
101 102
    val wdata = WireInit(io.wdata)
    val swdata = WireInit(io.swdata)
L
Lingrui98 已提交
103
    for (i <- 0 until nRows) {
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
      val copyValid = io.copyCnt(i)
      when (copyValid && io.swIdx === i.U && io.swen) {
        swdata.specCnt := mem(i).nSpecCnt
      }
      val wd = WireInit(mem(i)) // default for copycnt
      val wen = WireInit(io.copyCnt(i) || io.wen && io.wIdx === i.U || io.swen && io.swIdx === i.U)
      when (!copyValid) {
        when (io.swen) {
          wd := swdata
        }.elsewhen (io.wen) {
          wd := wdata
        }
      }
      when (wen) {
        mem.write(i.U, wd)
L
Lingrui98 已提交
119 120 121 122 123
      }
    }
  }
  // val ltb = Reg(Vec(nRows, new LoopEntry))
  val ltb = Module(new LTBMem).io
124 125 126
  val ltbAddr = new TableAddr(idxLen + 4, PredictWidth)
  val updateIdx = ltbAddr.getBankIdx(io.update.bits.pc)
  val updateTag = ltbAddr.getTag(io.update.bits.pc)(tagLen - 1, 0)
127
  val updateBrTag = io.update.bits.brTag
128 129 130 131 132 133 134

  val doingReset = RegInit(true.B)
  val resetIdx = RegInit(0.U(idxLen.W))
  resetIdx := resetIdx + doingReset
  when (resetIdx === (nRows - 1).U) { doingReset := false.B }

  // during branch prediction
135 136 137
  val if4_idx = io.req.idx
  val if4_tag = io.req.tag
  val if4_pc = io.req.pc // only for debug
138 139
  ltb.rIdx := if4_idx
  val if4_entry = WireInit(ltb.rdata)
140

141 142 143 144 145 146 147 148
  val valid = RegInit(false.B)
  when (io.if4_fire) { valid := false.B }
  when (io.if3_fire) { valid := true.B }
  when (io.update.valid && io.update.bits.misPred) { valid := false.B }

  io.resp.meta := if4_entry.specCnt + 1.U
  // io.resp.exit := if4_tag === if4_entry.tag && (if4_entry.specCnt + 1.U) === if4_entry.tripCnt && valid && !if4_entry.unusable
  io.resp.exit := if4_tag === if4_entry.tag && (if4_entry.specCnt + 1.U) === if4_entry.tripCnt && valid && if4_entry.isConf
149 150

  // when resolving a branch
L
Lingrui98 已提交
151 152
  ltb.urIdx := updateIdx
  val entry = ltb.urdata
153
  val tagMatch = entry.tag === updateTag
154
  val cntMatch = entry.tripCnt === io.update.bits.meta
155 156
  val wEntry = WireInit(entry)

L
Lingrui98 已提交
157 158
  ltb.wIdx := updateIdx
  ltb.wdata := wEntry
159
  ltb.wen := false.B
L
Lingrui98 已提交
160

161 162 163 164 165 166 167 168
  when (io.update.valid && !doingReset) {
    // When a branch resolves and is found to not be in the LTB,
    // it is inserted into the LTB if determined to be a loop-branch and if it is mispredicted by the default predictor.
    when (!tagMatch && io.update.bits.misPred) {
      wEntry.tag := updateTag
      wEntry.conf := 0.U
      wEntry.age := 7.U
      wEntry.tripCnt := Fill(cntBits, 1.U(1.W))
169 170
      wEntry.specCnt := Mux(io.update.bits.taken, 1.U, 0.U)
      wEntry.nSpecCnt := Mux(io.update.bits.taken, 1.U, 0.U)
171
      wEntry.brTag := updateBrTag
172
      wEntry.unusable := false.B
L
Lingrui98 已提交
173
      // ltb(updateIdx) := wEntry
174
      ltb.wen := true.B
175 176 177
    }.elsewhen (tagMatch) {
      // During resolution, a taken branch found in the LTB has its nSpecCnt incremented by one.
      when (io.update.bits.taken) {
178 179
        wEntry.nSpecCnt := entry.nSpecCnt + 1.U
        wEntry.specCnt := Mux(io.update.bits.misPred/* && !entry.brTag.needBrFlush(updateBrTag)*/, entry.nSpecCnt + 1.U, entry.specCnt)
180 181 182
        wEntry.conf := Mux(io.update.bits.misPred, 0.U, entry.conf)
        // wEntry.tripCnt := Fill(cntBits, 1.U(1.W))
        wEntry.tripCnt := Mux(io.update.bits.misPred, Fill(cntBits, 1.U(1.W)), entry.tripCnt)
183 184
      // A not-taken loop-branch found in the LTB during branch resolution updates its trip count and conf.
      }.otherwise {
185
        // wEntry.conf := Mux(entry.nSpecCnt === entry.tripCnt, Mux(entry.isLearned, 7.U, entry.conf + 1.U), 0.U)
186 187
        // wEntry.conf := Mux(io.update.bits.misPred, 0.U, Mux(entry.isLearned, 7.U, entry.conf + 1.U))
        wEntry.conf := Mux((entry.nSpecCnt + 1.U) === entry.tripCnt, Mux(entry.isLearned, 7.U, entry.conf + 1.U), 0.U)
188 189
        // wEntry.tripCnt := entry.nSpecCnt + 1.U
        wEntry.tripCnt := io.update.bits.meta
190
        wEntry.specCnt := Mux(io.update.bits.misPred, 0.U, entry.specCnt/* - entry.nSpecCnt - 1.U*/)
191
        wEntry.nSpecCnt := 0.U
192
        wEntry.brTag := updateBrTag
193
        wEntry.unusable := io.update.bits.misPred && (io.update.bits.meta > entry.tripCnt)
194
      }
L
Lingrui98 已提交
195
      // ltb(updateIdx) := wEntry
196
      ltb.wen := true.B
197 198 199
    }
  }

200
  // speculatively update specCnt
L
Lingrui98 已提交
201
  ltb.swen := valid && if4_entry.tag === if4_tag || doingReset
202 203
  ltb.swIdx := Mux(doingReset, resetIdx, if4_idx)
  val swEntry = WireInit(if4_entry)
L
Lingrui98 已提交
204
  ltb.swdata := Mux(doingReset, 0.U.asTypeOf(new LoopEntry), swEntry)
205
  when (io.if4_fire && if4_entry.tag === if4_tag && io.outMask) {
206
    when ((if4_entry.specCnt + 1.U) === if4_entry.tripCnt/* && if4_entry.isConf*/) {
L
Lingrui98 已提交
207 208
      swEntry.age := 7.U
      swEntry.specCnt := 0.U
209
    }.otherwise {
210 211
      swEntry.age := Mux(if4_entry.age === 7.U, 7.U, if4_entry.age + 1.U)
      swEntry.specCnt := if4_entry.specCnt + 1.U
212 213 214
    }
  }

215
  // Reseting
L
Lingrui98 已提交
216 217 218
  // when (doingReset) {
  //   ltb(resetIdx) := 0.U.asTypeOf(new LoopEntry)
  // }
219 220 221

  // when a branch misprediction occurs, all of the nSpecCnts copy their values into the specCnts
  for (i <- 0 until nRows) {
L
Lingrui98 已提交
222
    ltb.copyCnt(i) := io.update.valid && io.update.bits.misPred && i.U =/= updateIdx || io.repair
223 224
  }

225 226
  // bypass for if4_entry.specCnt
  when (io.update.valid && !doingReset && valid && updateIdx === if4_idx) {
227
    when (!tagMatch && io.update.bits.misPred || tagMatch) {
L
Lingrui98 已提交
228
      swEntry.specCnt := wEntry.specCnt
229 230
    }
  }
231
  when (io.repair && !doingReset && valid) {
L
Lingrui98 已提交
232
    swEntry.specCnt := if4_entry.nSpecCnt
233 234
  }

G
GouLingrui 已提交
235 236 237
  if (BPUDebug && debug) {
    //debug info
    XSDebug(doingReset, "Reseting...\n")
238 239 240 241 242 243 244
    XSDebug("if3_fire=%d if4_fire=%d valid=%d\n", io.if3_fire, io.if4_fire,valid)
    XSDebug("[req] v=%d pc=%x idx=%x tag=%x\n", valid, io.req.pc, io.req.idx, io.req.tag)
    XSDebug("[if4_entry] tag=%x conf=%d age=%d tripCnt=%d specCnt=%d nSpecCnt=%d", 
      if4_entry.tag, if4_entry.conf, if4_entry.age, if4_entry.tripCnt, if4_entry.specCnt, if4_entry.nSpecCnt)
    XSDebug(false, true.B, p" brTag=${if4_entry.brTag} unusable=${if4_entry.unusable}\n")
    XSDebug(io.if4_fire && if4_entry.tag === if4_tag && io.outMask, "[speculative update] new specCnt=%d\n",
      Mux((if4_entry.specCnt + 1.U) === if4_entry.tripCnt, 0.U, if4_entry.specCnt + 1.U))
G
GouLingrui 已提交
245 246 247
    XSDebug("[update] v=%d misPred=%d pc=%x idx=%x tag=%x meta=%d taken=%d tagMatch=%d cntMatch=%d", io.update.valid, io.update.bits.misPred, io.update.bits.pc, updateIdx, updateTag, io.update.bits.meta, io.update.bits.taken, tagMatch, cntMatch)
    XSDebug(false, true.B, p" brTag=${updateBrTag}\n")
    XSDebug("[entry ] tag=%x conf=%d age=%d tripCnt=%d specCnt=%d nSpecCnt=%d", entry.tag, entry.conf, entry.age, entry.tripCnt, entry.specCnt, entry.nSpecCnt)
248
    XSDebug(false, true.B, p" brTag=${entry.brTag} unusable=${entry.unusable}\n")
G
GouLingrui 已提交
249
    XSDebug("[wEntry] tag=%x conf=%d age=%d tripCnt=%d specCnt=%d nSpecCnt=%d", wEntry.tag, wEntry.conf, wEntry.age, wEntry.tripCnt, wEntry.specCnt, wEntry.nSpecCnt)
250
    XSDebug(false, true.B, p" brTag=${wEntry.brTag} unusable=${wEntry.unusable}\n")
G
GouLingrui 已提交
251 252
    XSDebug(io.update.valid && io.update.bits.misPred || io.repair, "MisPred or repairing, all of the nSpecCnts copy their values into the specCnts\n")
  }
253

254 255
}

L
Lingrui98 已提交
256
@chiselName
257 258 259 260 261 262 263
class LoopPredictor extends BasePredictor with LTBParams {
  class LoopResp extends Resp {
    val exit = Vec(PredictWidth, Bool())
  }
  class LoopMeta extends Meta {
    val specCnts = Vec(PredictWidth, UInt(cntBits.W))
  }
264 265 266 267
  class LoopRespIn extends XSBundle {
    val taken = Bool()
    val jmpIdx = UInt(log2Up(PredictWidth).W)
  }
268 269

  class LoopIO extends DefaultBasePredictorIO {
270
    val respIn = Input(new LoopRespIn)
271 272 273
    val resp = Output(new LoopResp)
    val meta = Output(new LoopMeta)
  }
274

275 276
  override val io = IO(new LoopIO)
  
277 278 279 280
  val ltbs = Seq.fill(PredictWidth) { Module(new LTBColumn) }

  val ltbAddr = new TableAddr(idxLen + 4, PredictWidth)

281 282 283 284 285 286
  // Latch for 1 cycle
  val pc = RegEnable(io.pc.bits, io.pc.valid)
  val inMask = RegEnable(io.inMask, io.pc.valid)
  val baseBank = ltbAddr.getBank(pc)
  val baseRow = ltbAddr.getBankIdx(pc)
  val baseTag = ltbAddr.getTag(pc)
287 288 289 290
  val nextRowStartsUp = baseRow.andR // TODO: use parallel andR
  val isInNextRow = VecInit((0 until PredictWidth).map(_.U < baseBank))
  val tagIncremented = VecInit((0 until PredictWidth).map(i => isInNextRow(i.U) && nextRowStartsUp))
  val realTags = VecInit((0 until PredictWidth).map(i => Mux(tagIncremented(i), baseTag + 1.U, baseTag)(tagLen - 1, 0)))
291
  val bankIdxInOrder = VecInit((0 until PredictWidth).map(i => (baseBank +& i.U)(log2Up(PredictWidth) - 1, 0)))
292 293
  val realMask = circularShiftLeft(inMask, PredictWidth, baseBank)
  val outMask = inMask & (Fill(PredictWidth, !io.respIn.taken) | (Fill(PredictWidth, 1.U(1.W)) >> (~io.respIn.jmpIdx)))
294 295

  for (i <- 0 until PredictWidth) {
296 297
    ltbs(i).io.req.pc := pc
    ltbs(i).io.outMask := false.B
298 299
    for (j <- 0 until PredictWidth) {
      when (Mux(isInNextRow(i), baseBank + j.U === (PredictWidth + i).U, baseBank + j.U === i.U)) {
300 301
        ltbs(i).io.req.pc := pc + (j.U << 1)
        ltbs(i).io.outMask := outMask(j).asBool
302 303 304 305
      }
    }
  }

306
  for (i <- 0 until PredictWidth) {
307 308 309 310 311
    ltbs(i).io.if3_fire := io.pc.valid
    ltbs(i).io.if4_fire := io.outFire
    ltbs(i).io.req.idx := Mux(isInNextRow(i), baseRow + 1.U, baseRow)
    ltbs(i).io.req.tag := realTags(i)
    // ltbs(i).io.outMask := outMask(i)
312
    ltbs(i).io.update.valid := i.U === ltbAddr.getBank(io.update.bits.ui.pc) && io.update.valid && io.update.bits.ui.pd.isBr
313 314 315 316
    ltbs(i).io.update.bits.misPred := io.update.bits.ui.isMisPred
    ltbs(i).io.update.bits.pc := io.update.bits.ui.pc
    ltbs(i).io.update.bits.meta := io.update.bits.ui.brInfo.specCnt
    ltbs(i).io.update.bits.taken := io.update.bits.ui.taken
317
    ltbs(i).io.update.bits.brTag := io.update.bits.ui.brTag
318
    ltbs(i).io.repair := i.U =/= ltbAddr.getBank(io.update.bits.ui.pc) && io.update.valid && io.update.bits.ui.isMisPred
319 320 321 322
  }

  val ltbResps = VecInit((0 until PredictWidth).map(i => ltbs(i).io.resp))

323 324
  (0 until PredictWidth).foreach(i => io.resp.exit(i) := ltbResps(bankIdxInOrder(i)).exit)
  (0 until PredictWidth).foreach(i => io.meta.specCnts(i) := ltbResps(bankIdxInOrder(i)).meta)
325

G
GouLingrui 已提交
326 327
  if (BPUDebug && debug) {
    // debug info
328 329 330
    XSDebug("[IF3][req] fire=%d flush=%d fetchpc=%x\n", io.pc.valid, io.flush, io.pc.bits)
    XSDebug("[IF4][req] fire=%d baseBank=%x baseRow=%x baseTag=%x\n", io.outFire, baseBank, baseRow, baseTag)
    XSDebug("[IF4][req] isInNextRow=%b tagInc=%b\n", isInNextRow.asUInt, tagIncremented.asUInt)
G
GouLingrui 已提交
331
    for (i <- 0 until PredictWidth) {
332
      XSDebug("[IF4][req] bank %d: realMask=%d pc=%x idx=%x tag=%x\n", i.U, realMask(i), ltbs(i).io.req.pc, ltbs(i).io.req.idx, ltbs(i).io.req.tag)
G
GouLingrui 已提交
333
    }
334
    XSDebug("[IF4] baseBank=%x bankIdxInOrder=", baseBank)
G
GouLingrui 已提交
335
    for (i <- 0 until PredictWidth) {
336
      XSDebug(false, true.B, "%x ", bankIdxInOrder(i))
G
GouLingrui 已提交
337 338 339
    }
    XSDebug(false, true.B, "\n")
    for (i <- 0 until PredictWidth) {
340 341 342
      XSDebug(io.outFire && (i.U === 0.U || i.U === 8.U), "[IF4][resps]")
      XSDebug(false, io.outFire, " %d:%d %d", i.U, io.resp.exit(i), io.meta.specCnts(i))
      XSDebug(false, io.outFire && (i.U === 7.U || i.U === 15.U), "\n")
G
GouLingrui 已提交
343
    }
Z
zhanglinjuan 已提交
344
  }
345
}