uBTB.scala 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
package xiangshan.frontend

import chisel3._
import chisel3.util._
import xiangshan._
import xiangshan.backend.ALUOpType
import utils._

trait MicroBTBPatameter{
    val nWays = 16
    val offsetSize = 13
    val tagSize = VAddrBits - log2Ceil(PredictWidth) - 1

}

class MicroBTB extends BasePredictor
    with MicroBTBPatameter
{
    class MicroBTBResp extends resp
    {
        val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
        val takens = Vec(PredictWidth, Bool())
23
        val notTakens = Vec(PredictWidth, Bool())
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        val isRVC = Vec(PredictWidth, Bool())
    }

    class MicroBTBPredictMeta extends meta
    {
        val writeWay = UInt(log2Ceil(nWays).W)
        val hits = Vec(PredictWidth,Bool())
    }
    val out_meta = Wire(new MicroBTBMeta)
    override val metaLen = out_meta.asUInt.getWidth

    class MicroBTBIO extends defaultBasePredictorIO
    {
        val out = Output(new MicroBTBResp)   //
        val meta = Output(new MicroBTBPredictMeta)
    }

    val io = IO(new MicroBTBIO)

    def getTag(pc: UInt) = pc >> (log2Ceil(PredictWidth) + 1).U
    def getBank(pc: UInt) = pc(log2Ceil(PredictWidth) ,1)
    def satUpdate(old: UInt, len: Int, taken: Bool): UInt = {
        val oldSatTaken = old === ((1 << len)-1).U
        val oldSatNotTaken = old === 0.U
        Mux(oldSatTaken && taken, ((1 << len)-1).U,
            Mux(oldSatNotTaken && !taken, 0.U,
            Mux(taken, old + 1.U, old - 1.U)))
    } 

    class MicroBTBMeta extends XSBundle
    {
        val is_Br = Bool()
        val is_RVC = Bool()
        val valid = Bool()
        val pred = UInt(2.W)
        val tag = UInt(tagSize.W)
    }

    class MicroBTBEntry extends XSBundle
    {
        val offset = SInt(offsetSize.W)
    }

    val uBTBMeta = RegInit(0.U).asTypeOf(Vec(PredictWidth,Vec(nWays,new MicroBTBMeta)))
    val uBTB  = Reg(Vec(PredictWidth,Vec(nWays,new MicroBTBEntry)))

    //uBTB read
    //tag is bank align
    val read_req_tag = getTag(io.pc)
    val read_req_basebank = getBank(io.pc)
    val read_mask = io.inMask
    
    class ReadRespEntry extends XSBundle
    {
        val is_RVC = Bool()
79
        val target = UInt(VAddrBits.W)
80 81
        val valid = Bool()
        val taken = Bool()
82
        val notTaken = Bool()
83 84 85
    }
    val read_resp = Wire(Vec(PredictWidth,new ReadRespEntry))

86 87 88
    val read_bank_inOrder = VecInit((0 until PredictWidth).map(b => (read_req_basebank + b.U)(PredictWidth-1,0) ))

    val read_hit_ohs = read_bank_inOrder.map{ b =>
89
        VecInit((0 until nWays) map {w => 
90
            uBTBMeta(b)(w).tag === read_req_tag
91
        })
92 93
    }

94 95 96 97 98 99

    val read_hit_vec = read_hit_ohs.map{oh => ParallelOR(oh)}
    val read_hit_ways = read_hit_ohs.map{oh => PriorityEncoder(oh)}
    val read_hit =  ParallelOR(read_hit_vec)
    val read_hit_way = PriorityEncoder(ParallelOR(read_hit_vec.map(_.asUInt)))
    
100
    read_bank_inOrder.foreach{ i =>
101 102
        val  meta_resp = uBTBMeta(i)(read_hit_ways(i))
        val  btb_resp = uBTB(i)(read_hit_ways(i))
103 104
        var  index = 0
        read_resp(i).valid := meta_resp.valid && read_hit_vec(i) && read_mask(index)
105
        read_resp(i).taken := read_resp(i).valid && meta_resp.pred(1)
106 107 108
        read_resp(i).notTaken := read_resp(i).valid && !meta_resp.pred(1)
        read_resp(i).target := (io.in.pc).asSInt + (index<<1).S + btb_resp.offset
        index += 1
109 110 111 112 113 114

        out_meta.hits(i) := read_hit_vec(i)
    }

    //TODO: way alloc algorithm
    val alloc_way = { 
115
        val r_metas = Cat(VecInit(meta.map(e => VecInit(e.map(_.tag)))).asUInt, (s1_idx)(tagSz-1,0))
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        val l = log2Ceil(nWays)
        val nChunks = (r_metas.getWidth + l - 1) / l
        val chunks = (0 until nChunks) map { i =>
        r_metas(min((i+1)*l, r_metas.getWidth)-1, i*l)
        }
        chunks.reduce(_^_)
    }
    val out_meta.writeWay = Mux(read_hit,read_hit_way,alloc_way)

    //response
    //only when hit and instruction valid and entry valid can output data
    for(i <- 0 until PredictWidth)
    {
        when(read_resp(i).valid)
        {
            io.out.targets(i) := read_resp(i).target
132
            io.out.takens(i) := read_resp(i).taken
133
            io.out.isRVC(i) := read_resp(i).is_RVC
134 135
            io.out.notTakens(i) := read_resp(i).notTaken
\        } .otherwise 
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        {
            io.out := (0.U).asTypeOf(new MicroBTBResp)
        }

    }

    //uBTB update 
    //backend should send fetch pc to update
    val update_fetch_pc  = Wire(UInt(VAddrBits.W))//TODO: io.update.bitspc
    val update_idx = io.update.bitsfetchIdx
    val update_br_offset = update_idx << 1.U
    val update_br_pc = update_fetch_pc + update_br_offset
    val update_write_way = io.update.bits.brInfo.ubtbWriteWay
    val update_hits = io.update.bits.brInfo.ubtbHits
    val update_taken = io.update.bits.taken

    val update_bank = getBank(update_br_pc)
    val update_tag = getTag(update_br_pc)
    val update_taget_offset =  io.update.bits.target.asSInt - update_br_pc.asSInt
    val update_is_BR_or_JAL = (io.update.bits.pd.brType === BrType.branch) || (io.update.bits.pd.brType === BrType.jal) 

    val uBTB_write_valid = io.update.valid && io.update.bits.isMisPred
    val uBTB_Meta_write_valid = io.update.valid && update_is_BR_or_JAL
    //write btb target when miss prediction
    when(uBTB_write_valid)
    {
        uBTB(update_bank)(update_write_way).offset := update_taget_offset
    }
    //write the meta
    when(uBTB_Meta_write_valid)
    {
        //commit update
        uBTBMeta(update_bank)(update_write_way).is_Br := io.update.bits.pd.brType === BrType.branch
        uBTBMeta(update_bank)(update_write_way).is_RVC := io.update.bits.pd.isRVC
        uBTBMeta(update_bank)(update_write_way).valid := true.B
        uBTBMeta(update_bank)(update_write_way).tag := update_tag
        uBTBMeta(update_bank)(update_write_way).pred := 
        Mux(!update_hits(update_bank),
            Mux(update_taken,3.0,0.U)
            satUpdate( uBTBMeta(update_bank)(update_write_way).pred,2,update_taken)
        )
    }

    //bypass:read-after-write 


}