Tage.scala 27.5 KB
Newer Older
L
Lingrui98 已提交
1 2 3 4 5 6
package xiangshan.frontend

import chisel3._
import chisel3.util._
import xiangshan._
import utils._
L
Lingrui98 已提交
7
import chisel3.experimental.chiselName
L
Lingrui98 已提交
8 9

import scala.math.min
L
Lingrui98 已提交
10
import scala.util.matching.Regex
L
Lingrui98 已提交
11

12
trait HasTageParameter extends HasXSParameter with HasBPUParameter with HasIFUConst {
L
Lingrui98 已提交
13 14 15 16 17 18 19
  //                   Sets  Hist   Tag
  val TableInfo = Seq(( 128,    2,    7),
                      ( 128,    4,    7),
                      ( 256,    8,    8),
                      ( 256,   16,    8),
                      ( 128,   32,    9),
                      ( 128,   64,    9))
G
GouLingrui 已提交
20 21 22 23 24 25
                      // (  64,   64,   11),
                      // (  64,  101,   12),
                      // (  64,  160,   12),
                      // (  64,  254,   13),
                      // (  32,  403,   14),
                      // (  32,  640,   15))
L
Lingrui98 已提交
26
  val TageNTables = TableInfo.size
27
  val UBitPeriod = 2048
L
Lingrui98 已提交
28
  val TageBanks = PredictWidth // FetchWidth
29 30 31 32 33 34
  val TageCtrBits = 3
  val SCHistLens = 0 :: TableInfo.map{ case (_,h,_) => h}.toList
  val SCNTables = 6
  val SCCtrBits = 6
  val SCNRows = 1024
  val SCTableInfo = Seq.fill(SCNTables)((SCNRows, SCCtrBits)) zip SCHistLens map {case ((n, cb), h) => (n, cb, h)}
L
Lingrui98 已提交
35 36
  val TotalBits = TableInfo.map {
    case (s, h, t) => {
37
      s * (1+t+TageCtrBits) * PredictWidth
L
Lingrui98 已提交
38 39
    }
  }.reduce(_+_)
L
Lingrui98 已提交
40 41
}

42
abstract class TageBundle extends XSBundle with HasTageParameter with PredictorUtils
43
abstract class TageModule extends XSModule with HasTageParameter with PredictorUtils { val debug = true }
L
Lingrui98 已提交
44

L
Lingrui98 已提交
45 46 47



L
Lingrui98 已提交
48
class TageReq extends TageBundle {
L
Lingrui98 已提交
49 50 51
  val pc = UInt(VAddrBits.W)
  val hist = UInt(HistoryLength.W)
  val mask = UInt(PredictWidth.W)
L
Lingrui98 已提交
52 53 54
}

class TageResp extends TageBundle {
55
  val ctr = UInt(TageCtrBits.W)
L
Lingrui98 已提交
56
  val u = UInt(2.W)
L
Lingrui98 已提交
57 58 59
}

class TageUpdate extends TageBundle {
L
Lingrui98 已提交
60
  val pc = UInt(VAddrBits.W)
61
  val fetchIdx = UInt(log2Up(TageBanks).W)
L
Lingrui98 已提交
62 63 64 65 66
  val hist = UInt(HistoryLength.W)
  // update tag and ctr
  val mask = Vec(TageBanks, Bool())
  val taken = Vec(TageBanks, Bool())
  val alloc = Vec(TageBanks, Bool())
67
  val oldCtr = Vec(TageBanks, UInt(TageCtrBits.W))
L
Lingrui98 已提交
68 69 70
  // update u
  val uMask = Vec(TageBanks, Bool())
  val u = Vec(TageBanks, UInt(2.W))
L
Lingrui98 已提交
71 72 73
}

class FakeTageTable() extends TageModule {
L
Lingrui98 已提交
74 75 76 77 78 79
  val io = IO(new Bundle() {
    val req = Input(Valid(new TageReq))
    val resp = Output(Vec(TageBanks, Valid(new TageResp)))
    val update = Input(new TageUpdate)
  })
  io.resp := DontCare
L
Lingrui98 已提交
80 81

}
L
Lingrui98 已提交
82
@chiselName
L
Lingrui98 已提交
83
class TageTable(val nRows: Int, val histLen: Int, val tagLen: Int, val uBitPeriod: Int) extends TageModule with HasIFUConst {
L
Lingrui98 已提交
84 85 86 87 88
  val io = IO(new Bundle() {
    val req = Input(Valid(new TageReq))
    val resp = Output(Vec(TageBanks, Valid(new TageResp)))
    val update = Input(new TageUpdate)
  })
89
  // override val debug = true
L
Lingrui98 已提交
90
  // bypass entries for tage update
L
Lingrui98 已提交
91
  val wrBypassEntries = 4
L
Lingrui98 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

  def compute_folded_hist(hist: UInt, l: Int) = {
    val nChunks = (histLen + l - 1) / l
    val hist_chunks = (0 until nChunks) map {i =>
      hist(min((i+1)*l, histLen)-1, i*l)
    }
    hist_chunks.reduce(_^_)
  }

  def compute_tag_and_hash(unhashed_idx: UInt, hist: UInt) = {
    val idx_history = compute_folded_hist(hist, log2Ceil(nRows))
    val idx = (unhashed_idx ^ idx_history)(log2Ceil(nRows)-1,0)
    val tag_history = compute_folded_hist(hist, tagLen)
    // Use another part of pc to make tags
    val tag = ((unhashed_idx >> log2Ceil(nRows)) ^ tag_history)(tagLen-1,0)
    (idx, tag)
  }

110
  def inc_ctr(ctr: UInt, taken: Bool): UInt = satUpdate(ctr, TageCtrBits, taken)
L
Lingrui98 已提交
111

L
Lingrui98 已提交
112 113 114 115 116 117 118 119
  val doing_reset = RegInit(true.B)
  val reset_idx = RegInit(0.U(log2Ceil(nRows).W))
  reset_idx := reset_idx + doing_reset
  when (reset_idx === (nRows-1).U) { doing_reset := false.B }

