uBTB.scala 6.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
package xiangshan.frontend

import chisel3._
import chisel3.util._
import xiangshan._
import xiangshan.backend.ALUOpType
import utils._

trait MicroBTBPatameter{
    val nWays = 16
    val offsetSize = 13
    val tagSize = VAddrBits - log2Ceil(PredictWidth) - 1

}

class MicroBTB extends BasePredictor
    with MicroBTBPatameter
{
    class MicroBTBResp extends resp
    {
        val targets = Vec(PredictWidth, ValidUndirectioned(UInt(VaddrBits.W)))
        val takens = Vec(PredictWidth, Bool())
23
        val notTakens = Vec(PredictWidth, Bool())
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        val isRVC = Vec(PredictWidth, Bool())
    }

    class MicroBTBPredictMeta extends meta
    {
        val writeWay = UInt(log2Ceil(nWays).W)
        val hits = Vec(PredictWidth,Bool())
    }
    val out_meta = Wire(new MicroBTBMeta)
    override val metaLen = out_meta.asUInt.getWidth

    class MicroBTBIO extends defaultBasePredictorIO
    {
        val out = Output(new MicroBTBResp)   //
        val meta = Output(new MicroBTBPredictMeta)
    }

    val io = IO(new MicroBTBIO)

    def getTag(pc: UInt) = pc >> (log2Ceil(PredictWidth) + 1).U
    def getBank(pc: UInt) = pc(log2Ceil(PredictWidth) ,1)
    def satUpdate(old: UInt, len: Int, taken: Bool): UInt = {
        val oldSatTaken = old === ((1 << len)-1).U
        val oldSatNotTaken = old === 0.U
        Mux(oldSatTaken && taken, ((1 << len)-1).U,
            Mux(oldSatNotTaken && !taken, 0.U,
            Mux(taken, old + 1.U, old - 1.U)))
    } 

    class MicroBTBMeta extends XSBundle
    {
        val is_Br = Bool()
        val is_RVC = Bool()
        val valid = Bool()
        val pred = UInt(2.W)
        val tag = UInt(tagSize.W)
    }

    class MicroBTBEntry extends XSBundle
    {
        val offset = SInt(offsetSize.W)
    }

    val uBTBMeta = RegInit(0.U).asTypeOf(Vec(PredictWidth,Vec(nWays,new MicroBTBMeta)))
    val uBTB  = Reg(Vec(PredictWidth,Vec(nWays,new MicroBTBEntry)))

    //uBTB read
    //tag is bank align
    val read_req_tag = getTag(io.pc)
    val read_req_basebank = getBank(io.pc)
    val read_mask = io.inMask
    
    class ReadRespEntry extends XSBundle
    {
        val is_RVC = Bool()
79
        val target = UInt(VAddrBits.W)
80 81
        val valid = Bool()
        val taken = Bool()
82
        val notTaken = Bool()
83 84 85
    }
    val read_resp = Wire(Vec(PredictWidth,new ReadRespEntry))

86
    val read_bank_inOrder = VecInit((0 until PredictWidth).map(b => (read_req_basebank + b.U)(PredictWidth-1,0) ))
J
jinyue110 已提交
87
    val isInNextRow = VecInit((0 until PredictWidth).map(_.U < read_req_basebank))
88
    val read_hit_ohs = read_bank_inOrder.map{ b =>
89
        VecInit((0 until nWays) map {w => 
J
jinyue110 已提交
90
            Mux(isInNextRow(b),read_req_tag + 1.U,read_req_tag) === uBTBMeta(b)(w).tag
91
        })
92 93
    }

94 95 96 97 98 99

    val read_hit_vec = read_hit_ohs.map{oh => ParallelOR(oh)}
    val read_hit_ways = read_hit_ohs.map{oh => PriorityEncoder(oh)}
    val read_hit =  ParallelOR(read_hit_vec)
    val read_hit_way = PriorityEncoder(ParallelOR(read_hit_vec.map(_.asUInt)))
    
100
    read_bank_inOrder.foreach{ i =>
101 102
        val  meta_resp = uBTBMeta(i)(read_hit_ways(i))
        val  btb_resp = uBTB(i)(read_hit_ways(i))
103 104
        var  index = 0
        read_resp(i).valid := meta_resp.valid && read_hit_vec(i) && read_mask(index)
105
        read_resp(i).taken := read_resp(i).valid && meta_resp.pred(1)
106 107 108
        read_resp(i).notTaken := read_resp(i).valid && !meta_resp.pred(1)
        read_resp(i).target := (io.in.pc).asSInt + (index<<1).S + btb_resp.offset
        index += 1
109 110 111 112 113 114

        out_meta.hits(i) := read_hit_vec(i)
    }

    //TODO: way alloc algorithm
    val alloc_way = { 
115
        val r_metas = Cat(VecInit(meta.map(e => VecInit(e.map(_.tag)))).asUInt, (s1_idx)(tagSz-1,0))
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        val l = log2Ceil(nWays)
        val nChunks = (r_metas.getWidth + l - 1) / l
        val chunks = (0 until nChunks) map { i =>
        r_metas(min((i+1)*l, r_metas.getWidth)-1, i*l)
        }
        chunks.reduce(_^_)
    }
    val out_meta.writeWay = Mux(read_hit,read_hit_way,alloc_way)

    //response
    //only when hit and instruction valid and entry valid can output data
    for(i <- 0 until PredictWidth)
    {
        when(read_resp(i).valid)
        {
            io.out.targets(i) := read_resp(i).target
132
            io.out.takens(i) := read_resp(i).taken
133
            io.out.isRVC(i) := read_resp(i).is_RVC
134
            io.out.notTakens(i) := read_resp(i).notTaken
J
jinyue110 已提交
135
        } .otherwise 
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
        {
            io.out := (0.U).asTypeOf(new MicroBTBResp)
        }

    }

    //uBTB update 
    //backend should send fetch pc to update
    val update_fetch_pc  = Wire(UInt(VAddrBits.W))//TODO: io.update.bitspc
    val update_idx = io.update.bitsfetchIdx
    val update_br_offset = update_idx << 1.U
    val update_br_pc = update_fetch_pc + update_br_offset
    val update_write_way = io.update.bits.brInfo.ubtbWriteWay
    val update_hits = io.update.bits.brInfo.ubtbHits
    val update_taken = io.update.bits.taken

    val update_bank = getBank(update_br_pc)
J
jinyue110 已提交
153
    val update_base_bank = getBank(update_fetch_pc)
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    val update_tag = getTag(update_br_pc)
    val update_taget_offset =  io.update.bits.target.asSInt - update_br_pc.asSInt
    val update_is_BR_or_JAL = (io.update.bits.pd.brType === BrType.branch) || (io.update.bits.pd.brType === BrType.jal) 

    val uBTB_write_valid = io.update.valid && io.update.bits.isMisPred
    val uBTB_Meta_write_valid = io.update.valid && update_is_BR_or_JAL
    //write btb target when miss prediction
    when(uBTB_write_valid)
    {
        uBTB(update_bank)(update_write_way).offset := update_taget_offset
    }
    //write the meta
    when(uBTB_Meta_write_valid)
    {
        //commit update
        uBTBMeta(update_bank)(update_write_way).is_Br := io.update.bits.pd.brType === BrType.branch
        uBTBMeta(update_bank)(update_write_way).is_RVC := io.update.bits.pd.isRVC
        uBTBMeta(update_bank)(update_write_way).valid := true.B
        uBTBMeta(update_bank)(update_write_way).tag := update_tag
        uBTBMeta(update_bank)(update_write_way).pred := 
        Mux(!update_hits(update_bank),
            Mux(update_taken,3.0,0.U)
            satUpdate( uBTBMeta(update_bank)(update_write_way).pred,2,update_taken)
        )
    }

    //bypass:read-after-write 
J
jinyue110 已提交
181 182 183 184 185 186 187 188 189 190 191 192
    val rawBypassHit = Wire(Vec(PredictWidth, Bool()))
    for( b <- 0 until PredictWidth) {
        when(update_bank === b.U && read_hit_vec(b) && uBTB_Meta_write_valid 
            && Mux(b.U < update_base_bank,update_tag===read_req_tag+1.U ,update_tag===read_req_tag))  //read and write is the same fetch-packet
        {
            io.out.targets(b) := io.update.bits.target
            io.out.takens(b) := io.update.bits.taken
            io.out.isRVC(b) := io.update.bits.pd.isRVC
            io.out.notTakens(b) := (io.update.bits.pd.brType === BrType.branch) && (!io.out.takens(b) := io.update.bits.taken)
        }
    }

193 194 195


}