未验证 提交 69a65c2b 编写于 作者: L ljw 提交者: GitHub

Merge pull request #380 from RISCVERS/hardfloat

Use hardfloat instead xs-fpu
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
[submodule "block-inclusivecache-sifive"] [submodule "block-inclusivecache-sifive"]
path = block-inclusivecache-sifive path = block-inclusivecache-sifive
url = https://github.com/RISCVERS/block-inclusivecache-sifive.git url = https://github.com/RISCVERS/block-inclusivecache-sifive.git
branch = 5491dcc937ed3c6f7722bef9db448653daab75e8 branch = 0315ccf27963d7fe4b5e850c709fb66298f8390c
[submodule "chiseltest"] [submodule "chiseltest"]
path = chiseltest path = chiseltest
url = https://github.com/ucb-bar/chisel-testers2.git url = https://github.com/ucb-bar/chisel-testers2.git
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
url = https://github.com/chipsalliance/api-config-chipsalliance url = https://github.com/chipsalliance/api-config-chipsalliance
[submodule "berkeley-hardfloat"] [submodule "berkeley-hardfloat"]
path = berkeley-hardfloat path = berkeley-hardfloat
url = https://github.com/ucb-bar/berkeley-hardfloat url = https://github.com/RISCVERS/berkeley-hardfloat.git
branch = 759d99d90dc119d071a5fdbf35ee4578d0613a2f
Subproject commit 267357bdae5973a30565da6ebc728d513827ca5e Subproject commit 759d99d90dc119d071a5fdbf35ee4578d0613a2f
...@@ -4,7 +4,6 @@ import chisel3._ ...@@ -4,7 +4,6 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan.backend.SelImm import xiangshan.backend.SelImm
import xiangshan.backend.brq.BrqPtr import xiangshan.backend.brq.BrqPtr
import xiangshan.backend.fu.fpu.Fflags
import xiangshan.backend.rename.FreeListPtr import xiangshan.backend.rename.FreeListPtr
import xiangshan.backend.roq.RoqPtr import xiangshan.backend.roq.RoqPtr
import xiangshan.backend.decode.XDecode import xiangshan.backend.decode.XDecode
...@@ -183,6 +182,23 @@ class CtrlFlow extends XSBundle { ...@@ -183,6 +182,23 @@ class CtrlFlow extends XSBundle {
val crossPageIPFFix = Bool() val crossPageIPFFix = Bool()
} }
class FPUCtrlSignals extends XSBundle {
val isAddSub = Bool() // swap23
val typeTagIn = UInt(2.W)
val typeTagOut = UInt(2.W)
val fromInt = Bool()
val wflags = Bool()
val fpWen = Bool()
val fmaCmd = UInt(2.W)
val div = Bool()
val sqrt = Bool()
val fcvt = Bool()
val typ = UInt(2.W)
val fmt = UInt(2.W)
val ren3 = Bool() //TODO: remove SrcType.fp
}
// Decode DecodeWidth insts at Decode Stage // Decode DecodeWidth insts at Decode Stage
class CtrlSignals extends XSBundle { class CtrlSignals extends XSBundle {
val src1Type, src2Type, src3Type = SrcType() val src1Type, src2Type, src3Type = SrcType()
...@@ -200,6 +216,7 @@ class CtrlSignals extends XSBundle { ...@@ -200,6 +216,7 @@ class CtrlSignals extends XSBundle {
val selImm = SelImm() val selImm = SelImm()
val imm = UInt(XLEN.W) val imm = UInt(XLEN.W)
val commitType = CommitType() val commitType = CommitType()
val fpu = new FPUCtrlSignals
def decode(inst: UInt, table: Iterable[(BitPat, List[BitPat])]) = { def decode(inst: UInt, table: Iterable[(BitPat, List[BitPat])]) = {
val decoder = freechips.rocketchip.rocket.DecodeLogic(inst, XDecode.decodeDefault, table) val decoder = freechips.rocketchip.rocket.DecodeLogic(inst, XDecode.decodeDefault, table)
...@@ -271,7 +288,7 @@ class ExuInput extends XSBundle { ...@@ -271,7 +288,7 @@ class ExuInput extends XSBundle {
class ExuOutput extends XSBundle { class ExuOutput extends XSBundle {
val uop = new MicroOp val uop = new MicroOp
val data = UInt((XLEN+1).W) val data = UInt((XLEN+1).W)
val fflags = new Fflags val fflags = UInt(5.W)
val redirectValid = Bool() val redirectValid = Bool()
val redirect = new Redirect val redirect = new Redirect
val brUpdate = new CfiUpdateInfo val brUpdate = new CfiUpdateInfo
...@@ -297,6 +314,7 @@ class RoqCommitInfo extends XSBundle { ...@@ -297,6 +314,7 @@ class RoqCommitInfo extends XSBundle {
val ldest = UInt(5.W) val ldest = UInt(5.W)
val rfWen = Bool() val rfWen = Bool()
val fpWen = Bool() val fpWen = Bool()
val wflags = Bool()
val commitType = CommitType() val commitType = CommitType()
val pdest = UInt(PhyRegIdxWidth.W) val pdest = UInt(PhyRegIdxWidth.W)
val old_pdest = UInt(PhyRegIdxWidth.W) val old_pdest = UInt(PhyRegIdxWidth.W)
......
...@@ -10,13 +10,14 @@ import xiangshan.backend.exu.Exu._ ...@@ -10,13 +10,14 @@ import xiangshan.backend.exu.Exu._
import xiangshan.frontend._ import xiangshan.frontend._
import xiangshan.mem._ import xiangshan.mem._
import xiangshan.backend.fu.HasExceptionNO import xiangshan.backend.fu.HasExceptionNO
import xiangshan.cache.{ICache, DCache, L1plusCache, DCacheParameters, ICacheParameters, L1plusCacheParameters, PTW, Uncache} import xiangshan.cache.{DCache, DCacheParameters, ICache, ICacheParameters, L1plusCache, L1plusCacheParameters, PTW, Uncache}
import chipsalliance.rocketchip.config import chipsalliance.rocketchip.config
import freechips.rocketchip.diplomacy.{LazyModule, LazyModuleImp, AddressSet} import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp}
import freechips.rocketchip.tilelink.{TLBundleParameters, TLCacheCork, TLBuffer, TLClientNode, TLIdentityNode, TLXbar, TLWidthWidget, TLFilter, TLToAXI4} import freechips.rocketchip.tilelink.{TLBuffer, TLBundleParameters, TLCacheCork, TLClientNode, TLFilter, TLIdentityNode, TLToAXI4, TLWidthWidget, TLXbar}
import freechips.rocketchip.devices.tilelink.{TLError, DevNullParams} import freechips.rocketchip.devices.tilelink.{DevNullParams, TLError}
import sifive.blocks.inclusivecache.{CacheParameters, InclusiveCache, InclusiveCacheMicroParameters} import sifive.blocks.inclusivecache.{CacheParameters, InclusiveCache, InclusiveCacheMicroParameters}
import freechips.rocketchip.amba.axi4.{AXI4ToTL, AXI4IdentityNode, AXI4UserYanker, AXI4Fragmenter, AXI4IdIndexer, AXI4Deinterleaver} import freechips.rocketchip.amba.axi4.{AXI4Deinterleaver, AXI4Fragmenter, AXI4IdIndexer, AXI4IdentityNode, AXI4ToTL, AXI4UserYanker}
import freechips.rocketchip.tile.HasFPUParameters
import utils._ import utils._
case class XSCoreParameters case class XSCoreParameters
...@@ -96,7 +97,10 @@ trait HasXSParameter { ...@@ -96,7 +97,10 @@ trait HasXSParameter {
val core = Parameters.get.coreParameters val core = Parameters.get.coreParameters
val env = Parameters.get.envParameters val env = Parameters.get.envParameters
val XLEN = core.XLEN val XLEN = 64
val minFLen = 32
val fLen = 64
def xLen = 64
val HasMExtension = core.HasMExtension val HasMExtension = core.HasMExtension
val HasCExtension = core.HasCExtension val HasCExtension = core.HasCExtension
val HasDiv = core.HasDiv val HasDiv = core.HasDiv
...@@ -214,6 +218,7 @@ abstract class XSModule extends MultiIOModule ...@@ -214,6 +218,7 @@ abstract class XSModule extends MultiIOModule
with HasXSParameter with HasXSParameter
with HasExceptionNO with HasExceptionNO
with HasXSLog with HasXSLog
with HasFPUParameters
{ {
def io: Record def io: Record
} }
......
...@@ -8,7 +8,6 @@ import xiangshan.backend.exu.{AluExeUnit, ExuConfig, JumpExeUnit, MulDivExeUnit, ...@@ -8,7 +8,6 @@ import xiangshan.backend.exu.{AluExeUnit, ExuConfig, JumpExeUnit, MulDivExeUnit,
import xiangshan.backend.fu.FenceToSbuffer import xiangshan.backend.fu.FenceToSbuffer
import xiangshan.backend.issue.{ReservationStationCtrl, ReservationStationData} import xiangshan.backend.issue.{ReservationStationCtrl, ReservationStationData}
import xiangshan.backend.regfile.Regfile import xiangshan.backend.regfile.Regfile
import xiangshan.backend.fu.fpu.Fflags
class WakeUpBundle(numFast: Int, numSlow: Int) extends XSBundle { class WakeUpBundle(numFast: Int, numSlow: Int) extends XSBundle {
val fastUops = Vec(numFast, Flipped(ValidIO(new MicroOp))) val fastUops = Vec(numFast, Flipped(ValidIO(new MicroOp)))
...@@ -72,7 +71,7 @@ class IntegerBlock ...@@ -72,7 +71,7 @@ class IntegerBlock
val wakeUpIntOut = Flipped(new WakeUpBundle(fastIntOut.size, slowIntOut.size)) val wakeUpIntOut = Flipped(new WakeUpBundle(fastIntOut.size, slowIntOut.size))
val csrio = new Bundle { val csrio = new Bundle {
val fflags = Input(new Fflags) // from roq val fflags = Flipped(Valid(UInt(5.W))) // from roq
val dirty_fs = Input(Bool()) // from roq val dirty_fs = Input(Bool()) // from roq
val frm = Output(UInt(3.W)) // to float val frm = Output(UInt(3.W)) // to float
val exception = Flipped(ValidIO(new MicroOp)) // from roq val exception = Flipped(ValidIO(new MicroOp)) // from roq
......
...@@ -73,7 +73,8 @@ class MemBlock ...@@ -73,7 +73,8 @@ class MemBlock
atomicsUnit.io.out.ready := ldOut0.ready atomicsUnit.io.out.ready := ldOut0.ready
loadUnits.head.io.ldout.ready := ldOut0.ready loadUnits.head.io.ldout.ready := ldOut0.ready
val exeWbReqs = ldOut0 +: loadUnits.tail.map(_.io.ldout) val intExeWbReqs = ldOut0 +: loadUnits.tail.map(_.io.ldout)
val fpExeWbReqs = loadUnits.map(_.io.fpout)
val reservationStations = (loadExuConfigs ++ storeExuConfigs).zipWithIndex.map({ case (cfg, i) => val reservationStations = (loadExuConfigs ++ storeExuConfigs).zipWithIndex.map({ case (cfg, i) =>
var certainLatency = -1 var certainLatency = -1
...@@ -90,7 +91,7 @@ class MemBlock ...@@ -90,7 +91,7 @@ class MemBlock
.map(_._2.bits.data) .map(_._2.bits.data)
val wakeupCnt = writeBackData.length val wakeupCnt = writeBackData.length
val inBlockListenPorts = exeWbReqs val inBlockListenPorts = intExeWbReqs ++ fpExeWbReqs
val extraListenPorts = inBlockListenPorts ++ val extraListenPorts = inBlockListenPorts ++
slowWakeUpIn.zip(io.wakeUpIn.slow) slowWakeUpIn.zip(io.wakeUpIn.slow)
.filter(x => (x._1.writeIntRf && readIntRf) || (x._1.writeFpRf && readFpRf)) .filter(x => (x._1.writeIntRf && readIntRf) || (x._1.writeFpRf && readFpRf))
...@@ -139,20 +140,12 @@ class MemBlock ...@@ -139,20 +140,12 @@ class MemBlock
io.wakeUpIn.fast.foreach(_.ready := true.B) io.wakeUpIn.fast.foreach(_.ready := true.B)
io.wakeUpIn.slow.foreach(_.ready := true.B) io.wakeUpIn.slow.foreach(_.ready := true.B)
io.wakeUpFpOut.slow <> exeWbReqs.map(x => { io.wakeUpFpOut.slow <> fpExeWbReqs
val raw = WireInit(x) io.wakeUpIntOut.slow <> intExeWbReqs
raw.valid := x.valid && x.bits.uop.ctrl.fpWen
raw
})
io.wakeUpIntOut.slow <> exeWbReqs.map(x => {
val raw = WireInit(x)
raw.valid := x.valid && x.bits.uop.ctrl.rfWen
raw
})
// load always ready // load always ready
exeWbReqs.foreach(_.ready := true.B) fpExeWbReqs.foreach(_.ready := true.B)
intExeWbReqs.foreach(_.ready := true.B)
val dtlb = Module(new TLB(Width = DTLBWidth, isDtlb = true)) val dtlb = Module(new TLB(Width = DTLBWidth, isDtlb = true))
val lsq = Module(new LsqWrappper) val lsq = Module(new LsqWrappper)
...@@ -185,16 +178,22 @@ class MemBlock ...@@ -185,16 +178,22 @@ class MemBlock
// StoreUnit // StoreUnit
for (i <- 0 until exuParameters.StuCnt) { for (i <- 0 until exuParameters.StuCnt) {
storeUnits(i).io.redirect <> io.fromCtrlBlock.redirect val stu = storeUnits(i)
storeUnits(i).io.tlbFeedback <> reservationStations(exuParameters.LduCnt + i).io.feedback val rs = reservationStations(exuParameters.LduCnt + i)
storeUnits(i).io.dtlb <> dtlb.io.requestor(exuParameters.LduCnt + i) val dtlbReq = dtlb.io.requestor(exuParameters.LduCnt + i)
// get input form dispatch
storeUnits(i).io.stin <> reservationStations(exuParameters.LduCnt + i).io.deq stu.io.redirect <> io.fromCtrlBlock.redirect
// passdown to lsq stu.io.tlbFeedback <> rs.io.feedback
storeUnits(i).io.lsq <> lsq.io.storeIn(i) stu.io.dtlb <> dtlbReq
io.toCtrlBlock.stOut(i).valid := storeUnits(i).io.stout.valid
io.toCtrlBlock.stOut(i).bits := storeUnits(i).io.stout.bits // get input from dispatch
storeUnits(i).io.stout.ready := true.B stu.io.stin <> rs.io.deq
// passdown to lsq
stu.io.lsq <> lsq.io.storeIn(i)
io.toCtrlBlock.stOut(i).valid := stu.io.stout.valid
io.toCtrlBlock.stOut(i).bits := stu.io.stout.bits
stu.io.stout.ready := true.B
} }
// mmio store writeback will use store writeback port 0 // mmio store writeback will use store writeback port 0
...@@ -296,4 +295,4 @@ class MemBlock ...@@ -296,4 +295,4 @@ class MemBlock
lsq.io.exceptionAddr.isStore := io.lsqio.exceptionAddr.isStore lsq.io.exceptionAddr.isStore := io.lsqio.exceptionAddr.isStore
io.lsqio.exceptionAddr.vaddr := Mux(atomicsUnit.io.exceptionAddr.valid, atomicsUnit.io.exceptionAddr.bits, lsq.io.exceptionAddr.vaddr) io.lsqio.exceptionAddr.vaddr := Mux(atomicsUnit.io.exceptionAddr.valid, atomicsUnit.io.exceptionAddr.bits, lsq.io.exceptionAddr.vaddr)
} }
\ No newline at end of file
...@@ -17,7 +17,6 @@ import xiangshan._ ...@@ -17,7 +17,6 @@ import xiangshan._
import utils._ import utils._
import xiangshan.backend._ import xiangshan.backend._
import xiangshan.backend.decode.Instructions._ import xiangshan.backend.decode.Instructions._
import xiangshan.backend.fu.fpu.FPUOpType
import freechips.rocketchip.tile.RocketTile import freechips.rocketchip.tile.RocketTile
/** /**
...@@ -199,81 +198,81 @@ object XDecode extends DecodeConstants { ...@@ -199,81 +198,81 @@ object XDecode extends DecodeConstants {
object FDecode extends DecodeConstants{ object FDecode extends DecodeConstants{
val table: Array[(BitPat, List[BitPat])] = Array( val table: Array[(BitPat, List[BitPat])] = Array(
FLW -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.ldu, LSUOpType.flw, N, Y, N, N, N, N, Y, SelImm.IMM_I), FLW -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.ldu, LSUOpType.lw, N, Y, N, N, N, N, Y, SelImm.IMM_I),
FLD -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.ldu, LSUOpType.ld, N, Y, N, N, N, N, N, SelImm.IMM_I), FLD -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.ldu, LSUOpType.ld, N, Y, N, N, N, N, N, SelImm.IMM_I),
FSW -> List(SrcType.reg, SrcType.fp, SrcType.DC, FuType.stu, LSUOpType.sw, N, N, N, N, N, N, Y, SelImm.IMM_S), FSW -> List(SrcType.reg, SrcType.fp, SrcType.DC, FuType.stu, LSUOpType.sw, N, N, N, N, N, N, Y, SelImm.IMM_S),
FSD -> List(SrcType.reg, SrcType.fp, SrcType.DC, FuType.stu, LSUOpType.sd, N, N, N, N, N, N, N, SelImm.IMM_S), FSD -> List(SrcType.reg, SrcType.fp, SrcType.DC, FuType.stu, LSUOpType.sd, N, N, N, N, N, N, N, SelImm.IMM_S),
FCLASS_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fclass, Y, N, N, N, N, N, Y, SelImm.IMM_X), FCLASS_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FCLASS_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fclass, Y, N, N, N, N, N, N, SelImm.IMM_X), FCLASS_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FMV_D_X -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.fmv_i2f, N, Y, N, N, N, N, N, SelImm.IMM_X), FMV_D_X -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FMV_X_D -> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fmv_f2i, Y, N, N, N, N, N, N, SelImm.IMM_X), FMV_X_D -> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FMV_X_W -> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fmv_f2i, Y, N, N, N, N, N, Y, SelImm.IMM_X), FMV_X_W -> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FMV_W_X -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.fmv_i2f, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMV_W_X -> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSGNJ_S -> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnj, N, Y, N, N, N, N, Y, SelImm.IMM_X), FSGNJ_S -> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSGNJ_D -> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnj, N, Y, N, N, N, N, N, SelImm.IMM_X), FSGNJ_D -> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FSGNJX_S-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnjx, N, Y, N, N, N, N, Y, SelImm.IMM_X), FSGNJX_S-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSGNJX_D-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnjx, N, Y, N, N, N, N, N, SelImm.IMM_X), FSGNJX_D-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FSGNJN_S-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnjn, N, Y, N, N, N, N, Y, SelImm.IMM_X), FSGNJN_S-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSGNJN_D-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fsgnjn, N, Y, N, N, N, N, N, SelImm.IMM_X), FSGNJN_D-> List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
// FP to FP // FP to FP
FCVT_S_D-> List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.d2s, N, Y, N, N, N, N, Y, SelImm.IMM_X), FCVT_S_D-> List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FCVT_D_S-> List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.s2d, N, Y, N, N, N, N, N, SelImm.IMM_X), FCVT_D_S-> List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
// Int to FP // Int to FP
FCVT_S_W-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.w2f, N, Y, N, N, N, N, Y, SelImm.IMM_X), FCVT_S_W-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FCVT_S_WU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.wu2f, N, Y, N, N, N, N, Y, SelImm.IMM_X), FCVT_S_WU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FCVT_S_L-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.l2f, N, Y, N, N, N, N, Y, SelImm.IMM_X), FCVT_S_L-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FCVT_S_LU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.lu2f, N, Y, N, N, N, N, Y, SelImm.IMM_X), FCVT_S_LU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FCVT_D_W-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.w2f, N, Y, N, N, N, N, N, SelImm.IMM_X), FCVT_D_W-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FCVT_D_WU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.wu2f, N, Y, N, N, N, N, N, SelImm.IMM_X), FCVT_D_WU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FCVT_D_L-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.l2f, N, Y, N, N, N, N, N, SelImm.IMM_X), FCVT_D_L-> List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FCVT_D_LU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, FPUOpType.lu2f, N, Y, N, N, N, N, N, SelImm.IMM_X), FCVT_D_LU->List(SrcType.reg, SrcType.imm, SrcType.DC, FuType.i2f, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
// FP to Int // FP to Int
FCVT_W_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2w, Y, N, N, N, N, N, Y, SelImm.IMM_X), FCVT_W_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FCVT_WU_S->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2wu, Y, N, N, N, N, N, Y, SelImm.IMM_X), FCVT_WU_S->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FCVT_L_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2l, Y, N, N, N, N, N, Y, SelImm.IMM_X), FCVT_L_S-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FCVT_LU_S->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2lu, Y, N, N, N, N, N, Y, SelImm.IMM_X), FCVT_LU_S->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FCVT_W_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2w, Y, N, N, N, N, N, N, SelImm.IMM_X), FCVT_W_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FCVT_WU_D->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2wu, Y, N, N, N, N, N, N, SelImm.IMM_X), FCVT_WU_D->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FCVT_L_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2l, Y, N, N, N, N, N, N, SelImm.IMM_X), FCVT_L_D-> List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FCVT_LU_D->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.f2lu, Y, N, N, N, N, N, N, SelImm.IMM_X), FCVT_LU_D->List(SrcType.fp , SrcType.imm, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
// "fp_single" is used for wb_data formatting (and debugging) // "fp_single" is used for wb_data formatting (and debugging)
FEQ_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.feq, Y, N, N, N, N, N, Y, SelImm.IMM_X), FEQ_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FLT_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.flt, Y, N, N, N, N, N, Y, SelImm.IMM_X), FLT_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FLE_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fle, Y, N, N, N, N, N, Y, SelImm.IMM_X), FLE_S ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, Y, SelImm.IMM_X),
FEQ_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.feq, Y, N, N, N, N, N, N, SelImm.IMM_X), FEQ_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FLT_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.flt, Y, N, N, N, N, N, N, SelImm.IMM_X), FLT_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FLE_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fle, Y, N, N, N, N, N, N, SelImm.IMM_X), FLE_D ->List(SrcType.fp , SrcType.fp, SrcType.DC, FuType.fmisc, X, Y, N, N, N, N, N, N, SelImm.IMM_X),
FMIN_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fmin, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMIN_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FMAX_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fmax, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMAX_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FMIN_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fmin, N, Y, N, N, N, N, N, SelImm.IMM_X), FMIN_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FMAX_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fmax, N, Y, N, N, N, N, N, SelImm.IMM_X), FMAX_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FADD_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fadd, N, Y, N, N, N, N, Y, SelImm.IMM_X), FADD_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fsub, N, Y, N, N, N, N, Y, SelImm.IMM_X), FSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FMUL_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fmul, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMUL_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FADD_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fadd, N, Y, N, N, N, N, N, SelImm.IMM_X), FADD_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fsub, N, Y, N, N, N, N, N, SelImm.IMM_X), FSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FMUL_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, FPUOpType.fmul, N, Y, N, N, N, N, N, SelImm.IMM_X), FMUL_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FMADD_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fmadd, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMADD_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FMSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fmsub, N, Y, N, N, N, N, Y, SelImm.IMM_X), FMSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FNMADD_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fnmadd, N, Y, N, N, N, N, Y, SelImm.IMM_X), FNMADD_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FNMSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fnmsub, N, Y, N, N, N, N, Y, SelImm.IMM_X), FNMSUB_S ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FMADD_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fmadd, N, Y, N, N, N, N, N, SelImm.IMM_X), FMADD_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FMSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fmsub, N, Y, N, N, N, N, N, SelImm.IMM_X), FMSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FNMADD_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fnmadd, N, Y, N, N, N, N, N, SelImm.IMM_X), FNMADD_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FNMSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, FPUOpType.fnmsub, N, Y, N, N, N, N, N, SelImm.IMM_X) FNMSUB_D ->List(SrcType.fp, SrcType.fp, SrcType.fp, FuType.fmac, X, N, Y, N, N, N, N, N, SelImm.IMM_X)
) )
} }
...@@ -282,10 +281,10 @@ object FDecode extends DecodeConstants{ ...@@ -282,10 +281,10 @@ object FDecode extends DecodeConstants{
*/ */
object FDivSqrtDecode extends DecodeConstants { object FDivSqrtDecode extends DecodeConstants {
val table: Array[(BitPat, List[BitPat])] = Array( val table: Array[(BitPat, List[BitPat])] = Array(
FDIV_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fdiv, N, Y, N, N, N, N, Y, SelImm.IMM_X), FDIV_S ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FDIV_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, FPUOpType.fdiv, N, Y, N, N, N, N, N, SelImm.IMM_X), FDIV_D ->List(SrcType.fp, SrcType.fp, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X),
FSQRT_S ->List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fsqrt, N, Y, N, N, N, N, Y, SelImm.IMM_X), FSQRT_S ->List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, Y, SelImm.IMM_X),
FSQRT_D ->List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, FPUOpType.fsqrt, N, Y, N, N, N, N, N, SelImm.IMM_X) FSQRT_D ->List(SrcType.fp, SrcType.imm, SrcType.DC, FuType.fmisc, X, N, Y, N, N, N, N, N, SelImm.IMM_X)
) )
} }
...@@ -375,6 +374,10 @@ class DecodeUnit extends XSModule with DecodeUnitConstants { ...@@ -375,6 +374,10 @@ class DecodeUnit extends XSModule with DecodeUnitConstants {
cf_ctrl.brTag := DontCare cf_ctrl.brTag := DontCare
val cs = Wire(new CtrlSignals()).decode(ctrl_flow.instr, decode_table) val cs = Wire(new CtrlSignals()).decode(ctrl_flow.instr, decode_table)
val fpDecoder = Module(new FPDecoder)
fpDecoder.io.instr := io.enq.ctrl_flow.instr
cs.fpu := fpDecoder.io.fpCtrl
// read src1~3 location // read src1~3 location
cs.lsrc1 := Mux(ctrl_flow.instr === LUI || cs.src1Type === SrcType.pc, 0.U, ctrl_flow.instr(RS1_MSB,RS1_LSB)) cs.lsrc1 := Mux(ctrl_flow.instr === LUI || cs.src1Type === SrcType.pc, 0.U, ctrl_flow.instr(RS1_MSB,RS1_LSB))
cs.lsrc2 := ctrl_flow.instr(RS2_MSB,RS2_LSB) cs.lsrc2 := ctrl_flow.instr(RS2_MSB,RS2_LSB)
......
package xiangshan.backend.decode
import chisel3._
import chisel3.util._
import freechips.rocketchip.rocket.DecodeLogic
import xiangshan.backend.decode.Instructions._
import xiangshan.{FPUCtrlSignals, XSModule}
class FPDecoder extends XSModule{
val io = IO(new Bundle() {
val instr = Input(UInt(32.W))
val fpCtrl = Output(new FPUCtrlSignals)
})
def X = BitPat("b?")
def N = BitPat("b0")
def Y = BitPat("b1")
val s = BitPat(S)
val d = BitPat(D)
val default = List(X,X,X,X,N,N,X,X,X)
// isAddSub tagIn tagOut fromInt wflags fpWen div sqrt fcvt
val single: Array[(BitPat, List[BitPat])] = Array(
FMV_W_X -> List(N,s,d,Y,N,Y,N,N,N),
FCVT_S_W -> List(N,s,s,Y,Y,Y,N,N,Y),
FCVT_S_WU-> List(N,s,s,Y,Y,Y,N,N,Y),
FCVT_S_L -> List(N,s,s,Y,Y,Y,N,N,Y),
FCVT_S_LU-> List(N,s,s,Y,Y,Y,N,N,Y),
FMV_X_W -> List(N,s,X,N,N,N,N,N,N),
FCLASS_S -> List(N,s,X,N,N,N,N,N,N),
FCVT_W_S -> List(N,s,X,N,Y,N,N,N,Y),
FCVT_WU_S-> List(N,s,X,N,Y,N,N,N,Y),
FCVT_L_S -> List(N,s,X,N,Y,N,N,N,Y),
FCVT_LU_S-> List(N,s,X,N,Y,N,N,N,Y),
FEQ_S -> List(N,s,X,N,Y,N,N,N,N),
FLT_S -> List(N,s,X,N,Y,N,N,N,N),
FLE_S -> List(N,s,X,N,Y,N,N,N,N),
FSGNJ_S -> List(N,s,s,N,N,Y,N,N,N),
FSGNJN_S -> List(N,s,s,N,N,Y,N,N,N),
FSGNJX_S -> List(N,s,s,N,N,Y,N,N,N),
FMIN_S -> List(N,s,s,N,Y,Y,N,N,N),
FMAX_S -> List(N,s,s,N,Y,Y,N,N,N),
FADD_S -> List(Y,s,s,N,Y,Y,N,N,N),
FSUB_S -> List(Y,s,s,N,Y,Y,N,N,N),
FMUL_S -> List(N,s,s,N,Y,Y,N,N,N),
FMADD_S -> List(N,s,s,N,Y,Y,N,N,N),
FMSUB_S -> List(N,s,s,N,Y,Y,N,N,N),
FNMADD_S -> List(N,s,s,N,Y,Y,N,N,N),
FNMSUB_S -> List(N,s,s,N,Y,Y,N,N,N),
FDIV_S -> List(N,s,s,N,Y,Y,Y,N,N),
FSQRT_S -> List(N,s,s,N,Y,Y,N,Y,N)
)
// isAddSub tagIn tagOut fromInt wflags fpWen div sqrt fcvt
val double: Array[(BitPat, List[BitPat])] = Array(
FMV_D_X -> List(N,d,d,Y,N,Y,N,N,N),
FCVT_D_W -> List(N,d,d,Y,Y,Y,N,N,Y),
FCVT_D_WU-> List(N,d,d,Y,Y,Y,N,N,Y),
FCVT_D_L -> List(N,d,d,Y,Y,Y,N,N,Y),
FCVT_D_LU-> List(N,d,d,Y,Y,Y,N,N,Y),
FMV_X_D -> List(N,d,X,N,N,N,N,N,N),
FCLASS_D -> List(N,d,X,N,N,N,N,N,N),
FCVT_W_D -> List(N,d,X,N,Y,N,N,N,Y),
FCVT_WU_D-> List(N,d,X,N,Y,N,N,N,Y),
FCVT_L_D -> List(N,d,X,N,Y,N,N,N,Y),
FCVT_LU_D-> List(N,d,X,N,Y,N,N,N,Y),
FCVT_S_D -> List(N,d,s,N,Y,Y,N,N,Y),
FCVT_D_S -> List(N,s,d,N,Y,Y,N,N,Y),
FEQ_D -> List(N,d,X,N,Y,N,N,N,N),
FLT_D -> List(N,d,X,N,Y,N,N,N,N),
FLE_D -> List(N,d,X,N,Y,N,N,N,N),
FSGNJ_D -> List(N,d,d,N,N,Y,N,N,N),
FSGNJN_D -> List(N,d,d,N,N,Y,N,N,N),
FSGNJX_D -> List(N,d,d,N,N,Y,N,N,N),
FMIN_D -> List(N,d,d,N,Y,Y,N,N,N),
FMAX_D -> List(N,d,d,N,Y,Y,N,N,N),
FADD_D -> List(Y,d,d,N,Y,Y,N,N,N),
FSUB_D -> List(Y,d,d,N,Y,Y,N,N,N),
FMUL_D -> List(N,d,d,N,Y,Y,N,N,N),
FMADD_D -> List(N,d,d,N,Y,Y,N,N,N),
FMSUB_D -> List(N,d,d,N,Y,Y,N,N,N),
FNMADD_D -> List(N,d,d,N,Y,Y,N,N,N),
FNMSUB_D -> List(N,d,d,N,Y,Y,N,N,N),
FDIV_D -> List(N,d,d,N,Y,Y,Y,N,N),
FSQRT_D -> List(N,d,d,N,Y,Y,N,Y,N)
)
val table = single ++ double
val decoder = DecodeLogic(io.instr, default, table)
val ctrl = io.fpCtrl
val sigs = Seq(
ctrl.isAddSub, ctrl.typeTagIn, ctrl.typeTagOut,
ctrl.fromInt, ctrl.wflags, ctrl.fpWen,
ctrl.div, ctrl.sqrt, ctrl.fcvt
)
sigs.zip(decoder).foreach({case (s, d) => s := d})
ctrl.typ := io.instr(21,20)
ctrl.fmt := io.instr(26,25)
val fmaTable: Array[(BitPat, List[BitPat])] = Array(
FADD_S -> List(BitPat("b00"),N),
FADD_D -> List(BitPat("b00"),N),
FSUB_S -> List(BitPat("b01"),N),
FSUB_D -> List(BitPat("b01"),N),
FMUL_S -> List(BitPat("b00"),N),
FMUL_D -> List(BitPat("b00"),N),
FMADD_S -> List(BitPat("b00"),Y),
FMADD_D -> List(BitPat("b00"),Y),
FMSUB_S -> List(BitPat("b01"),Y),
FMSUB_D -> List(BitPat("b01"),Y),
FNMADD_S-> List(BitPat("b11"),Y),
FNMADD_D-> List(BitPat("b11"),Y),
FNMSUB_S-> List(BitPat("b10"),Y),
FNMSUB_D-> List(BitPat("b10"),Y)
)
val fmaDefault = List(BitPat("b??"), N)
Seq(ctrl.fmaCmd, ctrl.ren3).zip(
DecodeLogic(io.instr, fmaDefault, fmaTable)
).foreach({
case (s, d) => s := d
})
}
...@@ -120,7 +120,7 @@ abstract class Exu(val config: ExuConfig) extends XSModule { ...@@ -120,7 +120,7 @@ abstract class Exu(val config: ExuConfig) extends XSModule {
def writebackArb(in: Seq[DecoupledIO[FuOutput]], out: DecoupledIO[ExuOutput]): Arbiter[FuOutput] = { def writebackArb(in: Seq[DecoupledIO[FuOutput]], out: DecoupledIO[ExuOutput]): Arbiter[FuOutput] = {
if (needArbiter) { if (needArbiter) {
val arb = Module(new Arbiter(new FuOutput, in.size)) val arb = Module(new Arbiter(new FuOutput(in.head.bits.len), in.size))
arb.io.in <> in arb.io.in <> in
arb.io.out.ready := out.ready arb.io.out.ready := out.ready
out.bits.data := arb.io.out.bits.data out.bits.data := arb.io.out.bits.data
...@@ -203,7 +203,7 @@ object Exu { ...@@ -203,7 +203,7 @@ object Exu {
val fmacExeUnitCfg = ExuConfig("FmacExeUnit", Seq(fmacCfg), Int.MaxValue, 0) val fmacExeUnitCfg = ExuConfig("FmacExeUnit", Seq(fmacCfg), Int.MaxValue, 0)
val fmiscExeUnitCfg = ExuConfig( val fmiscExeUnitCfg = ExuConfig(
"FmiscExeUnit", "FmiscExeUnit",
Seq(fcmpCfg, fminCfg, fmvCfg, fsgnjCfg, f2iCfg, s2dCfg, d2sCfg, fdivSqrtCfg), Seq(f2iCfg, f2fCfg, fdivSqrtCfg),
Int.MaxValue, 1 Int.MaxValue, 1
) )
val ldExeUnitCfg = ExuConfig("LoadExu", Seq(lduCfg), wbIntPriority = 0, wbFpPriority = 0) val ldExeUnitCfg = ExuConfig("LoadExu", Seq(lduCfg), wbIntPriority = 0, wbFpPriority = 0)
......
...@@ -4,7 +4,6 @@ import chisel3._ ...@@ -4,7 +4,6 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan.backend.exu.Exu.fmacExeUnitCfg import xiangshan.backend.exu.Exu.fmacExeUnitCfg
import xiangshan.backend.fu.fpu._ import xiangshan.backend.fu.fpu._
import xiangshan.backend.fu.fpu.fma.FMA
class FmacExeUnit extends Exu(fmacExeUnitCfg) class FmacExeUnit extends Exu(fmacExeUnitCfg)
{ {
...@@ -15,15 +14,13 @@ class FmacExeUnit extends Exu(fmacExeUnitCfg) ...@@ -15,15 +14,13 @@ class FmacExeUnit extends Exu(fmacExeUnitCfg)
val input = io.fromFp.bits val input = io.fromFp.bits
val fmaOut = fma.io.out.bits val fmaOut = fma.io.out.bits
val isRVD = !io.fromFp.bits.uop.ctrl.isRVF val isRVD = !io.fromFp.bits.uop.ctrl.isRVF
fma.io.in.bits.src := VecInit(Seq(input.src1, input.src2, input.src3).map( fma.io.in.bits.src := VecInit(Seq(input.src1, input.src2, input.src3))
src => Mux(isRVD, src, unboxF64ToF32(src))
))
val instr_rm = io.fromFp.bits.uop.cf.instr(14, 12) val instr_rm = io.fromFp.bits.uop.cf.instr(14, 12)
fma.rm := Mux(instr_rm =/= 7.U, instr_rm, frm) fma.rm := Mux(instr_rm =/= 7.U, instr_rm, frm)
fma.io.redirectIn := io.redirect fma.io.redirectIn := io.redirect
fma.io.out.ready := io.toFp.ready fma.io.out.ready := io.toFp.ready
io.toFp.bits.data := Mux(fmaOut.uop.ctrl.isRVF, boxF32ToF64(fmaOut.data), fmaOut.data) io.toFp.bits.data := box(fma.io.out.bits.data, fma.io.out.bits.uop.ctrl.fpu.typeTagOut)
io.toFp.bits.fflags := fma.fflags io.toFp.bits.fflags := fma.fflags
} }
...@@ -4,58 +4,39 @@ import chisel3._ ...@@ -4,58 +4,39 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import utils._ import utils._
import xiangshan.backend.exu.Exu.fmiscExeUnitCfg import xiangshan.backend.exu.Exu.fmiscExeUnitCfg
import xiangshan.backend.fu.fpu.FPUOpType._
import xiangshan.backend.fu.fpu._ import xiangshan.backend.fu.fpu._
class FmiscExeUnit extends Exu(fmiscExeUnitCfg) { class FmiscExeUnit extends Exu(fmiscExeUnitCfg) {
val frm = IO(Input(UInt(3.W))) val frm = IO(Input(UInt(3.W)))
val fcmp :: fmin :: fmv :: fsgnj :: f2i :: f32toF64 :: f64toF32 :: fdivSqrt :: Nil = supportedFunctionUnits.map(fu => fu.asInstanceOf[FPUSubModule]) val f2i :: f2f :: fdivSqrt :: Nil = supportedFunctionUnits.map(fu => fu.asInstanceOf[FPUSubModule])
val toFpUnits = Seq(fmin, fsgnj, f32toF64, f64toF32, fdivSqrt) val toFpUnits = Seq(f2f, fdivSqrt)
val toIntUnits = Seq(fcmp, fmv, f2i) val toIntUnits = Seq(f2i)
assert(fpArb.io.in.length == toFpUnits.size) assert(fpArb.io.in.length == toFpUnits.size)
assert(intArb.io.in.length == toIntUnits.size) assert(intArb.io.in.length == toIntUnits.size)
val input = io.fromFp val input = io.fromFp
val fuOp = input.bits.uop.ctrl.fuOpType
assert(fuOp.getWidth == 7) // when fuOp's WIDTH change, here must change too
val fu = fuOp.head(4)
val op = fuOp.tail(4)
val isRVF = input.bits.uop.ctrl.isRVF val isRVF = input.bits.uop.ctrl.isRVF
val instr_rm = input.bits.uop.cf.instr(14, 12) val instr_rm = input.bits.uop.cf.instr(14, 12)
val (src1, src2) = (input.bits.src1, input.bits.src2) val (src1, src2) = (input.bits.src1, input.bits.src2)
supportedFunctionUnits.foreach { module => supportedFunctionUnits.foreach { module =>
module.io.in.bits.src(0) := Mux( module.io.in.bits.src(0) := src1
(isRVF && fuOp =/= d2s && fuOp =/= fmv_f2i) || fuOp === s2d, module.io.in.bits.src(1) := src2
unboxF64ToF32(src1),
src1
)
module.io.in.bits.src(1) := Mux(isRVF, unboxF64ToF32(src2), src2)
module.asInstanceOf[FPUSubModule].rm := Mux(instr_rm =/= 7.U, instr_rm, frm) module.asInstanceOf[FPUSubModule].rm := Mux(instr_rm =/= 7.U, instr_rm, frm)
} }
io.toFp.bits.fflags := MuxCase( io.toFp.bits.fflags := MuxCase(
0.U.asTypeOf(new Fflags), 0.U,
toFpUnits.map(x => x.io.out.fire() -> x.fflags) toFpUnits.map(x => x.io.out.fire() -> x.fflags)
) )
val fpOutCtrl = io.toFp.bits.uop.ctrl val fpOutCtrl = io.toFp.bits.uop.ctrl.fpu
io.toFp.bits.data := Mux(fpOutCtrl.isRVF, io.toFp.bits.data := box(fpArb.io.out.bits.data, fpOutCtrl.typeTagOut)
boxF32ToF64(fpArb.io.out.bits.data),
fpArb.io.out.bits.data
)
val intOutCtrl = io.toInt.bits.uop.ctrl
io.toInt.bits.data := Mux(
(intOutCtrl.isRVF && intOutCtrl.fuOpType === fmv_f2i) ||
intOutCtrl.fuOpType === f2w ||
intOutCtrl.fuOpType === f2wu,
SignExt(intArb.io.out.bits.data(31, 0), XLEN),
intArb.io.out.bits.data
)
io.toInt.bits.fflags := MuxCase( io.toInt.bits.fflags := MuxCase(
0.U.asTypeOf(new Fflags), 0.U,
toIntUnits.map(x => x.io.out.fire() -> x.fflags) toIntUnits.map(x => x.io.out.fire() -> x.fflags)
) )
} }
...@@ -5,14 +5,13 @@ import chisel3._ ...@@ -5,14 +5,13 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan._ import xiangshan._
import xiangshan.backend.exu.Exu.jumpExeUnitCfg import xiangshan.backend.exu.Exu.jumpExeUnitCfg
import xiangshan.backend.fu.fpu.FPUOpType.FU_I2F import xiangshan.backend.fu.fpu.IntToFP
import xiangshan.backend.fu.{CSR, Fence, FenceToSbuffer, FunctionUnit, Jump} import xiangshan.backend.fu.{CSR, Fence, FenceToSbuffer, FunctionUnit, Jump}
import xiangshan.backend.fu.fpu.{Fflags, IntToFloatSingleCycle, boxF32ToF64}
class JumpExeUnit extends Exu(jumpExeUnitCfg) class JumpExeUnit extends Exu(jumpExeUnitCfg)
{ {
val csrio = IO(new Bundle { val csrio = IO(new Bundle {
val fflags = Input(new Fflags) val fflags = Flipped(ValidIO(UInt(5.W)))
val dirty_fs = Input(Bool()) val dirty_fs = Input(Bool())
val frm = Output(UInt(3.W)) val frm = Output(UInt(3.W))
val exception = Flipped(ValidIO(new MicroOp)) val exception = Flipped(ValidIO(new MicroOp))
...@@ -39,7 +38,7 @@ class JumpExeUnit extends Exu(jumpExeUnitCfg) ...@@ -39,7 +38,7 @@ class JumpExeUnit extends Exu(jumpExeUnitCfg)
case f: Fence => f case f: Fence => f
}.get }.get
val i2f = supportedFunctionUnits.collectFirst { val i2f = supportedFunctionUnits.collectFirst {
case i: IntToFloatSingleCycle => i case i: IntToFP => i
}.get }.get
csr.csrio.perf <> DontCare csr.csrio.perf <> DontCare
...@@ -66,17 +65,6 @@ class JumpExeUnit extends Exu(jumpExeUnitCfg) ...@@ -66,17 +65,6 @@ class JumpExeUnit extends Exu(jumpExeUnitCfg)
val isDouble = !uop.ctrl.isRVF val isDouble = !uop.ctrl.isRVF
when(i2f.io.in.valid){
when(uop.ctrl.fuOpType.head(4)===s"b$FU_I2F".U){
io.toFp.bits.data := Mux(isDouble, i2f.io.out.bits.data, boxF32ToF64(i2f.io.out.bits.data))
io.toFp.bits.fflags := i2f.fflags
}.otherwise({
// a mov.(s/d).x instruction
io.toFp.bits.data := Mux(isDouble, io.fromInt.bits.src1, boxF32ToF64(io.fromInt.bits.src1))
io.toFp.bits.fflags := 0.U.asTypeOf(new Fflags)
})
}
when(csr.io.out.valid){ when(csr.io.out.valid){
io.toInt.bits.redirectValid := csr.csrio.redirectOut.valid io.toInt.bits.redirectValid := csr.csrio.redirectOut.valid
io.toInt.bits.redirect.brTag := uop.brTag io.toInt.bits.redirect.brTag := uop.brTag
......
...@@ -17,8 +17,7 @@ class Alu extends FunctionUnit with HasRedirectOut { ...@@ -17,8 +17,7 @@ class Alu extends FunctionUnit with HasRedirectOut {
io.in.bits.uop io.in.bits.uop
) )
val redirectHit = uop.roqIdx.needFlush(io.redirectIn) val valid = io.in.valid
val valid = io.in.valid && !redirectHit
val isAdderSub = (func =/= ALUOpType.add) && (func =/= ALUOpType.addw) val isAdderSub = (func =/= ALUOpType.add) && (func =/= ALUOpType.addw)
val adderRes = (src1 +& (src2 ^ Fill(XLEN, isAdderSub))) + isAdderSub val adderRes = (src1 +& (src2 ^ Fill(XLEN, isAdderSub))) + isAdderSub
......
...@@ -3,7 +3,6 @@ package xiangshan.backend.fu ...@@ -3,7 +3,6 @@ package xiangshan.backend.fu
import chisel3._ import chisel3._
import chisel3.ExcitingUtils.{ConnectionType, Debug} import chisel3.ExcitingUtils.{ConnectionType, Debug}
import chisel3.util._ import chisel3.util._
import fpu.Fflags
import utils._ import utils._
import xiangshan._ import xiangshan._
import xiangshan.backend._ import xiangshan.backend._
...@@ -165,7 +164,7 @@ trait HasExceptionNO { ...@@ -165,7 +164,7 @@ trait HasExceptionNO {
} }
class FpuCsrIO extends XSBundle { class FpuCsrIO extends XSBundle {
val fflags = Output(new Fflags) val fflags = Output(Valid(UInt(5.W)))
val isIllegal = Output(Bool()) val isIllegal = Output(Bool())
val dirty_fs = Output(Bool()) val dirty_fs = Output(Bool())
val frm = Input(UInt(3.W)) val frm = Input(UInt(3.W))
...@@ -386,11 +385,16 @@ class CSR extends FunctionUnit with HasCSRConst ...@@ -386,11 +385,16 @@ class CSR extends FunctionUnit with HasCSRConst
} }
def frm_rfn(rdata: UInt): UInt = rdata(7,5) def frm_rfn(rdata: UInt): UInt = rdata(7,5)
def fflags_wfn(wdata: UInt): UInt = { def fflags_wfn(update: Boolean)(wdata: UInt): UInt = {
val fcsrOld = WireInit(fcsr.asTypeOf(new FcsrStruct)) val fcsrOld = fcsr.asTypeOf(new FcsrStruct)
val fcsrNew = WireInit(fcsrOld)
csrw_dirty_fp_state := true.B csrw_dirty_fp_state := true.B
fcsrOld.fflags := wdata(4,0) if(update){
fcsrOld.asUInt() fcsrNew.fflags := wdata(4,0) | fcsrOld.fflags
} else {
fcsrNew.fflags := wdata(4,0)
}
fcsrNew.asUInt()
} }
def fflags_rfn(rdata:UInt): UInt = rdata(4,0) def fflags_rfn(rdata:UInt): UInt = rdata(4,0)
...@@ -401,7 +405,7 @@ class CSR extends FunctionUnit with HasCSRConst ...@@ -401,7 +405,7 @@ class CSR extends FunctionUnit with HasCSRConst
} }
val fcsrMapping = Map( val fcsrMapping = Map(
MaskedRegMap(Fflags, fcsr, wfn = fflags_wfn, rfn = fflags_rfn), MaskedRegMap(Fflags, fcsr, wfn = fflags_wfn(update = false), rfn = fflags_rfn),
MaskedRegMap(Frm, fcsr, wfn = frm_wfn, rfn = frm_rfn), MaskedRegMap(Frm, fcsr, wfn = frm_wfn, rfn = frm_rfn),
MaskedRegMap(Fcsr, fcsr, wfn = fcsr_wfn) MaskedRegMap(Fcsr, fcsr, wfn = fcsr_wfn)
) )
...@@ -538,8 +542,8 @@ class CSR extends FunctionUnit with HasCSRConst ...@@ -538,8 +542,8 @@ class CSR extends FunctionUnit with HasCSRConst
val rdataDummy = Wire(UInt(XLEN.W)) val rdataDummy = Wire(UInt(XLEN.W))
MaskedRegMap.generate(fixMapping, addr, rdataDummy, wen, wdata) MaskedRegMap.generate(fixMapping, addr, rdataDummy, wen, wdata)
when(csrio.fpu.fflags.asUInt() =/= 0.U){ when(csrio.fpu.fflags.valid){
fcsr := fflags_wfn(csrio.fpu.fflags.asUInt()) fcsr := fflags_wfn(update = true)(csrio.fpu.fflags.bits)
} }
// set fs and sd in mstatus // set fs and sd in mstatus
when(csrw_dirty_fp_state || csrio.fpu.dirty_fs){ when(csrw_dirty_fp_state || csrio.fpu.dirty_fs){
......
...@@ -4,15 +4,7 @@ import chisel3._ ...@@ -4,15 +4,7 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan._ import xiangshan._
import xiangshan.backend.MDUOpType import xiangshan.backend.MDUOpType
import xiangshan.backend.fu.fpu.FPUOpType.{FU_D2S, FU_DIVSQRT, FU_F2I, FU_FCMP, FU_FMV, FU_S2D}
import xiangshan.backend.fu.fpu.divsqrt.DivSqrt
import xiangshan.backend.fu.fpu._ import xiangshan.backend.fu.fpu._
import xiangshan.backend.fu.fpu.fma.FMA
/*
XiangShan Function Unit
A Exu can have one or more function units
*/
trait HasFuLatency { trait HasFuLatency {
val latencyVal: Option[Int] val latencyVal: Option[Int]
...@@ -43,24 +35,21 @@ case class FuConfig ...@@ -43,24 +35,21 @@ case class FuConfig
} }
class FuOutput extends XSBundle { class FuOutput(val len: Int) extends XSBundle {
val data = UInt(XLEN.W) val data = UInt(len.W)
val uop = new MicroOp val uop = new MicroOp
} }
class FunctionUnitIO(len: Int) extends XSBundle { class FunctionUnitIO(val len: Int) extends XSBundle {
val in = Flipped(DecoupledIO(new Bundle() { val in = Flipped(DecoupledIO(new Bundle() {
val src = Vec(3, UInt(len.W)) val src = Vec(3, UInt(len.W))
val uop = new MicroOp val uop = new MicroOp
})) }))
val out = DecoupledIO(new FuOutput) val out = DecoupledIO(new FuOutput(len))
val redirectIn = Flipped(ValidIO(new Redirect)) val redirectIn = Flipped(ValidIO(new Redirect))
override def cloneType: FunctionUnitIO.this.type =
new FunctionUnitIO(len).asInstanceOf[this.type]
} }
abstract class FunctionUnit(len: Int = 64) extends XSModule { abstract class FunctionUnit(len: Int = 64) extends XSModule {
...@@ -81,7 +70,9 @@ trait HasPipelineReg { ...@@ -81,7 +70,9 @@ trait HasPipelineReg {
val uopVec = io.in.bits.uop +: Array.fill(latency)(Reg(new MicroOp)) val uopVec = io.in.bits.uop +: Array.fill(latency)(Reg(new MicroOp))
val flushVec = uopVec.zip(validVec).map(x => x._2 && x._1.roqIdx.needFlush(io.redirectIn)) // if flush(0), valid 0 will not given, so set flushVec(0) to false.B
val flushVec = WireInit(false.B) +:
validVec.zip(uopVec).tail.map(x => x._1 && x._2.roqIdx.needFlush(io.redirectIn))
for (i <- 0 until latency) { for (i <- 0 until latency) {
rdyVec(i) := !validVec(i + 1) || rdyVec(i + 1) rdyVec(i) := !validVec(i + 1) || rdyVec(i + 1)
...@@ -97,12 +88,14 @@ trait HasPipelineReg { ...@@ -97,12 +88,14 @@ trait HasPipelineReg {
} }
io.in.ready := rdyVec(0) io.in.ready := rdyVec(0)
io.out.valid := validVec.last && !flushVec.last io.out.valid := validVec.last
io.out.bits.uop := uopVec.last io.out.bits.uop := uopVec.last
def regEnable(i: Int): Bool = validVec(i - 1) && rdyVec(i - 1) && !flushVec(i - 1)
def PipelineReg[TT <: Data](i: Int)(next: TT) = RegEnable( def PipelineReg[TT <: Data](i: Int)(next: TT) = RegEnable(
next, next,
enable = validVec(i - 1) && rdyVec(i - 1) && !flushVec(i - 1) enable = regEnable(i)
) )
def S1Reg[TT <: Data](next: TT): TT = PipelineReg[TT](1)(next) def S1Reg[TT <: Data](next: TT): TT = PipelineReg[TT](1)(next)
...@@ -130,24 +123,32 @@ object FunctionUnit extends HasXSParameter { ...@@ -130,24 +123,32 @@ object FunctionUnit extends HasXSParameter {
def csr = new CSR def csr = new CSR
def i2f = new IntToFloatSingleCycle def i2f = new IntToFP
def fmac = new FMA def fmac = new FMA
def fcmp = new FCMP def f2i = new FPToInt
def fmv = new FMV(XLEN) def f2f = new FPToFP
def f2i = new FloatToInt def fdivSqrt = new FDivSqrt
def f32toF64 = new F32toF64 def f2iSel(x: FunctionUnit): Bool = {
x.io.in.bits.uop.ctrl.rfWen
}
def f64toF32 = new F64toF32 def i2fSel(x: FunctionUnit): Bool = {
x.io.in.bits.uop.ctrl.fpu.fromInt
}
def fdivSqrt = new DivSqrt def f2fSel(x: FunctionUnit): Bool = {
val ctrl = x.io.in.bits.uop.ctrl.fpu
ctrl.fpWen && !ctrl.div && !ctrl.sqrt
}
def fmiscSel(fu: String)(x: FunctionUnit): Bool = { def fdivSqrtSel(x: FunctionUnit): Bool = {
x.io.in.bits.uop.ctrl.fuOpType.head(4) === s"b$fu".U val ctrl = x.io.in.bits.uop.ctrl.fpu
ctrl.div || ctrl.sqrt
} }
val aluCfg = FuConfig( val aluCfg = FuConfig(
...@@ -192,7 +193,7 @@ object FunctionUnit extends HasXSParameter { ...@@ -192,7 +193,7 @@ object FunctionUnit extends HasXSParameter {
val i2fCfg = FuConfig( val i2fCfg = FuConfig(
fuGen = i2f _, fuGen = i2f _,
fuSel = (x: FunctionUnit) => x.io.in.bits.uop.ctrl.fuType === FuType.i2f, fuSel = i2fSel,
FuType.i2f, FuType.i2f,
numIntSrc = 1, numIntSrc = 1,
numFpSrc = 0, numFpSrc = 0,
...@@ -229,54 +230,24 @@ object FunctionUnit extends HasXSParameter { ...@@ -229,54 +230,24 @@ object FunctionUnit extends HasXSParameter {
val fmacCfg = FuConfig( val fmacCfg = FuConfig(
fuGen = fmac _, fuGen = fmac _,
fuSel = _ => true.B, fuSel = _ => true.B,
FuType.fmac, 0, 3, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(5) FuType.fmac, 0, 3, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(4)
)
val fcmpCfg = FuConfig(
fuGen = fcmp _,
fuSel = (x: FunctionUnit) => fmiscSel(FU_FCMP)(x) && x.io.in.bits.uop.ctrl.rfWen,
FuType.fmisc, 0, 2, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(2)
)
val fminCfg = FuConfig(
fuGen = fcmp _,
fuSel = (x: FunctionUnit) => fmiscSel(FU_FCMP)(x) && x.io.in.bits.uop.ctrl.fpWen,
FuType.fmisc, 0, 2, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(2)
)
val fsgnjCfg = FuConfig(
fuGen = fmv _,
fuSel = (x: FunctionUnit) => fmiscSel(FU_FMV)(x) && x.io.in.bits.uop.ctrl.fpWen,
FuType.fmisc, 0, 2, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(1)
)
val fmvCfg = FuConfig(
fuGen = fmv _,
fuSel = (x: FunctionUnit) => fmiscSel(FU_FMV)(x) && x.io.in.bits.uop.ctrl.rfWen,
FuType.fmisc, 0, 2, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(1)
) )
val f2iCfg = FuConfig( val f2iCfg = FuConfig(
fuGen = f2i _, fuGen = f2i _,
fuSel = fmiscSel(FU_F2I), fuSel = f2iSel,
FuType.fmisc, 0, 1, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(2) FuType.fmisc, 0, 1, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(2)
) )
val s2dCfg = FuConfig( val f2fCfg = FuConfig(
fuGen = f32toF64 _, fuGen = f2f _,
fuSel = fmiscSel(FU_S2D), fuSel = f2fSel,
FuType.fmisc, 0, 1, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(2)
)
val d2sCfg = FuConfig(
fuGen = f64toF32 _,
fuSel = fmiscSel(FU_D2S),
FuType.fmisc, 0, 1, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(2) FuType.fmisc, 0, 1, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(2)
) )
val fdivSqrtCfg = FuConfig( val fdivSqrtCfg = FuConfig(
fuGen = fdivSqrt _, fuGen = fdivSqrt _,
fuSel = fmiscSel(FU_DIVSQRT), fuSel = fdivSqrtSel,
FuType.fDivSqrt, 0, 2, writeIntRf = false, writeFpRf = true, hasRedirect = false, UncertainLatency() FuType.fDivSqrt, 0, 2, writeIntRf = false, writeFpRf = true, hasRedirect = false, UncertainLatency()
) )
......
...@@ -25,7 +25,7 @@ class Jump extends FunctionUnit with HasRedirectOut { ...@@ -25,7 +25,7 @@ class Jump extends FunctionUnit with HasRedirectOut {
) )
val redirectHit = uop.roqIdx.needFlush(io.redirectIn) val redirectHit = uop.roqIdx.needFlush(io.redirectIn)
val valid = io.in.valid && !redirectHit val valid = io.in.valid
val isRVC = uop.cf.brUpdate.pd.isRVC val isRVC = uop.cf.brUpdate.pd.isRVC
val snpc = Mux(isRVC, pc + 2.U, pc + 4.U) val snpc = Mux(isRVC, pc + 2.U, pc + 4.U)
......
...@@ -4,7 +4,7 @@ import chisel3._ ...@@ -4,7 +4,7 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan._ import xiangshan._
import utils._ import utils._
import xiangshan.backend.fu.fpu.util.{C22, C32, C53} import xiangshan.backend.fu.util.{C22, C32, C53}
class MulDivCtrl extends Bundle{ class MulDivCtrl extends Bundle{
val sign = Bool() val sign = Bool()
...@@ -165,7 +165,7 @@ class ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len) ...@@ -165,7 +165,7 @@ class ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len)
for(i <- 1 to latency){ for(i <- 1 to latency){
ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
} }
val xlen = io.out.bits.data.getWidth val xlen = len - 1
val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
......
...@@ -3,7 +3,7 @@ package xiangshan.backend.fu ...@@ -3,7 +3,7 @@ package xiangshan.backend.fu
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
import utils.SignExt import utils.SignExt
import xiangshan.backend.fu.fpu.util.CSA3_2 import xiangshan.backend.fu.util.CSA3_2
/** A Radix-4 SRT Integer Divider /** A Radix-4 SRT Integer Divider
* *
......
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
class Classify(expWidth: Int, mantWidth: Int) extends Module{
val io = IO(new Bundle() {
val in = Input(UInt((1 + expWidth + mantWidth).W))
val isNegInf = Output(Bool())
val isNegNormal = Output(Bool())
val isNegSubnormal = Output(Bool())
val isNegZero = Output(Bool())
val isPosZero = Output(Bool())
val isPosSubnormal = Output(Bool())
val isPosNormal = Output(Bool())
val isPosInf = Output(Bool())
val isSNaN = Output(Bool())
val isQNaN = Output(Bool())
val isNaN = Output(Bool())
val isInf = Output(Bool())
val isInfOrNaN = Output(Bool())
val isSubnormal = Output(Bool())
val isZero = Output(Bool())
val isSubnormalOrZero = Output(Bool())
})
val flpt = io.in.asTypeOf(new FloatPoint(expWidth, mantWidth))
val (sign, exp, mant) = (flpt.sign, flpt.exp, flpt.mant)
val isSubnormOrZero = exp === 0.U
val mantIsZero = mant === 0.U
val isInfOrNaN = (~exp).asUInt() === 0.U
io.isNegInf := sign && io.isInf
io.isNegNormal := sign && !isSubnormOrZero && !isInfOrNaN
io.isNegSubnormal := sign && io.isSubnormal
io.isNegZero := sign && io.isZero
io.isPosZero := !sign && io.isZero
io.isPosSubnormal := !sign && io.isSubnormal
io.isPosNormal := !sign && !isSubnormOrZero && !isInfOrNaN
io.isPosInf := !sign && io.isInf
io.isSNaN := io.isNaN && !mant.head(1)
io.isQNaN := io.isNaN && mant.head(1).asBool()
io.isNaN := isInfOrNaN && !mantIsZero
io.isInf := isInfOrNaN && mantIsZero
io.isInfOrNaN := isInfOrNaN
io.isSubnormal := isSubnormOrZero && !mantIsZero
io.isZero := isSubnormOrZero && mantIsZero
io.isSubnormalOrZero := isSubnormOrZero
}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.FunctionUnit._
class F32toF64 extends FPUPipelineModule {
override def latency: Int = FunctionUnit.s2dCfg.latency.latencyVal.get
val a = io.in.bits.src(0)
val f32 = Float32(a)
val classify = Module(new Classify(Float32.expWidth, Float32.mantWidth))
classify.io.in := a
val isNaN = classify.io.isNaN
val isSNaN = classify.io.isSNaN
val isSubnormal = classify.io.isSubnormal
val isSubnormalOrZero = classify.io.isSubnormalOrZero
val invalid = isSNaN
val isInfOrNaN = classify.io.isInfOrNaN
val isInf = classify.io.isInf
val f32Mant = f32.mant // not include hidden bit here
val f32MantLez = PriorityEncoder(f32Mant.asBools().reverse)
val exp = Mux(isSubnormalOrZero,
0.U(Float64.expWidth.W),
Mux(isInfOrNaN,
Cat("b111".U(3.W), f32.exp),
Cat("b0111".U(4.W) + f32.exp.head(1), f32.exp.tail(1))
)
)
val s1_isNaN = S1Reg(isNaN)
val s1_isSNaN = S1Reg(isSNaN)
val s1_isSubnormal = S1Reg(isSubnormal)
val s1_mantLez = S1Reg(f32MantLez)
val s1_mant = S1Reg(f32Mant)
val s1_exp = S1Reg(exp)
val s1_sign = S1Reg(f32.sign)
// MantNorm: 1.xx...x * 2^(-127 - lez)
val f32MantFromDenorm = Wire(UInt(Float32.mantWidth.W))
f32MantFromDenorm := Cat(s1_mant.tail(1) << s1_mantLez, 0.U(1.W))
val f64ExpFromDenorm = Wire(UInt(Float64.expWidth.W)) // -127 - lez + 1023 = 0x380 - lez
f64ExpFromDenorm := "h380".U - s1_mantLez
val commonResult = Cat(
s1_sign,
Mux(s1_isSubnormal, f64ExpFromDenorm, s1_exp),
Mux(s1_isSubnormal, f32MantFromDenorm, s1_mant),
0.U((Float64.mantWidth-Float32.mantWidth).W)
)
val result = Mux(s1_isNaN, Float64.defaultNaN, commonResult)
io.out.bits.data := S2Reg(result)
fflags.invalid := S2Reg(s1_isSNaN)
fflags.overflow := false.B
fflags.underflow := false.B
fflags.infinite := false.B
fflags.inexact := false.B
}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.fpu.util.ShiftRightJam
class F64toF32 extends FPUPipelineModule {
override def latency = FunctionUnit.d2sCfg.latency.latencyVal.get
def SEXP_WIDTH = Float64.expWidth + 2
val a = io.in.bits.src(0)
val classify = Module(new Classify(Float64.expWidth, Float64.mantWidth))
classify.io.in := a
val isNaN = classify.io.isNaN
val isSNaN = classify.io.isSNaN
val isInf = classify.io.isInf
val f64 = Float64(a)
val f64sign = f64.sign
val f64exp = Wire(SInt(SEXP_WIDTH.W))
f64exp := f64.exp.toSInt
val f64mant = f64.mantExt
val f32exp = f64exp - (Float64.expBiasInt - Float32.expBiasInt).S
val shiftAmt = 1.S - f32exp
val needDenorm = shiftAmt > 0.S
val mantShifted = ShiftRightJam(f64mant,
Mux(needDenorm, shiftAmt.asUInt(), 0.U),
Float32.mantWidth+4
)
val s1_mantShifted = S1Reg(mantShifted)
val s1_shiftAmt = S1Reg(shiftAmt)
val s1_sign = S1Reg(f64sign)
val s1_exp = S1Reg(f32exp)
val s1_rm = S1Reg(rm)
val s1_isNaN = S1Reg(isNaN)
val s1_isInf = S1Reg(isInf)
val s1_isSNaN = S1Reg(isSNaN)
val rounding = Module(new RoundF64AndF32WithExceptions(expInHasBias = true))
rounding.io.isDouble := false.B
rounding.io.denormShiftAmt := s1_shiftAmt
rounding.io.sign := s1_sign
rounding.io.expNorm := s1_exp
rounding.io.mantWithGRS := s1_mantShifted
rounding.io.rm := s1_rm
rounding.io.specialCaseHappen := s1_isNaN || s1_isInf
val inexact = rounding.io.inexact
val underflow = rounding.io.underflow
val overflow = rounding.io.overflow
val ovSetInf = rounding.io.ovSetInf
val expRounded = rounding.io.expRounded
val mantRounded = rounding.io.mantRounded
val result = Mux(s1_isNaN,
Float32.defaultNaN,
Mux(overflow || s1_isInf,
Cat(s1_sign, Mux(ovSetInf || s1_isInf, Float32.posInf, Float32.maxNorm).tail(1)),
Cat(s1_sign, expRounded(Float32.expWidth-1, 0), mantRounded(Float32.mantWidth-1, 0))
)
)
io.out.bits.data := S2Reg(result)
fflags.invalid := S2Reg(s1_isSNaN)
fflags.overflow := S2Reg(overflow)
fflags.underflow := S2Reg(underflow)
fflags.infinite := false.B
fflags.inexact := S2Reg(inexact)
}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.FunctionUnit._
class FCMP extends FPUPipelineModule {
override def latency = FunctionUnit.fcmpCfg.latency.latencyVal.get
val src = io.in.bits.src.map(x => Mux(isDouble, x, extF32ToF64(x)))
val sign = src.map(_(63))
val aSign = sign(0)
val subRes = src(0).toSInt - src(1).toSInt
val classify = Array.fill(2)(Module(new Classify(Float64.expWidth, Float64.mantWidth)).io)
classify.zip(src).foreach({case (c, s) => c.in := s})
val srcIsNaN = classify.map(_.isNaN)
val srcIsSNaN = classify.map(_.isSNaN)
val isDoubleReg = S1Reg(isDouble)
val opReg = S1Reg(op)
val srcReg = io.in.bits.src.map(S1Reg)
val (aSignReg, bSignReg) = (S1Reg(sign(0)), S1Reg(sign(1)))
val hasNaNReg = S1Reg(srcIsNaN(0) || srcIsNaN(1))
val bothNaNReg = S1Reg(srcIsNaN(0) && srcIsNaN(1))
val hasSNaNReg = S1Reg(srcIsSNaN(0) || srcIsSNaN(1))
val aIsNaNReg = S1Reg(srcIsNaN(0))
val bothZeroReg = S1Reg(src(0).tail(1)===0.U && src(1).tail(1)===0.U)
val uintEqReg = S1Reg(subRes===0.S)
val uintLessReg = S1Reg(aSign ^ (subRes < 0.S))
val invalid = Mux(opReg(2) || !opReg(1), hasSNaNReg, hasNaNReg)
val le,lt,eq = Wire(Bool())
eq := uintEqReg || bothZeroReg
le := Mux(aSignReg =/= bSignReg, aSignReg || bothZeroReg, uintEqReg || uintLessReg)
lt := Mux(aSignReg =/= bSignReg, aSignReg && !bothZeroReg, !uintEqReg && uintLessReg)
val fcmpResult = Mux(hasNaNReg,
false.B,
Mux(opReg(2), eq, Mux(opReg(0), lt, le))
)
val sel_a = lt || (eq && aSignReg)
val defaultNaN = Mux(isDoubleReg, Float64.defaultNaN, Float32.defaultNaN)
val min = Mux(bothNaNReg, defaultNaN, Mux(sel_a && !aIsNaNReg, srcReg(0), srcReg(1)))
val max = Mux(bothNaNReg, defaultNaN, Mux(!sel_a && !aIsNaNReg, srcReg(0), srcReg(1)))
fflags.inexact := false.B
fflags.underflow := false.B
fflags.overflow := false.B
fflags.infinite := false.B
fflags.invalid := S2Reg(invalid)
io.out.bits.data := S2Reg(Mux(opReg===0.U, min, Mux(opReg===1.U, max, fcmpResult)))
}
\ No newline at end of file
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import freechips.rocketchip.tile.FType
import hardfloat.{DivSqrtRecFNToRaw_small, RoundAnyRawFNToRecFN}
class FDivSqrt extends FPUSubModule {
val s_idle :: s_div :: s_finish :: Nil = Enum(3)
val state = RegInit(s_idle)
val divSqrt = Module(new DivSqrtRecFNToRaw_small(FType.D.exp, FType.D.sig, 0))
val divSqrtRawValid = divSqrt.io.rawOutValid_sqrt || divSqrt.io.rawOutValid_div
val fpCtrl = io.in.bits.uop.ctrl.fpu
val tag = fpCtrl.typeTagIn
val uopReg = RegEnable(io.in.bits.uop, io.in.fire())
val single = RegEnable(tag === S, io.in.fire())
val kill = uopReg.roqIdx.needFlush(io.redirectIn)
val killReg = RegInit(false.B)
switch(state){
is(s_idle){
when(io.in.fire() && !io.in.bits.uop.roqIdx.needFlush(io.redirectIn)){ state := s_div }
}
is(s_div){
when(divSqrtRawValid){
when(kill || killReg){
state := s_idle
}.otherwise({
state := s_finish
})
}.elsewhen(kill){
killReg := true.B
}
}
is(s_finish){
state := s_idle
killReg := false.B
}
}
val src1 = unbox(io.in.bits.src(0), tag, None)
val src2 = unbox(io.in.bits.src(1), tag, None)
divSqrt.io.inValid := io.in.fire()
divSqrt.io.sqrtOp := fpCtrl.sqrt
divSqrt.io.a := src1
divSqrt.io.b := src2
divSqrt.io.roundingMode := rm
val round32 = Module(new RoundAnyRawFNToRecFN(
FType.D.exp, FType.D.sig+2, FType.S.exp, FType.S.sig, 0
))
val round64 = Module(new RoundAnyRawFNToRecFN(
FType.D.exp, FType.D.sig+2, FType.D.exp, FType.D.sig, 0
))
for(rounder <- Seq(round32, round64)){
rounder.io.invalidExc := divSqrt.io.invalidExc
rounder.io.infiniteExc := divSqrt.io.infiniteExc
rounder.io.in := divSqrt.io.rawOut
rounder.io.roundingMode := rm
rounder.io.detectTininess := hardfloat.consts.tininess_afterRounding
}
val data = Mux(single, round32.io.out, round64.io.out)
val flags = Mux(single, round32.io.exceptionFlags, round64.io.exceptionFlags)
io.in.ready := state===s_idle
io.out.valid := state===s_finish && !(killReg || kill)
io.out.bits.uop := uopReg
io.out.bits.data := RegNext(data, divSqrtRawValid)
fflags := RegNext(flags, divSqrtRawValid)
}
package xiangshan.backend.fu.fpu
import chisel3._
import freechips.rocketchip.tile.FType
import hardfloat.{MulAddRecFN_pipeline_stage1, MulAddRecFN_pipeline_stage2, MulAddRecFN_pipeline_stage3, MulAddRecFN_pipeline_stage4, RoundAnyRawFNToRecFN}
import xiangshan.backend.fu.FunctionUnit
class FMA extends FPUPipelineModule {
override def latency: Int = FunctionUnit.fmacCfg.latency.latencyVal.get
val fpCtrl = io.in.bits.uop.ctrl.fpu
val typeTagIn = fpCtrl.typeTagIn
val src1 = unbox(io.in.bits.src(0), typeTagIn, None)
val src2 = unbox(io.in.bits.src(1), typeTagIn, None)
val src3 = unbox(io.in.bits.src(2), typeTagIn, None)
val (in1, in2, in3) = (
WireInit(src1), WireInit(src2), WireInit(Mux(fpCtrl.isAddSub, src2, src3))
)
val one = 1.U << (FType.D.sig + FType.D.exp - 1)
val zero = (src1 ^ src2) & (1.U << (FType.D.sig + FType.D.exp))
when(fpCtrl.isAddSub){ in2 := one }
when(!(fpCtrl.isAddSub || fpCtrl.ren3)){ in3 := zero }
val stage1 = Module(new MulAddRecFN_pipeline_stage1(maxExpWidth, maxSigWidth))
val stage2 = Module(new MulAddRecFN_pipeline_stage2(maxExpWidth, maxSigWidth))
val stage3 = Module(new MulAddRecFN_pipeline_stage3(maxExpWidth, maxSigWidth))
val stage4 = Module(new MulAddRecFN_pipeline_stage4(maxExpWidth, maxSigWidth))
val mul = Module(new hardfloat.ArrayMultiplier(
maxSigWidth+1,
regDepth = 0,
realArraryMult = true,
hasReg = true
))
mul.io.a := stage1.io.mulAddA
mul.io.b := stage1.io.mulAddB
mul.io.reg_en := regEnable(1)
stage2.io.mulSum := mul.io.sum
stage2.io.mulCarry := mul.io.carry
stage1.io.in.valid := DontCare
stage1.io.toStage2.ready := DontCare
stage2.io.fromStage1.valid := DontCare
stage2.io.toStage3.ready := DontCare
stage3.io.fromStage2.valid := DontCare
stage3.io.toStage4.ready := DontCare
stage4.io.fromStage3.valid := DontCare
stage4.io.toStage5.ready := DontCare
stage1.io.in.bits.a := in1
stage1.io.in.bits.b := in2
stage1.io.in.bits.c := in3
stage1.io.in.bits.op := fpCtrl.fmaCmd
stage1.io.in.bits.roundingMode := rm
stage1.io.in.bits.detectTininess := hardfloat.consts.tininess_afterRounding
stage2.io.fromStage1.bits <> S1Reg(stage1.io.toStage2.bits)
stage3.io.fromStage2.bits <> S2Reg(stage2.io.toStage3.bits)
stage4.io.fromStage3.bits <> S3Reg(stage3.io.toStage4.bits)
val stage4toStage5 = S4Reg(stage4.io.toStage5.bits)
val rounders = Seq(FType.S, FType.D).map(t => {
val rounder = Module(new RoundAnyRawFNToRecFN(FType.D.exp, FType.D.sig+2, t.exp, t.sig, 0))
rounder.io.invalidExc := stage4toStage5.invalidExc
rounder.io.infiniteExc := false.B
rounder.io.in := stage4toStage5.rawOut
rounder.io.roundingMode := stage4toStage5.roundingMode
rounder.io.detectTininess := stage4toStage5.detectTininess
rounder
})
val singleOut = io.out.bits.uop.ctrl.fpu.typeTagOut === S
io.out.bits.data := Mux(singleOut,
sanitizeNaN(rounders(0).io.out, FType.S),
sanitizeNaN(rounders(1).io.out, FType.D)
)
fflags := Mux(singleOut,
rounders(0).io.exceptionFlags,
rounders(1).io.exceptionFlags
)
}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
class FMV(XLEN: Int) extends FPUPipelineModule {
override def latency = FunctionUnit.fmvCfg.latency.latencyVal.get
val src = io.in.bits.src.map(x =>
Mux(isDouble || op(2,1)==="b00".U, x, extF32ToF64(x))
)
val aSign = Mux(op(2,1)==="b00".U && !isDouble, src(0)(31), src(0)(63))
val bSign = Mux(op(2,1)==="b00".U && !isDouble, src(1)(31), src(1)(63))
val sgnjSign = Mux(op(1),
bSign,
Mux(op(0), !bSign, aSign ^ bSign)
)
val resSign = Mux(op(2), sgnjSign, aSign)
val cls = Module(new Classify(Float64.expWidth, Float64.mantWidth)).io
cls.in := src(0)
val classifyResult = Cat(
cls.isQNaN, // 9
cls.isSNaN, // 8
cls.isPosInf, // 7
cls.isPosNormal, // 6
cls.isPosSubnormal, // 5
cls.isPosZero, // 4
cls.isNegZero, // 3
cls.isNegSubnormal, // 2
cls.isNegNormal, // 1
cls.isNegInf // 0
)
val result = Mux(op === "b010".U,
classifyResult,
Mux(isDouble,
Cat(resSign, io.in.bits.src(0)(62, 0)),
Cat(resSign, io.in.bits.src(0)(30 ,0))
)
)
val resultReg = S1Reg(result)
io.out.bits.data := resultReg
fflags := 0.U.asTypeOf(new Fflags)
}
// See LICENSE.Berkeley for license details.
// See LICENSE.SiFive for license details.
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import hardfloat.CompareRecFN
import xiangshan.backend.fu.FunctionUnit
class FPToFP extends FPUPipelineModule{
override def latency: Int = FunctionUnit.f2iCfg.latency.latencyVal.get
val ctrl = io.in.bits.uop.ctrl.fpu
val inTag = ctrl.typeTagIn
val outTag = ctrl.typeTagOut
val src1 = unbox(io.in.bits.src(0), inTag, None)
val src2 = unbox(io.in.bits.src(1), inTag, None)
val wflags = ctrl.wflags
val signNum = Mux(rm(1), src1 ^ src2, Mux(rm(0), ~src2, src2))
val fsgnj = Cat(signNum(fLen), src1(fLen-1, 0))
val fsgnjMux = Wire(new Bundle() {
val data = UInt((XLEN+1).W)
val exc = UInt(5.W)
})
fsgnjMux.data := fsgnj
fsgnjMux.exc := 0.U
val dcmp = Module(new CompareRecFN(maxExpWidth, maxSigWidth))
dcmp.io.a := src1
dcmp.io.b := src2
dcmp.io.signaling := !rm(1)
val lt = dcmp.io.lt || (dcmp.io.a.asSInt() < 0.S && dcmp.io.b.asSInt() >= 0.S)
when(wflags){
val isnan1 = maxType.isNaN(src1)
val isnan2 = maxType.isNaN(src2)
val isInvalid = maxType.isSNaN(src1) || maxType.isSNaN(src2)
val isNaNOut = isnan1 && isnan2
val isLHS = isnan2 || rm(0) =/= lt && !isnan1
fsgnjMux.exc := isInvalid << 4
fsgnjMux.data := Mux(isNaNOut, maxType.qNaN, Mux(isLHS, src1, src2))
}
val mux = WireInit(fsgnjMux)
for(t <- floatTypes.init){
when(outTag === typeTag(t).U){
mux.data := Cat(fsgnjMux.data >> t.recodedWidth, maxType.unsafeConvert(fsgnjMux.data, t))
}
}
when(ctrl.fcvt){
if(floatTypes.size > 1){
// widening conversions simply canonicalize NaN operands
val widened = Mux(maxType.isNaN(src1), maxType.qNaN, src1)
fsgnjMux.data := widened
fsgnjMux.exc := maxType.isSNaN(src1) << 4
// narrowing conversions require rounding (for RVQ, this could be
// optimized to use a single variable-position rounding unit, rather
// than two fixed-position ones)
for(outType <- floatTypes.init){
when(outTag === typeTag(outType).U && (typeTag(outType) == 0).B || (outTag < inTag)){
val narrower = Module(new hardfloat.RecFNToRecFN(maxType.exp, maxType.sig, outType.exp, outType.sig))
narrower.io.in := src1
narrower.io.roundingMode := rm
narrower.io.detectTininess := hardfloat.consts.tininess_afterRounding
val narrowed = sanitizeNaN(narrower.io.out, outType)
mux.data := Cat(fsgnjMux.data >> narrowed.getWidth, narrowed)
mux.exc := narrower.io.exceptionFlags
}
}
}
}
var resVec = Seq(mux)
for(i <- 1 to latency){
resVec = resVec :+ PipelineReg(i)(resVec(i-1))
}
io.out.bits.data := resVec.last.data
fflags := resVec.last.exc
}
// See LICENSE.Berkeley for license details.
// See LICENSE.SiFive for license details.
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import freechips.rocketchip.tile.FType
import hardfloat.RecFNToIN
import utils.SignExt
import xiangshan.backend.fu.FunctionUnit
class FPToInt extends FPUPipelineModule {
override def latency = FunctionUnit.f2iCfg.latency.latencyVal.get
val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
val ctrl = io.in.bits.uop.ctrl.fpu
val src1_s = unbox(src1, S, Some(FType.S))
val src1_d = unbox(src1, ctrl.typeTagIn, None)
val src2_d = unbox(src2, ctrl.typeTagIn, None)
val src1_ieee = ieee(src1)
val move_out = Mux(ctrl.typeTagIn === S, src1_ieee(31, 0), src1_ieee)
val classify_out = Mux(ctrl.typeTagIn === S,
FType.S.classify(src1_s),
FType.D.classify(src1)
)
val dcmp = Module(new hardfloat.CompareRecFN(maxExpWidth, maxSigWidth))
dcmp.io.a := src1_d
dcmp.io.b := src2_d
dcmp.io.signaling := !rm(1)
val dcmp_out = ((~rm).asUInt() & Cat(dcmp.io.lt, dcmp.io.eq)).orR()
val dcmp_exc = dcmp.io.exceptionFlags
val conv = Module(new RecFNToIN(maxExpWidth, maxSigWidth, XLEN))
conv.io.in := src1_d
conv.io.roundingMode := rm
conv.io.signedOut := ~ctrl.typ(0)
val conv_out = WireInit(conv.io.out)
val conv_exc = WireInit(Cat(
conv.io.intExceptionFlags(2, 1).orR(),
0.U(3.W),
conv.io.intExceptionFlags(0)
))
val narrow = Module(new RecFNToIN(maxExpWidth, maxSigWidth, 32))
narrow.io.in := src1_d
narrow.io.roundingMode := rm
narrow.io.signedOut := ~ctrl.typ(0)
when(!ctrl.typ(1)) { // fcvt.w/wu.fp
val excSign = src1_d(maxExpWidth + maxSigWidth) && !maxType.isNaN(src1_d)
val excOut = Cat(conv.io.signedOut === excSign, Fill(32 - 1, !excSign))
val invalid = conv.io.intExceptionFlags(2) || narrow.io.intExceptionFlags(1)
when(invalid) {
conv_out := Cat(conv.io.out >> 32, excOut)
}
conv_exc := Cat(invalid, 0.U(3.W), !invalid && conv.io.intExceptionFlags(0))
}
val intData = Wire(UInt(XLEN.W))
intData := Mux(ctrl.wflags,
Mux(ctrl.fcvt, conv_out, dcmp_out),
Mux(rm(0), classify_out, move_out)
)
val doubleOut = Mux(ctrl.fcvt, ctrl.typ(1), ctrl.fmt(0))
val intValue = Mux(doubleOut,
SignExt(intData, XLEN),
SignExt(intData(31, 0), XLEN)
)
val exc = Mux(ctrl.fcvt, conv_exc, dcmp_exc)
var dataVec = Seq(intValue)
var excVec = Seq(exc)
for (i <- 1 to latency) {
dataVec = dataVec :+ PipelineReg(i)(dataVec(i - 1))
excVec = excVec :+ PipelineReg(i)(excVec(i - 1))
}
io.out.bits.data := dataVec.last
fflags := excVec.last
}
...@@ -4,43 +4,17 @@ import chisel3._ ...@@ -4,43 +4,17 @@ import chisel3._
import chisel3.util._ import chisel3.util._
import xiangshan.backend.fu.{FuConfig, FunctionUnit, HasPipelineReg} import xiangshan.backend.fu.{FuConfig, FunctionUnit, HasPipelineReg}
class FPUSubModuleInput extends Bundle{
val op = UInt(3.W)
val isDouble = Bool()
val a, b, c = UInt(64.W)
val rm = UInt(3.W)
}
class FPUSubModuleOutput extends Bundle{
val fflags = new Fflags
val result = UInt(64.W)
}
class FPUSubModuleIO extends Bundle{
val in = Flipped(DecoupledIO(new FPUSubModuleInput))
val out = DecoupledIO(new FPUSubModuleOutput)
}
trait HasUIntToSIntHelper { trait HasUIntToSIntHelper {
implicit class UIntToSIntHelper(x: UInt){ implicit class UIntToSIntHelper(x: UInt){
def toSInt: SInt = Cat(0.U(1.W), x).asSInt() def toSInt: SInt = Cat(0.U(1.W), x).asSInt()
} }
} }
trait HasFPUSigs { this: FPUSubModule => abstract class FPUSubModule extends FunctionUnit(len = 65)
val op = io.in.bits.uop.ctrl.fuOpType(2, 0)
// 'op' must change with fuOpType
require(io.in.bits.uop.ctrl.fuOpType.getWidth == 7)
val isDouble = !io.in.bits.uop.ctrl.isRVF
}
abstract class FPUSubModule extends FunctionUnit
with HasUIntToSIntHelper with HasUIntToSIntHelper
with HasFPUSigs
{ {
val rm = IO(Input(UInt(3.W))) val rm = IO(Input(UInt(3.W)))
val fflags = IO(Output(new Fflags)) val fflags = IO(Output(UInt(5.W)))
} }
abstract class FPUPipelineModule abstract class FPUPipelineModule
......
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.fpu.util.{ORTree, ShiftRightJam}
//def f2w:UInt = FpuOp("011", "000")
//def f2wu:UInt = FpuOp("011", "001")
//def f2l:UInt = FpuOp("011", "010")
//def f2lu:UInt = FpuOp("011", "011")
class FloatToInt extends FPUPipelineModule {
override def latency = FunctionUnit.f2iCfg.latency.latencyVal.get
def SEXP_WIDTH = Float64.expWidth + 2
/** Stage 1: Shift Operand
*/
val a = Mux(isDouble, io.in.bits.src(0), extF32ToF64(io.in.bits.src(0)))
val f64 = Float64(a)
val cls = Module(new Classify(Float64.expWidth, Float64.mantWidth))
cls.io.in := a
val isNaN = cls.io.isNaN
val sign = f64.sign
val exp = Wire(SInt(SEXP_WIDTH.W))
exp := f64.exp.toSInt
val mant = f64.mantExt
val leftShiftAmt = exp - (Float64.expBiasInt + Float64.mantWidth).S
val rightShiftAmt = -leftShiftAmt.asUInt()
val needRightShift = leftShiftAmt.head(1).asBool() // exp - 52 < 0
val expOv = leftShiftAmt > Mux(op(1), 11.S, (-21).S) // exp > 63 / exp > 31
val uintUnrounded = Wire(UInt((64+3).W)) // 64 + g r s
uintUnrounded := Mux(needRightShift,
ShiftRightJam(Cat(mant, 0.U(3.W)), rightShiftAmt, Float64.mantWidth+4),
Cat((mant << leftShiftAmt(3, 0))(63, 0), 0.U(3.W))
)
val s1_uint = S1Reg(uintUnrounded)
val s1_sign = S1Reg(sign)
val s1_rm = S1Reg(rm)
val s1_op = S1Reg(op)
val s1_isNaN = S1Reg(isNaN)
val s1_expOv = S1Reg(expOv)
/** Stage 2: Rounding
*/
val rounding = Module(new RoundingUnit(64))
rounding.io.in.rm := s1_rm
rounding.io.in.sign := s1_sign
rounding.io.in.mant := s1_uint.head(64)
rounding.io.in.guard := s1_uint.tail(64).head(1)
rounding.io.in.round := s1_uint.tail(65).head(1)
rounding.io.in.sticky := s1_uint.tail(66).head(1)
val uint = rounding.io.out.mantRounded
val int = Mux(s1_sign, -uint, uint)
val commonResult = Mux(s1_op(1), int, int(31, 0))
val orHi = ORTree(uint.head(32))
val orLo = ORTree(uint.tail(32))
val diffSign = (orHi | orLo) && Mux(s1_op(0),
s1_sign,
Mux(s1_op(1),
int(63),
int(31)
) ^ s1_sign
)
val max64 = Cat(s1_op(0), Fill(63, 1.U(1.W)))
val min64 = Cat(!s1_op(0), 0.U(63.W))
val specialResult = Mux(s1_isNaN || !s1_sign,
Mux(s1_op(1), max64, max64.head(32)),
Mux(s1_op(1), min64, min64.head(32))
)
val invalid = s1_isNaN || s1_expOv || diffSign || (!s1_op(1) && orHi)
val s2_invalid = S2Reg(invalid)
val s2_result = S2Reg(Mux(invalid, specialResult, commonResult))
val s2_inexact =S2Reg(!invalid && rounding.io.out.inexact)
/** Assign Outputs
*/
io.out.bits.data := s2_result
fflags.invalid := s2_invalid
fflags.overflow := false.B
fflags.underflow := false.B
fflags.infinite := false.B
fflags.inexact := s2_inexact
}
\ No newline at end of file
// See LICENSE.Berkeley for license details.
// See LICENSE.SiFive for license details.
package xiangshan.backend.fu.fpu
import chisel3._
import hardfloat.INToRecFN
import utils.{SignExt, ZeroExt}
class IntToFP extends FPUSubModule {
val ctrl = io.in.bits.uop.ctrl.fpu
val tag = ctrl.typeTagIn
val typ = ctrl.typ
val wflags = ctrl.wflags
val src1 = io.in.bits.src(0)(XLEN-1, 0)
val mux = Wire(new Bundle() {
val data = UInt((XLEN+1).W)
val exc = UInt(5.W)
})
mux.data := recode(src1, tag)
mux.exc := 0.U
val intValue = Mux(typ(1),
Mux(typ(0), ZeroExt(src1, XLEN), SignExt(src1, XLEN)),
Mux(typ(0), ZeroExt(src1(31, 0), XLEN), SignExt(src1(31, 0), XLEN))
)
when(wflags){
val i2fResults = for(t <- floatTypes) yield {
val i2f = Module(new INToRecFN(XLEN, t.exp, t.sig))
i2f.io.signedIn := ~typ(0)
i2f.io.in := intValue
i2f.io.roundingMode := rm
i2f.io.detectTininess := hardfloat.consts.tininess_afterRounding
(sanitizeNaN(i2f.io.out, t), i2f.io.exceptionFlags)
}
val (data, exc) = i2fResults.unzip
mux.data := VecInit(data)(tag)
mux.exc := VecInit(exc)(tag)
}
fflags := mux.exc
io.out.bits.uop := io.in.bits.uop
io.out.bits.data := box(mux.data, io.in.bits.uop.ctrl.fpu.typeTagOut)
io.out.valid := io.in.valid
io.in.ready := io.out.ready
}
//package xiangshan.backend.fu.fpu
//
//import chisel3._
//import chisel3.util._
//import xiangshan.FuType
//import xiangshan.backend.fu.{CertainLatency, FuConfig}
//import xiangshan.backend.fu.fpu.util.ORTree
//
//class IntToFloat extends FPUPipelineModule(
// FuConfig(FuType.i2f, 1, 0, writeIntRf = false, writeFpRf = true, hasRedirect = false, CertainLatency(2))
//) {
// /** Stage 1: Count leading zeros and shift
// */
//
// val a = io.in.bits.src(0)
// val aNeg = (~a).asUInt()
// val aComp = aNeg + 1.U
// val aSign = Mux(op(0), false.B, Mux(op(1), a(63), a(31)))
//
// val leadingZerosComp = PriorityEncoder(Mux(op(1), aComp, aComp(31, 0)).asBools().reverse)
// val leadingZerosNeg = PriorityEncoder(Mux(op(1), aNeg, aNeg(31, 0)).asBools().reverse)
// val leadingZerosPos = PriorityEncoder(Mux(op(1), a, a(31,0)).asBools().reverse)
//
// val aVal = Mux(aSign, Mux(op(1), aComp, aComp(31, 0)), Mux(op(1), a, a(31, 0)))
// val leadingZeros = Mux(aSign, leadingZerosNeg, leadingZerosPos)
//
// // exp = xlen - 1 - leadingZeros + bias
// val expUnrounded = S1Reg(
// Mux(isDouble,
// (64 - 1 + Float64.expBiasInt).U - leadingZeros,
// (64 - 1 + Float32.expBiasInt).U - leadingZeros
// )
// )
// val leadingZeroHasError = S1Reg(aSign && (leadingZerosComp=/=leadingZerosNeg))
// val rmReg = S1Reg(rm)
// val opReg = S1Reg(op)
// val isDoubleReg = S1Reg(isDouble)
// val aIsZeroReg = S1Reg(a===0.U)
// val aSignReg = S1Reg(aSign)
// val aShifted = S1Reg((aVal << leadingZeros)(63, 0))
//
// /** Stage 2: Rounding
// */
// val aShiftedFix = Mux(leadingZeroHasError, aShifted(63, 1), aShifted(62, 0))
// val mantD = aShiftedFix(62, 62-51)
// val mantS = aShiftedFix(62, 62-22)
//
// val g = Mux(isDoubleReg, aShiftedFix(62-52), aShiftedFix(62-23))
// val r = Mux(isDoubleReg, aShiftedFix(62-53), aShiftedFix(62-24))
// val s = Mux(isDoubleReg, ORTree(aShiftedFix(62-54, 0)), ORTree(aShiftedFix(62-25, 0)))
//
// val roudingUnit = Module(new RoundingUnit(Float64.mantWidth))
// roudingUnit.io.in.rm := rmReg
// roudingUnit.io.in.mant := Mux(isDoubleReg, mantD, mantS)
// roudingUnit.io.in.sign := aSignReg
// roudingUnit.io.in.guard := g
// roudingUnit.io.in.round := r
// roudingUnit.io.in.sticky := s
//
// val mantRounded = roudingUnit.io.out.mantRounded
// val expRounded = Mux(isDoubleReg,
// expUnrounded + roudingUnit.io.out.mantCout,
// expUnrounded + mantRounded(Float32.mantWidth)
// ) + leadingZeroHasError
//
// val resS = Cat(
// aSignReg,
// expRounded(Float32.expWidth-1, 0),
// mantRounded(Float32.mantWidth-1, 0)
// )
// val resD = Cat(aSignReg, expRounded, mantRounded)
//
// io.out.bits.data := S2Reg(Mux(aIsZeroReg, 0.U, Mux(isDoubleReg, resD, resS)))
// fflags.inexact := S2Reg(roudingUnit.io.out.inexact)
// fflags.underflow := false.B
// fflags.overflow := false.B
// fflags.infinite := false.B
// fflags.invalid := false.B
//}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.fpu.util.ORTree
class IntToFloatSingleCycle extends FPUSubModule {
val a = io.in.bits.src(0)
val aNeg = (~a).asUInt()
val aComp = aNeg + 1.U
val aSign = Mux(op(0), false.B, Mux(op(1), a(63), a(31)))
val leadingZerosComp = PriorityEncoder(Mux(op(1), aComp, aComp(31, 0)).asBools().reverse)
val leadingZerosNeg = PriorityEncoder(Mux(op(1), aNeg, aNeg(31, 0)).asBools().reverse)
val leadingZerosPos = PriorityEncoder(Mux(op(1), a, a(31,0)).asBools().reverse)
val aVal = Mux(aSign, Mux(op(1), aComp, aComp(31, 0)), Mux(op(1), a, a(31, 0)))
val leadingZeros = Mux(aSign, leadingZerosNeg, leadingZerosPos)
// exp = xlen - 1 - leadingZeros + bias
val expUnrounded = Mux(isDouble,
(64 - 1 + Float64.expBiasInt).U - leadingZeros,
(64 - 1 + Float32.expBiasInt).U - leadingZeros
)
val leadingZeroHasError = aSign && (leadingZerosComp=/=leadingZerosNeg)
val rmReg = rm
val opReg = op
val isDoubleReg = isDouble
val aIsZeroReg = a===0.U
val aSignReg = aSign
val aShifted = (aVal << leadingZeros)(63, 0)
/** Stage 2: Rounding
*/
val aShiftedFix = Mux(leadingZeroHasError, aShifted(63, 1), aShifted(62, 0))
val mantD = aShiftedFix(62, 62-51)
val mantS = aShiftedFix(62, 62-22)
val g = Mux(isDoubleReg, aShiftedFix(62-52), aShiftedFix(62-23))
val r = Mux(isDoubleReg, aShiftedFix(62-53), aShiftedFix(62-24))
val s = Mux(isDoubleReg, ORTree(aShiftedFix(62-54, 0)), ORTree(aShiftedFix(62-25, 0)))
val roudingUnit = Module(new RoundingUnit(Float64.mantWidth))
roudingUnit.io.in.rm := rmReg
roudingUnit.io.in.mant := Mux(isDoubleReg, mantD, mantS)
roudingUnit.io.in.sign := aSignReg
roudingUnit.io.in.guard := g
roudingUnit.io.in.round := r
roudingUnit.io.in.sticky := s
val mantRounded = roudingUnit.io.out.mantRounded
val expRounded = Mux(isDoubleReg,
expUnrounded + roudingUnit.io.out.mantCout,
expUnrounded + mantRounded(Float32.mantWidth)
) + leadingZeroHasError
val resS = Cat(
aSignReg,
expRounded(Float32.expWidth-1, 0),
mantRounded(Float32.mantWidth-1, 0)
)
val resD = Cat(aSignReg, expRounded, mantRounded)
io.in.ready := io.out.ready
io.out.valid := io.in.valid
io.out.bits.uop := io.in.bits.uop
io.out.bits.data := Mux(aIsZeroReg, 0.U, Mux(isDoubleReg, resD, resS))
fflags.inexact := roudingUnit.io.out.inexact
fflags.underflow := false.B
fflags.overflow := false.B
fflags.infinite := false.B
fflags.invalid := false.B
}
# NOOP-FPU
一个完全符合IEEE754-2008标准的混合精度(Float/Double)RISCV-FPU
FPU除法/开方模块使用了SRT-4算法,采用多周期设计,其余部件均为流水线结构,具体情况如下:
| 功能部件 | 流水级数 |
| :----: | :----: |
|FMA | 5 |
|F32toF64 | 2 |
|F64toF32 | 2 |
|FCMP | 2 |
|FloatToInt| 2 |
|IntToFloat| 2 |
不同功能部件之间相互独立,不共享硬件资源;
同一部件内部,双精度/单精度运算共享硬件资源。
FPU中所有部件都已通过
berkeley-testfloat和riscv-tests中的rvd/rvf测试,
在axu3cg上运行频率超过200MHz
## 开启/关闭FPU
`HasNOOPParameter`中的`HasFPU`定义为`true`/`false`即可
## FPU单元测试
### 使用berkeley-testfloat测试FPU中的所有模块:
```
cd deug
make fputest FPU_TEST_ARGS=-Pn
```
`n`为线程数
### 自定义测试:
`src/test/fpu/FPUSubModuleTester`中修改测试配置
```
配置格式
case class FpuTest
(
name: String,
roundingModes: Seq[UInt],
backend: String = "verilator",
writeVcd: Boolean = false,
pipeline: Boolean = true
)
```
`backend`可选`verilator`/`treadle`/`vcs`
`verilator`编译较慢但仿真运行速度最快;
`treadle`输出格式较为整齐,适合debug
`pipeline``false`时每执行完一个测例才开始输入下一个
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
import xiangshan.backend.fu.fpu.RoudingMode._
import xiangshan.backend.fu.fpu.util.ORTree
class RoundingUnit(mantWidth: Int) extends Module{
val io = IO(new Bundle() {
val in = Input(new Bundle() {
val rm = UInt(3.W)
val mant = UInt(mantWidth.W)
val sign, guard, round, sticky = Bool()
})
val out = Output(new Bundle() {
val mantRounded = UInt(mantWidth.W)
val inexact = Bool()
val mantCout = Bool()
val roundUp = Bool()
})
})
val inexact = io.in.guard | io.in.round | io.in.sticky
val lsb = io.in.mant(0)
val roundUp = MuxLookup(io.in.rm, false.B, Seq(
RNE -> (io.in.guard && (io.in.round | io.in.sticky | lsb)),
RTZ -> false.B,
RUP -> (inexact & (!io.in.sign)),
RDN -> (inexact & io.in.sign),
RMM -> io.in.guard
))
val mantRoundUp = io.in.mant +& 1.U
val cout = mantRoundUp(mantWidth)
val mantRounded = Mux(roundUp,
mantRoundUp(mantWidth-1, 0),
io.in.mant
)
io.out.inexact := inexact
io.out.mantRounded := mantRounded
io.out.mantCout := cout & roundUp
io.out.roundUp := roundUp
}
class RoundWithExceptionsIO(sexpWidth:Int, mantWidth:Int) extends Bundle {
val isDouble = Input(Bool())
val denormShiftAmt = Input(SInt(sexpWidth.W))
val sign = Input(Bool())
val expNorm = Input(SInt(sexpWidth.W))
val mantWithGRS = Input(UInt((mantWidth+3).W))
val rm = Input(UInt(3.W))
val specialCaseHappen = Input(Bool())
val expRounded = Output(SInt(sexpWidth.W))
val mantRounded = Output(UInt(mantWidth.W))
val inexact = Output(Bool())
val overflow = Output(Bool())
val underflow = Output(Bool())
val ovSetInf = Output(Bool())
val isZeroResult = Output(Bool())
override def cloneType: RoundWithExceptionsIO.this.type =
new RoundWithExceptionsIO(sexpWidth, mantWidth).asInstanceOf[this.type]
}
class RoundF64AndF32WithExceptions
(
expInHasBias: Boolean = false,
D_MANT_WIDTH: Int = Float64.mantWidth + 1,
D_SEXP_WIDTH: Int = Float64.expWidth + 2,
D_EXP_WIDTH: Int = Float64.expWidth,
S_MANT_WIDTH: Int = Float32.mantWidth + 1,
S_SEXP_WIDTH: Int = Float32.expWidth + 2,
S_EXP_WIDTH: Int = Float32.expWidth
) extends Module with HasUIntToSIntHelper {
val io = IO(new RoundWithExceptionsIO(D_SEXP_WIDTH, D_MANT_WIDTH))
val isDouble = io.isDouble
val rounding = Module(new RoundingUnit(D_MANT_WIDTH))
val mantUnrounded = io.mantWithGRS.head(D_MANT_WIDTH)
rounding.io.in.sign := io.sign
rounding.io.in.mant := mantUnrounded
rounding.io.in.rm := io.rm
rounding.io.in.guard := io.mantWithGRS(2)
rounding.io.in.round := io.mantWithGRS(1)
rounding.io.in.sticky := io.mantWithGRS(0)
val mantRounded = rounding.io.out.mantRounded
val mantCout = Mux(isDouble,
Mux(!mantUnrounded(D_MANT_WIDTH-1),
mantRounded(D_MANT_WIDTH-1),
rounding.io.out.mantCout
),
Mux(!mantUnrounded(S_MANT_WIDTH-1),
mantRounded(S_MANT_WIDTH-1),
mantRounded(S_MANT_WIDTH)
)
)
val isZeroResult = !ORTree(Cat(mantCout, mantRounded))
val expRounded = Mux(io.denormShiftAmt > 0.S || isZeroResult,
0.S,
if(expInHasBias) io.expNorm
else io.expNorm + Mux(isDouble, Float64.expBias, Float32.expBias).toSInt
) + mantCout.toSInt
val common_inexact = rounding.io.out.inexact
val roundingInc = MuxLookup(io.rm, "b10".U(2.W), Seq(
RoudingMode.RDN -> Mux(io.sign, "b11".U, "b00".U),
RoudingMode.RUP -> Mux(io.sign, "b00".U, "b11".U),
RoudingMode.RTZ -> "b00".U
))
val isDenormalMant = (io.mantWithGRS + roundingInc) < Mux(isDouble,
Cat(1.U(1.W), 0.U((D_MANT_WIDTH+2).W)),
Cat(1.U(1.W), 0.U((S_MANT_WIDTH+2).W))
)
val common_underflow = (
io.denormShiftAmt > 1.S ||
io.denormShiftAmt===1.S && isDenormalMant ||
isZeroResult
) && common_inexact
val common_overflow = Mux(isDouble,
expOverflow(expRounded, D_EXP_WIDTH),
expOverflow(expRounded, S_EXP_WIDTH)
)
val ovSetInf = io.rm === RoudingMode.RNE ||
io.rm === RoudingMode.RMM ||
(io.rm === RoudingMode.RDN && io.sign) ||
(io.rm === RoudingMode.RUP && !io.sign)
io.expRounded := expRounded
io.mantRounded := mantRounded
io.inexact := !io.specialCaseHappen && (common_inexact || common_overflow || common_underflow)
io.overflow := !io.specialCaseHappen && common_overflow
io.underflow := !io.specialCaseHappen && common_underflow
io.ovSetInf := ovSetInf
io.isZeroResult := isZeroResult
}
\ No newline at end of file
package xiangshan.backend.fu.fpu.divsqrt
import xiangshan.backend.fu.fpu._
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{FuConfig, FunctionUnit, UncertainLatency}
import xiangshan.backend.fu.fpu.util.{FPUDebug, ORTree, ShiftRightJam}
class DivSqrt extends FPUSubModule {
def SEXP_WIDTH: Int = Float64.expWidth + 2
def D_MANT_WIDTH: Int = Float64.mantWidth + 1
def S_MANT_WIDTH: Int = Float32.mantWidth+1
val s_idle :: s_norm :: s_start :: s_compute :: s_round:: s_finish :: Nil = Enum(6)
val state = RegInit(s_idle)
val uopReg = RegEnable(io.in.bits.uop, io.in.fire())
val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn)
val rmReg = RegEnable(rm, io.in.fire())
val isDiv = !op(0)
val isDivReg = RegEnable(isDiv, io.in.fire())
val isDoubleReg = RegEnable(isDouble, io.in.fire())
val (a, b) = (
Mux(isDouble, io.in.bits.src(0), extF32ToF64(io.in.bits.src(0))),
Mux(isDouble, io.in.bits.src(1), extF32ToF64(io.in.bits.src(1)))
)
/** Detect special cases
*/
val classify_a = Module(new Classify(Float64.expWidth, Float64.mantWidth))
classify_a.io.in := a
val aIsSubnormalOrZero = classify_a.io.isSubnormalOrZero
val classify_b = Module(new Classify(Float64.expWidth, Float64.mantWidth))
classify_b.io.in := b
val bIsSubnormalOrZero = classify_b.io.isSubnormalOrZero
def decode(x: UInt, expIsZero: Bool) = {
val f64 = Float64(x)
val exp = Cat(0.U(1.W), f64.exp) - Float64.expBias
val mantExt = Cat(!expIsZero, f64.mant)
(f64.sign, exp.asSInt(), mantExt)
}
val (aSign, aExp, aMant) = decode(a, aIsSubnormalOrZero)
val (bSign, bExp, bMant) = decode(b, bIsSubnormalOrZero)
val resSign = Mux(isDiv, aSign ^ bSign, aSign)
val resSignReg = RegEnable(resSign, io.in.fire())
val aExpReg = Reg(SInt(SEXP_WIDTH.W))
val aMantReg = Reg(UInt(D_MANT_WIDTH.W))
val aIsOddExp = aExpReg(0)
val bMantReg = Reg(UInt(D_MANT_WIDTH.W))
val bExpReg = Reg(SInt(SEXP_WIDTH.W))
val aIsNaN = classify_a.io.isNaN
val aIsSNaN = classify_a.io.isSNaN
val aIsInf = classify_a.io.isInf
val aIsPosInf = classify_a.io.isPosInf
val aIsInfOrNaN = classify_a.io.isInfOrNaN
val aIsSubnormal = classify_a.io.isSubnormal
val aIsSubnormalReg = RegEnable(aIsSubnormal, io.in.fire())
val aIsZero = classify_a.io.isZero
val sel_NaN_OH = UIntToOH(2.U, 3)
val sel_Zero_OH = UIntToOH(1.U, 3)
val sel_Inf_OH = UIntToOH(0.U, 3)
val sqrtSepcialResSel = MuxCase(sel_NaN_OH, Seq(
aIsZero -> sel_Zero_OH,
aIsPosInf -> sel_Inf_OH
))
val sqrtInvalid = ((aSign && !aIsNaN && !aIsZero) || aIsSNaN) && !isDiv
val sqrtSpecial = (aSign || aIsInfOrNaN || aIsZero) && !isDiv
val sqrtInvalidReg = RegEnable(sqrtInvalid, io.in.fire())
val bIsZero = classify_b.io.isZero
val bIsNaN = classify_b.io.isNaN
val bIsSNaN = classify_b.io.isSNaN
val bIsSubnormal = classify_b.io.isSubnormal && isDiv
val bIsSubnormalReg = RegEnable(bIsSubnormal, io.in.fire())
val bIsInf = classify_b.io.isInf
val hasNaN = aIsNaN || bIsNaN
val bothZero = aIsZero && bIsZero
val bothInf = aIsInf && bIsInf
val divInvalid = bothZero || aIsSNaN || bIsSNaN || bothInf
val divInf = !divInvalid && !aIsNaN && bIsZero && !aIsInf
val divSepcial = (aIsZero || bIsZero || hasNaN || bIsInf || aIsInf) && isDiv
val divZeroReg = RegEnable(bIsZero, io.in.fire())
val divInvalidReg = RegEnable(divInvalid, io.in.fire())
val divInfReg = RegEnable(divInf, io.in.fire())
val divSepcialResSel = PriorityMux(Seq(
(divInvalid || hasNaN) -> sel_NaN_OH,
bIsZero -> sel_Inf_OH,
(aIsZero || bIsInf) -> sel_Zero_OH,
aIsInf -> sel_Inf_OH
))
val specialCaseHappen = sqrtSpecial || divSepcial
val specialCaseHappenReg = RegEnable(specialCaseHappen, io.in.fire())
val specialResSel = Mux(sqrtSpecial, sqrtSepcialResSel, divSepcialResSel)
val sel_NaN :: sel_Zero :: sel_Inf :: Nil = specialResSel.asBools().reverse
val specialResult = RegEnable(
Mux(sel_NaN,
Mux(isDouble,
Float64.defaultNaN,
Float32.defaultNaN
),
Mux(sel_Zero,
Mux(isDouble,
Cat(resSign, 0.U((Float64.getWidth-1).W)),
Cat(resSign, 0.U((Float32.getWidth-1).W))
),
Mux(isDouble,
Cat(resSign, Float64.posInf.tail(1)),
Cat(resSign, Float32.posInf.tail(1))
)
)
),
io.in.fire()
)
// used in 's_norm' to normalize a subnormal number to normal
val aMantLez = PriorityEncoder(aMantReg(51, 0).asBools().reverse)
val bMantLez = PriorityEncoder(bMantReg(51, 0).asBools().reverse)
// 53 + 2 + 2 = 57 bits are needed, but 57 % log2(4) != 0, use 58 bits instead
val mantDivSqrt = Module(new MantDivSqrt(D_MANT_WIDTH+2+2+1))
mantDivSqrt.io.kill := kill
mantDivSqrt.io.out.ready := true.B
mantDivSqrt.io.in.valid := state === s_start
mantDivSqrt.io.in.bits.a := Mux(isDivReg || aIsOddExp, Cat(aMantReg, 0.U(5.W)), Cat(0.U(1.W), aMantReg, 0.U(4.W)))
mantDivSqrt.io.in.bits.b := Cat(bMantReg, 0.U(5.W))
mantDivSqrt.io.in.bits.isDiv := isDivReg
/** Output format:
*
* 57 56 55 4 3 2 1 0
* 0 x. x x x ... x | x x x x
*
*/
val mantDivSqrtResult = mantDivSqrt.io.out.bits.quotient
val needNormalize = !mantDivSqrtResult(56)
val mantNorm = Mux(needNormalize, mantDivSqrtResult<<1, mantDivSqrtResult)(56, 0)
val expNorm = (aExpReg.asUInt() - (Mux(needNormalize, 2.U, 1.U) - isDivReg)).asSInt()
val denormalizeShift = Mux(
isDoubleReg,
(-Float64.expBiasInt+1).S,
(-Float32.expBiasInt+1).S
) - expNorm
val denormShiftReg = RegEnable(denormalizeShift, mantDivSqrt.io.out.fire())
val mantShifted = ShiftRightJam(mantNorm,
Mux(denormalizeShift.head(1).asBool(), 0.U, denormalizeShift.asUInt()),
D_MANT_WIDTH+3
)
val mantPostNorm = Mux(isDoubleReg,
mantShifted.head(D_MANT_WIDTH),
mantShifted.head(S_MANT_WIDTH)
)
val g = Mux(isDoubleReg,
mantShifted.tail(D_MANT_WIDTH).head(1),
mantShifted.tail(S_MANT_WIDTH).head(1)
).asBool()
val r = Mux(isDoubleReg,
mantShifted.tail(D_MANT_WIDTH+1).head(1),
mantShifted.tail(S_MANT_WIDTH+1).head(1)
).asBool()
val s = !mantDivSqrt.io.out.bits.isZeroRem || ORTree(Mux(isDoubleReg,
mantShifted.tail(D_MANT_WIDTH+2),
mantShifted.tail(S_MANT_WIDTH+2)
))
val gReg = RegNext(g)
val rReg = RegNext(r) // false.B
val sReg = RegNext(s)
/** state === s_round
*
*/
val rounding = Module(new RoundF64AndF32WithExceptions)
rounding.io.isDouble := isDoubleReg
rounding.io.denormShiftAmt := denormShiftReg
rounding.io.sign := resSignReg
rounding.io.expNorm := aExpReg
rounding.io.mantWithGRS := Cat(aMantReg, gReg, rReg, sReg)
rounding.io.rm := rmReg
rounding.io.specialCaseHappen := false.B
val expRounded = rounding.io.expRounded
val mantRounded = rounding.io.mantRounded
val overflowReg = RegEnable(rounding.io.overflow, state===s_round)
val underflowReg = RegEnable(rounding.io.underflow, state===s_round)
val inexactReg = RegEnable(rounding.io.inexact, state===s_round)
val ovSetInfReg = RegEnable(rounding.io.ovSetInf, state===s_round)
switch(state){
is(s_idle){
when(io.in.fire()){
when(sqrtSpecial || divSepcial){
state := s_finish
}.elsewhen(aIsSubnormal || bIsSubnormal){
state := s_norm
}.otherwise({
state := s_start
})
}
}
is(s_norm){
state := s_start
}
is(s_start){
state := s_compute
}
is(s_compute){
when(mantDivSqrt.io.out.fire()){
state := s_round
}
}
is(s_round){
state := s_finish
}
is(s_finish){
when(io.out.fire()){
state := s_idle
}
}
}
when(kill){ state := s_idle }
switch(state){
is(s_idle){
when(io.in.fire()){
aExpReg := aExp
aMantReg := aMant
bExpReg := bExp
bMantReg := bMant
}
}
is(s_norm){
when(aIsSubnormalReg){
aExpReg := (Mux(isDoubleReg, aExpReg, (-Float32.expBiasInt).S(SEXP_WIDTH.W)).asUInt() - aMantLez).asSInt()
aMantReg := (aMantReg << aMantLez) << 1 // use 'Cat' instead ?
}
when(bIsSubnormalReg){
bExpReg := (Mux(isDoubleReg, bExpReg, (-Float32.expBiasInt).S(SEXP_WIDTH.W)).asUInt() - bMantLez).asSInt()
bMantReg := (bMantReg << bMantLez) << 1 // use 'Cat' instead ?
}
}
is(s_start){
aExpReg := Mux(isDivReg, aExpReg - bExpReg, (aExpReg >> 1).asSInt() + 1.S)
}
is(s_compute){
when(mantDivSqrt.io.out.fire()){
aExpReg := expNorm
aMantReg := mantPostNorm
}
}
is(s_round){
aExpReg := expRounded
aMantReg := mantRounded
}
}
val commonResult = Mux(isDoubleReg,
Cat(resSignReg, aExpReg(Float64.expWidth-1, 0), aMantReg(Float64.mantWidth-1, 0)),
Cat(resSignReg, aExpReg(Float32.expWidth-1, 0), aMantReg(Float32.mantWidth-1, 0))
)
io.in.ready := (state === s_idle) && io.out.ready
io.out.valid := (state === s_finish) && !kill
io.out.bits.data := Mux(specialCaseHappenReg,
specialResult,
Mux(overflowReg,
Mux(isDoubleReg,
Cat(resSignReg, Mux(ovSetInfReg, Float64.posInf.tail(1), Float64.maxNorm.tail(1))),
Cat(resSignReg, Mux(ovSetInfReg, Float32.posInf.tail(1), Float32.maxNorm.tail(1)))
),
commonResult
)
)
io.out.bits.uop := uopReg
fflags.invalid := Mux(isDivReg, divInvalidReg, sqrtInvalidReg)
fflags.underflow := !specialCaseHappenReg && underflowReg
fflags.overflow := !specialCaseHappenReg && overflowReg
fflags.infinite := Mux(isDivReg, divInfReg, false.B)
fflags.inexact := !specialCaseHappenReg && (inexactReg || overflowReg || underflowReg)
// FPUDebug() {
// // printf(p"$cnt in:${Hexadecimal(io.in.bits.src0)} \n")
// when(io.in.fire()) {
// printf(p"[In.fire] " +
// p"a:${Hexadecimal(io.in.bits.a)} aexp:${aExp.asSInt()} amant:${Hexadecimal(aMant)} " +
// p"b:${Hexadecimal(io.in.bits.b)} bexp:${bExp.asSInt()} bmant:${Hexadecimal(bMant)}\n")
// }
//// when(state === s_norm) {
//// printf(p"[norm] lz:$aMantLez\n")
//// }
// when(state === s_compute){
//// when(sqrt.io.out.fire()){
//// printf(p"[compute] ")
//// }
// }
// when(state === s_start) {
// printf(p"[start] sign:$resSignReg mant:${Hexadecimal(aMantReg)} exp:${aExpReg.asSInt()}\n")
// }
// when(state === s_round){
// printf(p"[round] exp before round:${aExpReg} g:$gReg r:$rReg s:$sReg mant:${Hexadecimal(aMantReg)}\n" +
// p"[round] mantRounded:${Hexadecimal(mantRounded)}\n")
// }
// when(io.out.valid) {
// printf(p"[Out.valid] " +
// p"invalid:$sqrtInvalidReg result:${Hexadecimal(commonResult)}\n" +
// p"output:${Hexadecimal(io.out.bits.result)} " +
// p"exp:${aExpReg.asSInt()} \n")
// }
// }
}
package xiangshan.backend.fu.fpu.divsqrt
import chisel3._
import chisel3.util._
import xiangshan.backend.fu.fpu.util._
import xiangshan.backend.fu.fpu.util.FPUDebug
class MantDivSqrt(len: Int) extends Module{
val io = IO(new Bundle() {
val in = Flipped(DecoupledIO(new Bundle() {
val a, b = UInt(len.W)
val isDiv = Bool()
}))
val kill = Input(Bool())
val out = DecoupledIO(new Bundle() {
val quotient = UInt(len.W)
val isZeroRem = Bool()
})
})
val (a, b) = (io.in.bits.a, io.in.bits.b)
val isDiv = io.in.bits.isDiv
val isDivReg = RegEnable(isDiv, io.in.fire())
val divisor = RegEnable(b, io.in.fire())
val s_idle :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(4)
val state = RegInit(s_idle)
val cnt_next = Wire(UInt(log2Up((len+1)/2).W))
val cnt = RegEnable(cnt_next, state===s_idle || state===s_recurrence)
cnt_next := Mux(state === s_idle, (len/2).U, cnt - 1.U)
val firstCycle = RegNext(io.in.fire())
switch(state){
is(s_idle){
when(io.in.fire()){ state := s_recurrence }
}
is(s_recurrence){
when(cnt_next === 0.U){ state := s_recovery }
}
is(s_recovery){
state := s_finish
}
is(s_finish){
when(io.out.fire()){ state := s_idle }
}
}
when(io.kill){ state := s_idle }
val ws, wc = Reg(UInt((len+4).W))
val table = Module(new SrtTable)
val conv = Module(new OnTheFlyConv(len+3))
val csa = Module(new CSA3_2(len+4))
// partial square root
val S = conv.io.Q >> 2
val s0 :: s1 :: s2 :: s3 :: s4 :: Nil = S(len-2, len-6).asBools().reverse
val sqrt_d = Mux(firstCycle, "b101".U(3.W), Mux(s0, "b111".U(3.W), Cat(s2, s3, s4)))
val div_d = divisor(len-2, len-4)
val sqrt_y = ws(len+3, len-4) + wc(len+3, len-4)
val div_y = ws(len+2, len-5) + wc(len+2, len-5)
table.io.d := Mux(isDivReg, div_d, sqrt_d)
table.io.y := Mux(isDivReg, div_y, sqrt_y)
conv.io.resetSqrt := io.in.fire() && !isDiv
conv.io.resetDiv := io.in.fire() && isDiv
conv.io.enable := state===s_recurrence
conv.io.qi := table.io.q
val dx1, dx2, neg_dx1, neg_dx2 = Wire(UInt((len+4).W))
dx1 := divisor
dx2 := divisor << 1
neg_dx1 := ~dx1
neg_dx2 := neg_dx1 << 1
val divCsaIn = MuxLookup(table.io.q.asUInt(), 0.U, Seq(
-1 -> dx1,
-2 -> dx2,
1 -> neg_dx1,
2 -> neg_dx2
).map(m => m._1.S(3.W).asUInt() -> m._2))
csa.io.in(0) := ws
csa.io.in(1) := Mux(isDivReg & !table.io.q(2), wc | table.io.q(1, 0), wc)
csa.io.in(2) := Mux(isDivReg, divCsaIn, conv.io.F)
val divWsInit = a
val sqrtWsInit = Cat( Cat(0.U(2.W), a) - Cat(1.U(2.W), 0.U(len.W)), 0.U(2.W))
when(io.in.fire()){
ws := Mux(isDiv, divWsInit, sqrtWsInit)
wc := 0.U
}.elsewhen(state === s_recurrence){
ws := Mux(cnt_next === 0.U, csa.io.out(0), csa.io.out(0) << 2)
wc := Mux(cnt_next === 0.U, csa.io.out(1) << 1, csa.io.out(1) << 3)
}
val rem = ws + wc
/** Remainder format:
* Sqrt:
* s s x x. x x x ... x
* Div:
* s s s x. x x x ... x
*/
val remSignReg = RegEnable(rem.head(1).asBool(), state===s_recovery)
val isZeroRemReg = RegEnable(rem===0.U, state===s_recovery)
io.in.ready := state === s_idle
io.out.valid := state === s_finish
io.out.bits.quotient := Mux(remSignReg, conv.io.QM, conv.io.Q) >> !isDivReg
io.out.bits.isZeroRem := isZeroRemReg
FPUDebug(){
when(io.in.fire()){
printf(p"a:${Hexadecimal(io.in.bits.a)} b:${Hexadecimal(io.in.bits.b)}\n")
}
when(io.out.valid) {
printf(p"Q:${Binary(conv.io.Q)} QM:${Binary(conv.io.QM)} isNegRem:${rem.head(1)}\n" +
p"rem:${Hexadecimal(rem)}\n")
}
}
}
package xiangshan.backend.fu.fpu.divsqrt
import chisel3._
import chisel3.util._
import utils._
import xiangshan.backend.fu.fpu._
import xiangshan.backend.fu.fpu.util.FPUDebug
class OnTheFlyConv(len: Int) extends Module {
val io = IO(new Bundle() {
val resetSqrt = Input(Bool())
val resetDiv = Input(Bool())
val enable = Input(Bool())
val qi = Input(SInt(3.W))
val QM = Output(UInt(len.W))
val Q = Output(UInt(len.W))
val F = Output(UInt(len.W))
})
val Q, QM = Reg(UInt(len.W))
/** FGen:
* use additional regs to avoid
* big width shifter since FGen is in cirtical path
*/
val mask = Reg(SInt(len.W))
val b_111, b_1100 = Reg(UInt(len.W))
when(io.resetSqrt){
mask := Cat("b1".U(1.W), 0.U((len-1).W)).asSInt()
b_111 := "b111".U(3.W) << (len-5)
b_1100 := "b1100".U(4.W) << (len-5)
}.elsewhen(io.enable){
mask := mask >> 2
b_111 := b_111 >> 2
b_1100 := b_1100 >> 2
}
val b_00, b_01, b_10, b_11 = Reg(UInt((len-3).W))
b_00 := 0.U
when(io.resetDiv || io.resetSqrt){
b_01 := Cat("b01".U(2.W), 0.U((len-5).W))
b_10 := Cat("b10".U(2.W), 0.U((len-5).W))
b_11 := Cat("b11".U(2.W), 0.U((len-5).W))
}.elsewhen(io.enable){
b_01 := b_01 >> 2
b_10 := b_10 >> 2
b_11 := b_11 >> 2
}
val negQ = ~Q
val sqrtToCsaMap = Seq(
1 -> (negQ, b_111),
2 -> (negQ, b_1100),
-1 -> (QM, b_111),
-2 -> (QM, b_1100)
).map(
m => m._1.S(3.W).asUInt() ->
( ((m._2._1 << Mux(io.qi(0), 1.U, 2.U)).asUInt() & (mask >> io.qi(0)).asUInt()) | m._2._2 )
)
val sqrtToCsa = MuxLookup(io.qi.asUInt(), 0.U, sqrtToCsaMap)
val Q_load_00 = Q | b_00
val Q_load_01 = Q | b_01
val Q_load_10 = Q | b_10
val QM_load_01 = QM | b_01
val QM_load_10 = QM | b_10
val QM_load_11 = QM | b_11
when(io.resetSqrt){
Q := Cat(1.U(3.W), 0.U((len-3).W))
QM := 0.U
}.elsewhen(io.resetDiv){
Q := 0.U
QM := 0.U
}.elsewhen(io.enable){
val QConvMap = Seq(
0 -> Q_load_00,
1 -> Q_load_01,
2 -> Q_load_10,
-1 -> QM_load_11,
-2 -> QM_load_10
).map(m => m._1.S(3.W).asUInt() -> m._2)
val QMConvMap = Seq(
0 -> QM_load_11,
1 -> Q_load_00,
2 -> Q_load_01,
-1 -> QM_load_10,
-2 -> QM_load_01
).map(m => m._1.S(3.W).asUInt() -> m._2)
Q := MuxLookup(io.qi.asUInt(), DontCare, QConvMap)
QM := MuxLookup(io.qi.asUInt(), DontCare, QMConvMap)
}
io.F := sqrtToCsa
io.QM := QM
io.Q := Q
FPUDebug(){
when(io.enable){
printf(p"[on the fly conv] q:${io.qi} A:${Binary(Q)} B:${Binary(QM)} \n")
}
}
}
package xiangshan.backend.fu.fpu.divsqrt
import chisel3._
import chisel3.util._
import utils._
import xiangshan.backend.fu.fpu._
class SrtTable extends Module {
val io = IO(new Bundle() {
val d = Input(UInt(3.W))
val y = Input(UInt(8.W))
val q = Output(SInt(3.W))
})
val qSelTable = Array(
Array(12, 4, -4, -13),
Array(14, 4, -5, -14),
Array(16, 4, -6, -16),
Array(16, 4, -6, -17),
Array(18, 6, -6, -18),
Array(20, 6, -8, -20),
Array(20, 8, -8, -22),
Array(24, 8, -8, -23)
).map(_.map(_ * 2))
var ge = Map[Int, Bool]()
for(row <- qSelTable){
for(k <- row){
if(!ge.contains(k)) ge = ge + (k -> (io.y.asSInt() >= k.S(8.W)))
}
}
io.q := MuxLookup(io.d, 0.S,
qSelTable.map(x =>
MuxCase((-2).S(3.W), Seq(
ge(x(0)) -> 2.S(3.W),
ge(x(1)) -> 1.S(3.W),
ge(x(2)) -> 0.S(3.W),
ge(x(3)) -> (-1).S(3.W)
))
).zipWithIndex.map({case(v, i) => i.U -> v})
)
}
package xiangshan.backend.fu.fpu.fma
import chisel3._
import chisel3.util._
import xiangshan.backend.fu.fpu.util._
import utils.SignExt
class ArrayMultiplier(len: Int, regDepth: Int = 0, realArraryMult: Boolean = false) extends Module {
val io = IO(new Bundle() {
val a, b = Input(UInt(len.W))
val reg_en = Input(Bool())
val carry, sum = Output(UInt((2*len).W))
})
val (a, b) = (io.a, io.b)
val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
b_sext := SignExt(b, len+1)
bx2 := b_sext << 1
neg_b := (~b_sext).asUInt()
neg_bx2 := neg_b << 1
val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
var last_x = WireInit(0.U(3.W))
for(i <- Range(0, len, 2)){
val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
val pp_temp = MuxLookup(x, 0.U, Seq(
1.U -> b_sext,
2.U -> b_sext,
3.U -> bx2,
4.U -> neg_bx2,
5.U -> neg_b,
6.U -> neg_b
))
val s = pp_temp(len)
val t = MuxLookup(last_x, 0.U(2.W), Seq(
4.U -> 2.U(2.W),
5.U -> 1.U(2.W),
6.U -> 1.U(2.W)
))
last_x = x
val (pp, weight) = i match {
case 0 =>
(Cat(~s, s, s, pp_temp), 0)
case n if (n==len-1) || (n==len-2) =>
(Cat(~s, pp_temp, t), i-2)
case _ =>
(Cat(1.U(1.W), ~s, pp_temp, t), i-2)
}
for(j <- columns.indices){
if(j >= weight && j < (weight + pp.getWidth)){
columns(j) = columns(j) :+ pp(j-weight)
}
}
}
def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
var sum = Seq[Bool]()
var cout1 = Seq[Bool]()
var cout2 = Seq[Bool]()
col.size match {
case 1 => // do nothing
sum = col ++ cin
case 2 =>
val c22 = Module(new C22)
c22.io.in := col
sum = c22.io.out(0).asBool() +: cin
cout2 = Seq(c22.io.out(1).asBool())
case 3 =>
val c32 = Module(new C32)
c32.io.in := col
sum = c32.io.out(0).asBool() +: cin
cout2 = Seq(c32.io.out(1).asBool())
case 4 =>
val c53 = Module(new C53)
for((x, y) <- c53.io.in.take(4) zip col){
x := y
}
c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
cout1 = Seq(c53.io.out(1).asBool())
cout2 = Seq(c53.io.out(2).asBool())
case n =>
val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
sum = s_1 ++ s_2
cout1 = c_1_1 ++ c_2_1
cout2 = c_1_2 ++ c_2_2
}
(sum, cout1, cout2)
}
def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
if(max(cols.map(_.size)) <= 2){
val sum = Cat(cols.map(_(0)).reverse)
var k = 0
while(cols(k).size == 1) k = k+1
val carry = Cat(cols.drop(k).map(_(1)).reverse)
(sum, Cat(carry, 0.U(k.W)))
} else {
val columns_next = Array.fill(2*len)(Seq[Bool]())
var cout1, cout2 = Seq[Bool]()
for( i <- cols.indices){
val (s, c1, c2) = addOneColumn(cols(i), cout1)
columns_next(i) = s ++ cout2
cout1 = c1
cout2 = c2
}
val needReg = depth == regDepth
val toNextLayer = if(needReg) columns_next.map(_.map(RegEnable(_, io.reg_en))) else columns_next
addAll(toNextLayer, depth+1)
}
}
val (sum, carry) = if(realArraryMult) addAll(cols = columns, depth = 0) else (RegEnable(a*b, io.reg_en), 0.U)
io.sum := sum
io.carry := carry
}
package xiangshan.backend.fu.fpu.fma
import chisel3._
import chisel3.util._
import xiangshan.FuType
import xiangshan.backend.fu.{CertainLatency, FuConfig, FunctionUnit}
import xiangshan.backend.fu.fpu._
import xiangshan.backend.fu.fpu.util.{CSA3_2, FPUDebug, ORTree, ShiftLeftJam, ShiftRightJam}
class FMA extends FPUPipelineModule {
override def latency = FunctionUnit.fmacCfg.latency.latencyVal.get
def UseRealArraryMult = false
def SEXP_WIDTH: Int = Float64.expWidth + 2
def D_MANT_WIDTH: Int = Float64.mantWidth + 1
def S_MANT_WIDTH: Int = Float32.mantWidth + 1
def INITIAL_EXP_DIFF: Int = Float64.mantWidth + 4
def ADD_WIDTH: Int = 3*D_MANT_WIDTH + 2
/******************************************************************
* Stage 1: Decode Operands
*****************************************************************/
val rs0 = io.in.bits.src(0)
val rs1 = io.in.bits.src(1)
val rs2 = io.in.bits.src(2)
val zero = 0.U(Float64.getWidth.W)
val one = Mux(isDouble,
Cat(0.U(1.W), Float64.expBiasInt.U(Float64.expWidth.W), 0.U(Float64.mantWidth.W)),
Cat(0.U(1.W), Float32.expBiasInt.U(Float32.expWidth.W), 0.U(Float32.mantWidth.W))
)
val a = {
val x = Mux(op(2),
rs2,
Mux(op(1),
zero,
rs1
)
)
val sign = Mux(isDouble, x.head(1), x.tail(32).head(1)).asBool() ^ op(0)
Mux(isDouble,
Cat(sign, x.tail(1)),
Cat(sign, x.tail(32).tail(1))
)
}
val b = rs0
val c = Mux(op(2,1) === 0.U, one, rs1)
val operands = Seq(a, b, c).map(x => Mux(isDouble, x, extF32ToF64(x)))
val classify = Array.fill(3)(Module(new Classify(Float64.expWidth, Float64.mantWidth)).io)
classify.zip(operands).foreach({case (cls, x) => cls.in := x})
def decode(x: UInt, isSubnormal: Bool, isZero: Bool) = {
val f64 = Float64(x)
val exp = Mux(isSubnormal,
Mux(isDouble, (-Float64.expBiasInt+1).S, (-Float32.expBiasInt+1).S),
f64.exp.toSInt - Float64.expBias.toSInt
)
val mantExt = Mux(isZero, 0.U, Cat(!isSubnormal, f64.mant))
(f64.sign, exp, mantExt)
}
val signs = Array.fill(3)(Wire(Bool()))
val exps = Array.fill(3)(Wire(SInt(SEXP_WIDTH.W)))
val mants = Array.fill(3)(Wire(UInt(D_MANT_WIDTH.W)))
for(i <- 0 until 3){
val (s, e, m) = decode(operands(i), classify(i).isSubnormal, classify(i).isZero)
signs(i) := s
exps(i) := e
mants(i) := m
}
val aIsSubnormal = classify(0).isSubnormal
val bIsSubnormal = classify(1).isSubnormal
val cIsSubnormal = classify(2).isSubnormal
val prodHasSubnormal = bIsSubnormal || cIsSubnormal
val aSign = signs(0)
val aExpRaw = exps(0)
val prodIsZero = classify.drop(1).map(_.isZero).reduce(_||_)
val aIsZero = classify.head.isZero
val prodSign = signs(1) ^ signs(2) ^ (op(2,1)==="b11".U)
val prodExpRaw = Mux(prodIsZero,
Mux(isDouble,
(-Float64.expBiasInt).S,
(-Float32.expBiasInt).S),
exps(1) + exps(2)
)
val zeroResultSign = Mux(op(2,1) === "b01".U,
prodSign,
(aSign & prodSign) | ((aSign | prodSign) & rm===RoudingMode.RDN)
)
val hasNaN = classify.map(_.isNaN).reduce(_||_)
val hasSNaN = classify.map(_.isSNaN).reduce(_||_)
val isInf = classify.map(_.isInf)
val aIsInf = isInf(0)
val prodHasInf = isInf.drop(1).reduce(_||_)
val hasInf = isInf(0) || prodHasInf
val addInfInvalid = (aIsInf & prodHasInf & (aSign ^ prodSign)) & !(aIsInf ^ prodHasInf)
val zeroMulInf = prodIsZero && prodHasInf
val infInvalid = addInfInvalid || zeroMulInf
val invalid = hasSNaN || infInvalid
val specialCaseHappen = hasNaN || hasInf
val specialOutput = PriorityMux(Seq(
(hasNaN || infInvalid) -> Mux(isDouble,
Float64.defaultNaN,
Float32.defaultNaN
),
aIsInf -> Mux(isDouble,
Cat(aSign, Float64.posInf.tail(1)),
Cat(aSign, Float32.posInf.tail(1))
),
prodHasInf -> Mux(isDouble,
Cat(prodSign, Float64.posInf.tail(1)),
Cat(prodSign, Float32.posInf.tail(1))
)
))
val prodExpAdj = prodExpRaw + INITIAL_EXP_DIFF.S
val expDiff = prodExpAdj - aExpRaw
val mult = Module(new ArrayMultiplier(D_MANT_WIDTH+1, 0, UseRealArraryMult))
mult.io.a := mants(1)
mult.io.b := mants(2)
mult.io.reg_en := io.in.fire()
val s1_isDouble = S1Reg(isDouble)
val s1_rm = S1Reg(rm)
val s1_zeroSign = S1Reg(zeroResultSign)
val s1_specialCaseHappen = S1Reg(specialCaseHappen)
val s1_specialOutput = S1Reg(specialOutput)
val s1_aSign = S1Reg(aSign)
val s1_aExpRaw = S1Reg(aExpRaw)
val s1_aMant = S1Reg(mants(0))
val s1_prodSign = S1Reg(prodSign)
val s1_prodExpAdj = S1Reg(prodExpAdj)
val s1_expDiff = S1Reg(expDiff)
val s1_discardProdMant = S1Reg(prodIsZero || expDiff.head(1).asBool()) //expDiff < 0.S
val s1_discardAMant = S1Reg(aIsZero || expDiff > (ADD_WIDTH+3).S)
val s1_invalid = S1Reg(invalid)
// FPUDebug(){
// when(valids(1) && ready){
// printf(p"[s1] prodExp+56:${s1_prodExpAdj} aExp:${s1_aExpRaw} diff:${s1_expDiff}\n")
// }
// }
/******************************************************************
* Stage 2: align A | compute product (B*C)
*****************************************************************/
val alignedAMant = Wire(UInt((ADD_WIDTH+4).W))
alignedAMant := Cat(
0.U(1.W), // sign bit
ShiftRightJam(s1_aMant, Mux(s1_discardProdMant, 0.U, s1_expDiff.asUInt()), ADD_WIDTH+3)
)
val alignedAMantNeg = -alignedAMant
val effSub = s1_prodSign ^ s1_aSign
val mul_prod = mult.io.carry.tail(1) + mult.io.sum.tail(1)
val s2_isDouble = S2Reg(s1_isDouble)
val s2_rm = S2Reg(s1_rm)
val s2_zeroSign = S2Reg(s1_zeroSign)
val s2_specialCaseHappen = S2Reg(s1_specialCaseHappen)
val s2_specialOutput = S2Reg(s1_specialOutput)
val s2_aSign = S2Reg(s1_aSign)
val s2_prodSign = S2Reg(s1_prodSign)
val s2_expPreNorm = S2Reg(Mux(s1_discardAMant || !s1_discardProdMant, s1_prodExpAdj, s1_aExpRaw))
val s2_invalid = S2Reg(s1_invalid)
val s2_prod = S2Reg(mul_prod)
val s2_aMantNeg = S2Reg(alignedAMantNeg)
val s2_aMant = S2Reg(alignedAMant)
val s2_effSub = S2Reg(effSub)
// FPUDebug(){
// when(valids(1) && ready){
// printf(p"[s2] discardAMant:${s1_discardAMant} discardProd:${s1_discardProdMant} \n")
// }
// }
/******************************************************************
* Stage 3: A + Prod => adder result
*****************************************************************/
val prodMinusA = Cat(s2_prod, 0.U(3.W)) + s2_aMantNeg
val prodMinusA_Sign = prodMinusA.head(1).asBool()
val aMinusProd = -prodMinusA
val prodAddA = Cat(s2_prod, 0.U(3.W)) + s2_aMant
val lza = Module(new LZA(ADD_WIDTH+4))
lza.io.a := s2_aMant
lza.io.b := Cat(s2_prod, 0.U(3.W))
val effSubLez = lza.io.out - 1.U
val effAddLez = PriorityEncoder(prodAddA.tail(1).asBools().reverse)
val res = Mux(s2_effSub,
Mux(prodMinusA_Sign,
aMinusProd,
prodMinusA
),
prodAddA
)
val resSign = Mux(s2_prodSign,
Mux(s2_aSign,
true.B, // -(b*c) - a
!prodMinusA_Sign // -(b*c) + a
),
Mux(s2_aSign,
prodMinusA_Sign, // b*c - a
false.B // b*c + a
)
)
val mantPreNorm = res.tail(1)
val normShift = Mux(s2_effSub, effSubLez, effAddLez)
val roundingInc = MuxLookup(s2_rm, "b10".U(2.W), Seq(
RoudingMode.RDN -> Mux(resSign, "b11".U, "b00".U),
RoudingMode.RUP -> Mux(resSign, "b00".U, "b11".U),
RoudingMode.RTZ -> "b00".U
))
val ovSetInf = rm === RoudingMode.RNE ||
rm === RoudingMode.RMM ||
(rm === RoudingMode.RDN && resSign) ||
(rm === RoudingMode.RUP && !resSign)
val s3_ovSetInf = S3Reg(ovSetInf)
val s3_roundingInc = S3Reg(roundingInc)
val s3_isDouble = S3Reg(s2_isDouble)
val s3_rm = S3Reg(s2_rm)
val s3_zeroSign = S3Reg(s2_zeroSign)
val s3_specialCaseHappen = S3Reg(s2_specialCaseHappen)
val s3_specialOutput = S3Reg(s2_specialOutput)
val s3_resSign = S3Reg(resSign)
val s3_mantPreNorm = S3Reg(mantPreNorm)
val s3_expPreNorm = S3Reg(s2_expPreNorm)
val s3_normShift = S3Reg(normShift)
val s3_invalid = S3Reg(s2_invalid)
/******************************************************************
* Stage 4: Normalize/Denormalize Shift
*****************************************************************/
val expPostNorm = s3_expPreNorm - s3_normShift.toSInt
val denormShift = Mux(
s3_isDouble,
(-Float64.expBiasInt+1).S,
(-Float32.expBiasInt+1).S
) - expPostNorm
val leftShift = s3_normShift.toSInt - Mux(denormShift.head(1).asBool(), 0.S, denormShift)
val rightShift = denormShift - s3_normShift.toSInt
val mantShifted = Mux(rightShift.head(1).asBool(), // < 0
ShiftLeftJam(s3_mantPreNorm, leftShift.asUInt(), D_MANT_WIDTH+3),
ShiftRightJam(s3_mantPreNorm, rightShift.asUInt(), D_MANT_WIDTH+3)
)
val s4_isDouble = S4Reg(s3_isDouble)
val s4_rm = S4Reg(s3_rm)
val s4_roundingInc = S4Reg(s3_roundingInc)
val s4_zeroSign = S4Reg(s3_zeroSign)
val s4_specialCaseHappen = S4Reg(s3_specialCaseHappen)
val s4_specialOutput = S4Reg(s3_specialOutput)
val s4_ovSetInf = S4Reg(s3_ovSetInf)
val s4_resSign = S4Reg(s3_resSign)
val s4_mantShifted = S4Reg(mantShifted)
val s4_denormShift = S4Reg(denormShift)
val s4_expPostNorm = S4Reg(expPostNorm)
val s4_invalid = S4Reg(s3_invalid)
// FPUDebug(){
// when(valids(3) && ready){
// printf(p"[s4] expPreNorm:${s3_expPreNorm} normShift:${s3_normShift} expPostNorm:${expPostNorm} " +
// p"denormShift:${denormShift}" +
// p"" +
// p" \n")
// }
// }
/******************************************************************
* Stage 5: Rounding
*****************************************************************/
val mantUnrounded = Mux(s4_isDouble,
s4_mantShifted.head(D_MANT_WIDTH),
s4_mantShifted.head(S_MANT_WIDTH)
)
val g = Mux(s4_isDouble,
s4_mantShifted.tail(D_MANT_WIDTH).head(1),
s4_mantShifted.tail(S_MANT_WIDTH).head(1)
).asBool()
val r = Mux(s4_isDouble,
s4_mantShifted.tail(D_MANT_WIDTH+1).head(1),
s4_mantShifted.tail(S_MANT_WIDTH+1).head(1)
).asBool()
val s = ORTree(Mux(s4_isDouble,
s4_mantShifted.tail(D_MANT_WIDTH+2),
s4_mantShifted.tail(S_MANT_WIDTH+2)
))
val rounding = Module(new RoundF64AndF32WithExceptions)
rounding.io.isDouble := s4_isDouble
rounding.io.denormShiftAmt := s4_denormShift
rounding.io.sign := s4_resSign
rounding.io.expNorm := s4_expPostNorm
rounding.io.mantWithGRS := Cat(mantUnrounded, g, r, s)
rounding.io.rm := s4_rm
rounding.io.specialCaseHappen := s4_specialCaseHappen
val isZeroResult = rounding.io.isZeroResult
val expRounded = rounding.io.expRounded
val mantRounded = rounding.io.mantRounded
val overflow = rounding.io.overflow
val underflow = rounding.io.underflow
val inexact = rounding.io.inexact
val s5_isDouble = S5Reg(s4_isDouble)
val s5_sign = S5Reg(Mux(isZeroResult, s4_zeroSign, s4_resSign))
val s5_exp = S5Reg(expRounded)
val s5_mant = S5Reg(mantRounded)
val s5_specialCaseHappen = S5Reg(s4_specialCaseHappen)
val s5_specialOutput = S5Reg(s4_specialOutput)
val s5_invalid = S5Reg(s4_invalid)
val s5_overflow = S5Reg(overflow)
val s5_underflow = S5Reg(underflow)
val s5_inexact = S5Reg(inexact)
val s5_ovSetInf = S5Reg(s4_ovSetInf)
// FPUDebug(){
// when(valids(4) && ready){
// printf(p"[s5] expPostNorm:${s4_expPostNorm} expRounded:${expRounded}\n")
// }
// }
/******************************************************************
* Assign Outputs
*****************************************************************/
val commonResult = Mux(s5_isDouble,
Cat(
s5_sign,
s5_exp(Float64.expWidth-1, 0),
s5_mant(Float64.mantWidth-1, 0)
),
Cat(
s5_sign,
s5_exp(Float32.expWidth-1, 0),
s5_mant(Float32.mantWidth-1, 0)
)
)
val result = Mux(s5_specialCaseHappen,
s5_specialOutput,
Mux(s5_overflow,
Mux(s5_isDouble,
Cat(s5_sign, Mux(s5_ovSetInf, Float64.posInf, Float64.maxNorm).tail(1)),
Cat(s5_sign, Mux(s5_ovSetInf, Float32.posInf, Float32.maxNorm).tail(1))
),
commonResult
)
)
io.out.bits.data := result
fflags.invalid := s5_invalid
fflags.inexact := s5_inexact
fflags.overflow := s5_overflow
fflags.underflow := s5_underflow
fflags.infinite := false.B
// FPUDebug(){
// //printf(p"v0:${valids(0)} v1:${valids(1)} v2:${valids(2)} v3:${valids(3)} v4:${valids(4)} v5:${valids(5)}\n")
// when(io.in.fire()){
// printf(p"[in] a:${Hexadecimal(a)} b:${Hexadecimal(b)} c:${Hexadecimal(c)}\n")
// }
// when(io.out.fire()){
// printf(p"[out] res:${Hexadecimal(io.out.bits.result)}\n")
// }
// }
}
package xiangshan.backend.fu.fpu.fma
import chisel3._
import chisel3.util._
class LzaIO(len: Int) extends Bundle {
val a, b = Input(UInt(len.W))
val out = Output(UInt(log2Up(len+1).W))
override def cloneType: LzaIO.this.type = new LzaIO(len).asInstanceOf[this.type]
}
// Leading Zero Anticipator
class LZA(len: Int) extends Module {
val io = IO(new LzaIO(len))
/** msb lsb
* 0 1 2 ... n-1
*/
val (a, b) = (io.a.asBools().reverse, io.b.asBools().reverse)
//
val g, s, e, f = Wire(Vec(len, Bool()))
for(i <- 0 until len){
g(i) := a(i) & !b(i)
s(i) := !a(i) & b(i)
e(i) := a(i) === b(i)
}
f(0) := (s(0) & !s(1)) | (g(0) & !g(1))
f(len-1) := false.B
for(i <- 1 until len-1){
f(i) := (e(i-1) & g(i) & !s(i+1)) |
(!e(i-1) & s(i) & !s(i+1)) |
(e(i-1) & s(i) & !g(i+1)) |
(!e(i-1) & g(i) & !g(i+1))
}
val res = PriorityEncoder(f)
val p, n, z = Wire(Vec(len, Bool()))
p(0) := g(0)
n(0) := s(0)
p(1) := g(1)
n(1) := s(1)
for(i <- 2 until len){
p(i) := (e(i-1) | e(i-2) & g(i-1) | !e(i-2) & s(i-1)) & g(i)
n(i) := (e(i-1) | e(i-2) & s(i-1) | !e(i-2) & g(i-1)) & s(i)
}
for(i <- 0 until len){
z(i) := !(p(i) | n(i))
}
class TreeNode extends Bundle {
val Z, P, N = Bool()
}
def buildOneLevel(nodes: Seq[TreeNode]): Seq[TreeNode] = {
nodes match {
case Seq(_) => nodes
case Seq(_, _) => nodes
case Seq(left, mid, right) =>
val next_l, next_r = Wire(new TreeNode)
next_l.P := left.P | left.Z & mid.P
next_l.N := left.N | left.Z & mid.N
next_l.Z := left.Z & mid.Z
next_r.P := !left.Z & mid.P | right.P & (left.Z | mid.Z)
next_r.N := !left.Z & mid.N | right.N & (left.Z | mid.Z)
next_r.Z := right.Z & (left.Z | mid.Z)
Seq(next_l, next_r)
case _ =>
buildOneLevel(nodes.take(3)) ++ buildOneLevel(nodes.drop(3))
}
}
def detectionTree(nodes: Seq[TreeNode]): Bool = {
assert(nodes.size >= 2)
nodes match {
case Seq(left, right) =>
left.P & right.N | left.N & right.P
case _ =>
val nextLevel = buildOneLevel(nodes)
detectionTree(nextLevel)
}
}
val nodes = (0 until len).map(i => {
val treeNode = Wire(new TreeNode)
treeNode.P := p(i)
treeNode.N := n(i)
treeNode.Z := z(i)
treeNode
})
val error = detectionTree(nodes)
io.out := res + error
}
package xiangshan.backend.fu.fpu
import chisel3._
import chisel3.util._
object FPUOpType {
def funcWidth = 6
def FpuOp(fu: String, op: String): UInt = ("b" + fu + op).U(funcWidth.W)
def FU_FMAC = "000"
def FU_FCMP = "001"
def FU_FMV = "010"
def FU_F2I = "011"
def FU_I2F = "100"
def FU_S2D = "101"
def FU_D2S = "110"
def FU_DIVSQRT = "111"
// FMA
def fadd:UInt = FpuOp(FU_FMAC, "000")
def fsub:UInt = FpuOp(FU_FMAC, "001")
def fmadd:UInt = FpuOp(FU_FMAC, "100")
def fmsub:UInt = FpuOp(FU_FMAC, "101")
def fnmsub:UInt = FpuOp(FU_FMAC, "110")
def fnmadd:UInt = FpuOp(FU_FMAC, "111")
def fmul:UInt = FpuOp(FU_FMAC, "010")
// FCMP
def fmin:UInt = FpuOp(FU_FCMP, "000")
def fmax:UInt = FpuOp(FU_FCMP, "001")
def fle:UInt = FpuOp(FU_FCMP, "010")
def flt:UInt = FpuOp(FU_FCMP, "011")
def feq:UInt = FpuOp(FU_FCMP, "100")
// FMV
def fmv_f2i:UInt= FpuOp(FU_FMV, "000")
def fmv_i2f:UInt= FpuOp(FU_FMV, "001")
def fclass:UInt = FpuOp(FU_FMV, "010")
def fsgnj:UInt = FpuOp(FU_FMV, "110")
def fsgnjn:UInt = FpuOp(FU_FMV, "101")
def fsgnjx:UInt = FpuOp(FU_FMV, "100")
// FloatToInt
def f2w:UInt = FpuOp(FU_F2I, "000")
def f2wu:UInt = FpuOp(FU_F2I, "001")
def f2l:UInt = FpuOp(FU_F2I, "010")
def f2lu:UInt = FpuOp(FU_F2I, "011")
// IntToFloat
def w2f:UInt = FpuOp(FU_I2F, "000")
def wu2f:UInt = FpuOp(FU_I2F, "001")
def l2f:UInt = FpuOp(FU_I2F, "010")
def lu2f:UInt = FpuOp(FU_I2F, "011")
// FloatToFloat
def s2d:UInt = FpuOp(FU_S2D, "000")
def d2s:UInt = FpuOp(FU_D2S, "000")
// Div/Sqrt
def fdiv:UInt = FpuOp(FU_DIVSQRT, "000")
def fsqrt:UInt = FpuOp(FU_DIVSQRT, "001")
}
object FPUIOFunc {
def in_raw = 0.U(1.W)
def in_unbox = 1.U(1.W)
def out_raw = 0.U(2.W)
def out_box = 1.U(2.W)
def out_sext = 2.U(2.W)
def out_zext = 3.U(2.W)
def apply(inputFunc: UInt, outputFunc:UInt) = Cat(inputFunc, outputFunc)
}
class Fflags extends Bundle {
val invalid = Bool() // 4
val infinite = Bool() // 3
val overflow = Bool() // 2
val underflow = Bool() // 1
val inexact = Bool() // 0
}
object RoudingMode {
val RNE = "b000".U(3.W)
val RTZ = "b001".U(3.W)
val RDN = "b010".U(3.W)
val RUP = "b011".U(3.W)
val RMM = "b100".U(3.W)
}
class FloatPoint(val expWidth: Int, val mantWidth:Int) extends Bundle{
val sign = Bool()
val exp = UInt(expWidth.W)
val mant = UInt(mantWidth.W)
def defaultNaN: UInt = Cat(0.U(1.W), Fill(expWidth+1,1.U(1.W)), Fill(mantWidth-1,0.U(1.W)))
def posInf: UInt = Cat(0.U(1.W), Fill(expWidth, 1.U(1.W)), 0.U(mantWidth.W))
def negInf: UInt = Cat(1.U(1.W), posInf.tail(1))
def maxNorm: UInt = Cat(0.U(1.W), Fill(expWidth-1, 1.U(1.W)), 0.U(1.W), Fill(mantWidth, 1.U(1.W)))
def expBias: UInt = Fill(expWidth-1, 1.U(1.W))
def expBiasInt: Int = (1 << (expWidth-1)) - 1
def mantExt: UInt = Cat(exp=/=0.U, mant)
def apply(x: UInt): FloatPoint = x.asTypeOf(new FloatPoint(expWidth, mantWidth))
}
object Float32 extends FloatPoint(8, 23)
object Float64 extends FloatPoint(11, 52)
object expOverflow {
def apply(sexp: SInt, expWidth: Int): Bool =
sexp >= Cat(0.U(1.W), Fill(expWidth, 1.U(1.W))).asSInt()
def apply(uexp: UInt, expWidth: Int): Bool =
expOverflow(Cat(0.U(1.W), uexp).asSInt(), expWidth)
}
object boxF32ToF64 {
def apply(x: UInt): UInt = Cat(Fill(32, 1.U(1.W)), x(31, 0))
}
object unboxF64ToF32 {
def apply(x: UInt): UInt =
Mux(x(63, 32)===Fill(32, 1.U(1.W)), x(31, 0), Float32.defaultNaN)
}
object extF32ToF64 {
def apply(x: UInt): UInt = {
val f32 = Float32(x)
Cat(
f32.sign,
Mux(f32.exp === 0.U,
0.U(Float64.expWidth.W),
Mux((~f32.exp).asUInt() === 0.U,
Cat("b111".U(3.W), f32.exp),
Cat("b0111".U(4.W) + f32.exp.head(1), f32.exp.tail(1))
)
),
Cat(f32.mant, 0.U((Float64.mantWidth - Float32.mantWidth).W))
)
}
}
package xiangshan.backend.fu.fpu.util
import chisel3._
object FPUDebug {
// don't care GTimer in FPU tests
def apply(flag: Boolean = false, cond: Bool = true.B)(body: => Unit): Any =
if (flag) { when (cond) { body } }
}
package xiangshan.backend.fu.fpu.util
import chisel3._
object ORTree {
def apply(x: Seq[Bool]): Bool = {
// Is 'x =/= 0' enough ?
x.size match {
case 1 => x.head
case n => ORTree(x.take(n/2)) | ORTree(x.drop(n/2))
}
}
def apply[T <: Bits](x: T): Bool = {
apply(x.asBools())
}
}
package xiangshan.backend.fu.fpu.util
import chisel3._
import chisel3.util._
object ShiftLeftJam {
def apply(x: UInt, shiftAmt: UInt, w:Int): UInt = {
val xLen = if(x.getWidth < w) w else x.getWidth
val x_shifted = Wire(UInt(xLen.W))
x_shifted := Mux(shiftAmt > (xLen-1).U,
0.U,
x << shiftAmt(log2Up(xLen)-1, 0)
)
val sticky = ORTree(x_shifted.tail(w))
x_shifted.head(w) | sticky
}
}
package xiangshan.backend.fu.fpu.util
import chisel3._
import chisel3.util._
object ShiftRightJam {
def apply(x: UInt, shiftAmt:UInt, w: Int): UInt ={
val xLen = if(x.getWidth < w) w else x.getWidth
val x_ext = Wire(UInt(xLen.W))
x_ext := (if(x.getWidth < w) Cat(x, 0.U((w-x.getWidth).W)) else x)
val realShiftAmt = Mux(shiftAmt > (w-1).U,
w.U,
shiftAmt(log2Up(w) - 1, 0)
)
val mask = ((-1).S(xLen.W).asUInt() >> (w.U - realShiftAmt)).asUInt()
val sticky = ORTree(mask & x_ext)
val x_shifted = Wire(UInt(xLen.W))
x_shifted := x_ext >> realShiftAmt
x_shifted.head(w) | sticky
}
}
package xiangshan.backend.fu.fpu.util package xiangshan.backend.fu.util
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
......
...@@ -329,7 +329,7 @@ class ReservationStationData ...@@ -329,7 +329,7 @@ class ReservationStationData
// listen to write back data bus(certain latency) // listen to write back data bus(certain latency)
// and extra wrtie back(uncertan latency) // and extra wrtie back(uncertan latency)
val writeBackedData = Vec(wakeupCnt, Input(UInt(XLEN.W))) val writeBackedData = Vec(wakeupCnt, Input(UInt((XLEN+1).W)))
val extraListenPorts = Vec(extraListenPortsCnt, Flipped(ValidIO(new ExuOutput))) val extraListenPorts = Vec(extraListenPortsCnt, Flipped(ValidIO(new ExuOutput)))
// tlb feedback // tlb feedback
...@@ -337,7 +337,7 @@ class ReservationStationData ...@@ -337,7 +337,7 @@ class ReservationStationData
}) })
val uop = Reg(Vec(iqSize, new MicroOp)) val uop = Reg(Vec(iqSize, new MicroOp))
val data = Reg(Vec(iqSize, Vec(srcNum, UInt(XLEN.W)))) val data = Reg(Vec(iqSize, Vec(srcNum, UInt((XLEN+1).W))))
// TODO: change srcNum // TODO: change srcNum
......
...@@ -124,9 +124,6 @@ package object backend { ...@@ -124,9 +124,6 @@ package object backend {
def isLoad(op: UInt): Bool = !op(3) def isLoad(op: UInt): Bool = !op(3)
def isStore(op: UInt): Bool = op(3) def isStore(op: UInt): Bool = op(3)
// float/double load store
def flw = "b010110".U
// atomics // atomics
// bit(1, 0) are size // bit(1, 0) are size
// since atomics use a different fu type // since atomics use a different fu type
......
...@@ -54,7 +54,11 @@ class Regfile ...@@ -54,7 +54,11 @@ class Regfile
) )
val debugArchReg = WireInit(VecInit(debugArchRat.zipWithIndex.map( val debugArchReg = WireInit(VecInit(debugArchRat.zipWithIndex.map(
x => if(hasZero && x._2==0) 0.U else mem(x._1) x => if(hasZero){
if(x._2 == 0) 0.U else mem(x._1)
} else {
ieee(mem(x._1))
}
))) )))
ExcitingUtils.addSource( ExcitingUtils.addSource(
debugArchReg, debugArchReg,
......
...@@ -6,7 +6,6 @@ import chisel3.util._ ...@@ -6,7 +6,6 @@ import chisel3.util._
import xiangshan._ import xiangshan._
import utils._ import utils._
import xiangshan.backend.LSUOpType import xiangshan.backend.LSUOpType
import xiangshan.backend.fu.fpu.Fflags
import xiangshan.mem.{LqPtr, SqPtr} import xiangshan.mem.{LqPtr, SqPtr}
object roqDebugId extends Function0[Integer] { object roqDebugId extends Function0[Integer] {
...@@ -37,7 +36,7 @@ class RoqCSRIO extends XSBundle { ...@@ -37,7 +36,7 @@ class RoqCSRIO extends XSBundle {
val intrBitSet = Input(Bool()) val intrBitSet = Input(Bool())
val trapTarget = Input(UInt(VAddrBits.W)) val trapTarget = Input(UInt(VAddrBits.W))
val fflags = Output(new Fflags) val fflags = Output(Valid(UInt(5.W)))
val dirty_fs = Output(Bool()) val dirty_fs = Output(Bool())
} }
...@@ -50,19 +49,7 @@ class RoqEnqIO extends XSBundle { ...@@ -50,19 +49,7 @@ class RoqEnqIO extends XSBundle {
val resp = Vec(RenameWidth, Output(new RoqPtr)) val resp = Vec(RenameWidth, Output(new RoqPtr))
} }
class RoqDispatchData extends XSBundle { class RoqDispatchData extends RoqCommitInfo {
// commit info
val ldest = UInt(5.W)
val rfWen = Bool()
val fpWen = Bool()
val commitType = CommitType()
val pdest = UInt(PhyRegIdxWidth.W)
val old_pdest = UInt(PhyRegIdxWidth.W)
val lqIdx = new LqPtr
val sqIdx = new SqPtr
// exception info
val pc = UInt(VAddrBits.W)
val crossPageIPFFix = Bool() val crossPageIPFFix = Bool()
val exceptionVec = Vec(16, Bool()) val exceptionVec = Vec(16, Bool())
} }
...@@ -70,7 +57,7 @@ class RoqDispatchData extends XSBundle { ...@@ -70,7 +57,7 @@ class RoqDispatchData extends XSBundle {
class RoqWbData extends XSBundle { class RoqWbData extends XSBundle {
// mostly for exceptions // mostly for exceptions
val exceptionVec = Vec(16, Bool()) val exceptionVec = Vec(16, Bool())
val fflags = new Fflags val fflags = UInt(5.W)
val flushPipe = Bool() val flushPipe = Bool()
} }
...@@ -324,6 +311,7 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper { ...@@ -324,6 +311,7 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
} }
} }
} }
// debug info for enqueue (dispatch) // debug info for enqueue (dispatch)
val dispatchNum = Mux(io.enq.canAccept, PopCount(Cat(io.enq.req.map(_.valid))), 0.U) val dispatchNum = Mux(io.enq.canAccept, PopCount(Cat(io.enq.req.map(_.valid))), 0.U)
XSDebug(p"(ready, valid): ${io.enq.canAccept}, ${Binary(Cat(io.enq.req.map(_.valid)))}\n") XSDebug(p"(ready, valid): ${io.enq.canAccept}, ${Binary(Cat(io.enq.req.map(_.valid)))}\n")
...@@ -404,8 +392,17 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper { ...@@ -404,8 +392,17 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
val usedSpaceForMPR = Reg(Vec(RenameWidth, Bool())) val usedSpaceForMPR = Reg(Vec(RenameWidth, Bool()))
// wiring to csr // wiring to csr
val fflags = WireInit(0.U.asTypeOf(new Fflags)) val (wflags, fpWen) = (0 until CommitWidth).map(i => {
val dirty_fs = Mux(io.commits.isWalk, false.B, Cat(io.commits.valid.zip(io.commits.info.map(_.fpWen)).map{case (v, w) => v & w}).orR) val v = io.commits.valid(i)
val info = io.commits.info(i)
(v & info.wflags, v & info.fpWen)
}).unzip
val fflags = Wire(Valid(UInt(5.W)))
fflags.valid := Mux(io.commits.isWalk, false.B, Cat(wflags).orR())
fflags.bits := wflags.zip(writebackDataRead.map(_.fflags)).map({
case (w, f) => Mux(w, f, 0.U)
}).reduce(_|_)
val dirty_fs = Mux(io.commits.isWalk, false.B, Cat(fpWen).orR())
io.commits.isWalk := state =/= s_idle io.commits.isWalk := state =/= s_idle
val commit_v = Mux(state === s_idle, VecInit(deqPtrVec.map(ptr => valid(ptr.value))), VecInit(walkPtrVec.map(ptr => valid(ptr.value)))) val commit_v = Mux(state === s_idle, VecInit(deqPtrVec.map(ptr => valid(ptr.value))), VecInit(walkPtrVec.map(ptr => valid(ptr.value))))
...@@ -419,12 +416,6 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper { ...@@ -419,12 +416,6 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
io.commits.valid(i) := commit_v(i) && commit_w(i) && !isBlocked && !commit_exception(i) io.commits.valid(i) := commit_v(i) && commit_w(i) && !isBlocked && !commit_exception(i)
io.commits.info(i) := dispatchDataRead(i) io.commits.info(i) := dispatchDataRead(i)
when (state === s_idle) {
when (io.commits.valid(i) && writebackDataRead(i).fflags.asUInt.orR()) {
fflags := writebackDataRead(i).fflags
}
}
when (state === s_walk) { when (state === s_walk) {
io.commits.valid(i) := commit_v(i) && shouldWalkVec(i) io.commits.valid(i) := commit_v(i) && shouldWalkVec(i)
}.elsewhen(state === s_extrawalk) { }.elsewhen(state === s_extrawalk) {
...@@ -475,6 +466,7 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper { ...@@ -475,6 +466,7 @@ class Roq(numWbPorts: Int) extends XSModule with HasCircularQueuePtrHelper {
wdata.ldest := req.ctrl.ldest wdata.ldest := req.ctrl.ldest
wdata.rfWen := req.ctrl.rfWen wdata.rfWen := req.ctrl.rfWen
wdata.fpWen := req.ctrl.fpWen wdata.fpWen := req.ctrl.fpWen
wdata.wflags := req.ctrl.fpu.wflags
wdata.commitType := req.ctrl.commitType wdata.commitType := req.ctrl.commitType
wdata.pdest := req.pdest wdata.pdest := req.pdest
wdata.old_pdest := req.old_pdest wdata.old_pdest := req.old_pdest
......
...@@ -35,7 +35,7 @@ class LsPipelineBundle extends XSBundle { ...@@ -35,7 +35,7 @@ class LsPipelineBundle extends XSBundle {
val paddr = UInt(PAddrBits.W) val paddr = UInt(PAddrBits.W)
val func = UInt(6.W) //fixme??? val func = UInt(6.W) //fixme???
val mask = UInt(8.W) val mask = UInt(8.W)
val data = UInt(XLEN.W) val data = UInt((XLEN+1).W)
val uop = new MicroOp val uop = new MicroOp
val miss = Bool() val miss = Bool()
...@@ -59,4 +59,4 @@ class LoadForwardQueryIO extends XSBundle { ...@@ -59,4 +59,4 @@ class LoadForwardQueryIO extends XSBundle {
// val lqIdx = Output(UInt(LoadQueueIdxWidth.W)) // val lqIdx = Output(UInt(LoadQueueIdxWidth.W))
val sqIdx = Output(new SqPtr) val sqIdx = Output(new SqPtr)
} }
\ No newline at end of file
...@@ -244,7 +244,7 @@ class LsqWrappper extends XSModule with HasDCacheParameters { ...@@ -244,7 +244,7 @@ class LsqWrappper extends XSModule with HasDCacheParameters {
val loadIn = Vec(LoadPipelineWidth, Flipped(Valid(new LsPipelineBundle))) val loadIn = Vec(LoadPipelineWidth, Flipped(Valid(new LsPipelineBundle)))
val storeIn = Vec(StorePipelineWidth, Flipped(Valid(new LsPipelineBundle))) val storeIn = Vec(StorePipelineWidth, Flipped(Valid(new LsPipelineBundle)))
val sbuffer = Vec(StorePipelineWidth, Decoupled(new DCacheWordReq)) val sbuffer = Vec(StorePipelineWidth, Decoupled(new DCacheWordReq))
val ldout = Vec(2, DecoupledIO(new ExuOutput)) // writeback store val ldout = Vec(2, DecoupledIO(new ExuOutput)) // writeback int load
val mmioStout = DecoupledIO(new ExuOutput) // writeback uncached store val mmioStout = DecoupledIO(new ExuOutput) // writeback uncached store
val forward = Vec(LoadPipelineWidth, Flipped(new LoadForwardQueryIO)) val forward = Vec(LoadPipelineWidth, Flipped(new LoadForwardQueryIO))
val commits = Flipped(new RoqCommitIO) val commits = Flipped(new RoqCommitIO)
......
...@@ -2,14 +2,14 @@ package xiangshan.mem ...@@ -2,14 +2,14 @@ package xiangshan.mem
import chisel3._ import chisel3._
import chisel3.util._ import chisel3.util._
import freechips.rocketchip.tile.HasFPUParameters
import utils._ import utils._
import xiangshan._ import xiangshan._
import xiangshan.cache._ import xiangshan.cache._
import xiangshan.cache.{DCacheWordIO, DCacheLineIO, TlbRequestIO, MemoryOpConstants} import xiangshan.cache.{DCacheLineIO, DCacheWordIO, MemoryOpConstants, TlbRequestIO}
import xiangshan.backend.LSUOpType import xiangshan.backend.LSUOpType
import xiangshan.mem._ import xiangshan.mem._
import xiangshan.backend.roq.RoqPtr import xiangshan.backend.roq.RoqPtr
import xiangshan.backend.fu.fpu.boxF32ToF64
class LqPtr extends CircularQueuePtr(LqPtr.LoadQueueSize) { } class LqPtr extends CircularQueuePtr(LqPtr.LoadQueueSize) { }
...@@ -23,6 +23,28 @@ object LqPtr extends HasXSParameter { ...@@ -23,6 +23,28 @@ object LqPtr extends HasXSParameter {
} }
} }
trait HasLoadHelper { this: XSModule =>
def rdataHelper(uop: MicroOp, rdata: UInt): UInt = {
val fpWen = uop.ctrl.fpWen
LookupTree(uop.ctrl.fuOpType, List(
LSUOpType.lb -> SignExt(rdata(7, 0) , XLEN),
LSUOpType.lh -> SignExt(rdata(15, 0), XLEN),
LSUOpType.lw -> Mux(fpWen, rdata, SignExt(rdata(31, 0), XLEN)),
LSUOpType.ld -> Mux(fpWen, rdata, SignExt(rdata(63, 0), XLEN)),
LSUOpType.lbu -> ZeroExt(rdata(7, 0) , XLEN),
LSUOpType.lhu -> ZeroExt(rdata(15, 0), XLEN),
LSUOpType.lwu -> ZeroExt(rdata(31, 0), XLEN),
))
}
def fpRdataHelper(uop: MicroOp, rdata: UInt): UInt = {
LookupTree(uop.ctrl.fuOpType, List(
LSUOpType.lw -> recode(rdata(31, 0), S),
LSUOpType.ld -> recode(rdata(63, 0), D)
))
}
}
class LqEnqIO extends XSBundle { class LqEnqIO extends XSBundle {
val canAccept = Output(Bool()) val canAccept = Output(Bool())
val sqCanAccept = Input(Bool()) val sqCanAccept = Input(Bool())
...@@ -32,13 +54,17 @@ class LqEnqIO extends XSBundle { ...@@ -32,13 +54,17 @@ class LqEnqIO extends XSBundle {
} }
// Load Queue // Load Queue
class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueuePtrHelper { class LoadQueue extends XSModule
with HasDCacheParameters
with HasCircularQueuePtrHelper
with HasLoadHelper
{
val io = IO(new Bundle() { val io = IO(new Bundle() {
val enq = new LqEnqIO val enq = new LqEnqIO
val brqRedirect = Input(Valid(new Redirect)) val brqRedirect = Input(Valid(new Redirect))
val loadIn = Vec(LoadPipelineWidth, Flipped(Valid(new LsPipelineBundle))) val loadIn = Vec(LoadPipelineWidth, Flipped(Valid(new LsPipelineBundle)))
val storeIn = Vec(StorePipelineWidth, Flipped(Valid(new LsPipelineBundle))) // FIXME: Valid() only val storeIn = Vec(StorePipelineWidth, Flipped(Valid(new LsPipelineBundle))) // FIXME: Valid() only
val ldout = Vec(2, DecoupledIO(new ExuOutput)) // writeback load val ldout = Vec(2, DecoupledIO(new ExuOutput)) // writeback int load
val load_s1 = Vec(LoadPipelineWidth, Flipped(new LoadForwardQueryIO)) val load_s1 = Vec(LoadPipelineWidth, Flipped(new LoadForwardQueryIO))
val commits = Flipped(new RoqCommitIO) val commits = Flipped(new RoqCommitIO)
val rollback = Output(Valid(new Redirect)) // replay now starts from load instead of store val rollback = Output(Valid(new Redirect)) // replay now starts from load instead of store
...@@ -274,7 +300,8 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP ...@@ -274,7 +300,8 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
(0 until StorePipelineWidth).map(i => { (0 until StorePipelineWidth).map(i => {
// data select // data select
val rdata = dataModule.io.rdata(loadWbSel(i)).data val rdata = dataModule.io.rdata(loadWbSel(i)).data
val func = uop(loadWbSel(i)).ctrl.fuOpType val seluop = uop(loadWbSel(i))
val func = seluop.ctrl.fuOpType
val raddr = dataModule.io.rdata(loadWbSel(i)).paddr val raddr = dataModule.io.rdata(loadWbSel(i)).paddr
val rdataSel = LookupTree(raddr(2, 0), List( val rdataSel = LookupTree(raddr(2, 0), List(
"b000".U -> rdata(63, 0), "b000".U -> rdata(63, 0),
...@@ -286,17 +313,14 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP ...@@ -286,17 +313,14 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
"b110".U -> rdata(63, 48), "b110".U -> rdata(63, 48),
"b111".U -> rdata(63, 56) "b111".U -> rdata(63, 56)
)) ))
val rdataPartialLoad = LookupTree(func, List( val rdataPartialLoad = rdataHelper(seluop, rdataSel)
LSUOpType.lb -> SignExt(rdataSel(7, 0) , XLEN),
LSUOpType.lh -> SignExt(rdataSel(15, 0), XLEN), val validWb = loadWbSelVec(loadWbSel(i)) && loadWbSelV(i)
LSUOpType.lw -> SignExt(rdataSel(31, 0), XLEN),
LSUOpType.ld -> SignExt(rdataSel(63, 0), XLEN), // writeback missed int/fp load
LSUOpType.lbu -> ZeroExt(rdataSel(7, 0) , XLEN), //
LSUOpType.lhu -> ZeroExt(rdataSel(15, 0), XLEN), // Int load writeback will finish (if not blocked) in one cycle
LSUOpType.lwu -> ZeroExt(rdataSel(31, 0), XLEN), io.ldout(i).bits.uop := seluop
LSUOpType.flw -> boxF32ToF64(rdataSel(31, 0))
))
io.ldout(i).bits.uop := uop(loadWbSel(i))
io.ldout(i).bits.uop.cf.exceptionVec := dataModule.io.rdata(loadWbSel(i)).exception.asBools io.ldout(i).bits.uop.cf.exceptionVec := dataModule.io.rdata(loadWbSel(i)).exception.asBools
io.ldout(i).bits.uop.lqIdx := loadWbSel(i).asTypeOf(new LqPtr) io.ldout(i).bits.uop.lqIdx := loadWbSel(i).asTypeOf(new LqPtr)
io.ldout(i).bits.data := rdataPartialLoad io.ldout(i).bits.data := rdataPartialLoad
...@@ -305,10 +329,14 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP ...@@ -305,10 +329,14 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
io.ldout(i).bits.brUpdate := DontCare io.ldout(i).bits.brUpdate := DontCare
io.ldout(i).bits.debug.isMMIO := dataModule.io.rdata(loadWbSel(i)).mmio io.ldout(i).bits.debug.isMMIO := dataModule.io.rdata(loadWbSel(i)).mmio
io.ldout(i).bits.fflags := DontCare io.ldout(i).bits.fflags := DontCare
io.ldout(i).valid := loadWbSelVec(loadWbSel(i)) && loadWbSelV(i) io.ldout(i).valid := validWb
when(io.ldout(i).fire()) {
when(io.ldout(i).fire()){
writebacked(loadWbSel(i)) := true.B writebacked(loadWbSel(i)) := true.B
XSInfo("load miss write to cbd roqidx %d lqidx %d pc 0x%x paddr %x data %x mmio %x\n", }
when(io.ldout(i).fire()) {
XSInfo("int load miss write to cbd roqidx %d lqidx %d pc 0x%x paddr %x data %x mmio %x\n",
io.ldout(i).bits.uop.roqIdx.asUInt, io.ldout(i).bits.uop.roqIdx.asUInt,
io.ldout(i).bits.uop.lqIdx.asUInt, io.ldout(i).bits.uop.lqIdx.asUInt,
io.ldout(i).bits.uop.cf.pc, io.ldout(i).bits.uop.cf.pc,
...@@ -317,6 +345,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP ...@@ -317,6 +345,7 @@ class LoadQueue extends XSModule with HasDCacheParameters with HasCircularQueueP
dataModule.io.rdata(loadWbSel(i)).mmio dataModule.io.rdata(loadWbSel(i)).mmio
) )
} }
}) })
/** /**
......
...@@ -7,7 +7,6 @@ import xiangshan._ ...@@ -7,7 +7,6 @@ import xiangshan._
import xiangshan.cache._ import xiangshan.cache._
// import xiangshan.cache.{DCacheWordIO, TlbRequestIO, TlbCmd, MemoryOpConstants, TlbReq, DCacheLoadReq, DCacheWordResp} // import xiangshan.cache.{DCacheWordIO, TlbRequestIO, TlbCmd, MemoryOpConstants, TlbReq, DCacheLoadReq, DCacheWordResp}
import xiangshan.backend.LSUOpType import xiangshan.backend.LSUOpType
import xiangshan.backend.fu.fpu.boxF32ToF64
class LoadToLsqIO extends XSBundle { class LoadToLsqIO extends XSBundle {
val loadIn = ValidIO(new LsPipelineBundle) val loadIn = ValidIO(new LsPipelineBundle)
...@@ -133,7 +132,7 @@ class LoadUnit_S1 extends XSModule { ...@@ -133,7 +132,7 @@ class LoadUnit_S1 extends XSModule {
// Load Pipeline Stage 2 // Load Pipeline Stage 2
// DCache resp // DCache resp
class LoadUnit_S2 extends XSModule { class LoadUnit_S2 extends XSModule with HasLoadHelper {
val io = IO(new Bundle() { val io = IO(new Bundle() {
val in = Flipped(Decoupled(new LsPipelineBundle)) val in = Flipped(Decoupled(new LsPipelineBundle))
val out = Decoupled(new LsPipelineBundle) val out = Decoupled(new LsPipelineBundle)
...@@ -175,16 +174,7 @@ class LoadUnit_S2 extends XSModule { ...@@ -175,16 +174,7 @@ class LoadUnit_S2 extends XSModule {
"b110".U -> rdata(63, 48), "b110".U -> rdata(63, 48),
"b111".U -> rdata(63, 56) "b111".U -> rdata(63, 56)
)) ))
val rdataPartialLoad = LookupTree(s2_uop.ctrl.fuOpType, List( val rdataPartialLoad = rdataHelper(s2_uop, rdataSel)
LSUOpType.lb -> SignExt(rdataSel(7, 0) , XLEN),
LSUOpType.lh -> SignExt(rdataSel(15, 0), XLEN),
LSUOpType.lw -> SignExt(rdataSel(31, 0), XLEN),
LSUOpType.ld -> SignExt(rdataSel(63, 0), XLEN),
LSUOpType.lbu -> ZeroExt(rdataSel(7, 0) , XLEN),
LSUOpType.lhu -> ZeroExt(rdataSel(15, 0), XLEN),
LSUOpType.lwu -> ZeroExt(rdataSel(31, 0), XLEN),
LSUOpType.flw -> boxF32ToF64(rdataSel(31, 0))
))
// TODO: ECC check // TODO: ECC check
...@@ -218,13 +208,13 @@ class LoadUnit_S2 extends XSModule { ...@@ -218,13 +208,13 @@ class LoadUnit_S2 extends XSModule {
s2_uop.cf.pc, rdataPartialLoad, io.dcacheResp.bits.data, s2_uop.cf.pc, rdataPartialLoad, io.dcacheResp.bits.data,
io.out.bits.forwardData.asUInt, io.out.bits.forwardMask.asUInt io.out.bits.forwardData.asUInt, io.out.bits.forwardMask.asUInt
) )
} }
class LoadUnit extends XSModule { class LoadUnit extends XSModule with HasLoadHelper {
val io = IO(new Bundle() { val io = IO(new Bundle() {
val ldin = Flipped(Decoupled(new ExuInput)) val ldin = Flipped(Decoupled(new ExuInput))
val ldout = Decoupled(new ExuOutput) val ldout = Decoupled(new ExuOutput)
val fpout = Decoupled(new ExuOutput)
val redirect = Flipped(ValidIO(new Redirect)) val redirect = Flipped(ValidIO(new Redirect))
val tlbFeedback = ValidIO(new TlbFeedback) val tlbFeedback = ValidIO(new TlbFeedback)
val dcache = new DCacheLoadIO val dcache = new DCacheLoadIO
...@@ -267,33 +257,49 @@ class LoadUnit extends XSModule { ...@@ -267,33 +257,49 @@ class LoadUnit extends XSModule {
// writeback to LSQ // writeback to LSQ
// Current dcache use MSHR // Current dcache use MSHR
// Load queue will be updated at s2 for both hit/miss int/fp load
io.lsq.loadIn.valid := load_s2.io.out.valid io.lsq.loadIn.valid := load_s2.io.out.valid
io.lsq.loadIn.bits := load_s2.io.out.bits io.lsq.loadIn.bits := load_s2.io.out.bits
val s2Valid = load_s2.io.out.valid && (!load_s2.io.out.bits.miss || load_s2.io.out.bits.uop.cf.exceptionVec.asUInt.orR)
val refillFpLoad = io.lsq.ldout.bits.uop.ctrl.fpWen
// Int load, if hit, will be writebacked at s2
val intHitLoadOut = Wire(Valid(new ExuOutput))
intHitLoadOut.valid := s2Valid && !load_s2.io.out.bits.uop.ctrl.fpWen
intHitLoadOut.bits.uop := load_s2.io.out.bits.uop
intHitLoadOut.bits.data := load_s2.io.out.bits.data
intHitLoadOut.bits.redirectValid := false.B
intHitLoadOut.bits.redirect := DontCare
intHitLoadOut.bits.brUpdate := DontCare
intHitLoadOut.bits.debug.isMMIO := load_s2.io.out.bits.mmio
intHitLoadOut.bits.fflags := DontCare
val hitLoadOut = Wire(Valid(new ExuOutput))
hitLoadOut.valid := load_s2.io.out.valid && (!load_s2.io.out.bits.miss || load_s2.io.out.bits.uop.cf.exceptionVec.asUInt.orR)
hitLoadOut.bits.uop := load_s2.io.out.bits.uop
hitLoadOut.bits.data := load_s2.io.out.bits.data
hitLoadOut.bits.redirectValid := false.B
hitLoadOut.bits.redirect := DontCare
hitLoadOut.bits.brUpdate := DontCare
hitLoadOut.bits.debug.isMMIO := load_s2.io.out.bits.mmio
hitLoadOut.bits.fflags := DontCare
// TODO: arbiter
// if hit, writeback result to CDB
// val ldout = Vec(2, Decoupled(new ExuOutput))
// when io.loadIn(i).fire() && !io.io.loadIn(i).miss, commit load to cdb
// val cdbArb = Module(new Arbiter(new ExuOutput, 2))
// io.ldout <> cdbArb.io.out
// hitLoadOut <> cdbArb.io.in(0)
// io.lsq.ldout <> cdbArb.io.in(1) // missLoadOut
load_s2.io.out.ready := true.B load_s2.io.out.ready := true.B
io.lsq.ldout.ready := !hitLoadOut.valid
io.ldout.bits := Mux(hitLoadOut.valid, hitLoadOut.bits, io.lsq.ldout.bits) io.ldout.bits := Mux(intHitLoadOut.valid, intHitLoadOut.bits, io.lsq.ldout.bits)
io.ldout.valid := hitLoadOut.valid || io.lsq.ldout.valid io.ldout.valid := intHitLoadOut.valid || io.lsq.ldout.valid && !refillFpLoad
// Fp load, if hit, will be send to recoder at s2, then it will be recoded & writebacked at s3
val fpHitLoadOut = Wire(Valid(new ExuOutput))
fpHitLoadOut.valid := s2Valid && load_s2.io.out.bits.uop.ctrl.fpWen
fpHitLoadOut.bits := intHitLoadOut.bits
val fpLoadOut = Wire(Valid(new ExuOutput))
fpLoadOut.bits := Mux(fpHitLoadOut.valid, fpHitLoadOut.bits, io.lsq.ldout.bits)
fpLoadOut.valid := fpHitLoadOut.valid || io.lsq.ldout.valid && refillFpLoad
val fpLoadOutReg = RegNext(fpLoadOut)
io.fpout.bits := fpLoadOutReg.bits
io.fpout.bits.data := fpRdataHelper(fpLoadOutReg.bits.uop, fpLoadOutReg.bits.data) // recode
io.fpout.valid := RegNext(fpLoadOut.valid && !load_s2.io.out.bits.uop.roqIdx.needFlush(io.redirect))
io.lsq.ldout.ready := Mux(refillFpLoad, !fpLoadOut.valid, !intHitLoadOut.valid)
when(io.ldout.fire()){ when(io.ldout.fire()){
XSDebug("ldout %x iw %x fw %x\n", io.ldout.bits.uop.cf.pc, io.ldout.bits.uop.ctrl.rfWen, io.ldout.bits.uop.ctrl.fpWen) XSDebug("ldout %x\n", io.ldout.bits.uop.cf.pc)
}
when(io.fpout.fire()){
XSDebug("fpout %x\n", io.fpout.bits.uop.cf.pc)
} }
} }
\ No newline at end of file
...@@ -28,6 +28,9 @@ class StoreUnit_S0 extends XSModule { ...@@ -28,6 +28,9 @@ class StoreUnit_S0 extends XSModule {
io.out.bits.vaddr := saddr io.out.bits.vaddr := saddr
io.out.bits.data := genWdata(io.in.bits.src2, io.in.bits.uop.ctrl.fuOpType(1,0)) io.out.bits.data := genWdata(io.in.bits.src2, io.in.bits.uop.ctrl.fuOpType(1,0))
when(io.in.bits.uop.ctrl.src2Type === SrcType.fp){
io.out.bits.data := io.in.bits.src2
} // not not touch fp store raw data
io.out.bits.uop := io.in.bits.uop io.out.bits.uop := io.in.bits.uop
io.out.bits.miss := DontCare io.out.bits.miss := DontCare
io.out.bits.mask := genWmask(io.out.bits.vaddr, io.in.bits.uop.ctrl.fuOpType(1,0)) io.out.bits.mask := genWmask(io.out.bits.vaddr, io.in.bits.uop.ctrl.fuOpType(1,0))
...@@ -74,6 +77,7 @@ class StoreUnit_S1 extends XSModule { ...@@ -74,6 +77,7 @@ class StoreUnit_S1 extends XSModule {
io.tlbFeedback.bits.roqIdx.asUInt io.tlbFeedback.bits.roqIdx.asUInt
) )
// get paddr from dtlb, check if rollback is needed // get paddr from dtlb, check if rollback is needed
// writeback store inst to lsq // writeback store inst to lsq
io.lsq.valid := io.in.valid && !s1_tlb_miss// TODO: && ! FP io.lsq.valid := io.in.valid && !s1_tlb_miss// TODO: && ! FP
...@@ -88,9 +92,10 @@ class StoreUnit_S1 extends XSModule { ...@@ -88,9 +92,10 @@ class StoreUnit_S1 extends XSModule {
io.out.valid := io.in.valid && (!io.out.bits.mmio || hasException) && !s1_tlb_miss io.out.valid := io.in.valid && (!io.out.bits.mmio || hasException) && !s1_tlb_miss
io.out.bits := io.lsq.bits io.out.bits := io.lsq.bits
// if fp // encode data for fp store
// io.fp_out.valid := ... when(io.in.bits.uop.ctrl.src2Type === SrcType.fp){
// io.fp_out.bits := ... io.lsq.bits.data := genWdata(ieee(io.in.bits.data), io.in.bits.uop.ctrl.fuOpType(1,0))
}
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册