提交 6dfb1ae3 编写于 作者: Z ZhangZifei

[WIP]PTW: change ptwl2 and tlbl2: from 1*256 to 4*64

上级 f7a5c579
......@@ -114,6 +114,66 @@ class TlbEntry extends TlbBundle {
}
}
class TlbEntires(num: Int, tagLen: Int) extends TlbBundle {
require(log2Up(num)==log2Down(num))
/* vpn can be divide into three part */
// vpn: tagPart + addrPart
val cutLen = log2Up(num)
val tag = UInt(tagLen.W) // NOTE: high part of vpn
val level = UInt(log2Up(Level).W)
val ppns = Vec(num, UInt(ppnLen.W))
val perms = Vec(num, new PermBundle(hasV = false))
val vs = Vec(num, Bool())
def tagClip(vpn: UInt, level: UInt) = { // full vpn => tagLen
Mux(level===0.U, Cat(vpn(vpnLen-1, vpnnLen*2+cutLen), 0.U(vpnnLen*2+cutLen)),
Mux(level===1.U, Cat(vpn(vpnLen-1, vpnnLen*1+cutLen), 0.U(vpnnLen*1+cutLen)),
Cat(vpn(vpnLen-1, vpnnLen*0+cutLen), 0.U(vpnnLen*0+cutLen))))(tagLen-1, 0)
}
// NOTE: get insize idx
def idxClip(vpn: UInt, level: UInt) = {
Mux(level===0.U, vpn(vpnnLen*2+cutLen-1, vpnnLen*2),
Mux(level===1.U, vpn(vpnnLen*1+cutLen-1, vpnnLen*1),
vpn(vpnnLen*0+cutLen-1, vpnnLen*0)))
}
def hit(vpn: UInt) = {
(tag === tagClip(vpn, level)) && vs(idxClip(vpn, level))
}
def genEntries(data: UInt, level: UInt, vpn: UInt): TlbEntires = {
require((data.getWidth / XLEN) == num,
"input data length must be multiple of pte length")
assert(level=/=3.U, "level should not be 3")
val ts = Wire(new TlbEntires(num, tagLen))
ts.tag := tagClip(vpn, level)
ts.level := level
for (i <- 0 until num) {
val pte = data((i+1)*XLEN-1, i*XLEN).asTypeOf(new PteBundle)
ts.ppns(i) := pte.ppn
ts.perms(i):= pte.perm // this.perms has no v
ts.vs(i) := !pte.isPf(level) && pte.isLeaf() // legal and leaf, store to l2Tlb
}
ts
}
def get(vpn: UInt): TlbEntry = {
val t = Wire(new TlbEntry())
val idx = idxClip(vpn, level)
t.vpn := vpn // Note: Use input vpn, not vpn in TlbL2
t.ppn := ppns(idx)
t.level := level
t.perm := perms(idx)
t
}
override def cloneType: this.type = (new TlbEntires(num, tagLen)).asInstanceOf[this.type]
}
object TlbCmd {
def read = "b00".U
def write = "b01".U
......
......@@ -11,6 +11,23 @@ import freechips.rocketchip.tilelink.{TLClientNode, TLMasterParameters, TLMaster
trait HasPtwConst extends HasTlbConst with MemoryOpConstants{
val PtwWidth = 2
val MemBandWidth = 256 // TODO: change to IO bandwidth param
val TlbL2LineSize = MemBandWidth/XLEN
val TlbL2LineNum = TlbL2EntrySize/TlbL2LineSize
val PtwL2LineSize = MemBandWidth/XLEN
val PtwL2LineNum = PtwL2EntrySize/PtwL2LineSize
val PtwL1TagLen = PAddrBits - log2Up(XLEN/8)
val PtwL2TagLen = PAddrBits - log2Up(XLEN/8) - log2Up(PtwL2EntrySize)
val TlbL2TagLen = vpnLen - log2Up(TlbL2EntrySize)
def genPtwL2Idx(addr: UInt) = {
/* tagLen :: outSizeIdxLen :: insideIdxLen*/
addr(log2Up(PtwL2EntrySize)-1+log2Up(XLEN/8), log2Up(PtwL2LineSize)+log2Up(XLEN/8))
}
def genTlbL2Idx(vpn: UInt) = {
vpn(log2Up(TlbL2LineNum)-1, 0)
}
def MakeAddr(ppn: UInt, off: UInt) = {
require(off.getWidth == 9)
......@@ -64,9 +81,7 @@ class PteBundle extends PtwBundle{
class PtwEntry(tagLen: Int) extends PtwBundle {
val tag = UInt(tagLen.W)
val ppn = UInt(ppnLen.W)
// val perm = new PermBundle
// TODO: add superpage
def hit(addr: UInt) = {
require(addr.getWidth >= PAddrBits)
tag === addr(PAddrBits-1, PAddrBits-tagLen)
......@@ -75,14 +90,12 @@ class PtwEntry(tagLen: Int) extends PtwBundle {
def refill(addr: UInt, pte: UInt) {
tag := addr(PAddrBits-1, PAddrBits-tagLen)
ppn := pte.asTypeOf(pteBundle).ppn
// perm := pte.asTypeOf(pteBundle).perm
}
def genPtwEntry(addr: UInt, pte: UInt) = {
val e = Wire(new PtwEntry(tagLen))
e.tag := addr(PAddrBits-1, PAddrBits-tagLen)
e.ppn := pte.asTypeOf(pteBundle).ppn
// e.perm := pte.asTypeOf(pteBundle).perm
e
}
......@@ -94,6 +107,49 @@ class PtwEntry(tagLen: Int) extends PtwBundle {
}
}
class PtwEntries(num: Int, tagLen: Int) extends PtwBundle {
require(log2Up(num)==log2Down(num))
val tag = UInt(tagLen.W)
val ppns = Vec(num, UInt(ppnLen.W))
val vs = Vec(num, Bool())
def tagClip(addr: UInt) = {
require(addr.getWidth==PAddrBits)
addr(PAddrBits-1, PAddrBits-tagLen)
}
def hit(idx: UInt, addr: UInt) = {
require(idx.getWidth == log2Up(num), "error idx width")
(tag === tagClip(addr)) && vs(idx)
}
def genEntries(addr: UInt, data: UInt, level: UInt): PtwEntries = {
require((data.getWidth / XLEN) == num,
"input data length must be multiple of pte length")
val ps = Wire(new PtwEntries(num, tagLen))
ps.tag := tagClip(addr)
for (i <- 0 until num) {
val pte = data((i+1)*XLEN-1, i*XLEN).asTypeOf(new PteBundle)
ps.ppns(i) := pte.ppn
ps.vs(i) := !pte.isPf(level) && !pte.isLeaf()
}
ps
}
def get(idx: UInt) = {
require(idx.getWidth == log2Up(num), "error idx width")
(vs(idx), ppns(idx))
}
override def cloneType: this.type = (new PtwEntries(num, tagLen)).asInstanceOf[this.type]
}
class PtwReq extends PtwBundle {
val vpn = UInt(vpnLen.W)
......@@ -161,7 +217,6 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
val req = RegEnable(arb.io.out.bits, arb.io.out.fire())
val resp = VecInit(io.tlb.map(_.resp))
val valid = ValidHold(arb.io.out.fire(), resp(arbChosen).fire())
val validOneCycle = OneCycleValid(arb.io.out.fire())
arb.io.out.ready := !valid// || resp(arbChosen).fire()
......@@ -174,28 +229,21 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
// two level: l2-tlb-cache && pde/pte-cache
// l2-tlb-cache is ram-larger-edition tlb
// pde/pte-cache is cache of page-table, speeding up ptw
// may seperate valid bits to speed up sfence's flush
// Reg/Mem/SyncReadMem is not sure now
val tagLen1 = PAddrBits - log2Up(XLEN/8)
val tagLen2 = PAddrBits - log2Up(XLEN/8) - log2Up(PtwL2EntrySize)
// val tlbl2 = SyncReadMem(TlbL2EntrySize, new TlbEntry)
val tlbl2 = Module(new SRAMTemplate(new TlbEntry, set = TlbL2EntrySize))
val tlbv = RegInit(0.U(TlbL2EntrySize.W)) // valid
val tlbg = Reg(UInt(TlbL2EntrySize.W)) // global
val ptwl1 = Reg(Vec(PtwL1EntrySize, new PtwEntry(tagLen = tagLen1)))
val tlbl2 = Module(new SRAMTemplate(new TlbEntires(num = TlbL2LineSize, tagLen = TlbL2TagLen), set = TlbL2LineNum)) // (total 256, one line is 4 => 64 lines)
val tlbv = RegInit(0.U(TlbL2LineNum.W)) // valid
val tlbg = Reg(UInt(TlbL2LineNum.W)) // global
val ptwl1 = Reg(Vec(PtwL1EntrySize, new PtwEntry(tagLen = PtwL1TagLen)))
val l1v = RegInit(0.U(PtwL1EntrySize.W)) // valid
// val l1g = VecInit((ptwl1.map(_.perm.g))).asUInt
val l1g = Reg(UInt(PtwL1EntrySize.W))
// val ptwl2 = SyncReadMem(PtwL2EntrySize, new PtwEntry(tagLen = tagLen2)) // NOTE: the Mem could be only single port(r&w)
val ptwl2 = Module(new SRAMTemplate(new PtwEntry(tagLen = tagLen2), set = PtwL2EntrySize))
val l2v = RegInit(0.U(PtwL2EntrySize.W)) // valid
val l2g = Reg(UInt(PtwL2EntrySize.W)) // global
val ptwl2 = Module(new SRAMTemplate(new PtwEntries(num = PtwL2LineSize, tagLen = PtwL2TagLen), set = PtwL2LineNum)) // (total 256, one line is 4 => 64 lines)
val l2v = RegInit(0.U(PtwL2LineNum.W)) // valid
val l2g = Reg(UInt(PtwL2LineNum.W)) // global
// mem alias
// val memRdata = mem.d.bits.data
val memRdata = Wire(UInt(XLEN.W))
val memPte = memRdata.asTypeOf(new PteBundle)
val memRdata = mem.d.bits.data
val memSelData = Wire(UInt(XLEN.W))
val memPte = memSelData.asTypeOf(new PteBundle)
val memPtes =(0 until TlbL2LineSize).map(i => memRdata((i+1)*XLEN-1, i*XLEN).asTypeOf(new PteBundle))
val memValid = mem.d.valid
val memRespReady = mem.d.ready
val memRespFire = mem.d.fire()
......@@ -214,26 +262,24 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
* tlbl2
*/
val (tlbHit, tlbHitData) = {
// tlbl2 is by addr
// TODO: optimize tlbl2'l2 tag len
assert(tlbl2.io.r.req.ready)
val ridx = genTlbL2Idx(req.vpn)
val vidx = RegEnable(tlbv(ridx), validOneCycle)
tlbl2.io.r.req.valid := validOneCycle
tlbl2.io.r.req.bits.apply(setIdx = req.vpn(log2Up(TlbL2EntrySize-1), 0))
tlbl2.io.r.req.bits.apply(setIdx = ridx)
val ramData = tlbl2.io.r.resp.data(0)
// val ramData = tlbl2.r(req.vpn(log2Up(TlbL2EntrySize)-1, 0), validOneCycle)
val vidx = RegEnable(tlbv(req.vpn(log2Up(TlbL2EntrySize)-1, 0)), validOneCycle)
(ramData.hit(req.vpn) && vidx, ramData) // TODO: optimize tag
// TODO: add exception and refill
(ramData.hit(req.vpn) && vidx, ramData.get(req.vpn))
}
/*
* ptwl1
*/
val l1addr = MakeAddr(satp.ppn, getVpnn(req.vpn, 2))
val (l1Hit, l1HitData) = { // TODO: add excp
// 16 terms may casue long latency, so divide it into 2 stage, like l2tlb
val (l1Hit, l1HitData) = {
val hitVecT = ptwl1.zipWithIndex.map{case (a,b) => a.hit(l1addr) && l1v(b) }
val hitVec = hitVecT.map(RegEnable(_, validOneCycle)) // TODO: could have useless init value
val hitVec = hitVecT.map(RegEnable(_, validOneCycle))
val hitData = ParallelMux(hitVec zip ptwl1)
val hit = ParallelOR(hitVec).asBool
(hit, hitData)
......@@ -245,17 +291,18 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
val l1MemBack = memRespFire && state===state_wait_resp && level===0.U
val l1Res = Mux(l1Hit, l1HitData.ppn, RegEnable(memPte.ppn, l1MemBack))
val l2addr = MakeAddr(l1Res, getVpnn(req.vpn, 1))
val (l2Hit, l2HitData) = { // TODO: add excp
val (l2Hit, l2HitPPN) = {
val readRam = (l1Hit && level===0.U && state===state_req) || (memRespFire && state===state_wait_resp && level===0.U)
val ridx = l2addr(log2Up(PtwL2EntrySize)-1+log2Up(XLEN/8), log2Up(XLEN/8))
val ridx = genPtwL2Idx(l2addr)
val idx = RegEnable(l2addr(log2Up(PtwL2LineSize)+log2Up(XLEN/8)-1, log2Up(XLEN/8)), readRam)
val vidx = RegEnable(l2v(ridx), readRam)
assert(ptwl2.io.r.req.ready)
ptwl2.io.r.req.valid := readRam
ptwl2.io.r.req.bits.apply(setIdx = ridx)
val ramData = ptwl2.io.r.resp.data(0)
// val ramData = ptwl2.read(ridx, readRam)
val vidx = RegEnable(l2v(ridx), readRam)
(ramData.hit(l2addr) && vidx, ramData) // TODO: optimize tag
(ramData.hit(idx, l2addr) && vidx, ramData.get(idx)._2) // TODO: optimize tag
}
/* ptwl3
......@@ -264,7 +311,7 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
* if l2-tlb does not hit, ptwl3 would not hit (mostly)
*/
val l2MemBack = memRespFire && state===state_wait_resp && level===1.U
val l2Res = Mux(l2Hit, l2HitData.ppn, RegEnable(memPte.ppn, l2MemBack))
val l2Res = Mux(l2Hit, l2HitPPN, RegEnable(memPte.ppn, l2MemBack))
val l3addr = MakeAddr(l2Res, getVpnn(req.vpn, 0))
/*
......@@ -347,16 +394,19 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
mem.d.ready := state === state_wait_resp || sfenceLatch
val memAddrLatch = RegEnable(memAddr, mem.a.valid)
memRdata := (mem.d.bits.data >> (memAddrLatch(log2Up(l1BusDataWidth/8) - 1, log2Up(XLEN/8)) << log2Up(XLEN)))(XLEN - 1, 0)
memSelData := memRdata.asTypeOf(Vec(MemBandWidth/XLEN, UInt(XLEN.W)))(memAddrLatch(log2Up(l1BusDataWidth/8) - 1, log2Up(XLEN/8)))
/*
* resp
*/
val ptwFinish = (state===state_req && tlbHit && level===0.U) || ((memPte.isLeaf() || memPte.isPf(level) || (!memPte.isLeaf() && level===2.U)) && memRespFire && !sfenceLatch) || state===state_wait_ready
val ptwFinish = (state===state_req && tlbHit && level===0.U) ||
((memPte.isLeaf() || memPte.isPf(level) ||
(!memPte.isLeaf() && level===2.U)) && memRespFire && !sfenceLatch) ||
state===state_wait_ready
for(i <- 0 until PtwWidth) {
resp(i).valid := valid && arbChosen===i.U && ptwFinish // TODO: add resp valid logic
resp(i).bits.entry := Mux(tlbHit, tlbHitData,
Mux(state===state_wait_ready, latch.entry, new TlbEntry().genTlbEntry(memRdata, Mux(level===3.U, 2.U, level), req.vpn)))
Mux(state===state_wait_ready, latch.entry, new TlbEntry().genTlbEntry(memSelData, Mux(level===3.U, 2.U, level), req.vpn)))
resp(i).bits.pf := Mux(level===3.U || notFound, true.B, Mux(tlbHit, false.B, Mux(state===state_wait_ready, latch.pf, memPte.isPf(level))))
// TODO: the pf must not be correct, check it
}
......@@ -372,30 +422,40 @@ class PTWImp(outer: PTW) extends PtwModule(outer){
when (memRespFire && !memPte.isPf(level) && !sfenceLatch) {
when (level===0.U && !memPte.isLeaf) {
val refillIdx = LFSR64()(log2Up(PtwL1EntrySize)-1,0) // TODO: may be LRU
ptwl1(refillIdx).refill(l1addr, memRdata)
ptwl1(refillIdx).refill(l1addr, memSelData)
l1v := l1v | UIntToOH(refillIdx)
l1g := (l1g & ~UIntToOH(refillIdx)) | Mux(memPte.perm.g, UIntToOH(refillIdx), 0.U)
}
when (level===1.U && !memPte.isLeaf) {
val l2addrStore = RegEnable(l2addr, memReqFire && state===state_req && level===1.U)
val refillIdx = getVpnn(req.vpn, 1)(log2Up(PtwL2EntrySize)-1, 0)
val refillIdx = genPtwL2Idx(l2addrStore) //getVpnn(req.vpn, 1)(log2Up(PtwL2EntrySize)-1, 0)
//TODO: check why the old refillIdx is right
assert(ptwl2.io.w.req.ready)
// ptwl2.io.w.req.valid := true.B
ptwl2.io.w.apply(valid = true.B, setIdx = refillIdx, data = new PtwEntry(tagLen2).genPtwEntry(l2addrStore, memRdata), waymask = -1.S.asUInt)
// ptwl2.write(refillIdx, new PtwEntry(tagLen2).genPtwEntry(l2addrStore, memRdata))
val ps = new PtwEntries(PtwL2LineSize, PtwL2TagLen).genEntries(l2addrStore, memRdata, level)
ptwl2.io.w.apply(
valid = true.B,
setIdx = refillIdx,
data = ps,
waymask = -1.S.asUInt
)
l2v := l2v | UIntToOH(refillIdx)
l2g := (l2g & ~UIntToOH(refillIdx)) | Mux(memPte.perm.g, UIntToOH(refillIdx), 0.U)
l2g := (l2g & ~UIntToOH(refillIdx)) | Mux(Cat(memPtes.map(_.perm.g)).andR, UIntToOH(refillIdx), 0.U)
}
when (memPte.isLeaf()) {
val refillIdx = getVpnn(req.vpn, 0)(log2Up(TlbL2EntrySize)-1, 0)
val refillIdx = genTlbL2Idx(req.vpn)//getVpnn(req.vpn, 0)(log2Up(TlbL2EntrySize)-1, 0)
//TODO: check why the old refillIdx is right
assert(tlbl2.io.w.req.ready)
// tlbl2.io.w.req.valid := true.B
tlbl2.io.w.apply(valid = true.B, setIdx = refillIdx, data = new TlbEntry().genTlbEntry(memRdata, level, req.vpn), waymask = -1.S.asUInt)
// tlbl2.write(refillIdx, new TlbEntry().genTlbEntry(memRdata, level, req.vpn))
val ts = new TlbEntires(num = TlbL2LineSize, tagLen = TlbL2TagLen).genEntries(memRdata, level, req.vpn)
tlbl2.io.w.apply(
valid = true.B,
setIdx = refillIdx,
data = ts,
waymask = -1.S.asUInt
)
tlbv := tlbv | UIntToOH(refillIdx)
tlbg := (tlbg & ~UIntToOH(refillIdx)) | Mux(memPte.perm.g, UIntToOH(refillIdx), 0.U)
tlbg := (tlbg & ~UIntToOH(refillIdx)) | Mux(Cat(memPtes.map(_.perm.g)).andR, UIntToOH(refillIdx), 0.U)
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册