  class TageEntry() extends TageBundle {
    val valid = Bool()
    val tag = UInt(tagLen.W)
120
    val ctr = UInt(TageCtrBits.W)
L
Lingrui98 已提交
121 122
  }

123
  val tageEntrySz = instOffsetBits + tagLen + TageCtrBits
L
Lingrui98 已提交
124

125
  val if2_bankAlignedPC = bankAligned(io.req.bits.pc)
L
Lingrui98 已提交
126
  // this bank means cache bank
127
  val if2_startsAtOddBank = bankInGroup(if2_bankAlignedPC)(0)
L
Lingrui98 已提交
128
  // use real address to index
129
  val if2_unhashed_idx = Wire(Vec(2, UInt((log2Ceil(nRows)+tagLen).W)))
L
Lingrui98 已提交
130
  // the first bank idx always correspond with pc
131
  if2_unhashed_idx(0) := io.req.bits.pc >> (instOffsetBits+log2Ceil(TageBanks))
L
Lingrui98 已提交
132
  // when pc is at odd bank, the second bank is at the next idx
133
  if2_unhashed_idx(1) := if2_unhashed_idx(0) + if2_startsAtOddBank
L
Lingrui98 已提交
134

135 136 137 138 139
  // val idxes_and_tags = (0 until TageBanks).map(b => compute_tag_and_hash(if2_unhashed_idxes(b.U), io.req.bits.hist))
  // val (idx, tag) = compute_tag_and_hash(if2_unhashed_idx, io.req.bits.hist)
  val if2_idxes_and_tags = if2_unhashed_idx.map(compute_tag_and_hash(_, io.req.bits.hist))
  // val idxes = VecInit(if2_idxes_and_tags.map(_._1))
  // val tags = VecInit(if2_idxes_and_tags.map(_._2))
L
Lingrui98 已提交
140

141 142 143
  val if3_idxes = RegEnable(VecInit(if2_idxes_and_tags.map(_._1)), io.req.valid)
  val if3_tags = RegEnable(VecInit(if2_idxes_and_tags.map(_._2)), io.req.valid)
  // and_if3_tags = RegEnable(if2_idxes_and_tags, enable=io.req.valid)
L
Lingrui98 已提交
144 145 146

  // val idxLatch = RegEnable(idx, enable=io.req.valid)
  // val tagLatch = RegEnable(tag, enable=io.req.valid)
L
Lingrui98 已提交
147

L
Lingrui98 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  class HL_Bank (val nRows: Int = nRows) extends TageModule {
    val io = IO(new Bundle {
      val r = new Bundle {
        val req = Flipped(ValidIO(new Bundle {
          val setIdx = UInt(log2Ceil(nRows).W)
        }))
        val resp = new Bundle {
          val data = Output(Bool())
        }
      }
      val w = new Bundle {
        val req = Flipped(ValidIO(new Bundle {
          val setIdx = UInt(log2Ceil(nRows).W)
          val data = Bool()
        }))
      }
    })

    val mem = Mem(nRows, Bool())
    // 1-cycle latency just as SyncReadMem
    io.r.resp.data := RegEnable(mem.read(io.r.req.bits.setIdx), enable=io.r.req.valid)
    when (io.w.req.valid) {
      mem.write(io.w.req.bits.setIdx, io.w.req.bits.data)
    }
  }

