RAS.scala 6.2 KB
Newer Older
Z
zoujr 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
/***************************************************************************************
  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
  *
  * XiangShan is licensed under Mulan PSL v2.
  * You can use this software according to the terms and conditions of the Mulan PSL v2.
  * You may obtain a copy of Mulan PSL v2 at:
  *          http://license.coscl.org.cn/MulanPSL2
  *
  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
  *
  * See the Mulan PSL v2 for more details.
  ***************************************************************************************/

package xiangshan.frontend

import chipsalliance.rocketchip.config.Parameters
import chisel3._
import chisel3.experimental.chiselName
import chisel3.util._
import utils._
import xiangshan._

class RASEntry()(implicit p: Parameters) extends XSBundle {
    val retAddr = UInt(VAddrBits.W)
    val ctr = UInt(8.W) // layer of nested call functions
}

@chiselName
class RAS(implicit p: Parameters) extends BasePredictor {
  object RASEntry {
    def apply(retAddr: UInt, ctr: UInt): RASEntry = {
      val e = Wire(new RASEntry)
      e.retAddr := retAddr
      e.ctr := ctr
      e
    }
  }

  @chiselName
  class RASStack(val rasSize: Int) extends XSModule {
    val io = IO(new Bundle {
      val push_valid = Input(Bool())
      val pop_valid = Input(Bool())
      val spec_new_addr = Input(UInt(VAddrBits.W))

      val recover_sp = Input(UInt(log2Up(rasSize).W))
      val recover_top = Input(new RASEntry)
      val recover_valid = Input(Bool())
      val recover_push = Input(Bool())
      val recover_pop = Input(Bool())
      val recover_new_addr = Input(UInt(VAddrBits.W))

      val sp = Output(UInt(log2Up(rasSize).W))
      val top = Output(new RASEntry)
    })

    val stack = Mem(RasSize, new RASEntry)
    val sp = RegInit(0.U(log2Up(rasSize).W))
    val top = RegInit(0.U.asTypeOf(new RASEntry))
    val topPtr = RegInit(0.U(log2Up(rasSize).W))

    def ptrInc(ptr: UInt) = Mux(ptr === (rasSize-1).U, 0.U, ptr + 1.U)
    def ptrDec(ptr: UInt) = Mux(ptr === 0.U, (rasSize-1).U, ptr - 1.U)

    val alloc_new = io.spec_new_addr =/= top.retAddr || top.ctr.andR
    val recover_alloc_new = io.recover_new_addr =/= io.recover_top.retAddr || io.recover_top.ctr.andR

    // TODO: fix overflow and underflow bugs
    def update(recover: Bool)(do_push: Bool, do_pop: Bool, do_alloc_new: Bool,
                              do_sp: UInt, do_top_ptr: UInt, do_new_addr: UInt,
                              do_top: RASEntry) = {
      when (do_push) {
        when (do_alloc_new) {
          sp     := ptrInc(do_sp)
          topPtr := do_sp
          top.retAddr := do_new_addr
          top.ctr := 1.U
          stack.write(do_sp, RASEntry(do_new_addr, 1.U))
        }.otherwise {
          when (recover) {
            sp := do_sp
            topPtr := do_top_ptr
            top.retAddr := do_top.retAddr
          }
          top.ctr := do_top.ctr + 1.U
          stack.write(do_top_ptr, RASEntry(do_new_addr, do_top.ctr + 1.U))
        }
      }.elsewhen (do_pop) {
        when (do_top.ctr === 1.U) {
          sp     := ptrDec(do_sp)
          topPtr := ptrDec(do_top_ptr)
          top := stack.read(ptrDec(do_top_ptr))
        }.otherwise {
          when (recover) {
            sp := do_sp
            topPtr := do_top_ptr
            top.retAddr := do_top.retAddr
          }
          top.ctr := do_top.ctr - 1.U
          stack.write(do_top_ptr, RASEntry(do_top.retAddr, do_top.ctr - 1.U))
        }
      }.otherwise {
        when (recover) {
          sp := do_sp
          topPtr := do_top_ptr
          top := do_top
          stack.write(do_top_ptr, do_top)
        }
      }
    }

    update(io.recover_valid)(
      Mux(io.recover_valid, io.recover_push,     io.push_valid),
      Mux(io.recover_valid, io.recover_pop,      io.pop_valid),
      Mux(io.recover_valid, recover_alloc_new,   alloc_new),
      Mux(io.recover_valid, io.recover_sp,       sp),
      Mux(io.recover_valid, io.recover_sp - 1.U, topPtr),
      Mux(io.recover_valid, io.recover_new_addr, io.spec_new_addr),
      Mux(io.recover_valid, io.recover_top,      top))

    io.sp := sp
    io.top := top

  }

  val spec = Module(new RASStack(RasSize))
  val spec_ras = spec.io


  val spec_push = WireInit(false.B)
  val spec_pop = WireInit(false.B)
  // val jump_is_first = io.callIdx.bits === 0.U
  // val call_is_last_half = io.isLastHalfRVI && jump_is_first
  // val spec_new_addr = packetAligned(io.pc.bits) + (io.callIdx.bits << instOffsetBits.U) + Mux( (io.isRVC | call_is_last_half) && HasCExtension.B, 2.U, 4.U)
L
Lingrui98 已提交
137 138 139
  val spec_new_addr = getFallThroughAddr(s3_pc,
                                         io.in.bits.resp_in(0).s3.ftb_entry.carry,
                                         io.in.bits.resp_in(0).s3.ftb_entry.pftAddr) 
Z
zoujr 已提交
140 141
  spec_ras.push_valid := spec_push
  spec_ras.pop_valid  := spec_pop
L
Lingrui98 已提交
142
  spec_ras.spec_new_addr := spec_new_addr
Z
zoujr 已提交
143 144
  val spec_top_addr = spec_ras.top.retAddr

L
Lingrui98 已提交
145 146 147 148 149 150 151 152 153 154
  // confirm that the call/ret is the taken cfi
  spec_push := io.s3_fire && io.in.bits.resp_in(0).s3.preds.hit_taken_on_call
  spec_pop  := io.s3_fire && io.in.bits.resp_in(0).s3.preds.hit_taken_on_ret
  
  when (spec_pop) {
    io.out.resp.s3.preds.target := spec_top_addr
  }

  io.out.resp.s3.rasSp  := spec_ras.sp
  io.out.resp.s3.rasTop := spec_ras.top
Z
zoujr 已提交
155 156

  val redirect = RegNext(io.redirect)
L
Lingrui98 已提交
157
  val do_recover = redirect.valid
158
  val recover_cfi = redirect.bits.cfiUpdate
Z
zoujr 已提交
159

L
Lingrui98 已提交
160 161
  val retMissPred  = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isRet
  val callMissPred = do_recover && redirect.bits.level === 0.U && recover_cfi.pd.isCall
Z
zoujr 已提交
162 163
  // when we mispredict a call, we must redo a push operation
  // similarly, when we mispredict a return, we should redo a pop
L
Lingrui98 已提交
164
  spec_ras.recover_valid := do_recover
Z
zoujr 已提交
165 166 167 168 169 170 171 172 173 174
  spec_ras.recover_push := callMissPred
  spec_ras.recover_pop  := retMissPred

  spec_ras.recover_sp  := recover_cfi.rasSp
  spec_ras.recover_top := recover_cfi.rasEntry
  spec_ras.recover_new_addr := recover_cfi.pc + Mux(recover_cfi.pd.isRVC, 2.U, 4.U)

  // TODO: back-up stack for ras
  // use checkpoint to recover RAS
}