MantDivSqrt.scala 3.5 KB
Newer Older
1
package xiangshan.backend.fu.fpu.divsqrt
L
FPUv0.1  
LinJiawei 已提交
2 3 4

import chisel3._
import chisel3.util._
5 6
import xiangshan.backend.fu.fpu.util._
import xiangshan.backend.fu.fpu.util.FPUDebug
L
FPUv0.1  
LinJiawei 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126

class MantDivSqrt(len: Int) extends Module{
  val io = IO(new Bundle() {
    val in = Flipped(DecoupledIO(new Bundle() {
      val a, b = UInt(len.W)
      val isDiv = Bool()
    }))
    val out = DecoupledIO(new Bundle() {
      val quotient = UInt(len.W)
      val isZeroRem = Bool()
    })
  })

  val (a, b) = (io.in.bits.a, io.in.bits.b)
  val isDiv = io.in.bits.isDiv
  val isDivReg = RegEnable(isDiv, io.in.fire())
  val divisor = RegEnable(b, io.in.fire())

  val s_idle :: s_recurrence :: s_recovery :: s_finish :: Nil = Enum(4)
  val state = RegInit(s_idle)
  val cnt_next = Wire(UInt(log2Up((len+1)/2).W))
  val cnt = RegEnable(cnt_next, state===s_idle || state===s_recurrence)
  cnt_next := Mux(state === s_idle, (len/2).U, cnt - 1.U)


  val firstCycle = RegNext(io.in.fire())

  switch(state){
    is(s_idle){
      when(io.in.fire()){ state := s_recurrence }
    }
    is(s_recurrence){
      when(cnt_next === 0.U){ state := s_recovery }
    }
    is(s_recovery){
      state := s_finish
    }
    is(s_finish){
      when(io.out.fire()){ state := s_idle }
    }
  }

  val ws, wc = Reg(UInt((len+4).W))


  val table = Module(new SrtTable)
  val conv = Module(new OnTheFlyConv(len+3))
  val csa = Module(new CSA3_2(len+4))

  // partial square root
  val S = conv.io.Q >> 2
  val s0 :: s1 :: s2 :: s3 :: s4 :: Nil =  S(len-2, len-6).asBools().reverse
  val sqrt_d = Mux(firstCycle, "b101".U(3.W), Mux(s0, "b111".U(3.W), Cat(s2, s3, s4)))
  val div_d = divisor(len-2, len-4)
  val sqrt_y = ws(len+3, len-4) + wc(len+3, len-4)
  val div_y = ws(len+2, len-5) + wc(len+2, len-5)

  table.io.d := Mux(isDivReg, div_d, sqrt_d)
  table.io.y := Mux(isDivReg, div_y, sqrt_y)

  conv.io.resetSqrt := io.in.fire() && !isDiv
  conv.io.resetDiv := io.in.fire() && isDiv
  conv.io.enable := state===s_recurrence
  conv.io.qi := table.io.q

  val dx1, dx2, neg_dx1, neg_dx2 = Wire(UInt((len+4).W))
  dx1 := divisor
  dx2 := divisor << 1
  neg_dx1 := ~dx1
  neg_dx2 := neg_dx1 << 1

  val divCsaIn = MuxLookup(table.io.q.asUInt(), 0.U, Seq(
    -1 -> dx1,
    -2 -> dx2,
    1 -> neg_dx1,
    2 -> neg_dx2
  ).map(m => m._1.S(3.W).asUInt() -> m._2))

  csa.io.in(0) := ws
  csa.io.in(1) := Mux(isDivReg & !table.io.q(2),  wc | table.io.q(1, 0), wc)
  csa.io.in(2) := Mux(isDivReg, divCsaIn, conv.io.F)

  val divWsInit =  a
  val sqrtWsInit = Cat( Cat(0.U(2.W), a) - Cat(1.U(2.W), 0.U(len.W)), 0.U(2.W))

  when(io.in.fire()){
    ws := Mux(isDiv, divWsInit, sqrtWsInit)
    wc := 0.U
  }.elsewhen(state === s_recurrence){
    ws := Mux(cnt_next === 0.U, csa.io.out(0), csa.io.out(0) << 2)
    wc := Mux(cnt_next === 0.U, csa.io.out(1) << 1, csa.io.out(1) << 3)
  }

  val rem = ws + wc

  /** Remainder format:
    * Sqrt:
    * s s x x. x x x ... x
    * Div:
    * s s s x. x x x ... x
    */
  val remSignReg = RegEnable(rem.head(1).asBool(), state===s_recovery)
  val isZeroRemReg = RegEnable(rem===0.U, state===s_recovery)

  io.in.ready := state === s_idle
  io.out.valid := state === s_finish
  io.out.bits.quotient := Mux(remSignReg, conv.io.QM, conv.io.Q) >> !isDivReg
  io.out.bits.isZeroRem := isZeroRemReg

  FPUDebug(){
    when(io.in.fire()){
      printf(p"a:${Hexadecimal(io.in.bits.a)} b:${Hexadecimal(io.in.bits.b)}\n")
    }
    when(io.out.valid) {
      printf(p"Q:${Binary(conv.io.Q)} QM:${Binary(conv.io.QM)} isNegRem:${rem.head(1)}\n" +
        p"rem:${Hexadecimal(rem)}\n")
    }
  }
}