  val hi_us = List.fill(TageBanks)(Module(new HL_Bank(nRows)))
  val lo_us = List.fill(TageBanks)(Module(new HL_Bank(nRows)))
L
Lingrui98 已提交
176 177
  val table = List.fill(TageBanks)(Module(new SRAMTemplate(new TageEntry, set=nRows, shouldReset=false, holdRead=true, singlePort=false)))

178 179 180
  val if3_hi_us_r = WireInit(0.U.asTypeOf(Vec(TageBanks, Bool())))
  val if3_lo_us_r = WireInit(0.U.asTypeOf(Vec(TageBanks, Bool())))
  val if3_table_r = WireInit(0.U.asTypeOf(Vec(TageBanks, new TageEntry)))
L
Lingrui98 已提交
181

182
  val if2_baseBank = io.req.bits.pc(log2Up(TageBanks), instOffsetBits)
183
  val if3_baseBank = RegEnable(if2_baseBank, enable=io.req.valid)
L
Lingrui98 已提交
184

185
  val if2_realMask = Mux(if2_startsAtOddBank,
L
Lingrui98 已提交
186 187
                      Cat(io.req.bits.mask(bankWidth-1,0), io.req.bits.mask(PredictWidth-1, bankWidth)),
                      io.req.bits.mask)
188
  val if3_realMask = RegEnable(if2_realMask, enable=io.req.valid)
L
Lingrui98 已提交
189

L
Lingrui98 已提交
190 191


L
Lingrui98 已提交
192 193
  (0 until TageBanks).map(
    b => {
194 195 196 197
      val idxes = VecInit(if2_idxes_and_tags.map(_._1))
      val idx = (if (b < bankWidth) Mux(if2_startsAtOddBank, idxes(1), idxes(0))
                 else Mux(if2_startsAtOddBank, idxes(0), idxes(1)))
      hi_us(b).io.r.req.valid := io.req.valid && if2_realMask(b)
L
Lingrui98 已提交
198 199
      hi_us(b).io.r.req.bits.setIdx := idx

200
      lo_us(b).io.r.req.valid := io.req.valid && if2_realMask(b)
L
Lingrui98 已提交
201 202 203
      lo_us(b).io.r.req.bits.setIdx := idx

      table(b).reset := reset.asBool
204
      table(b).io.r.req.valid := io.req.valid && if2_realMask(b)
L
Lingrui98 已提交
205
      table(b).io.r.req.bits.setIdx := idx
L
Lingrui98 已提交
206

207 208 209
      if3_hi_us_r(b) := hi_us(b).io.r.resp.data
      if3_lo_us_r(b) := lo_us(b).io.r.resp.data
      if3_table_r(b) := table(b).io.r.resp.data(0)
L
Lingrui98 已提交
210 211 212
    }
  )

213
  val if3_startsAtOddBank = RegEnable(if2_startsAtOddBank, io.req.valid)
L
Lingrui98 已提交
214

215 216 217 218 219 220
  val if3_req_rhits = VecInit((0 until TageBanks).map(b => {
    val tag = (if (b < bankWidth) Mux(if3_startsAtOddBank, if3_tags(1), if3_tags(0))
               else Mux(if3_startsAtOddBank, if3_tags(0), if3_tags(1)))
    val bank = (if (b < bankWidth) Mux(if3_startsAtOddBank, (b+bankWidth).U, b.U)
                else Mux(if3_startsAtOddBank, (b-bankWidth).U, b.U))
    if3_table_r(bank).valid && if3_table_r(bank).tag === tag
L
Lingrui98 已提交
221 222
  }))
  
L
Lingrui98 已提交
223
  (0 until TageBanks).map(b => {
224 225 226 227 228
    val bank = (if (b < bankWidth) Mux(if3_startsAtOddBank, (b+bankWidth).U, b.U)
                else Mux(if3_startsAtOddBank, (b-bankWidth).U, b.U))
    io.resp(b).valid := if3_req_rhits(b) && if3_realMask(b)
    io.resp(b).bits.ctr := if3_table_r(bank).ctr
    io.resp(b).bits.u := Cat(if3_hi_us_r(bank),if3_lo_us_r(bank))
L
Lingrui98 已提交
229 230 231 232 233 234 235 236 237 238 239
  })


  val clear_u_ctr = RegInit(0.U((log2Ceil(uBitPeriod) + log2Ceil(nRows) + 1).W))
  when (doing_reset) { clear_u_ctr := 1.U } .otherwise { clear_u_ctr := clear_u_ctr + 1.U }

  val doing_clear_u = clear_u_ctr(log2Ceil(uBitPeriod)-1,0) === 0.U
  val doing_clear_u_hi = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 1.U
  val doing_clear_u_lo = doing_clear_u && clear_u_ctr(log2Ceil(uBitPeriod) + log2Ceil(nRows)) === 0.U
  val clear_u_idx = clear_u_ctr >> log2Ceil(uBitPeriod)

240
  // Use fetchpc to compute hash
241
  val (update_idx, update_tag) = compute_tag_and_hash((io.update.pc >> (instOffsetBits + log2Ceil(TageBanks))), io.update.hist)
L
Lingrui98 已提交
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267

  val update_wdata = Wire(Vec(TageBanks, new TageEntry))


  (0 until TageBanks).map(b => {
    table(b).io.w.req.valid := io.update.mask(b) || doing_reset
    table(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, update_idx)
    table(b).io.w.req.bits.data := Mux(doing_reset, 0.U.asTypeOf(new TageEntry), update_wdata(b))
  })

  val update_hi_wdata = Wire(Vec(TageBanks, Bool()))
  (0 until TageBanks).map(b => {
    hi_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_hi
    hi_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_hi, clear_u_idx, update_idx))
    hi_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_hi, 0.U, update_hi_wdata(b))
  })

  val update_lo_wdata = Wire(Vec(TageBanks, Bool()))
  (0 until TageBanks).map(b => {
    lo_us(b).io.w.req.valid := io.update.uMask(b) || doing_reset || doing_clear_u_lo
    lo_us(b).io.w.req.bits.setIdx := Mux(doing_reset, reset_idx, Mux(doing_clear_u_lo, clear_u_idx, update_idx))
    lo_us(b).io.w.req.bits.data := Mux(doing_reset || doing_clear_u_lo, 0.U, update_lo_wdata(b))
  })

  val wrbypass_tags    = Reg(Vec(wrBypassEntries, UInt(tagLen.W)))
  val wrbypass_idxs    = Reg(Vec(wrBypassEntries, UInt(log2Ceil(nRows).W)))
268
  val wrbypass_ctrs    = Reg(Vec(wrBypassEntries, Vec(TageBanks, UInt(TageCtrBits.W))))
L
Lingrui98 已提交
269
  val wrbypass_ctr_valids = Reg(Vec(wrBypassEntries, Vec(TageBanks, Bool())))
L
Lingrui98 已提交
270 271
  val wrbypass_enq_idx = RegInit(0.U(log2Ceil(wrBypassEntries).W))

L
Lingrui98 已提交
272 273
  when (reset.asBool) { wrbypass_ctr_valids.foreach(_.foreach(_ := false.B))}

L
Lingrui98 已提交
274 275 276 277 278
  val wrbypass_hits    = VecInit((0 until wrBypassEntries) map { i =>
    !doing_reset &&
    wrbypass_tags(i) === update_tag &&
    wrbypass_idxs(i) === update_idx
  })
279 280


281
  val wrbypass_hit      = wrbypass_hits.reduce(_||_)
L
Lingrui98 已提交
282
  // val wrbypass_rhit     = wrbypass_rhits.reduce(_||_)
