BitUtils.scala 9.0 KB
Newer Older
L
Lemover 已提交
1 2
/***************************************************************************************
* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
Y
Yinan Xu 已提交
3
* Copyright (c) 2020-2021 Peng Cheng Laboratory
L
Lemover 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16
*
* XiangShan is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*          http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
*
* See the Mulan PSL v2 for more details.
***************************************************************************************/

17 18 19 20
package utils

import chisel3._
import chisel3.util._
21
import scala.math.min
22 23 24 25

object WordShift {
  def apply(data: UInt, wordIndex: UInt, step: Int) = (data << (wordIndex * step.U))
}
Z
Zihao Yu 已提交
26 27 28 29

object MaskExpand {
 def apply(m: UInt) = Cat(m.asBools.map(Fill(8, _)).reverse)
}
Z
Zihao Yu 已提交
30

Z
Zihao Yu 已提交
31
object MaskData {
32 33 34 35 36
  def apply(oldData: UInt, newData: UInt, fullmask: UInt) = {
    require(oldData.getWidth <= fullmask.getWidth, s"${oldData.getWidth} < ${fullmask.getWidth}")
    require(newData.getWidth <= fullmask.getWidth, s"${newData.getWidth} < ${fullmask.getWidth}")
    (newData & fullmask) | (oldData & ~fullmask)
  }
Z
Zihao Yu 已提交
37 38
}

Z
Zihao Yu 已提交
39 40 41 42
object SignExt {
  def apply(a: UInt, len: Int) = {
    val aLen = a.getWidth
    val signBit = a(aLen-1)
W
William Wang 已提交
43
    if (aLen >= len) a(len-1,0) else Cat(Fill(len - aLen, signBit), a)
Z
Zihao Yu 已提交
44 45 46 47 48 49
  }
}

object ZeroExt {
  def apply(a: UInt, len: Int) = {
    val aLen = a.getWidth
W
William Wang 已提交
50
    if (aLen >= len) a(len-1,0) else Cat(0.U((len - aLen).W), a)
Z
Zihao Yu 已提交
51 52
  }
}
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

object Or {
  // Fill 1s from low bits to high bits
  def leftOR(x: UInt): UInt = leftOR(x, x.getWidth, x.getWidth)
  def leftOR(x: UInt, width: Integer, cap: Integer = 999999): UInt = {
    val stop = min(width, cap)
    def helper(s: Int, x: UInt): UInt =
      if (s >= stop) x else helper(s+s, x | (x << s)(width-1,0))
    helper(1, x)(width-1, 0)
  }

  // Fill 1s form high bits to low bits
  def rightOR(x: UInt): UInt = rightOR(x, x.getWidth, x.getWidth)
  def rightOR(x: UInt, width: Integer, cap: Integer = 999999): UInt = {
    val stop = min(width, cap)
    def helper(s: Int, x: UInt): UInt =
      if (s >= stop) x else helper(s+s, x | (x >> s))
    helper(1, x)(width-1, 0)
  }
}
A
Allen 已提交
73 74 75 76 77 78

object OneHot {
  def OH1ToOH(x: UInt): UInt = (x << 1 | 1.U) & ~Cat(0.U(1.W), x)
  def OH1ToUInt(x: UInt): UInt = OHToUInt(OH1ToOH(x))
  def UIntToOH1(x: UInt, width: Int): UInt = ~((-1).S(width.W).asUInt << x)(width-1, 0)
  def UIntToOH1(x: UInt): UInt = UIntToOH1(x, (1 << x.getWidth) - 1)
A
Allen 已提交
79 80
  def checkOneHot(in: Bits): Unit = assert(PopCount(in) <= 1.U)
  def checkOneHot(in: Iterable[Bool]): Unit = assert(PopCount(in) <= 1.U)
A
Allen 已提交
81
}
82

G
GouLingrui 已提交
83 84
object LowerMask {
  def apply(a: UInt, len: Int) = {
L
Lingrui98 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97
    ParallelOR((0 until len).map(i => a >> i.U))
  }
  def apply(a: UInt): UInt = {
    apply(a, a.getWidth)
  }
}

