diff --git a/src/main/scala/xiangshan/backend/exu/MulDivExeUnit.scala b/src/main/scala/xiangshan/backend/exu/MulDivExeUnit.scala index ba807ac175c62bf0bf1bfad1035db6487a50af79..628e2db6bb695c803307aec8464c425d45b62fa8 100644 --- a/src/main/scala/xiangshan/backend/exu/MulDivExeUnit.scala +++ b/src/main/scala/xiangshan/backend/exu/MulDivExeUnit.scala @@ -5,7 +5,7 @@ import chisel3.util._ import xiangshan._ import utils._ import xiangshan.backend.MDUOpType -import xiangshan.backend.fu.{ArrayMultiplier, Divider, FunctionUnit} +import xiangshan.backend.fu.{AbstractDivider, ArrayMultiplier, FunctionUnit, Radix2Divider} class MulDivExeUnit(hasDiv: Boolean = true) extends Exu( exuName = if(hasDiv) "MulDivExeUnit" else "MulExeUnit", @@ -31,7 +31,7 @@ class MulDivExeUnit(hasDiv: Boolean = true) extends Exu( }.get val div = supportedFunctionUnits.collectFirst { - case d: Divider => d + case d: AbstractDivider => d }.orNull // override inputs diff --git a/src/main/scala/xiangshan/backend/fu/FunctionUnit.scala b/src/main/scala/xiangshan/backend/fu/FunctionUnit.scala index 0add91e2c50d24c90ce9068c2506b38d25fe385e..52589e936ca934b2261f2e8fa9cad5d860e3eba5 100644 --- a/src/main/scala/xiangshan/backend/fu/FunctionUnit.scala +++ b/src/main/scala/xiangshan/backend/fu/FunctionUnit.scala @@ -129,8 +129,9 @@ trait HasPipelineReg { this: FunctionUnit => object FunctionUnit extends HasXSParameter { + def divider = new SRT4Divider(XLEN) def multiplier = new ArrayMultiplier(XLEN+1, Seq(0, 2)) - def divider = new Divider(XLEN) + def alu = new Alu def jmp = new Jump diff --git a/src/main/scala/xiangshan/backend/fu/Divider.scala b/src/main/scala/xiangshan/backend/fu/Radix2Divider.scala similarity index 91% rename from src/main/scala/xiangshan/backend/fu/Divider.scala rename to src/main/scala/xiangshan/backend/fu/Radix2Divider.scala index b5ef5dec994ff4486bc01f5eb346f6e59923d0f5..84b052589c8138821665c6b8df24b2c6d1ae57d9 100644 --- a/src/main/scala/xiangshan/backend/fu/Divider.scala +++ b/src/main/scala/xiangshan/backend/fu/Radix2Divider.scala @@ -5,13 +5,15 @@ import chisel3.util._ import xiangshan._ import utils._ -class Divider(len: Int) extends FunctionUnit( +abstract class AbstractDivider(len: Int) extends FunctionUnit( FuConfig(FuType.div, 2, 0, writeIntRf = true, writeFpRf = false, hasRedirect = false, UncertainLatency()), len -) { - +){ val ctrl = IO(Input(new MulDivCtrl)) val sign = ctrl.sign +} + +class Radix2Divider(len: Int) extends AbstractDivider(len) { def abs(a: UInt, sign: Bool): (Bool, UInt) = { val s = a(len - 1) && sign @@ -72,7 +74,8 @@ class Divider(len: Int) extends FunctionUnit( } } - when(state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn)){ + val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn) + when(kill){ state := s_idle } @@ -85,6 +88,6 @@ class Divider(len: Int) extends FunctionUnit( io.out.bits.data := Mux(ctrlReg.isW, SignExt(res(31,0),xlen), res) io.out.bits.uop := uopReg - io.out.valid := state === s_finish + io.out.valid := state === s_finish && !kill io.in.ready := state === s_idle } \ No newline at end of file diff --git a/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala b/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala new file mode 100644 index 0000000000000000000000000000000000000000..df947e72a5896782831bf85772cb01924ed3a350 --- /dev/null +++ b/src/main/scala/xiangshan/backend/fu/SRT4Divider.scala @@ -0,0 +1,230 @@ +package xiangshan.backend.fu + +import chisel3._ +import chisel3.util._ +import utils.SignExt +import xiangshan.backend.fu.fpu.util.CSA3_2 + +/** A Radix-4 SRT Integer Divider + * + * 2 ~ (5 + (len+3)/2) cycles are needed for each division. + */ +class SRT4Divider(len: Int) extends AbstractDivider(len) { + + val s_idle :: s_lzd :: s_normlize :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(6) + val state = RegInit(s_idle) + val newReq = (state === s_idle) && io.in.fire() + val cnt_next = Wire(UInt(log2Up((len+3)/2).W)) + val cnt = RegEnable(cnt_next, state===s_normlize || state===s_recurrence) + val rec_enough = cnt_next === 0.U + + def abs(a: UInt, sign: Bool): (Bool, UInt) = { + val s = a(len - 1) && sign + (s, Mux(s, -a, a)) + } + val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) + val uop = io.in.bits.uop + val (aSign, aVal) = abs(a, sign) + val (bSign, bVal) = abs(b, sign) + val aSignReg = RegEnable(aSign, newReq) + val qSignReg = RegEnable(aSign ^ bSign, newReq) + val uopReg = RegEnable(uop, newReq) + val ctrlReg = RegEnable(ctrl, newReq) + val divZero = b === 0.U + val divZeroReg = RegEnable(divZero, newReq) + + val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn) + + switch(state){ + is(s_idle){ + when(io.in.fire()){ state := Mux(divZero, s_finish, s_lzd) } + } + is(s_lzd){ // leading zero detection + state := s_normlize + } + is(s_normlize){ // shift a/b + state := s_recurrence + } + is(s_recurrence){ // (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d + when(rec_enough){ state := s_recovery } + } + is(s_recovery){ // if rem < 0, rem = rem + d + state := s_finish + } + is(s_finish){ + when(io.out.fire()){ state := s_idle } + } + } + when(kill){ + state := s_idle + } + + /** Calculate abs(a)/abs(b) by recurrence + * + * ws, wc: partial remainder in carry-save form, + * in recurrence steps, ws/wc = 4ws[j]/4wc[j]; + * in recovery step, ws/wc = ws[j]/wc[j]; + * in final step, ws = abs(a)/abs(b). + * + * d: normlized divisor(1/2<=d<1) + * + * wLen = 3 integer bits + (len+1) frac bits + */ + def wLen = 3 + len + 1 + val ws, wc = Reg(UInt(wLen.W)) + val ws_next, wc_next = Wire(UInt(wLen.W)) + val d = Reg(UInt(wLen.W)) + + val aLeadingZeros = RegEnable( + next = PriorityEncoder(ws(len-1, 0).asBools().reverse), + enable = state===s_lzd + ) + val bLeadingZeros = RegEnable( + next = PriorityEncoder(d(len-1, 0).asBools().reverse), + enable = state===s_lzd + ) + val diff = Cat(0.U(1.W), bLeadingZeros).asSInt() - Cat(0.U(1.W), aLeadingZeros).asSInt() + val isNegDiff = diff(diff.getWidth - 1) + val quotientBits = Mux(isNegDiff, 0.U, diff.asUInt()) + val qBitsIsOdd = quotientBits(0) + val recoveryShift = RegEnable(len.U - bLeadingZeros, state===s_normlize) + val a_shifted, b_shifted = Wire(UInt(len.W)) + a_shifted := Mux(isNegDiff, + ws(len-1, 0) << bLeadingZeros, + ws(len-1, 0) << aLeadingZeros + ) + b_shifted := d(len-1, 0) << bLeadingZeros + + val rem_temp = ws + wc + val rem_fixed = Mux(rem_temp(wLen-1), rem_temp + d, rem_temp) + val rem_abs = (rem_fixed << recoveryShift)(2*len, len+1) + + when(newReq){ + ws := Cat(0.U(4.W), Mux(divZero, a, aVal)) + wc := 0.U + d := Cat(0.U(4.W), bVal) + }.elsewhen(state === s_normlize){ + d := Cat(0.U(3.W), b_shifted, 0.U(1.W)) + ws := Mux(qBitsIsOdd, a_shifted, a_shifted << 1) + }.elsewhen(state === s_recurrence){ + ws := Mux(rec_enough, ws_next, ws_next << 2) + wc := Mux(rec_enough, wc_next, wc_next << 2) + }.elsewhen(state === s_recovery){ + ws := rem_abs + } + + cnt_next := Mux(state === s_normlize, (quotientBits + 3.U) >> 1, cnt - 1.U) + + /** Quotient selection + * + * the quotient selection table use truncated 7-bit remainder + * and 3-bit divisor + */ + val sel_0 :: sel_d :: sel_dx2 :: sel_neg_d :: sel_neg_dx2 :: Nil = Enum(5) + val dx2, neg_d, neg_dx2 = Wire(UInt(wLen.W)) + dx2 := d << 1 + neg_d := (~d).asUInt() // add '1' in carry-save adder later + neg_dx2 := neg_d << 1 + + val q_sel = Wire(UInt(3.W)) + val wc_adj = MuxLookup(q_sel, 0.U(2.W), Seq( + sel_d -> 1.U(2.W), + sel_dx2 -> 2.U(2.W) + )) + + val w_truncated = (ws(wLen-1, wLen-1-6) + wc(wLen-1, wLen-1-6)).asSInt() + val d_truncated = d(len-1, len-3) + + val qSelTable = Array( + Array(12, 4, -4, -13), + Array(14, 4, -6, -15), + Array(15, 4, -6, -16), + Array(16, 4, -6, -18), + Array(18, 6, -8, -20), + Array(20, 6, -8, -20), + Array(20, 8, -8, -22), + Array(24, 8, -8, -24) + ) + + // ge(x): w_truncated >= x + var ge = Map[Int, Bool]() + for(row <- qSelTable){ + for(k <- row){ + if(!ge.contains(k)) ge = ge + (k -> (w_truncated >= k.S(7.W))) + } + } + q_sel := MuxLookup(d_truncated, sel_0, + qSelTable.map(x => + MuxCase(sel_neg_dx2, Seq( + ge(x(0)) -> sel_dx2, + ge(x(1)) -> sel_d, + ge(x(2)) -> sel_0, + ge(x(3)) -> sel_neg_d + )) + ).zipWithIndex.map({case(v, i) => i.U -> v}) + ) + + /** Calculate (ws[j+1],wc[j+1]) by a [3-2]carry-save adder + * + * (ws[j+1], wc[j+1]) = 4(ws[j],wc[j]) - q(j+1)*d + */ + val csa = Module(new CSA3_2(wLen)) + csa.io.in(0) := ws + csa.io.in(1) := Cat(wc(wLen-1, 2), wc_adj) + csa.io.in(2) := MuxLookup(q_sel, 0.U, Seq( + sel_d -> neg_d, + sel_dx2 -> neg_dx2, + sel_neg_d -> d, + sel_neg_dx2 -> dx2 + )) + ws_next := csa.io.out(0) + wc_next := csa.io.out(1) << 1 + + // On the fly quotient conversion + val q, qm = Reg(UInt(len.W)) + when(newReq){ + q := 0.U + qm := 0.U + }.elsewhen(state === s_recurrence){ + val qMap = Seq( + sel_0 -> (q, 0), + sel_d -> (q, 1), + sel_dx2 -> (q, 2), + sel_neg_d -> (qm, 3), + sel_neg_dx2 -> (qm, 2) + ) + q := MuxLookup(q_sel, 0.U, + qMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) + ) + val qmMap = Seq( + sel_0 -> (qm, 3), + sel_d -> (q, 0), + sel_dx2 -> (q, 1), + sel_neg_d -> (qm, 2), + sel_neg_dx2 -> (qm, 1) + ) + qm := MuxLookup(q_sel, 0.U, + qmMap.map(m => m._1 -> Cat(m._2._1(len-3, 0), m._2._2.U(2.W))) + ) + }.elsewhen(state === s_recovery){ + q := Mux(rem_temp(wLen-1), qm, q) + } + + + val remainder = Mux(aSignReg, -ws(len-1, 0), ws(len-1, 0)) + val quotient = Mux(qSignReg, -q, q) + + val res = Mux(ctrlReg.isHi, + Mux(divZeroReg, ws(len-1, 0), remainder), + Mux(divZeroReg, Fill(len, 1.U(1.W)), quotient) + ) + + io.in.ready := state===s_idle + io.out.valid := state===s_finish && !kill + io.out.bits.data := Mux(ctrlReg.isW, + SignExt(res(31, 0), len), + res + ) + io.out.bits.uop := uopReg + +} diff --git a/src/main/scala/xiangshan/backend/fu/fpu/divsqrt/DivSqrt.scala b/src/main/scala/xiangshan/backend/fu/fpu/divsqrt/DivSqrt.scala index bf4414df69754da134e1cabe00006db004052453..85a4e8e963d5f5de08ae65add3a423b01ab8d83e 100644 --- a/src/main/scala/xiangshan/backend/fu/fpu/divsqrt/DivSqrt.scala +++ b/src/main/scala/xiangshan/backend/fu/fpu/divsqrt/DivSqrt.scala @@ -21,7 +21,7 @@ class DivSqrt extends FPUSubModule( val uopReg = RegEnable(io.in.bits.uop, io.in.fire()) - val kill = uopReg.roqIdx.needFlush(io.redirectIn) + val kill = state=/=s_idle && uopReg.roqIdx.needFlush(io.redirectIn) val rmReg = RegEnable(rm, io.in.fire()) val isDiv = !op(0) val isDivReg = RegEnable(isDiv, io.in.fire())