283
  val wrbypass_hit_idx  = PriorityEncoder(wrbypass_hits)
L
Lingrui98 已提交
284
  // val wrbypass_rhit_idx = PriorityEncoder(wrbypass_rhits)
285

L
Lingrui98 已提交
286
  // val wrbypass_rctr_hits = VecInit((0 until TageBanks).map( b => wrbypass_ctr_valids(wrbypass_rhit_idx)(b)))
287

L
Lingrui98 已提交
288
  // val rhit_ctrs = RegEnable(wrbypass_ctrs(wrbypass_rhit_idx), wrbypass_rhit)
289

L
Lingrui98 已提交
290 291 292
  // when (RegNext(wrbypass_rhit)) {
  //   for (b <- 0 until TageBanks) {
  //     when (RegNext(wrbypass_rctr_hits(b.U + baseBank))) {
293
  //       io.resp(b).bits.ctr := rhit_ctrs(if3_bankIdxInOrder(b))
L
Lingrui98 已提交
294 295 296
  //     }
  //   }
  // }
297 298


L
Lingrui98 已提交
299
  val updateBank = PriorityEncoder(io.update.mask)
L
Lingrui98 已提交
300 301 302 303 304 305

  for (w <- 0 until TageBanks) {
    update_wdata(w).ctr   := Mux(io.update.alloc(w),
      Mux(io.update.taken(w), 4.U,
                              3.U
      ),
L
Lingrui98 已提交
306 307 308 309
      Mux(wrbypass_hit && wrbypass_ctr_valids(wrbypass_hit_idx)(w),
            inc_ctr(wrbypass_ctrs(wrbypass_hit_idx)(w), io.update.taken(w)),
            inc_ctr(io.update.oldCtr(w), io.update.taken(w))
      )
L
Lingrui98 已提交
310 311 312 313 314 315 316 317 318 319
    )
    update_wdata(w).valid := true.B
    update_wdata(w).tag   := update_tag

    update_hi_wdata(w)    := io.update.u(w)(1)
    update_lo_wdata(w)    := io.update.u(w)(0)
  }

  when (io.update.mask.reduce(_||_)) {
    when (wrbypass_hits.reduce(_||_)) {
L
Lingrui98 已提交
320
      wrbypass_ctrs(wrbypass_hit_idx)(updateBank) := update_wdata(updateBank).ctr
321
      wrbypass_ctr_valids(wrbypass_hit_idx)(updateBank) := true.B
L
Lingrui98 已提交
322
    } .otherwise {
L
Lingrui98 已提交
323
      wrbypass_ctrs(wrbypass_enq_idx)(updateBank) := update_wdata(updateBank).ctr
324
      (0 until TageBanks).foreach(b => wrbypass_ctr_valids(wrbypass_enq_idx)(b) := false.B) // reset valid bits
L
Lingrui98 已提交
325
      wrbypass_ctr_valids(wrbypass_enq_idx)(updateBank) := true.B
L
Lingrui98 已提交
326 327 328 329 330
      wrbypass_tags(wrbypass_enq_idx) := update_tag
      wrbypass_idxs(wrbypass_enq_idx) := update_idx
      wrbypass_enq_idx := (wrbypass_enq_idx + 1.U)(log2Ceil(wrBypassEntries)-1,0)
    }
  }
331

332
  if (BPUDebug && debug) {
L
Lingrui98 已提交
333 334 335
    val u = io.update
    val b = PriorityEncoder(u.mask)
    val ub = PriorityEncoder(u.uMask)
336 337
    val idx = if2_idxes_and_tags.map(_._1)
    val tag = if2_idxes_and_tags.map(_._2)
L
Lingrui98 已提交
338
    XSDebug(io.req.valid, "tableReq: pc=0x%x, hist=%x, idx=(%d,%d), tag=(%x,%x), baseBank=%d, mask=%b, realMask=%b\n",
339
      io.req.bits.pc, io.req.bits.hist, idx(0), idx(1), tag(0), tag(1), if2_baseBank, io.req.bits.mask, if2_realMask)
L
Lingrui98 已提交
340
    for (i <- 0 until TageBanks) {
341 342
      XSDebug(RegNext(io.req.valid) && if3_req_rhits(i), "TageTableResp[%d]: idx=(%d,%d), hit:%d, ctr:%d, u:%d\n",
        i.U, if3_idxes(0), if3_idxes(1), if3_req_rhits(i), io.resp(i).bits.ctr, io.resp(i).bits.u)
L
Lingrui98 已提交
343
    }
344

345 346
    XSDebug(RegNext(io.req.valid), "TageTableResp: hits:%b, maskLatch is %b\n", if3_req_rhits.asUInt, if3_realMask)
    XSDebug(RegNext(io.req.valid) && !if3_req_rhits.reduce(_||_), "TageTableResp: no hits!\n")
L
Lingrui98 已提交
347 348 349 350 351 352 353 354

    XSDebug(io.update.mask.reduce(_||_), "update Table: pc:%x, fetchIdx:%d, hist:%x, bank:%d, taken:%d, alloc:%d, oldCtr:%d\n",
      u.pc, u.fetchIdx, u.hist, b, u.taken(b), u.alloc(b), u.oldCtr(b))
    XSDebug(io.update.mask.reduce(_||_), "update Table: writing tag:%b, ctr%d in idx:%d\n",
      update_wdata(b).tag, update_wdata(b).ctr, update_idx)
    XSDebug(io.update.mask.reduce(_||_), "update u: pc:%x, fetchIdx:%d, hist:%x, bank:%d, writing in u:%b\n",
      u.pc, u.fetchIdx, u.hist, ub, io.update.u(ub))

355 356 357 358 359
    val updateBank = PriorityEncoder(io.update.mask)
    XSDebug(wrbypass_hit && wrbypass_ctr_valids(wrbypass_hit_idx)(updateBank),
      "wrbypass hits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
      wrbypass_hit_idx, update_tag, update_idx, wrbypass_ctrs(wrbypass_hit_idx)(updateBank), updateBank)

L
Lingrui98 已提交
360 361 362 363 364 365 366
    // when (wrbypass_rhit && wrbypass_ctr_valids(wrbypass_rhit_idx).reduce(_||_)) {
    //   for (b <- 0 until TageBanks) {
    //     XSDebug(wrbypass_ctr_valids(wrbypass_rhit_idx)(b),
    //       "wrbypass rhits, wridx:%d, tag:%x, idx:%d, hitctr:%d, bank:%d\n",
    //       wrbypass_rhit_idx, tag, idx, wrbypass_ctrs(wrbypass_rhit_idx)(b), b.U)
    //   }
    // }
367

L
Lingrui98 已提交
368 369 370 371 372 373 374
    // ------------------------------Debug-------------------------------------
    val valids = Reg(Vec(TageBanks, Vec(nRows, Bool())))
    when (reset.asBool) { valids.foreach(b => b.foreach(r => r := false.B)) }
    (0 until TageBanks).map( b => { when (io.update.mask(b)) { valids(b)(update_idx) := true.B }})
    XSDebug("Table usage:------------------------\n")
    (0 until TageBanks).map( b => { XSDebug("Bank(%d): %d out of %d rows are valid\n", b.U, PopCount(valids(b)), nRows.U)})
  }
