ソースを参照

Improve ex0 implementation

sindre-ex0
コミット
eb66d7ffeb
4個のファイルの変更120行の追加29行の削除
  1. +5
    -0
      src/main/scala/DotProd.scala
  2. +27
    -22
      src/main/scala/MatMul.scala
  3. +7
    -7
      src/main/scala/Matrix.scala
  4. +81
    -0
      src/test/scala/MatMulSpec.scala

+ 5
- 0
src/main/scala/DotProd.scala ファイルの表示

@@ -17,11 +17,16 @@ class DotProd(val elements: Int) extends Module {


val counter = Counter(elements)

// The sum of the result from each column
val accumulator = RegInit(0.U(32.W))

val result = accumulator + io.dataInA * io.dataInB

// Reset the accumulator when we roll over, ready for
// a new calculation
accumulator := Mux(counter.inc(), 0.U, result)

io.dataOut := result
io.outputValid := counter.inc()
}

+ 27
- 22
src/main/scala/MatMul.scala ファイルの表示

@@ -22,18 +22,37 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module {
val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io
val dotProdCalc = Module(new DotProd(colDimsA)).io

val calculating = RegInit(false.B)
val resultCol = Counter(rowDimsA)
val col = Counter(colDimsA)
val row = Counter(rowDimsA)
// The number of elements in the matrices
val matSize = colDimsA * rowDimsA

matrixA.rowIdx := row.value
matrixA.colIdx := col.value
// We use a single counter that is incremented each tick.
// It rolls over when the calculation is finished, and we're
// ready for a new matrix.
val counter = Counter(matSize * (1 + rowDimsA))
counter.inc()

// We first go through all cells in the matrices and insert the values.
// We then do the multiplication by multiplying each row of matrix A with
// matrix B. This variable says if we are currently inserting or calculating.
val calculating = counter.value >= matSize.U
val calcOffset = counter.value - matSize.U

// We go through row for row, so the column is always incremented
val col = counter.value % colDimsA.U

// While inserting, we use the same position in both matrices.
// While calculating, matrix A stays at the same row while we go through
// all of matrix B.
val rowA = Mux(calculating, calcOffset / matSize.U, counter.value / colDimsA.U)
val rowB = Mux(calculating, (calcOffset / colDimsA.U) % rowDimsA.U, rowA)

matrixA.rowIdx := rowA
matrixA.colIdx := col
matrixA.dataIn := io.dataInA
matrixA.writeEnable := ~calculating

matrixB.rowIdx := Mux(calculating, resultCol.value, row.value)
matrixB.colIdx := col.value
matrixB.rowIdx := rowB
matrixB.colIdx := col
matrixB.dataIn := io.dataInB
matrixB.writeEnable := ~calculating

@@ -42,18 +61,4 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module {

io.dataOut := dotProdCalc.dataOut
io.outputValid := dotProdCalc.outputValid & calculating

when (col.inc()) {
when (calculating) {
when (resultCol.inc()) {
when (row.inc()) {
calculating := false.B
}
}
}.otherwise {
when (row.inc()) {
calculating := true.B
}
}
}
}

+ 7
- 7
src/main/scala/Matrix.scala ファイルの表示

@@ -22,14 +22,14 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module {
val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io))

for(ii <- 0 until rowsDim){
rows(ii).idx := 0.U
rows(ii).dataIn := 0.U
rows(ii).writeEnable := false.B
}
// It doesn't matter what we use to drive idx and dataIn, as long as
// writeEnable is low, so we always use the input values.
rows(ii).idx := io.colIdx
rows(ii).dataIn := io.dataIn

rows(io.rowIdx).writeEnable := io.writeEnable
rows(io.rowIdx).idx := io.colIdx
rows(io.rowIdx).dataIn := io.dataIn
// Enable writing when the current row is selected
rows(ii).writeEnable := io.writeEnable && (io.rowIdx === ii.U)
}

io.dataOut := rows(io.rowIdx).dataOut
}

+ 81
- 0
src/test/scala/MatMulSpec.scala ファイルの表示

@@ -22,6 +22,14 @@ class MatMulSpec extends FlatSpec with Matchers {
)
}

it should "Do a complete matrix multiplication multiple times" in {
wrapTester(
chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
new MultipleFullMatMul(c)
} should be(true)
)
}

it should "Signal at the end" in {
wrapTester(
chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
@@ -92,4 +100,77 @@ object MatMulTests {
step(1)
}
}

class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) {
val mA = genMatrix(c.rowDimsA, c.colDimsA)
val mB = genMatrix(c.rowDimsA, c.colDimsA)
val mC = matrixMultiply(mA, mB.transpose)

println("Multiplying")
println(printMatrix(mA))
println("With")
println(printMatrix(mB.transpose))
println("Expecting")
println(printMatrix(mC))

// Input data
for(ii <- 0 until c.colDimsA * c.rowDimsA){

val rowInputIdx = ii / c.colDimsA
val colInputIdx = ii % c.colDimsA

poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx))
poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx))
expect(c.io.outputValid, false, "Valid output during initialization")

step(1)
}

// Perform calculation
for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
for(kk <- 0 until c.colDimsA - 1){
expect(c.io.outputValid, false, "Valid output mistimed")
step(1)
}
expect(c.io.outputValid, true, "Valid output timing is wrong")
expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
step(1)
}

val mX = genMatrix(c.rowDimsA, c.colDimsA)
val mY = genMatrix(c.rowDimsA, c.colDimsA)
val mZ = matrixMultiply(mX, mY.transpose)

println("Starting new calculation")
println("Multiplying")
println(printMatrix(mX))
println("With")
println(printMatrix(mY.transpose))
println("Expecting")
println(printMatrix(mZ))

// Input data
for(ii <- 0 until c.colDimsA * c.rowDimsA){

val rowInputIdx = ii / c.colDimsA
val colInputIdx = ii % c.colDimsA

poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx))
poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx))
expect(c.io.outputValid, false, "Valid output during initialization")

step(1)
}

// Perform calculation
for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
for(kk <- 0 until c.colDimsA - 1){
expect(c.io.outputValid, false, "Valid output mistimed")
step(1)
}
expect(c.io.outputValid, true, "Valid output timing is wrong")
expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
step(1)
}
}
}

読み込み中…
キャンセル
保存