什么是在数组中找到最大值的低延迟硬件算法

我需要构建一个低延迟、单周期的硬件模块来查找数组中最大元素的索引和值。目前,我正在使用比较器树,但延迟并不令人满意。那么是否还有其他算法可能具有更低的延迟?

我希望输入数组很大(256 到 4096 个元素)并且值很小(3 到 5 位)。另外,我希望数组是稀疏的,即很多小值和很少的大值。

我主要关注延迟;区域没有那么重要。

我当前使用比较器树的实现如下所示:

implicit class reduceTreeOp[A](seq: Seq[A]) {
  def reduceTree[B >: A](op: (B,B) => B): B = {
    if(seq.length == 0)
      throw new NoSuchElementException("cannot reduce empty Seq")
    var rseq: Seq[B] = seq
    while(rseq.length != 1)
      rseq = rseq.grouped(2).toSeq
        .map(s => if(s.length == 1) s(0) else op(s(0),s(1)))
    rseq(0)
  }
}
val (value,index) = array
  .zipWithIndex
  .map{case (v,i) => (v,i.U)}
  .reduceTree[(UInt,UInt)]{case ((val1,idx1),(val2,idx2)) =>
    val is1 = val1 >= idx2
    ( Mux(is1,val1,idx2),Mux(is1,idx1,idx2))
  }

FWIW 这是为 7nm 硬件设计的;虽然我怀疑这对我的问题是否真的很重要。

tanichos 回答:什么是在数组中找到最大值的低延迟硬件算法

这在 chisel3 中非常简单。结果在单个周期中返回的约束将导致生成一大堆硬件。仔细评估您在此处需要什么可能是个好主意。

尽管如此,这是一个有趣的问题,并且展示了 chisel 的一些威力。我提供了一个可运行的示例 Scastie example

这是一个非常简单的测试套件的代码

import chisel3._
import chisel3.util.log2Ceil
import chiseltest._
import org.scalatest.freespec.AnyFreeSpec
import treadle.extremaOfUIntOfWidth

/** write only memory that continuously outputs the first index of the highest value in an array and
  * that highest value. It works by building a evaluation network of highest values
  *
  * It would be trivial to add ability to read the values in this memory
  *
  * @param depth
  * @param bitWidth
  */
class ArrayMax(val depth: Int,val bitWidth: Int) extends MultiIOModule {
  val writeEnable = IO(Input(Bool()))
  val writeAddress = IO(Input(UInt(log2Ceil(depth).W)))
  val writeData = IO(Input(UInt(bitWidth.W)))

  val maxValue = IO(Output(UInt(bitWidth.W)))
  val indexOfMaxValue = IO(Output(UInt(log2Ceil(depth).W)))

  val array = Reg(Vec(depth,UInt(bitWidth.W)))

  when(writeEnable) {
    array(writeAddress) := writeData
  }

  val valuesAndIndices = array.zipWithIndex.map { case (value,index) => (value,index.U)}.toList

  // Look through the array pair wise and return the value and index of the higher of the pair
  def compareAdjacentValues(valuesAndIndices: Seq[(UInt,UInt)]): Seq[(UInt,UInt)] = {
    val pairs = valuesAndIndices.sliding(2,2)
    pairs.map {
      case (aValue,aIndex) :: (bValue,bIndex) :: Nil =>
        val (higherValue,higherIndex) = (Wire(UInt(bitWidth.W)),Wire(UInt(log2Ceil(depth).W)))

        when(aValue < bValue) {
          higherValue := bValue
          higherIndex := bIndex
        } otherwise {
          higherValue := aValue
          higherIndex := aIndex
        }
        (higherValue,higherIndex)
      case (aValue,aIndex) :: Nil =>
        (aValue,aIndex)
      case a =>
        throw new Exception("Cannot get here,sliding should return list of size 1 or 2,$a")
    }.toList
  }

  def reduceToOne(pairs: Seq[(UInt,UInt)]): (UInt,UInt) = {
    if(pairs.length == 1) {
      pairs.head
    } else {
      reduceToOne(compareAdjacentValues(pairs))
    }
  }

  val (highestValue,index) = reduceToOne(valuesAndIndices)

  maxValue := highestValue
  indexOfMaxValue := index
}

/** Pumps random values at random indices into the dut and buffer that models it
  * Checks to see that the first highest value (there can be multiple occurrences)
  * is returned and the first index where that value appears.
  *
  * A simpler model might try to keep a record of the highest value in the array
  * but that will break down if that value is replaced by something lower
  */
class ArrayMaxSpec extends AnyFreeSpec with ChiselScalatestTester {
  val rand = new scala.util.Random

  "should act like a little write only memory" in {
    test(new ArrayMax(depth = 256,bitWidth = 8)) { dut =>

      val testArray = Array.fill(dut.depth)(0)

      for(i <- 0 until dut.depth * 10) {
        val index = rand.nextInt(dut.depth)
        dut.writeEnable.poke(true.B)
        dut.writeAddress.poke(index.U)
        val newValue = rand.nextInt(extremaOfUIntOfWidth(dut.bitWidth)._2.toInt)
        dut.writeData.poke(newValue.U)
        testArray(index) = newValue

        dut.clock.step()

        dut.maxValue.expect(testArray.max.U)
        dut.indexOfMaxValue.expect(testArray.indexOf(testArray.max).U)
      }
    }
  }
}
,

我找到了一种通过反转数组来工作的算法。本质上,创建一个新数组,每个可能的值都有一个槽,原始数组的每个元素根据值排序到一个槽中,最后,可以对新数组进行优先级编码以找到具有元素的最高槽。

实现如下所示:

// VALUE_SPACE: number of possible values
val exists = WireDefault(VecInit(Seq.fill(VALUE_SPACE)(false.B)))
val index = Wire(Vec(VALUE_SPACE,UInt(log2ceil(array.length).W)))
index := DontCare

array.zipWithIndex.foreach{case (v,i) =>
  exists(v) := true.B
  index(v) := i.U
}

maxVal := VALUE_SPACE.U - PriorityEncoder(exists.reverse)
maxidx := PriorityMux(exists.reverse,index.reverse)

当使用具有 256 个 4 位元素的数组时,该算法使我在比较器树上获得了大约 2.5 倍的加速。然而,它使用了更多的区域,并且可能只有在可能值的数量非常少的情况下才有好处。

本文链接:https://www.f2er.com/57399.html

大家都在问