L
Lingrui98 已提交
375

L
Lingrui98 已提交
376 377
}

378
abstract class BaseTage extends BasePredictor with HasTageParameter {
L
Lingrui98 已提交
379
  class TAGEResp extends Resp {
380 381
    val takens = Vec(PredictWidth, Bool())
    val hits = Vec(PredictWidth, Bool())
L
Lingrui98 已提交
382
  }
383
  class TAGEMeta extends Meta{
L
Lingrui98 已提交
384 385 386 387 388 389
  }
  class FromBIM extends FromOthers {
    val ctrs = Vec(PredictWidth, UInt(2.W))
  }
  class TageIO extends DefaultBasePredictorIO {
    val resp = Output(new TAGEResp)
L
Lingrui98 已提交
390
    val meta = Output(Vec(PredictWidth, new TageMeta))
L
Lingrui98 已提交
391 392 393 394
    val bim = Input(new FromBIM)
    val s3Fire = Input(Bool())
  }

L
Lingrui98 已提交
395
  override val io = IO(new TageIO)
396
}
L
Lingrui98 已提交
397

398
class FakeTage extends BaseTage {
L
Lingrui98 已提交
399 400
  io.resp <> DontCare
  io.meta <> DontCare
L
Lingrui98 已提交
401 402
}

L
Lingrui98 已提交
403
@chiselName
404
class Tage extends BaseTage {
L
Lingrui98 已提交
405 406 407 408

  val tables = TableInfo.map {
    case (nRows, histLen, tagLen) => {
      val t = if(EnableBPD) Module(new TageTable(nRows, histLen, tagLen, UBitPeriod)) else Module(new FakeTageTable)
L
Lingrui98 已提交
409
      t.io.req.valid := io.pc.valid
L
Lingrui98 已提交
410 411 412 413 414 415 416
      t.io.req.bits.pc := io.pc.bits
      t.io.req.bits.hist := io.hist
      t.io.req.bits.mask := io.inMask
      t
    }
  }

417 418
  val scTables = SCTableInfo.map {
    case (nRows, ctrBits, histLen) => {
L
Lingrui98 已提交
419
      val t = if (EnableSC) Module(new SCTable(nRows/TageBanks, ctrBits, histLen)) else Module(new FakeSCTable)
420
      val req = t.io.req
L
Lingrui98 已提交
421
      req.valid := io.pc.valid
422 423 424 425 426 427 428
      req.bits.pc := io.pc.bits
      req.bits.hist := io.hist
      req.bits.mask := io.inMask
      t
    }
  }

L
Lingrui98 已提交
429
  val scThreshold = RegInit(SCThreshold(5))
430 431 432
  val useThreshold = WireInit(scThreshold.thres)
  val updateThreshold = WireInit((useThreshold << 3) + 21.U)

433
  override val debug = true
434

L
Lingrui98 已提交
435
  // Keep the table responses to process in s3
L
Lingrui98 已提交
436 437 438 439 440
  // val if4_resps = RegEnable(VecInit(tables.map(t => t.io.resp)), enable=io.s3Fire)
  // val if4_scResps = RegEnable(VecInit(scTables.map(t => t.io.resp)), enable=io.s3Fire)
  
  val if3_resps = VecInit(tables.map(t => t.io.resp))
  val if3_scResps = VecInit(scTables.map(t => t.io.resp))
L
Lingrui98 已提交
441
  // val flushLatch = RegNext(io.flush)
L
Lingrui98 已提交
442

L
Lingrui98 已提交
443 444
  val if3_bim = RegEnable(io.bim, enable=io.pc.valid) // actually it is s2Fire
  val if4_bim = RegEnable(if3_bim, enable=io.s3Fire)
L
Lingrui98 已提交
445 446 447 448

