diff --git a/src/main/scala/xiangshan/frontend/IFU.scala b/src/main/scala/xiangshan/frontend/IFU.scala index f2c7222df7e3038d76b2dfa0cc0bc4c677545f80..70ce5339f994ff405588bff5cae32eccd1d2587d 100644 --- a/src/main/scala/xiangshan/frontend/IFU.scala +++ b/src/main/scala/xiangshan/frontend/IFU.scala @@ -121,7 +121,7 @@ class NewIFU(implicit p: Parameters) extends XSModule val predChecker = Module(new PredChecker) val frontendTrigger = Module(new FrontendTrigger) val (preDecoderIn, preDecoderOut) = (preDecoder.io.in, preDecoder.io.out) - val (checkerIn, checkerOut) = (predChecker.io.in, predChecker.io.out) + val (checkerIn, checkerOutStage1, checkerOutStage2) = (predChecker.io.in, predChecker.io.out.stage1Out,predChecker.io.out.stage2Out) io.iTLBInter.req_kill := false.B io.iTLBInter.resp.ready := true.B @@ -519,10 +519,11 @@ class NewIFU(implicit p: Parameters) extends XSModule /*** handle half RVI in the last 2 Bytes ***/ def hasLastHalf(idx: UInt) = { - !f3_pd(idx).isRVC && checkerOut.fixedRange(idx) && f3_instr_valid(idx) && !checkerOut.fixedTaken(idx) && !checkerOut.fixedMissPred(idx) && ! f3_req_is_mmio + //!f3_pd(idx).isRVC && checkerOutStage1.fixedRange(idx) && f3_instr_valid(idx) && !checkerOutStage1.fixedTaken(idx) && !checkerOutStage2.fixedMissPred(idx) && ! f3_req_is_mmio + !f3_pd(idx).isRVC && checkerOutStage1.fixedRange(idx) && f3_instr_valid(idx) && !checkerOutStage1.fixedTaken(idx) && ! f3_req_is_mmio } - val f3_last_validIdx = ~ParallelPriorityEncoder(checkerOut.fixedRange.reverse) + val f3_last_validIdx = ~ParallelPriorityEncoder(checkerOutStage1.fixedRange.reverse) val f3_hasLastHalf = hasLastHalf((PredictWidth - 1).U) val f3_false_lastHalf = hasLastHalf(f3_last_validIdx) @@ -554,20 +555,19 @@ class NewIFU(implicit p: Parameters) extends XSModule io.toIbuffer.valid := f3_valid && (!f3_req_is_mmio || f3_mmio_can_go) && !f3_flush io.toIbuffer.bits.instrs := f3_expd_instr io.toIbuffer.bits.valid := f3_instr_valid.asUInt - io.toIbuffer.bits.enqEnable := checkerOut.fixedRange.asUInt & f3_instr_valid.asUInt + io.toIbuffer.bits.enqEnable := checkerOutStage1.fixedRange.asUInt & f3_instr_valid.asUInt io.toIbuffer.bits.pd := f3_pd io.toIbuffer.bits.ftqPtr := f3_ftq_req.ftqIdx io.toIbuffer.bits.pc := f3_pc - io.toIbuffer.bits.ftqOffset.zipWithIndex.map{case(a, i) => a.bits := i.U; a.valid := checkerOut.fixedTaken(i) && !f3_req_is_mmio} + io.toIbuffer.bits.ftqOffset.zipWithIndex.map{case(a, i) => a.bits := i.U; a.valid := checkerOutStage1.fixedTaken(i) && !f3_req_is_mmio} io.toIbuffer.bits.foldpc := f3_foldpc io.toIbuffer.bits.ipf := VecInit(f3_pf_vec.zip(f3_crossPageFault).map{case (pf, crossPF) => pf || crossPF}) io.toIbuffer.bits.acf := f3_af_vec io.toIbuffer.bits.crossPageIPFFix := f3_crossPageFault io.toIbuffer.bits.triggered := f3_triggered - val lastHalfMask = VecInit((0 until PredictWidth).map(i => if(i ==0) false.B else true.B)) when(f3_lastHalf.valid){ - io.toIbuffer.bits.enqEnable := checkerOut.fixedRange.asUInt & f3_instr_valid.asUInt & lastHalfMask.asUInt + io.toIbuffer.bits.enqEnable := checkerOutStage1.fixedRange.asUInt & f3_instr_valid.asUInt & f3_lastHalf_mask io.toIbuffer.bits.valid := f3_lastHalf_mask & f3_instr_valid.asUInt } @@ -634,7 +634,8 @@ class NewIFU(implicit p: Parameters) extends XSModule val wb_valid = RegNext(RegNext(f2_fire && !f2_flush) && !f3_req_is_mmio && !f3_flush) val wb_ftq_req = RegNext(f3_ftq_req) - val wb_check_result = RegNext(checkerOut) + val wb_check_result_stage1 = RegNext(checkerOutStage1) + val wb_check_result_stage2 = checkerOutStage2 val wb_instr_range = RegNext(io.toIbuffer.bits.enqEnable) val wb_pc = RegNext(f3_pc) val wb_pd = RegNext(f3_pd) @@ -651,7 +652,7 @@ class NewIFU(implicit p: Parameters) extends XSModule /* false oversize */ val lastIsRVC = wb_instr_range.asTypeOf(Vec(PredictWidth,Bool())).last && wb_pd.last.isRVC val lastIsRVI = wb_instr_range.asTypeOf(Vec(PredictWidth,Bool()))(PredictWidth - 2) && !wb_pd(PredictWidth - 2).isRVC - val lastTaken = wb_check_result.fixedTaken.last + val lastTaken = wb_check_result_stage1.fixedTaken.last f3_wb_not_flush := wb_ftq_req.ftqIdx === f3_ftq_req.ftqIdx && f3_valid && wb_valid @@ -662,12 +663,12 @@ class NewIFU(implicit p: Parameters) extends XSModule checkFlushWb.bits.pd.zipWithIndex.map{case(instr,i) => instr.valid := wb_instr_valid(i)} checkFlushWb.bits.ftqIdx := wb_ftq_req.ftqIdx checkFlushWb.bits.ftqOffset := wb_ftq_req.ftqOffset.bits - checkFlushWb.bits.misOffset.valid := ParallelOR(wb_check_result.fixedMissPred) || wb_half_flush - checkFlushWb.bits.misOffset.bits := Mux(wb_half_flush, wb_lastIdx, ParallelPriorityEncoder(wb_check_result.fixedMissPred)) - checkFlushWb.bits.cfiOffset.valid := ParallelOR(wb_check_result.fixedTaken) - checkFlushWb.bits.cfiOffset.bits := ParallelPriorityEncoder(wb_check_result.fixedTaken) - checkFlushWb.bits.target := Mux(wb_half_flush, wb_half_target, wb_check_result.fixedTarget(ParallelPriorityEncoder(wb_check_result.fixedMissPred))) - checkFlushWb.bits.jalTarget := wb_check_result.fixedTarget(ParallelPriorityEncoder(VecInit(wb_pd.zip(wb_instr_valid).map{case (pd, v) => v && pd.isJal }))) + checkFlushWb.bits.misOffset.valid := ParallelOR(wb_check_result_stage2.fixedMissPred) || wb_half_flush + checkFlushWb.bits.misOffset.bits := Mux(wb_half_flush, wb_lastIdx, ParallelPriorityEncoder(wb_check_result_stage2.fixedMissPred)) + checkFlushWb.bits.cfiOffset.valid := ParallelOR(wb_check_result_stage1.fixedTaken) + checkFlushWb.bits.cfiOffset.bits := ParallelPriorityEncoder(wb_check_result_stage1.fixedTaken) + checkFlushWb.bits.target := Mux(wb_half_flush, wb_half_target, wb_check_result_stage2.fixedTarget(ParallelPriorityEncoder(wb_check_result_stage2.fixedMissPred))) + checkFlushWb.bits.jalTarget := wb_check_result_stage2.fixedTarget(ParallelPriorityEncoder(VecInit(wb_pd.zip(wb_instr_valid).map{case (pd, v) => v && pd.isJal }))) checkFlushWb.bits.instrRange := wb_instr_range.asTypeOf(Vec(PredictWidth, Bool())) toFtq.pdWb := Mux(wb_valid, checkFlushWb, mmioFlushWb) @@ -675,7 +676,7 @@ class NewIFU(implicit p: Parameters) extends XSModule wb_redirect := checkFlushWb.bits.misOffset.valid && wb_valid /*write back flush type*/ - val checkFaultType = wb_check_result.faultType + val checkFaultType = wb_check_result_stage2.faultType val checkJalFault = wb_valid && checkFaultType.map(_.isjalFault).reduce(_||_) val checkRetFault = wb_valid && checkFaultType.map(_.isRetFault).reduce(_||_) val checkTargetFault = wb_valid && checkFaultType.map(_.istargetFault).reduce(_||_) diff --git a/src/main/scala/xiangshan/frontend/PreDecode.scala b/src/main/scala/xiangshan/frontend/PreDecode.scala index ab6ca25b89f6d68c78e575cfe35623f9b6513a13..0e9dad1c694da67abdc07e37cef542ae8b962963 100644 --- a/src/main/scala/xiangshan/frontend/PreDecode.scala +++ b/src/main/scala/xiangshan/frontend/PreDecode.scala @@ -192,13 +192,17 @@ class CheckInfo extends Bundle { // 8 bit } class PredCheckerResp(implicit p: Parameters) extends XSBundle with HasPdConst { - //to Ibuffer write port (timing critical) - val fixedRange = Vec(PredictWidth, Bool()) - val fixedTaken = Vec(PredictWidth, Bool()) - //to Ftq write back port (not timing critical) - val fixedTarget = Vec(PredictWidth, UInt(VAddrBits.W)) - val fixedMissPred = Vec(PredictWidth, Bool()) - val faultType = Vec(PredictWidth, new CheckInfo) + //to Ibuffer write port (stage 1) + val stage1Out = new Bundle{ + val fixedRange = Vec(PredictWidth, Bool()) + val fixedTaken = Vec(PredictWidth, Bool()) + } + //to Ftq write back port (stage 2) + val stage2Out = new Bundle{ + val fixedTarget = Vec(PredictWidth, UInt(VAddrBits.W)) + val fixedMissPred = Vec(PredictWidth, Bool()) + val faultType = Vec(PredictWidth, new CheckInfo) + } } @@ -220,6 +224,7 @@ class PredChecker(implicit p: Parameters) extends XSModule with HasPdConst { * we first detecct remask fault and then use fixedRange to do second check **/ + //Stage 1: detect remask fault /** first check: remask Fault */ jalFaultVec := VecInit(pds.zipWithIndex.map{case(pd, i) => pd.isJal && instrRange(i) && instrValid(i) && (takenIdx > i.U && predTaken || !predTaken) }) retFaultVec := VecInit(pds.zipWithIndex.map{case(pd, i) => pd.isRet && instrRange(i) && instrValid(i) && (takenIdx > i.U && predTaken || !predTaken) }) @@ -228,28 +233,43 @@ class PredChecker(implicit p: Parameters) extends XSModule with HasPdConst { val needRemask = ParallelOR(remaskFault) val fixedRange = instrRange.asUInt & (Fill(PredictWidth, !needRemask) | Fill(PredictWidth, 1.U(1.W)) >> ~remaskIdx) - io.out.fixedRange := fixedRange.asTypeOf((Vec(PredictWidth, Bool()))) + io.out.stage1Out.fixedRange := fixedRange.asTypeOf((Vec(PredictWidth, Bool()))) - io.out.fixedTaken := VecInit(pds.zipWithIndex.map{case(pd, i) => instrValid (i) && fixedRange(i) && (pd.isRet || pd.isJal || takenIdx === i.U && predTaken && !pd.notCFI) }) + io.out.stage1Out.fixedTaken := VecInit(pds.zipWithIndex.map{case(pd, i) => instrValid (i) && fixedRange(i) && (pd.isRet || pd.isJal || takenIdx === i.U && predTaken && !pd.notCFI) }) /** second check: faulse prediction fault and target fault */ notCFITaken := VecInit(pds.zipWithIndex.map{case(pd, i) => fixedRange(i) && instrValid(i) && i.U === takenIdx && pd.notCFI && predTaken }) invalidTaken := VecInit(pds.zipWithIndex.map{case(pd, i) => fixedRange(i) && !instrValid(i) && i.U === takenIdx && predTaken }) - /** target calculation */ val jumpTargets = VecInit(pds.zipWithIndex.map{case(pd,i) => pc(i) + jumpOffset(i)}) - targetFault := VecInit(pds.zipWithIndex.map{case(pd,i) => fixedRange(i) && instrValid(i) && (pd.isJal || pd.isBr) && takenIdx === i.U && predTaken && (predTarget =/= jumpTargets(i))}) - val seqTargets = VecInit((0 until PredictWidth).map(i => pc(i) + Mux(pds(i).isRVC || !instrValid(i), 2.U, 4.U ) )) - io.out.faultType.zipWithIndex.map{case(faultType, i) => faultType.value := Mux(jalFaultVec(i) , FaultType.jalFault , - Mux(retFaultVec(i), FaultType.retFault , + //Stage 2: detect target fault + /** target calculation: in the next stage */ + val fixedRangeNext = RegNext(fixedRange) + val instrValidNext = RegNext(instrValid) + val takenIdxNext = RegNext(takenIdx) + val predTakenNext = RegNext(predTaken) + val predTargetNext = RegNext(predTarget) + val jumpTargetsNext = RegNext(jumpTargets) + val seqTargetsNext = RegNext(seqTargets) + val pdsNext = RegNext(pds) + val jalFaultVecNext = RegNext(jalFaultVec) + val retFaultVecNext = RegNext(retFaultVec) + val notCFITakenNext = RegNext(notCFITaken) + val invalidTakenNext = RegNext(invalidTaken) + + targetFault := VecInit(pdsNext.zipWithIndex.map{case(pd,i) => fixedRangeNext(i) && instrValidNext(i) && (pd.isJal || pd.isBr) && takenIdxNext === i.U && predTakenNext && (predTargetNext =/= jumpTargetsNext(i))}) + + + io.out.stage2Out.faultType.zipWithIndex.map{case(faultType, i) => faultType.value := Mux(jalFaultVecNext(i) , FaultType.jalFault , + Mux(retFaultVecNext(i), FaultType.retFault , Mux(targetFault(i), FaultType.targetFault , - Mux(notCFITaken(i) , FaultType.notCFIFault, - Mux(invalidTaken(i), FaultType.invalidTaken, FaultType.noFault)))))} + Mux(notCFITakenNext(i) , FaultType.notCFIFault, + Mux(invalidTakenNext(i), FaultType.invalidTaken, FaultType.noFault)))))} - io.out.fixedMissPred.zipWithIndex.map{case(missPred, i ) => missPred := jalFaultVec(i) || retFaultVec(i) || notCFITaken(i) || invalidTaken(i) || targetFault(i)} - io.out.fixedTarget.zipWithIndex.map{case(target, i) => target := Mux(jalFaultVec(i) || targetFault(i), jumpTargets(i), seqTargets(i) )} + io.out.stage2Out.fixedMissPred.zipWithIndex.map{case(missPred, i ) => missPred := jalFaultVecNext(i) || retFaultVecNext(i) || notCFITakenNext(i) || invalidTakenNext(i) || targetFault(i)} + io.out.stage2Out.fixedTarget.zipWithIndex.map{case(target, i) => target := Mux(jalFaultVecNext(i) || targetFault(i), jumpTargetsNext(i), seqTargetsNext(i) )} }