From 967327d825d0045c06f42e2312a7e86c2e6b6662 Mon Sep 17 00:00:00 2001 From: LinJiawei Date: Sat, 15 Oct 2022 09:25:18 +0800 Subject: [PATCH] sms: prefetch to l1 --- src/main/scala/xiangshan/XSCore.scala | 6 - .../scala/xiangshan/backend/MemBlock.scala | 15 +- .../mem/prefetch/BasePrefecher.scala | 3 +- .../mem/prefetch/SMSPrefetcher.scala | 378 +++++++++++------- 4 files changed, 246 insertions(+), 156 deletions(-) diff --git a/src/main/scala/xiangshan/XSCore.scala b/src/main/scala/xiangshan/XSCore.scala index 6b74278fb..4683dcf80 100644 --- a/src/main/scala/xiangshan/XSCore.scala +++ b/src/main/scala/xiangshan/XSCore.scala @@ -366,12 +366,6 @@ class XSCoreImp(outer: XSCoreBase) extends LazyModuleImp(outer) XSPerfHistogram("fastIn_count", PopCount(allFastUop1.map(_.valid)), true.B, 0, allFastUop1.length, 1) XSPerfHistogram("wakeup_count", PopCount(rfWriteback.map(_.valid)), true.B, 0, rfWriteback.length, 1) - // l1 prefetch fuzzer, for debug only - val debug_l1PrefetchFuzzer = Module(new L1PrefetchFuzzer) - debug_l1PrefetchFuzzer.io.req <> memBlock.io.prefetch_req - debug_l1PrefetchFuzzer.io.vaddr := memBlock.io.writeback(0).bits.debug.vaddr - debug_l1PrefetchFuzzer.io.paddr := memBlock.io.writeback(0).bits.debug.paddr - ctrlBlock.perfinfo.perfEventsEu0 := exuBlocks(0).getPerf.dropRight(outer.exuBlocks(0).scheduler.numRs) ctrlBlock.perfinfo.perfEventsEu1 := exuBlocks(1).getPerf.dropRight(outer.exuBlocks(1).scheduler.numRs) if (!coreParams.softPTW) { diff --git a/src/main/scala/xiangshan/backend/MemBlock.scala b/src/main/scala/xiangshan/backend/MemBlock.scala index bac2fc8c5..a98737e2d 100644 --- a/src/main/scala/xiangshan/backend/MemBlock.scala +++ b/src/main/scala/xiangshan/backend/MemBlock.scala @@ -140,6 +140,7 @@ class MemBlockImp(outer: MemBlock) extends LazyModuleImp(outer) val stdExeUnits = Seq.fill(exuParameters.StuCnt)(Module(new StdExeUnit)) val stData = stdExeUnits.map(_.io.out) val exeUnits = loadUnits ++ storeUnits + val l1_pf_req = Wire(Decoupled(new L1PrefetchReq())) val prefetcherOpt: Option[BasePrefecher] = coreParams.prefetcher.map { case _: SMSParams => val sms = Module(new SMSPrefetcher()) @@ -156,6 +157,12 @@ class MemBlockImp(outer: MemBlock) extends LazyModuleImp(outer) outer.pf_sender_opt.get.out.head._1.l2_pf_en := RegNextN(io.csrCtrl.l2_pf_enable, 2, Some(true.B)) pf.io.enable := RegNextN(io.csrCtrl.l1D_pf_enable, 2, Some(false.B)) }) + prefetcherOpt match { + case Some(pf) => l1_pf_req <> pf.io.l1_req + case None => + l1_pf_req.valid := false.B + l1_pf_req.bits := DontCare + } val pf_train_on_hit = RegNextN(io.csrCtrl.l1D_pf_train_on_hit, 2, Some(true.B)) loadUnits.zipWithIndex.map(x => x._1.suggestName("LoadUnit_"+x._2)) @@ -186,14 +193,14 @@ class MemBlockImp(outer: MemBlock) extends LazyModuleImp(outer) val stOut = io.writeback.drop(exuParameters.LduCnt).dropRight(exuParameters.StuCnt) // prefetch to l1 req - loadUnits.map(load_unit => { - load_unit.io.prefetch_req.valid <> io.prefetch_req.valid - load_unit.io.prefetch_req.bits <> io.prefetch_req.bits + loadUnits.foreach(load_unit => { + load_unit.io.prefetch_req.valid <> l1_pf_req.valid + load_unit.io.prefetch_req.bits <> l1_pf_req.bits }) // when loadUnits(0) stage 0 is busy, hw prefetch will never use that pipeline loadUnits(0).io.prefetch_req.bits.confidence := 0.U - io.prefetch_req.ready := (io.prefetch_req.bits.confidence > 0.U) || + l1_pf_req.ready := (l1_pf_req.bits.confidence > 0.U) || loadUnits.map(!_.io.ldin.valid).reduce(_ || _) // TODO: fast load wakeup diff --git a/src/main/scala/xiangshan/mem/prefetch/BasePrefecher.scala b/src/main/scala/xiangshan/mem/prefetch/BasePrefecher.scala index 290ade8aa..59f327e8c 100644 --- a/src/main/scala/xiangshan/mem/prefetch/BasePrefecher.scala +++ b/src/main/scala/xiangshan/mem/prefetch/BasePrefecher.scala @@ -5,12 +5,13 @@ import chisel3.util._ import chipsalliance.rocketchip.config.Parameters import xiangshan._ import xiangshan.cache.mmu.TlbRequestIO -import xiangshan.mem.LsPipelineBundle +import xiangshan.mem.{L1PrefetchReq, LsPipelineBundle} class PrefetcherIO()(implicit p: Parameters) extends XSBundle { val ld_in = Flipped(Vec(exuParameters.LduCnt, ValidIO(new LsPipelineBundle()))) val tlb_req = new TlbRequestIO(nRespDups = 2) val pf_addr = ValidIO(UInt(PAddrBits.W)) + val l1_req = DecoupledIO(new L1PrefetchReq()) val enable = Input(Bool()) } diff --git a/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala b/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala index b4ef999ab..f00d761df 100644 --- a/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala +++ b/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala @@ -7,13 +7,16 @@ import xiangshan._ import utils._ import xiangshan.cache.HasDCacheParameters import xiangshan.cache.mmu._ +import xiangshan.mem.L1PrefetchReq case class SMSParams ( region_size: Int = 1024, vaddr_hash_width: Int = 5, block_addr_raw_width: Int = 10, - filter_table_size: Int = 16, + stride_pc_bits: Int = 10, + max_stride: Int = 1024, + stride_entries: Int = 16, active_gen_table_size: Int = 16, pht_size: Int = 64, pht_ways: Int = 2, @@ -42,6 +45,8 @@ trait HasSMSModuleHelper extends HasCircularQueuePtrHelper with HasDCacheParamet // page bit index in block addr val BLOCK_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / dcacheParameters.blockBytes) val REGION_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / smsParams.region_size) + val STRIDE_PC_BITS = smsParams.stride_pc_bits + val STRIDE_BLK_ADDR_BITS = log2Up(smsParams.max_stride) def block_addr(x: UInt): UInt = { val offset = log2Up(dcacheParameters.blockBytes) @@ -93,79 +98,128 @@ trait HasSMSModuleHelper extends HasCircularQueuePtrHelper with HasDCacheParamet def pht_tag(pc: UInt): UInt = { pc(PHT_INDEX_BITS + 2 + PHT_TAG_BITS - 1, PHT_INDEX_BITS + 2) } + + def get_alias_bits(region_vaddr: UInt): UInt = region_vaddr(7, 6) } -class FilterTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper { +class StridePF()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper { val io = IO(new Bundle() { - val s0_lookup = Flipped(ValidIO(new FilterEntry())) - val s1_result = ValidIO(new FilterEntry()) - val s1_update = Input(Bool()) + val stride_en = Input(Bool()) + val s0_lookup = Flipped(new ValidIO(new Bundle() { + val pc = UInt(STRIDE_PC_BITS.W) + val vaddr = UInt(VAddrBits.W) + val paddr = UInt(PAddrBits.W) + })) + val s1_valid = Input(Bool()) + val s2_gen_req = ValidIO(new PfGenReq()) }) - val s0_lookup_entry = io.s0_lookup.bits - val s0_lookup_valid = io.s0_lookup.valid + val prev_valid = RegNext(io.s0_lookup.valid, false.B) + val prev_pc = RegEnable(io.s0_lookup.bits.pc, io.s0_lookup.valid) - val entries = Seq.fill(smsParams.filter_table_size){ Reg(new FilterEntry()) } - val valids = Seq.fill(smsParams.filter_table_size){ RegInit(false.B) } - val w_ptr = RegInit(0.U(log2Up(smsParams.filter_table_size).W)) + val s0_valid = io.s0_lookup.valid && !(prev_valid && prev_pc === io.s0_lookup.bits.pc) - val prev_entry = RegEnable(s0_lookup_entry, s0_lookup_valid) - val prev_lookup_valid = RegNext(s0_lookup_valid, false.B) + def entry_map[T](fn: Int => T) = (0 until smsParams.stride_entries).map(fn) - val s0_entry_match_vec = entries.zip(valids).map({ - case (ent, v) => v && ent.region_tag === s0_lookup_entry.region_tag && ent.offset =/= s0_lookup_entry.offset + val replacement = ReplacementPolicy.fromString("plru", smsParams.stride_entries) + val valids = entry_map(_ => RegInit(false.B)) + val entries_pc = entry_map(_ => Reg(UInt(STRIDE_PC_BITS.W)) ) + val entries_conf = entry_map(_ => RegInit(1.U(2.W))) + val entries_last_addr = entry_map(_ => Reg(UInt(STRIDE_BLK_ADDR_BITS.W)) ) + val entries_stride = entry_map(_ => Reg(SInt((STRIDE_BLK_ADDR_BITS+1).W))) + + + val s0_match_vec = valids.zip(entries_pc).map({ + case (v, pc) => v && pc === io.s0_lookup.bits.pc }) - val s0_any_entry_match = Cat(s0_entry_match_vec).orR - val s0_matched_entry = Mux1H(s0_entry_match_vec, entries) - val s0_match_s1 = prev_lookup_valid && - prev_entry.region_tag === s0_lookup_entry.region_tag && prev_entry.offset =/= s0_lookup_entry.offset - - val s0_hit = s0_lookup_valid && (s0_any_entry_match || s0_match_s1) - - val s0_lookup_result = Wire(new FilterEntry()) - s0_lookup_result := Mux(s0_match_s1, prev_entry, s0_matched_entry) - io.s1_result.valid := RegNext(s0_hit, false.B) - io.s1_result.bits := RegEnable(s0_lookup_result, s0_hit) - - val s0_invalid_mask = valids.map(!_) - val s0_has_invalid_entry = Cat(s0_invalid_mask).orR - val s0_invalid_index = PriorityEncoder(s0_invalid_mask) - // if match, invalidte entry - for((v, i) <- valids.zipWithIndex){ - when(s0_lookup_valid && s0_entry_match_vec(i)){ - v := false.B - } - } - // stage1 - val s1_has_invalid_entry = RegEnable(s0_has_invalid_entry, s0_lookup_valid) - val s1_invalid_index = RegEnable(s0_invalid_index, s0_lookup_valid) - // alloc entry if (s0 miss && s1_update) - val s1_do_update = io.s1_update && prev_lookup_valid && !io.s1_result.valid - val update_ptr = Mux(s1_has_invalid_entry, s1_invalid_index, w_ptr) - when(s1_do_update && !s1_has_invalid_entry){ w_ptr := w_ptr + 1.U } - for((ent, i) <- entries.zipWithIndex){ - val wen = s1_do_update && update_ptr === i.U - when(wen){ + val s0_hit = s0_valid && Cat(s0_match_vec).orR + val s0_miss = s0_valid && !s0_hit + val s0_matched_conf = Mux1H(s0_match_vec, entries_conf) + val s0_matched_last_addr = Mux1H(s0_match_vec, entries_last_addr) + val s0_matched_last_stride = Mux1H(s0_match_vec, entries_stride) + + + val s1_vaddr = RegEnable(io.s0_lookup.bits.vaddr, s0_valid) + val s1_paddr = RegEnable(io.s0_lookup.bits.paddr, s0_valid) + val s1_hit = RegNext(s0_hit) && io.s1_valid + val s1_alloc = RegNext(s0_miss) && io.s1_valid + val s1_conf = RegNext(s0_matched_conf) + val s1_last_addr = RegNext(s0_matched_last_addr) + val s1_last_stride = RegNext(s0_matched_last_stride) + val s1_match_vec = RegNext(VecInit(s0_match_vec)) + + val BLOCK_OFFSET = log2Up(dcacheParameters.blockBytes) + val s1_new_stride_vaddr = s1_vaddr(BLOCK_OFFSET + STRIDE_BLK_ADDR_BITS - 1, BLOCK_OFFSET) + val s1_new_stride = (0.U(1.W) ## s1_new_stride_vaddr).asSInt - (0.U(1.W) ## s1_last_addr).asSInt + val s1_stride_non_zero = s1_last_stride =/= 0.S + val s1_stride_match = s1_new_stride === s1_last_stride && s1_stride_non_zero + val s1_replace_idx = replacement.way + + for(i <- 0 until smsParams.stride_entries){ + val alloc = s1_alloc && i.U === s1_replace_idx + val update = s1_hit && s1_match_vec(i) + when(update){ + assert(valids(i)) + entries_conf(i) := Mux(s1_stride_match, + Mux(s1_conf === 3.U, 3.U, s1_conf + 1.U), + Mux(s1_conf === 0.U, 0.U, s1_conf - 1.U) + ) + entries_last_addr(i) := s1_new_stride_vaddr + when(!s1_conf(1)){ + entries_stride(i) := s1_new_stride + } + } + when(alloc){ valids(i) := true.B - ent := prev_entry + entries_pc(i) := prev_pc + entries_conf(i) := 0.U + entries_last_addr(i) := s1_new_stride_vaddr + entries_stride(i) := 0.S } + assert(!(update && alloc)) } - - XSPerfAccumulate("sms_filter_table_hit", io.s1_result.valid) - XSPerfAccumulate("sms_filter_table_update", s1_do_update) - for(i <- 0 until smsParams.filter_table_size){ - XSPerfAccumulate(s"sms_filter_table_access_$i", - s1_do_update && update_ptr === i.U - ) + when(s1_hit){ + replacement.access(OHToUInt(s1_match_vec.asUInt)) + }.elsewhen(s1_alloc){ + replacement.access(s1_replace_idx) } -} -class FilterEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper { - val pht_index = UInt(PHT_INDEX_BITS.W) - val pht_tag = UInt(PHT_TAG_BITS.W) - val region_tag = UInt(REGION_TAG_WIDTH.W) - val offset = UInt(REGION_OFFSET.W) + val s1_block_vaddr = block_addr(s1_vaddr) + val s1_pf_block_vaddr = (s1_block_vaddr.asSInt + s1_last_stride).asUInt + val s1_pf_cross_page = s1_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT) =/= s1_block_vaddr(BLOCK_ADDR_PAGE_BIT) + + val s2_pf_gen_valid = RegNext(s1_hit && s1_stride_match, false.B) + val s2_pf_gen_paddr_valid = RegEnable(!s1_pf_cross_page, s1_hit && s1_stride_match) + val s2_pf_block_vaddr = RegEnable(s1_pf_block_vaddr, s1_hit && s1_stride_match) + val s2_block_paddr = RegEnable(block_addr(s1_paddr), s1_hit && s1_stride_match) + + val s2_pf_block_addr = Mux(s2_pf_gen_paddr_valid, + Cat( + s2_block_paddr(PAddrBits - BLOCK_OFFSET - 1, BLOCK_ADDR_PAGE_BIT), + s2_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT - 1, 0) + ), + s2_pf_block_vaddr + ) + val s2_pf_full_addr = Wire(UInt(VAddrBits.W)) + s2_pf_full_addr := s2_pf_block_addr ## 0.U(BLOCK_OFFSET.W) + + val s2_pf_region_addr = region_addr(s2_pf_full_addr) + val s2_pf_region_offset = s2_pf_block_addr(REGION_OFFSET - 1, 0) + + val s2_full_vaddr = Wire(UInt(VAddrBits.W)) + s2_full_vaddr := s2_pf_block_vaddr ## 0.U(BLOCK_OFFSET.W) + + val s2_region_tag = region_hash_tag(region_addr(s2_full_vaddr)) + + io.s2_gen_req.valid := s2_pf_gen_valid && io.stride_en + io.s2_gen_req.bits.region_tag := s2_region_tag + io.s2_gen_req.bits.region_addr := s2_pf_region_addr + io.s2_gen_req.bits.alias_bits := get_alias_bits(region_addr(s2_full_vaddr)) + io.s2_gen_req.bits.region_bits := region_offset_to_bits(s2_pf_region_offset) + io.s2_gen_req.bits.paddr_valid := s2_pf_gen_paddr_valid + io.s2_gen_req.bits.decr_mode := false.B + } class AGTEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper { @@ -173,6 +227,7 @@ class AGTEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelpe val pht_tag = UInt(PHT_TAG_BITS.W) val region_bits = UInt(REGION_BLKS.W) val region_tag = UInt(REGION_TAG_WIDTH.W) + val region_offset = UInt(REGION_OFFSET.W) val access_cnt = UInt((REGION_BLKS-1).U.getWidth.W) val decr_mode = Bool() } @@ -183,10 +238,12 @@ class PfGenReq()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelpe val region_bits = UInt(REGION_BLKS.W) val paddr_valid = Bool() val decr_mode = Bool() + val alias_bits = UInt(2.W) } class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper { val io = IO(new Bundle() { + val agt_en = Input(Bool()) val s0_lookup = Flipped(ValidIO(new Bundle() { val region_tag = UInt(REGION_TAG_WIDTH.W) val region_p1_tag = UInt(REGION_TAG_WIDTH.W) @@ -201,12 +258,10 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS val region_paddr = UInt(REGION_ADDR_BITS.W) val region_vaddr = UInt(REGION_ADDR_BITS.W) })) - // do not alloc entry in filter table if agt hit - val s1_match_or_alloc = Output(Bool()) - // if agt missed, try lookup pht + val s1_sel_stride = Output(Bool()) + val s2_stride_hit = Input(Bool()) + // if agt/stride missed, try lookup pht val s2_pht_lookup = ValidIO(new PhtLookup()) - // receive second hit from filter table - val s1_recv_entry = Flipped(ValidIO(new AGTEntry())) // evict entry to pht val s2_evict = ValidIO(new AGTEntry()) val s2_pf_gen_req = ValidIO(new PfGenReq()) @@ -218,6 +273,8 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS val valids = Seq.fill(smsParams.active_gen_table_size){ RegInit(false.B) } val replacement = ReplacementPolicy.fromString("plru", smsParams.active_gen_table_size) + val s1_replace_mask_w = Wire(UInt(smsParams.active_gen_table_size.W)) + val s0_lookup = io.s0_lookup.bits val s0_lookup_valid = io.s0_lookup.valid @@ -241,9 +298,8 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS val any_region_m1_match = Cat(region_m1_match_vec_s0).orR && s0_lookup.allow_cross_region_m1 val s0_region_hit = any_region_match - // region miss, but cross region match - val s0_alloc = !s0_region_hit && (any_region_p1_match || any_region_m1_match) && !s0_match_prev - val s0_match_or_alloc = any_region_match || any_region_p1_match || any_region_m1_match + val s0_cross_region_hit = any_region_m1_match || any_region_p1_match + val s0_alloc = s0_lookup_valid && !s0_region_hit && !s0_match_prev val s0_pf_gen_match_vec = valids.indices.map(i => { Mux(any_region_match, region_match_vec_s0(i), @@ -258,32 +314,35 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS s0_agt_entry.pht_tag := s0_lookup.pht_tag s0_agt_entry.region_bits := region_offset_to_bits(s0_lookup.region_offset) s0_agt_entry.region_tag := s0_lookup.region_tag + s0_agt_entry.region_offset := s0_lookup.region_offset s0_agt_entry.access_cnt := 1.U // lookup_region + 1 == entry_region // lookup_region = entry_region - 1 => decr mode s0_agt_entry.decr_mode := !s0_region_hit && !any_region_m1_match && any_region_p1_match - val s0_replace_mask = UIntToOH(replacement.way) + val s0_replace_way = replacement.way + val s0_replace_mask = UIntToOH(s0_replace_way) // s0 hit a entry that may be replaced in s1 - val s0_update_conflict = Cat(VecInit(region_match_vec_s0).asUInt & s0_replace_mask).orR + val s0_update_conflict = Cat(VecInit(region_match_vec_s0).asUInt & s1_replace_mask_w).orR + val s0_update = s0_lookup_valid && s0_region_hit && !s0_update_conflict + + val s0_access_way = Mux1H( + Seq(s0_update, s0_alloc), + Seq(OHToUInt(region_match_vec_s0), s0_replace_way) + ) + when(s0_update || s0_alloc) { + replacement.access(s0_access_way) + } // stage1: update/alloc // region hit, update entry - val s1_update_conflict = RegEnable(s0_update_conflict, s0_lookup_valid && s0_region_hit) - val s1_update = RegNext(s0_lookup_valid && s0_region_hit, false.B) && !s1_update_conflict - val s1_update_mask = RegEnable( - VecInit(region_match_vec_s0), - VecInit(Seq.fill(smsParams.active_gen_table_size){ false.B }), - s0_lookup_valid - ) + val s1_update = RegNext(s0_update, false.B) + val s1_update_mask = RegEnable(VecInit(region_match_vec_s0), s0_lookup_valid) val s1_agt_entry = RegEnable(s0_agt_entry, s0_lookup_valid) - val s1_recv_entry = io.s1_recv_entry - val s1_drop = RegInit(false.B) - // cross region match or filter table second hit - val s1_cross_region_match = RegNext(s0_lookup_valid && s0_alloc, false.B) - val s1_alloc = s1_cross_region_match || (s1_recv_entry.valid && !s1_drop && !s1_update) - s1_drop := s0_lookup_valid && s0_match_prev && s1_alloc // TODO: use bypass update instead of drop - val s1_alloc_entry = Mux(s1_recv_entry.valid, s1_recv_entry.bits, s1_agt_entry) + val s1_cross_region_match = RegNext(s0_lookup_valid && s0_cross_region_hit, false.B) + val s1_alloc = RegNext(s0_alloc, false.B) + val s1_alloc_entry = s1_agt_entry val s1_replace_mask = RegEnable(s0_replace_mask, s0_lookup_valid) + s1_replace_mask_w := s1_replace_mask & Fill(smsParams.active_gen_table_size, s1_alloc) val s1_evict_entry = Mux1H(s1_replace_mask, entries) val s1_evict_valid = Mux1H(s1_replace_mask, valids) // pf gen @@ -303,13 +362,6 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS valids(i) := valids(i) || alloc entries(i) := Mux(alloc, s1_alloc_entry, Mux(update, update_entry, entries(i))) } - when(s1_update) { - replacement.access(OHToUInt(s1_update_mask)) - }.elsewhen(s1_alloc){ - replacement.access(OHToUInt(s1_replace_mask)) - } - - io.s1_match_or_alloc := s1_update || s1_alloc || s1_drop when(s1_update){ assert(PopCount(s1_update_mask) === 1.U, "multi-agt-update") @@ -350,10 +402,11 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS ) val s1_pf_gen_offset_mask = UIntToOH(s1_pf_gen_offset) val s1_pf_gen_access_cnt = Mux1H(s1_pf_gen_match_vec, entries.map(_.access_cnt)) - val s1_pf_gen_valid = prev_lookup_valid && io.s1_match_or_alloc && Mux(s1_pf_gen_decr_mode, + val s1_in_active_page = s1_pf_gen_access_cnt > io.act_threshold + val s1_pf_gen_valid = prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && Mux(s1_pf_gen_decr_mode, !s1_vaddr_dec_cross_max_lim, !s1_vaddr_inc_cross_max_lim - ) && (s1_pf_gen_access_cnt > io.act_threshold) + ) && s1_in_active_page && io.agt_en val s1_pf_gen_paddr_valid = Mux(s1_pf_gen_decr_mode, !s1_vaddr_dec_cross_page, !s1_vaddr_inc_cross_page) val s1_pf_gen_region_addr = Mux(s1_pf_gen_paddr_valid, Cat(s1_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), s1_pf_gen_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)), @@ -381,6 +434,8 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS s1_pht_lookup.region_paddr := s1_region_paddr s1_pht_lookup.region_offset := s1_region_offset + io.s1_sel_stride := prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && !s1_in_active_page + // stage2: gen pf reg / evict entry to pht val s2_evict_entry = RegEnable(s1_evict_entry, s1_alloc) val s2_evict_valid = RegNext(s1_alloc && s1_evict_valid, false.B) @@ -388,9 +443,10 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS val s2_pf_gen_region_tag = RegEnable(s1_pf_gen_region_tag, s1_pf_gen_valid) val s2_pf_gen_decr_mode = RegEnable(s1_pf_gen_decr_mode, s1_pf_gen_valid) val s2_pf_gen_region_paddr = RegEnable(s1_pf_gen_region_addr, s1_pf_gen_valid) + val s2_pf_gen_alias_bits = RegEnable(get_alias_bits(s1_pf_gen_vaddr), s1_pf_gen_valid) val s2_pf_gen_region_bits = RegEnable(s1_pf_gen_region_bits, s1_pf_gen_valid) val s2_pf_gen_valid = RegNext(s1_pf_gen_valid, false.B) - val s2_pht_lookup_valid = RegNext(s1_pht_lookup_valid, false.B) + val s2_pht_lookup_valid = RegNext(s1_pht_lookup_valid, false.B) && !io.s2_stride_hit val s2_pht_lookup = RegEnable(s1_pht_lookup, s1_pht_lookup_valid) io.s2_evict.valid := s2_evict_valid @@ -398,6 +454,7 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS io.s2_pf_gen_req.bits.region_tag := s2_pf_gen_region_tag io.s2_pf_gen_req.bits.region_addr := s2_pf_gen_region_paddr + io.s2_pf_gen_req.bits.alias_bits := s2_pf_gen_alias_bits io.s2_pf_gen_req.bits.region_bits := s2_pf_gen_region_bits io.s2_pf_gen_req.bits.paddr_valid := s2_paddr_valid io.s2_pf_gen_req.bits.decr_mode := s2_pf_gen_decr_mode @@ -406,6 +463,7 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS io.s2_pht_lookup.valid := s2_pht_lookup_valid io.s2_pht_lookup.bits := s2_pht_lookup + XSPerfAccumulate("sms_agt_in", io.s0_lookup.valid) XSPerfAccumulate("sms_agt_alloc", s1_alloc) // cross region match or filter evict XSPerfAccumulate("sms_agt_update", s1_update) // entry hit XSPerfAccumulate("sms_agt_pf_gen", io.s2_pf_gen_req.valid) @@ -466,6 +524,11 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS evict_queue.io.in := io.agt_update val evict = evict_queue.io.out + XSPerfAccumulate("sms_pht_lookup_in", lookup_queue.io.in.fire) + XSPerfAccumulate("sms_pht_lookup_out", lookup_queue.io.out.fire) + XSPerfAccumulate("sms_pht_evict_in", evict_queue.io.in.fire) + XSPerfAccumulate("sms_pht_evict_out", evict_queue.io.out.fire) + val s3_ram_en = Wire(Bool()) val s1_valid = Wire(Bool()) // if s1.raddr == s2.waddr or s3 is using ram port, block s1 @@ -481,9 +544,9 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS lookup.bits.pht_index ) val s0_tag = Mux(evict.valid, evict.bits.pht_tag, lookup.bits.pht_tag) + val s0_region_offset = Mux(evict.valid, evict.bits.region_offset, lookup.bits.region_offset) val s0_region_paddr = lookup.bits.region_paddr val s0_region_vaddr = lookup.bits.region_vaddr - val s0_region_offset = lookup.bits.region_offset val s0_region_bits = evict.bits.region_bits val s0_decr_mode = evict.bits.decr_mode val s0_evict = evict.valid @@ -520,8 +583,8 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS ) // pipe s2: generate ram write addr/data - val s2_valid = RegNext(s1_valid && !s3_ram_en, false.B) - val s2_reg_en = s1_valid && !s3_ram_en + val s2_valid = RegNext(s1_valid && !s1_wait, false.B) + val s2_reg_en = s1_valid && !s1_wait val s2_hist_update_mask = RegEnable(s1_hist_update_mask, s2_reg_en) val s2_hist_bits = RegEnable(s1_hist_bits, s2_reg_en) val s2_tag = RegEnable(s1_tag, s2_reg_en) @@ -632,7 +695,17 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS val s3_incr_region_valid = s3_pf_gen_valid && (s3_hist_hi & (~s3_hist_update_mask.head(REGION_BLKS - 1)).asUInt).orR val s3_decr_region_valid = s3_pf_gen_valid && (s3_hist_lo & (~s3_hist_update_mask.tail(REGION_BLKS - 1)).asUInt).orR val s3_incr_region_vaddr = s3_region_vaddr + 1.U + val s3_incr_alias_bits = get_alias_bits(s3_incr_region_vaddr) val s3_decr_region_vaddr = s3_region_vaddr - 1.U + val s3_decr_alias_bits = get_alias_bits(s3_decr_region_vaddr) + val s3_incr_region_paddr = Cat( + s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), + s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0) + ) + val s3_decr_region_paddr = Cat( + s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), + s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0) + ) val s3_incr_crosspage = s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT) val s3_decr_crosspage = s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT) val s3_cur_region_tag = region_hash_tag(s3_region_vaddr) @@ -650,6 +723,7 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS s4_pf_gen_cur_region_valid := s3_cur_region_valid when(s3_cur_region_valid){ s4_pf_gen_cur_region.region_addr := s3_region_paddr + s4_pf_gen_cur_region.alias_bits := get_alias_bits(s3_region_vaddr) s4_pf_gen_cur_region.region_tag := s3_cur_region_tag s4_pf_gen_cur_region.region_bits := s3_cur_region_bits s4_pf_gen_cur_region.paddr_valid := true.B @@ -658,7 +732,8 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS s4_pf_gen_incr_region_valid := s3_incr_region_valid || (!pf_gen_req_arb.io.in(1).ready && s4_pf_gen_incr_region_valid) when(s3_incr_region_valid){ - s4_pf_gen_incr_region.region_addr := Mux(s3_incr_crosspage, s3_incr_region_vaddr, s3_region_paddr) + s4_pf_gen_incr_region.region_addr := Mux(s3_incr_crosspage, s3_incr_region_vaddr, s3_incr_region_paddr) + s4_pf_gen_incr_region.alias_bits := s3_incr_alias_bits s4_pf_gen_incr_region.region_tag := s3_incr_region_tag s4_pf_gen_incr_region.region_bits := s3_incr_region_bits s4_pf_gen_incr_region.paddr_valid := !s3_incr_crosspage @@ -667,7 +742,8 @@ class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMS s4_pf_gen_decr_region_valid := s3_decr_region_valid || (!pf_gen_req_arb.io.in(2).ready && s4_pf_gen_decr_region_valid) when(s3_decr_region_valid){ - s4_pf_gen_decr_region.region_addr := Mux(s3_decr_crosspage, s3_decr_region_vaddr, s3_region_paddr) + s4_pf_gen_decr_region.region_addr := Mux(s3_decr_crosspage, s3_decr_region_vaddr, s3_decr_region_paddr) + s4_pf_gen_decr_region.alias_bits := s3_decr_alias_bits s4_pf_gen_decr_region.region_tag := s3_decr_region_tag s4_pf_gen_decr_region.region_bits := s3_decr_region_bits s4_pf_gen_decr_region.paddr_valid := !s3_decr_crosspage @@ -703,6 +779,7 @@ class PrefetchFilterEntry()(implicit p: Parameters) extends XSBundle with HasSMS val region_addr = UInt(REGION_ADDR_BITS.W) val region_bits = UInt(REGION_BLKS.W) val filter_bits = UInt(REGION_BLKS.W) + val alias_bits = UInt(2.W) val paddr_valid = Bool() val decr_mode = Bool() } @@ -712,6 +789,7 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul val gen_req = Flipped(ValidIO(new PfGenReq())) val tlb_req = new TlbRequestIO(2) val l2_pf_addr = ValidIO(UInt(PAddrBits.W)) + val pf_alias_bits = Output(UInt(2.W)) }) val entries = Seq.fill(smsParams.pf_filter_size){ Reg(new PrefetchFilterEntry()) } val valids = Seq.fill(smsParams.pf_filter_size){ RegInit(false.B) } @@ -720,14 +798,17 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul val prev_valid = RegNext(io.gen_req.valid, false.B) val prev_gen_req = RegEnable(io.gen_req.bits, io.gen_req.valid) - val tlb_req_arb = Module(new RRArbiter(new TlbReq, smsParams.pf_filter_size)) - val pf_req_arb = Module(new RRArbiter(UInt(PAddrBits.W), smsParams.pf_filter_size)) + val tlb_req_arb = Module(new RRArbiterInit(new TlbReq, smsParams.pf_filter_size)) + val pf_req_arb = Module(new RRArbiterInit(UInt(PAddrBits.W), smsParams.pf_filter_size)) io.tlb_req.req <> tlb_req_arb.io.out io.tlb_req.resp.ready := true.B io.tlb_req.req_kill := false.B io.l2_pf_addr.valid := pf_req_arb.io.out.valid io.l2_pf_addr.bits := pf_req_arb.io.out.bits + io.pf_alias_bits := Mux1H(entries.zipWithIndex.map({ + case (entry, i) => (i.U === pf_req_arb.io.chosen) -> entry.alias_bits + })) pf_req_arb.io.out.ready := true.B val s1_valid = Wire(Bool()) @@ -753,16 +834,17 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul tlb_req_arb.io.in(i).bits.cmd := TlbCmd.read tlb_req_arb.io.in(i).bits.size := 3.U tlb_req_arb.io.in(i).bits.robIdx := DontCare + tlb_req_arb.io.in(i).bits.no_translate := false.B tlb_req_arb.io.in(i).bits.debug := DontCare val pending_req_vec = ent.region_bits & (~ent.filter_bits).asUInt val first_one_offset = PriorityMux( pending_req_vec.asBools, - (0 until smsParams.filter_table_size).map(_.U(REGION_OFFSET.W)) + (0 until smsParams.pf_filter_size).map(_.U(REGION_OFFSET.W)) ) val last_one_offset = PriorityMux( pending_req_vec.asBools.reverse, - (0 until smsParams.filter_table_size).reverse.map(_.U(REGION_OFFSET.W)) + (0 until smsParams.pf_filter_size).reverse.map(_.U(REGION_OFFSET.W)) ) val pf_addr = Cat( ent.region_addr, @@ -776,9 +858,16 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul val s0_tlb_fire_vec = VecInit(tlb_req_arb.io.in.map(_.fire)) val s0_pf_fire_vec = VecInit(pf_req_arb.io.in.map(_.fire)) + val s0_update_way = OHToUInt(s0_match_vec) + val s0_replace_way = replacement.way + val s0_access_way = Mux(s0_any_matched, s0_update_way, s0_replace_way) + when(s0_gen_req_valid){ + replacement.access(s0_access_way) + } + // s1: update or alloc val s1_valid_r = RegNext(s0_gen_req_valid, false.B) - val s1_hit_r = RegEnable(s0_hit, s0_gen_req_valid) + val s1_hit_r = RegEnable(s0_hit, false.B, s0_gen_req_valid) val s1_gen_req = RegEnable(s0_gen_req, s0_gen_req_valid) val s1_replace_vec_r = RegEnable(s0_replace_vec, s0_gen_req_valid && !s0_hit) val s1_update_vec = RegEnable(VecInit(s0_match_vec).asUInt, s0_gen_req_valid && s0_hit) @@ -794,6 +883,7 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul s1_alloc_entry.paddr_valid := s1_gen_req.paddr_valid s1_alloc_entry.decr_mode := s1_gen_req.decr_mode s1_alloc_entry.filter_bits := 0.U + s1_alloc_entry.alias_bits := s1_gen_req.alias_bits for(((v, ent), i) <- valids.zip(entries).zipWithIndex){ val alloc = s1_valid && !s1_hit && s1_replace_vec(i) val update = s1_valid && s1_hit && s1_update_vec(i) @@ -816,11 +906,6 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul v := true.B } } - val s1_access_mask = Mux(s1_hit, s1_update_vec, s1_replace_vec) - val s1_access_way = OHToUInt(s1_access_mask.asUInt) - when(s1_valid){ - replacement.access(s1_access_way) - } when(s1_valid && s1_hit){ assert(PopCount(s1_update_vec) === 1.U, "sms_pf_filter: multi-hit") } @@ -830,7 +915,7 @@ class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModul XSPerfAccumulate("sms_pf_filter_tlb_req", io.tlb_req.req.fire) XSPerfAccumulate("sms_pf_filter_tlb_resp_miss", io.tlb_req.resp.fire && io.tlb_req.resp.bits.miss) for(i <- 0 until smsParams.pf_filter_size){ - XSPerfAccumulate(s"sms_pf_filter_access_way_$i", s1_valid && s1_access_way === i.U) + XSPerfAccumulate(s"sms_pf_filter_access_way_$i", s0_gen_req_valid && s0_access_way === i.U) } XSPerfAccumulate("sms_pf_filter_l2_req", io.l2_pf_addr.valid) } @@ -840,6 +925,7 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM require(exuParameters.LduCnt == 2) val io_agt_en = IO(Input(Bool())) + val io_stride_en = IO(Input(Bool())) val io_pht_en = IO(Input(Bool())) val io_act_threshold = IO(Input(UInt(REGION_OFFSET.W))) val io_act_stride = IO(Input(UInt(6.W))) @@ -869,19 +955,14 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM val pending_sel_ld0 = RegNext(Mux(pending_vld, ld0_older_than_ld1, !ld0_older_than_ld1)) val pending_ld = Mux(pending_sel_ld0, ld_prev.head, ld_prev.last) val pending_ld_block_tag = Mux(pending_sel_ld0, ld_prev_block_tag.head, ld_prev_block_tag.last) - - // prepare training data - val train_ld = RegEnable( - Mux(pending_vld, pending_ld, Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr.head, ld_curr.last)), - pending_vld || Cat(ld_curr_vld).orR + val oldest_ld = Mux(pending_vld, + pending_ld, + Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr.head, ld_curr.last) ) - val train_block_tag = RegEnable( - Mux(pending_vld, pending_ld_block_tag, - Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr_block_tag.head, ld_curr_block_tag.last) - ), - pending_vld || Cat(ld_curr_vld).orR - ) + val train_ld = RegEnable(oldest_ld, pending_vld || Cat(ld_curr_vld).orR) + + val train_block_tag = block_hash_tag(train_ld.vaddr) val train_region_tag = train_block_tag.head(REGION_TAG_WIDTH) val train_region_addr_raw = region_addr(train_ld.vaddr)(REGION_TAG_WIDTH + 2 * VADDR_HASH_WIDTH - 1, 0) @@ -904,8 +985,8 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM // prefetch stage0 - val filter_table = Module(new FilterTable()) val active_gen_table = Module(new ActiveGenerationTable()) + val stride = Module(new StridePF()) val pht = Module(new PatternHistoryTable()) val pf_filter = Module(new PrefetchFilter()) @@ -924,13 +1005,7 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM val train_region_paddr_s0 = RegEnable(train_region_paddr, train_vld) val train_region_vaddr_s0 = RegEnable(train_region_vaddr, train_vld) - filter_table.io.s0_lookup.valid := train_vld_s0 - filter_table.io.s0_lookup.bits.pht_tag := train_pht_tag_s0 - filter_table.io.s0_lookup.bits.pht_index := train_pht_index_s0 - filter_table.io.s0_lookup.bits.region_tag := train_region_tag_s0 - filter_table.io.s0_lookup.bits.offset := train_region_offset_s0 - filter_table.io.s1_update := !active_gen_table.io.s1_match_or_alloc - + active_gen_table.io.agt_en := io_agt_en active_gen_table.io.act_threshold := io_act_threshold active_gen_table.io.act_stride := io_act_stride active_gen_table.io.s0_lookup.valid := train_vld_s0 @@ -946,33 +1021,44 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM active_gen_table.io.s0_lookup.bits.region_m1_cross_page := train_region_m1_cross_page_s0 active_gen_table.io.s0_lookup.bits.region_paddr := train_region_paddr_s0 active_gen_table.io.s0_lookup.bits.region_vaddr := train_region_vaddr_s0 + active_gen_table.io.s2_stride_hit := stride.io.s2_gen_req.valid - val train_region_offset_s1 = RegEnable(train_region_offset_s0, train_vld_s0) - val agt_region_bits_s1 = region_offset_to_bits(train_region_offset_s1) | - region_offset_to_bits(filter_table.io.s1_result.bits.offset) - - active_gen_table.io.s1_recv_entry.valid := filter_table.io.s1_result.valid - active_gen_table.io.s1_recv_entry.bits.pht_index := filter_table.io.s1_result.bits.pht_index - active_gen_table.io.s1_recv_entry.bits.pht_tag := filter_table.io.s1_result.bits.pht_tag - active_gen_table.io.s1_recv_entry.bits.region_bits := agt_region_bits_s1 - active_gen_table.io.s1_recv_entry.bits.region_tag := filter_table.io.s1_result.bits.region_tag - active_gen_table.io.s1_recv_entry.bits.access_cnt := 2.U - active_gen_table.io.s1_recv_entry.bits.decr_mode := false.B + stride.io.stride_en := io_stride_en + stride.io.s0_lookup.valid := train_vld_s0 + stride.io.s0_lookup.bits.pc := train_s0.uop.cf.pc(STRIDE_PC_BITS - 1, 0) + stride.io.s0_lookup.bits.vaddr := Cat( + train_region_vaddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W) + ) + stride.io.s0_lookup.bits.paddr := Cat( + train_region_paddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W) + ) + stride.io.s1_valid := active_gen_table.io.s1_sel_stride pht.io.s2_agt_lookup := active_gen_table.io.s2_pht_lookup pht.io.agt_update := active_gen_table.io.s2_evict val pht_gen_valid = pht.io.pf_gen_req.valid && io_pht_en - val agt_gen_valid = active_gen_table.io.s2_pf_gen_req.valid && io_agt_en - val pf_gen_req = Mux(agt_gen_valid, - active_gen_table.io.s2_pf_gen_req.bits, + val agt_gen_valid = active_gen_table.io.s2_pf_gen_req.valid + val stride_gen_valid = stride.io.s2_gen_req.valid + val pf_gen_req = Mux(agt_gen_valid || stride_gen_valid, + Mux1H(Seq( + agt_gen_valid -> active_gen_table.io.s2_pf_gen_req.bits, + stride_gen_valid -> stride.io.s2_gen_req.bits + )), pht.io.pf_gen_req.bits ) - pf_filter.io.gen_req.valid := pht_gen_valid || agt_gen_valid + assert(!(agt_gen_valid && stride_gen_valid)) + pf_filter.io.gen_req.valid := pht_gen_valid || agt_gen_valid || stride_gen_valid pf_filter.io.gen_req.bits := pf_gen_req io.tlb_req <> pf_filter.io.tlb_req - io.pf_addr.valid := pf_filter.io.l2_pf_addr.valid && io.enable + val is_valid_address = pf_filter.io.l2_pf_addr.bits > 0x80000000L.U + io.pf_addr.valid := false.B //pf_filter.io.l2_pf_addr.valid && io.enable && is_valid_address io.pf_addr.bits := pf_filter.io.l2_pf_addr.bits + io.l1_req.bits.paddr := pf_filter.io.l2_pf_addr.bits + io.l1_req.bits.alias := pf_filter.io.pf_alias_bits + io.l1_req.bits.is_store := true.B + io.l1_req.bits.confidence := 1.U + io.l1_req.valid := pf_filter.io.l2_pf_addr.valid && io.enable && is_valid_address XSPerfAccumulate("sms_pf_gen_conflict", pht_gen_valid && agt_gen_valid @@ -980,4 +1066,6 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM XSPerfAccumulate("sms_pht_disabled", pht.io.pf_gen_req.valid && !io_pht_en) XSPerfAccumulate("sms_agt_disabled", active_gen_table.io.s2_pf_gen_req.valid && !io_agt_en) XSPerfAccumulate("sms_pf_real_issued", io.pf_addr.valid) -} + XSPerfAccumulate("sms_l1_req_valid", io.l1_req.valid) + XSPerfAccumulate("sms_l1_req_fire", io.l1_req.fire) +} \ No newline at end of file -- GitLab