object HigherMask {
  def apply(a: UInt, len: Int) = {
    Reverse(LowerMask(Reverse(a), len))
  }
  def apply(a: UInt): UInt = {
    apply(a, a.getWidth)
G
GouLingrui 已提交
98 99 100
  }
}

L
Lingrui98 已提交
101 102 103 104 105 106 107 108 109 110 111 112
object LowerMaskFromLowest {
  def apply(a: UInt) = {
    LowerMask(PriorityEncoderOH(a))
  }
}

object HigherMaskFromHighest {
  def apply(a: UInt) = {
    Reverse(LowerMask(PriorityEncoderOH(Reverse(a))))
  }
}

G
GouLingrui 已提交
113 114
object LowestBit {
  def apply(a: UInt, len: Int) = {
L
Lingrui98 已提交
115
    Mux(a(0), 1.U(len.W), Reverse((ParallelOR((0 until len).map(i => Reverse(a(len - 1, 0)) >> i.U)) + 1.U) >> 1.U))
G
GouLingrui 已提交
116 117 118 119 120 121 122 123
  }
}

object HighestBit {
  def apply(a: UInt, len: Int) = {
    Reverse(LowestBit(Reverse(a), len))
  }
}
W
William Wang 已提交
124 125 126 127 128

object GenMask {
  // generate w/r mask
  def apply(high: Int, low: Int) = {
    require(high > low)
L
LinJiawei 已提交
129
    (VecInit(List.fill(high+1)(true.B)).asUInt >> low << low).asUInt()
W
William Wang 已提交
130 131
  }
  def apply(pos: Int) = {
L
LinJiawei 已提交
132
    (1.U << pos).asUInt()
W
William Wang 已提交
133
  }
Y
Yinan Xu 已提交
134 135 136
}

object UIntToMask {
W
William Wang 已提交
137 138 139 140 141 142
  def apply(ptr: UInt, length: Integer) = leftmask(ptr, length)
  def reverseUInt(input: UInt): UInt = {
    VecInit(input.asBools.reverse).asUInt
  }
  def leftmask(ptr: UInt, length: Integer) = UIntToOH(ptr)(length - 1, 0) - 1.U
  def rightmask(ptr: UInt, length: Integer) = reverseUInt(reverseUInt(UIntToOH(ptr)(length - 1, 0)) - 1.U)
Y
Yinan Xu 已提交
143
}
144 145 146 147 148 149 150 151 152 153 154 155

object GetEvenBits {
  def apply(input: UInt): UInt = {
    VecInit((0 until input.getWidth/2).map(i => {input(2*i)})).asUInt
  }
}


object GetOddBits {
  def apply(input: UInt): UInt = {
    VecInit((0 until input.getWidth/2).map(i => {input(2*i+1)})).asUInt
  }
156 157 158
}

object XORFold {
159 160 161 162 163
  def apply(input: UInt, resWidth: Int): UInt = {
    require(resWidth > 0)
    val fold_range = input.getWidth / resWidth
    val value = ZeroExt(input, fold_range * resWidth)
    ParallelXOR((0 until fold_range).map(i => value(i*resWidth+resWidth-1, i*resWidth)))
164
  }
165
}
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182

object OnesMoreThan {
  def apply(input: Seq[Bool], thres: Int): Bool = {
    if (thres == 0) {
      true.B
    }
    else if (input.length < thres) {
      false.B
    }
    else {
      val tail = input.drop(1)
      input(0) && OnesMoreThan(tail, thres - 1) || OnesMoreThan(tail, thres)
    }
  }
}

abstract class SelectOne {
183 184 185 186 187 188
  protected val balance2 = RegInit(false.B)
  balance2 := !balance2

  // need_balance: for balanced selections only (DO NOT use this if you don't know what it is)
  def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool])
  def getBalance2: Bool = balance2
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
}

class NaiveSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
  val n_bits = bits.length
  val n_sel = if (max_sel > 0) max_sel else n_bits
  require(n_bits > 0 && n_sel > 0 && n_bits >= n_sel)
  private val matrix = Wire(Vec(n_bits, Vec(n_sel, Bool())))
  // matrix[i][j]: first i bits has j one's
  for (i <- 0 until n_bits) {
    for (j <- 0 until n_sel) {
      if (j == 0) {
        matrix(i)(j) := (if (i == 0) true.B else !Cat(bits.take(i)).orR)
      }
      // it's impossible to select j-th one from i elements
      else if (i < j) {
        matrix(i)(j) := false.B
      }
      else {
        matrix(i)(j) := bits(i - 1) && matrix(i - 1)(j - 1) || !bits(i - 1) && matrix(i - 1)(j)
      }
    }
  }

212 213 214
  def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
    require(n > 0, s"$n should be positive to select the n-th one")
    require(n <= n_sel, s"$n should not be larger than $n_sel")
215 216 217 218 219 220 221 222 223 224 225 226 227 228
    // bits(i) is true.B and bits(i - 1, 0) has n - 1
    val selValid = OnesMoreThan(bits, n)
    val sel = VecInit(bits.zip(matrix).map{ case (b, m) => b && m(n - 1) })
    (selValid, sel)
  }
}

class CircSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
  val n_bits = bits.length
  val n_sel = if (max_sel > 0) max_sel else n_bits
  require(n_bits > 0 && n_sel > 0 && n_bits >= n_sel)

  val sel_forward = new NaiveSelectOne(bits, (n_sel + 1) / 2)
  val sel_backward = new NaiveSelectOne(bits.reverse, n_sel / 2)
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
  val moreThan = Seq(1, 2).map(i => OnesMoreThan(bits, i))

  def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
    require(!need_balance || max_sel == 2, s"does not support load balance between $max_sel selections")
    val selValid = if (!need_balance) {
      OnesMoreThan(bits, n)
    } else {
      if (n == 1) {
        // When balance2 bit is set, we prefer the second selection port.
        Mux(balance2, moreThan.last, moreThan.head)
      }
      else {
        require(n == 2)
        Mux(balance2, moreThan.head, moreThan.last)
      }
    }
245 246
    val sel_index = (n + 1) / 2
    if (n % 2 == 1) {
247
      (selValid, sel_forward.getNthOH(sel_index, need_balance)._2)
248 249
    }
    else {
250
      (selValid, VecInit(sel_backward.getNthOH(sel_index, need_balance)._2.reverse))
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
    }
  }
}

class OddEvenSelectOne(bits: Seq[Bool], max_sel: Int = -1) extends SelectOne {
  val n_bits = bits.length
  val n_sel = if (max_sel > 0) max_sel else n_bits
  require(n_bits > 0 && n_sel > 0 && n_bits >= n_sel)
  require(n_sel > 1, "Select only one entry via OddEven causes odd entries to be ignored")

  val n_even = (n_bits + 1) / 2
  val sel_even = new CircSelectOne((0 until n_even).map(i => bits(2 * i)), n_sel / 2)
  val n_odd = n_bits / 2
  val sel_odd = new CircSelectOne((0 until n_odd).map(i => bits(2 * i + 1)), (n_sel + 1) / 2)

266
  def getNthOH(n: Int, need_balance: Boolean = false): (Bool, Vec[Bool]) = {
267 268
    val sel_index = (n + 1) / 2
    if (n % 2 == 1) {
269
      val selected = sel_even.getNthOH(sel_index, need_balance)
270 271 272 273
      val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 0) selected._2(i / 2) else false.B))
      (selected._1, sel)
    }
    else {
274
      val selected = sel_odd.getNthOH(sel_index, need_balance)
275 276 277 278 279 280 281
      val sel = VecInit((0 until n_bits).map(i => if (i % 2 == 1) selected._2(i / 2) else false.B))
      (selected._1, sel)
    }
  }
}

object SelectOne {
282 283 284 285 286 287 288
  def apply(policy: String, bits: Seq[Bool], max_sel: Int = -1): SelectOne = {
    policy.toLowerCase match {
      case "naive" => new NaiveSelectOne(bits, max_sel)
      case "circ" => new CircSelectOne(bits, max_sel)
      case "oddeven" => new OddEvenSelectOne(bits, max_sel)
      case _ => throw new IllegalArgumentException(s"unknown select policy")
    }
289 290
  }
}