  val debug_pc_s2 = RegEnable(io.pc.bits, enable=io.pc.valid)
  val debug_pc_s3 = RegEnable(debug_pc_s2, enable=io.s3Fire)

449 450 451
  val debug_hist_s2 = RegEnable(io.hist, enable=io.pc.valid)
  val debug_hist_s3 = RegEnable(debug_hist_s2, enable=io.s3Fire)

L
Lingrui98 已提交
452
  val u = io.update.bits
453
  val updateValid = io.update.valid && !io.update.bits.isReplay
L
Lingrui98 已提交
454
  val updateHist = u.bpuMeta.predHist.asUInt
L
Lingrui98 已提交
455

L
Lingrui98 已提交
456
  val updateIsBr = u.pd.isBr
L
Lingrui98 已提交
457
  val updateMeta = u.bpuMeta.tageMeta
L
Lingrui98 已提交
458
  val updateMisPred = u.isMisPred && updateIsBr
L
Lingrui98 已提交
459 460 461 462 463

  val updateMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(TageBanks, Bool()))))
  val updateUMask = WireInit(0.U.asTypeOf(Vec(TageNTables, Vec(TageBanks, Bool()))))
  val updateTaken = Wire(Vec(TageNTables, Vec(TageBanks, Bool())))
  val updateAlloc = Wire(Vec(TageNTables, Vec(TageBanks, Bool())))
464
  val updateOldCtr = Wire(Vec(TageNTables, Vec(TageBanks, UInt(TageCtrBits.W))))
L
Lingrui98 已提交
465 466 467 468 469 470
  val updateU = Wire(Vec(TageNTables, Vec(TageBanks, UInt(2.W))))
  updateTaken := DontCare
  updateAlloc := DontCare
  updateOldCtr := DontCare
  updateU := DontCare

L
Lingrui98 已提交
471 472 473
  val scUpdateMask = WireInit(0.U.asTypeOf(Vec(SCNTables, Vec(TageBanks, Bool()))))
  val scUpdateTagePred = Wire(Bool())
  val scUpdateTaken = Wire(Bool())
474 475 476 477 478
  val scUpdateOldCtrs = Wire(Vec(SCNTables, SInt(SCCtrBits.W)))
  scUpdateTagePred := DontCare
  scUpdateTaken := DontCare
  scUpdateOldCtrs := DontCare

L
Lingrui98 已提交
479
  val updateSCMeta = u.bpuMeta.tageMeta.scMeta
L
Lingrui98 已提交
480
  val updateTageMisPred = updateMeta.taken =/= u.taken && updateIsBr
481

482
  val updateBank = u.pc(log2Ceil(TageBanks), instOffsetBits)
