diff --git a/src/main/scala/utils/SRAMTemplate.scala b/src/main/scala/utils/SRAMTemplate.scala index 77ffc51d8cf70b7384c179d6bb37ebae2cb17b50..0aea9b812cd1f4102d358a483379442833f9ff1e 100644 --- a/src/main/scala/utils/SRAMTemplate.scala +++ b/src/main/scala/utils/SRAMTemplate.scala @@ -187,7 +187,12 @@ class FoldedSRAMTemplate[T <: Data](gen: T, set: Int, width: Int = 4, way: Int = val wdata = VecInit(Seq.fill(width)(io.w.req.bits.data).flatten) val waddr = io.w.req.bits.setIdx >> log2Ceil(width) val widthIdx = if (width != 1) io.w.req.bits.setIdx(log2Ceil(width)-1, 0) else 0.U - val wmask = if (width*way != 1) VecInit(Seq.tabulate(width*way)(n => (n / way).U === widthIdx)).asUInt else 1.U(1.W) + val wmask = (width, way) match { + case (1, 1) => 1.U(1.W) + case (x, 1) => UIntToOH(widthIdx) + case _ => VecInit(Seq.tabulate(width*way)(n => (n / way).U === widthIdx && io.w.req.bits.waymask.get(n % way))).asUInt + } + require(wmask.getWidth == way*width) array.io.w.apply(wen, wdata, waddr, wmask) }