L
Lingrui98 已提交
483 484 485

  // access tag tables and output meta info
  for (w <- 0 until TageBanks) {
L
Lingrui98 已提交
486 487 488 489 490
    val if3_tageTaken = WireInit(if3_bim.ctrs(w)(1).asBool)
    var if3_altPred = if3_bim.ctrs(w)(1)
    val if3_finalAltPred = WireInit(if3_bim.ctrs(w)(1))
    var if3_provided = false.B
    var if3_provider = 0.U
L
Lingrui98 已提交
491 492

    for (i <- 0 until TageNTables) {
L
Lingrui98 已提交
493 494
      val hit = if3_resps(i)(w).valid
      val ctr = if3_resps(i)(w).bits.ctr
L
Lingrui98 已提交
495
      when (hit) {
L
Lingrui98 已提交
496 497
        if3_tageTaken := Mux(ctr === 3.U || ctr === 4.U, if3_altPred, ctr(2)) // Use altpred on weak taken
        if3_finalAltPred := if3_altPred
L
Lingrui98 已提交
498
      }
L
Lingrui98 已提交
499 500 501
      if3_provided = if3_provided || hit          // Once hit then provide
      if3_provider = Mux(hit, i.U, if3_provider)  // Use the last hit as provider
      if3_altPred = Mux(hit, ctr(2), if3_altPred) // Save current pred as potential altpred
L
Lingrui98 已提交
502
    }
L
Lingrui98 已提交
503 504 505 506 507 508 509
    val if4_provided = RegEnable(if3_provided, io.s3Fire)
    val if4_provider = RegEnable(if3_provider, io.s3Fire)
    val if4_finalAltPred = RegEnable(if3_finalAltPred, io.s3Fire)
    val if4_tageTaken = RegEnable(if3_tageTaken, io.s3Fire)
    val if4_providerU = RegEnable(if3_resps(if3_provider)(w).bits.u, io.s3Fire)
    val if4_providerCtr = RegEnable(if3_resps(if3_provider)(w).bits.ctr, io.s3Fire)

L
Lingrui98 已提交
510
    io.resp.hits(w) := if4_provided
L
Lingrui98 已提交
511
    io.resp.takens(w) := if4_tageTaken
L
Lingrui98 已提交
512 513 514
    io.meta(w).provider.valid := if4_provided
    io.meta(w).provider.bits := if4_provider
    io.meta(w).altDiffers := if4_finalAltPred =/= io.resp.takens(w)
L
Lingrui98 已提交
515 516
    io.meta(w).providerU := if4_providerU
    io.meta(w).providerCtr := if4_providerCtr
L
Lingrui98 已提交
517
    io.meta(w).taken := if4_tageTaken
L
Lingrui98 已提交
518 519 520

    // Create a mask fo tables which did not hit our query, and also contain useless entries
    // and also uses a longer history than the provider
L
Lingrui98 已提交
521 522
    val allocatableSlots = RegEnable(VecInit(if3_resps.map(r => !r(w).valid && r(w).bits.u === 0.U)).asUInt &
      ~(LowerMask(UIntToOH(if3_provider), TageNTables) & Fill(TageNTables, if3_provided.asUInt)), io.s3Fire
L
Lingrui98 已提交
523 524 525 526 527 528 529 530
    )
    val allocLFSR = LFSR64()(TageNTables - 1, 0)
    val firstEntry = PriorityEncoder(allocatableSlots)
    val maskedEntry = PriorityEncoder(allocatableSlots & allocLFSR)
    val allocEntry = Mux(allocatableSlots(maskedEntry), maskedEntry, firstEntry)
    io.meta(w).allocate.valid := allocatableSlots =/= 0.U
    io.meta(w).allocate.bits := allocEntry

531 532 533 534
    val scMeta = io.meta(w).scMeta
    scMeta := DontCare
    val scTableSums = VecInit(
      (0 to 1) map { i => {
L
Lingrui98 已提交
535
          // val providerCtr = if4_resps(if4_provider)(w).bits.ctr.zext()
536 537 538 539
          // val pvdrCtrCentered = (((providerCtr - 4.S) << 1) + 1.S) << 3
          // sum += pvdrCtrCentered
          if (EnableSC) {
            (0 until SCNTables) map { j => 
L
Lingrui98 已提交
540
              scTables(j).getCenteredValue(RegEnable(if3_scResps(j)(w).ctr(i), io.s3Fire))
541 542 543 544 545 546 547 548
            } reduce (_+_) // TODO: rewrite with adder tree
          }
          else 0.S
        }
      }
    )

    if (EnableSC) {
L
Lingrui98 已提交
549 550 551
      scMeta.tageTaken := if4_tageTaken
      scMeta.scUsed := if4_provided
      scMeta.scPred := if4_tageTaken
L
Lingrui98 已提交
552
      scMeta.sumAbs := 0.U
L
Lingrui98 已提交
553
      when (if4_provided) {
L
Lingrui98 已提交
554
        val providerCtr = if4_providerCtr.zext()
L
Lingrui98 已提交
555
        val pvdrCtrCentered = ((((providerCtr - 4.S) << 1).asSInt + 1.S) << 3).asSInt
L
Lingrui98 已提交
556
        val totalSum = scTableSums(if4_tageTaken.asUInt) + pvdrCtrCentered
L
Lingrui98 已提交
557
        val sumAbs = totalSum.abs().asUInt
558 559
        val sumBelowThreshold = totalSum.abs.asUInt < useThreshold
        val scPred = totalSum >= 0.S
L
Lingrui98 已提交
560
        scMeta.sumAbs := sumAbs
L
Lingrui98 已提交
561
        scMeta.ctrs   := RegEnable(VecInit(if3_scResps.map(r => r(w).ctr(if3_tageTaken.asUInt))), io.s3Fire)
L
Lingrui98 已提交
562
        for (i <- 0 until SCNTables) {
L
Lingrui98 已提交
563
          val if4_scResps = RegEnable(if3_scResps, io.s3Fire)
L
Lingrui98 已提交
564
          XSDebug(RegNext(io.s3Fire), p"SCTable(${i.U})(${w.U}): ctr:(${if4_scResps(i)(w).ctr(0)},${if4_scResps(i)(w).ctr(1)})\n")
L
Lingrui98 已提交
565 566
        }
        XSDebug(RegNext(io.s3Fire), p"SC(${w.U}): pvdCtr(${providerCtr}), pvdCentred(${pvdrCtrCentered}), totalSum(${totalSum}), abs(${sumAbs}) useThres(${useThreshold}), scPred(${scPred})\n")
567 568
        // Use prediction from Statistical Corrector
        when (!sumBelowThreshold) {
L
Lingrui98 已提交
569 570
          XSDebug(RegNext(io.s3Fire), p"SC(${w.U}) overriden pred to ${scPred}\n")
          scMeta.scPred := scPred
571 572 573 574
          io.resp.takens(w) := scPred
        }
      }
    }
L
Lingrui98 已提交
575 576

    val isUpdateTaken = updateValid && updateBank === w.U &&
L
Lingrui98 已提交
577 578
      u.taken && updateIsBr
    when (updateIsBr && updateValid && updateBank === w.U) {
L
Lingrui98 已提交
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595
      when (updateMeta.provider.valid) {
        val provider = updateMeta.provider.bits

        updateMask(provider)(w) := true.B
        updateUMask(provider)(w) := true.B

        updateU(provider)(w) := Mux(!updateMeta.altDiffers, updateMeta.providerU,
          Mux(updateMisPred, Mux(updateMeta.providerU === 0.U, 0.U, updateMeta.providerU - 1.U),
                              Mux(updateMeta.providerU === 3.U, 3.U, updateMeta.providerU + 1.U))
        )
        updateTaken(provider)(w) := isUpdateTaken
        updateOldCtr(provider)(w) := updateMeta.providerCtr
        updateAlloc(provider)(w) := false.B
      }
    }
  }

L
Lingrui98 已提交
596
  when (updateValid && updateTageMisPred) {
L
Lingrui98 已提交
597 598 599 600
    val idx = updateBank
    val allocate = updateMeta.allocate
    when (allocate.valid) {
      updateMask(allocate.bits)(idx) := true.B
L
Lingrui98 已提交
601
      updateTaken(allocate.bits)(idx) := u.taken
L
Lingrui98 已提交
602 603 604 605 606
      updateAlloc(allocate.bits)(idx) := true.B
      updateUMask(allocate.bits)(idx) := true.B
      updateU(allocate.bits)(idx) := 0.U
    }.otherwise {
      val provider = updateMeta.provider
G
GouLingrui 已提交
607
      val decrMask = Mux(provider.valid, ~LowerMask(UIntToOH(provider.bits), TageNTables), 0.U(TageNTables.W))
L
Lingrui98 已提交
608 609 610 611 612 613 614 615 616
      for (i <- 0 until TageNTables) {
        when (decrMask(i)) {
          updateUMask(i)(idx) := true.B
          updateU(i)(idx) := 0.U
        }
      }
    }
  }

617
  if (EnableSC) {
L
Lingrui98 已提交
618
    when (updateValid && updateSCMeta.scUsed.asBool && updateIsBr) {
619 620
      val scPred = updateSCMeta.scPred
      val tageTaken = updateSCMeta.tageTaken
L
Lingrui98 已提交
621
      val sumAbs = updateSCMeta.sumAbs.asUInt
622 623
      val scOldCtrs = updateSCMeta.ctrs
      when (scPred =/= tageTaken && sumAbs < useThreshold - 2.U) {
L
Lingrui98 已提交
624 625 626
        val newThres = scThreshold.update(scPred =/= u.taken)
        scThreshold := newThres
        XSDebug(p"scThres update: old d${useThreshold} --> new ${newThres.thres}\n")
627 628 629
      }
      when (scPred =/= u.taken || sumAbs < updateThreshold) {
        scUpdateMask.foreach(t => t(updateBank) := true.B)
L
Lingrui98 已提交
630 631
        scUpdateTagePred := tageTaken
        scUpdateTaken := u.taken
632
        (scUpdateOldCtrs zip scOldCtrs).foreach{case (t, c) => t := c}
L
Lingrui98 已提交
633 634
        XSDebug(p"scUpdate: bank(${updateBank}), scPred(${scPred}), tageTaken(${tageTaken}), scSumAbs(${sumAbs}), mispred: sc(${updateMisPred}), tage(${updateTageMisPred})\n")
        XSDebug(p"update: sc: ${updateSCMeta}\n")
635 636 637 638
      }
    }
  }

L
Lingrui98 已提交
639 640 641 642 643 644 645 646 647 648 649
  for (i <- 0 until TageNTables) {
    for (w <- 0 until TageBanks) {
      tables(i).io.update.mask(w) := updateMask(i)(w)
      tables(i).io.update.taken(w) := updateTaken(i)(w)
      tables(i).io.update.alloc(w) := updateAlloc(i)(w)
      tables(i).io.update.oldCtr(w) := updateOldCtr(i)(w)

      tables(i).io.update.uMask(w) := updateUMask(i)(w)
      tables(i).io.update.u(w) := updateU(i)(w)
    }
    // use fetch pc instead of instruction pc
L
Lingrui98 已提交
650 651
    tables(i).io.update.pc := u.pc
    tables(i).io.update.hist := updateHist
L
Lingrui98 已提交
652
    tables(i).io.update.fetchIdx := u.bpuMeta.fetchIdx
L
Lingrui98 已提交
653 654
  }

655
  for (i <- 0 until SCNTables) {
L
Lingrui98 已提交
656 657 658
    scTables(i).io.update.mask := scUpdateMask(i)
    scTables(i).io.update.tagePred := scUpdateTagePred
    scTables(i).io.update.taken    := scUpdateTaken
659 660 661
    scTables(i).io.update.oldCtr   := scUpdateOldCtrs(i)
    scTables(i).io.update.pc := u.pc
    scTables(i).io.update.hist := updateHist
L
Lingrui98 已提交
662
    scTables(i).io.update.fetchIdx := u.bpuMeta.fetchIdx
663 664
  }

L
Lingrui98 已提交
665

L
Lingrui98 已提交
666

667
  if (BPUDebug && debug) {
L
Lingrui98 已提交
668
    val m = updateMeta
L
Lingrui98 已提交
669
    val bri = u.bpuMeta
L
Lingrui98 已提交
670
    val if4_resps = RegEnable(if3_resps, io.s3Fire)
L
Lingrui98 已提交
671 672 673 674 675
    XSDebug(io.pc.valid, "req: pc=0x%x, hist=%x\n", io.pc.bits, io.hist)
    XSDebug(io.s3Fire, "s3Fire:%d, resp: pc=%x, hist=%x\n", io.s3Fire, debug_pc_s2, debug_hist_s2)
    XSDebug(RegNext(io.s3Fire), "s3FireOnLastCycle: resp: pc=%x, hist=%x, hits=%b, takens=%b\n",
      debug_pc_s3, debug_hist_s3, io.resp.hits.asUInt, io.resp.takens.asUInt)
    for (i <- 0 until TageNTables) {
L
Lingrui98 已提交
676
      XSDebug(RegNext(io.s3Fire), "TageTable(%d): valids:%b, resp_ctrs:%b, resp_us:%b\n", i.U, VecInit(if4_resps(i).map(_.valid)).asUInt, Cat(if4_resps(i).map(_.bits.ctr)), Cat(if4_resps(i).map(_.bits.u)))
L
Lingrui98 已提交
677
    }
L
Lingrui98 已提交
678
    XSDebug(io.update.valid, "update: pc=%x, fetchpc=%x, cycle=%d, hist=%x, taken:%d, misPred:%d, bimctr:%d, pvdr(%d):%d, altDiff:%d, pvdrU:%d, pvdrCtr:%d, alloc(%d):%d\n",
679
      u.pc, u.pc - (bri.fetchIdx << instOffsetBits.U), bri.debug_tage_cycle,  updateHist, u.taken, u.isMisPred, bri.bimCtr, m.provider.valid, m.provider.bits, m.altDiffers, m.providerU, m.providerCtr, m.allocate.valid, m.allocate.bits)
L
Lingrui98 已提交
680 681
    XSDebug(io.update.valid && updateIsBr, p"update: sc: ${updateSCMeta}\n")
    XSDebug(true.B, p"scThres: use(${useThreshold}), update(${updateThreshold})\n")
L
Lingrui98 已提交
682
  }
L
Lingrui98 已提交
683
}