diff --git a/src/main/scala/DotProd.scala b/src/main/scala/DotProd.scala index 08ee57f..7319904 100644 --- a/src/main/scala/DotProd.scala +++ b/src/main/scala/DotProd.scala @@ -17,11 +17,16 @@ class DotProd(val elements: Int) extends Module { val counter = Counter(elements) + + // The sum of the result from each column val accumulator = RegInit(0.U(32.W)) val result = accumulator + io.dataInA * io.dataInB + // Reset the accumulator when we roll over, ready for + // a new calculation accumulator := Mux(counter.inc(), 0.U, result) + io.dataOut := result io.outputValid := counter.inc() } diff --git a/src/main/scala/MatMul.scala b/src/main/scala/MatMul.scala index 617c251..da58940 100644 --- a/src/main/scala/MatMul.scala +++ b/src/main/scala/MatMul.scala @@ -22,18 +22,37 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io val dotProdCalc = Module(new DotProd(colDimsA)).io - val calculating = RegInit(false.B) - val resultCol = Counter(rowDimsA) - val col = Counter(colDimsA) - val row = Counter(rowDimsA) + // The number of elements in the matrices + val matSize = colDimsA * rowDimsA - matrixA.rowIdx := row.value - matrixA.colIdx := col.value + // We use a single counter that is incremented each tick. + // It rolls over when the calculation is finished, and we're + // ready for a new matrix. + val counter = Counter(matSize * (1 + rowDimsA)) + counter.inc() + + // We first go through all cells in the matrices and insert the values. + // We then do the multiplication by multiplying each row of matrix A with + // matrix B. This variable says if we are currently inserting or calculating. + val calculating = counter.value >= matSize.U + val calcOffset = counter.value - matSize.U + + // We go through row for row, so the column is always incremented + val col = counter.value % colDimsA.U + + // While inserting, we use the same position in both matrices. + // While calculating, matrix A stays at the same row while we go through + // all of matrix B. + val rowA = Mux(calculating, calcOffset / matSize.U, counter.value / colDimsA.U) + val rowB = Mux(calculating, (calcOffset / colDimsA.U) % rowDimsA.U, rowA) + + matrixA.rowIdx := rowA + matrixA.colIdx := col matrixA.dataIn := io.dataInA matrixA.writeEnable := ~calculating - matrixB.rowIdx := Mux(calculating, resultCol.value, row.value) - matrixB.colIdx := col.value + matrixB.rowIdx := rowB + matrixB.colIdx := col matrixB.dataIn := io.dataInB matrixB.writeEnable := ~calculating @@ -42,18 +61,4 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { io.dataOut := dotProdCalc.dataOut io.outputValid := dotProdCalc.outputValid & calculating - - when (col.inc()) { - when (calculating) { - when (resultCol.inc()) { - when (row.inc()) { - calculating := false.B - } - } - }.otherwise { - when (row.inc()) { - calculating := true.B - } - } - } } diff --git a/src/main/scala/Matrix.scala b/src/main/scala/Matrix.scala index c1658b4..589ab9a 100644 --- a/src/main/scala/Matrix.scala +++ b/src/main/scala/Matrix.scala @@ -22,14 +22,14 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module { val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io)) for(ii <- 0 until rowsDim){ - rows(ii).idx := 0.U - rows(ii).dataIn := 0.U - rows(ii).writeEnable := false.B - } + // It doesn't matter what we use to drive idx and dataIn, as long as + // writeEnable is low, so we always use the input values. + rows(ii).idx := io.colIdx + rows(ii).dataIn := io.dataIn - rows(io.rowIdx).writeEnable := io.writeEnable - rows(io.rowIdx).idx := io.colIdx - rows(io.rowIdx).dataIn := io.dataIn + // Enable writing when the current row is selected + rows(ii).writeEnable := io.writeEnable && (io.rowIdx === ii.U) + } io.dataOut := rows(io.rowIdx).dataOut } diff --git a/src/test/scala/MatMulSpec.scala b/src/test/scala/MatMulSpec.scala index 17f97b8..7ab7bfb 100644 --- a/src/test/scala/MatMulSpec.scala +++ b/src/test/scala/MatMulSpec.scala @@ -22,6 +22,14 @@ class MatMulSpec extends FlatSpec with Matchers { ) } + it should "Do a complete matrix multiplication multiple times" in { + wrapTester( + chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => + new MultipleFullMatMul(c) + } should be(true) + ) + } + it should "Signal at the end" in { wrapTester( chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => @@ -92,4 +100,77 @@ object MatMulTests { step(1) } } + + class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) { + val mA = genMatrix(c.rowDimsA, c.colDimsA) + val mB = genMatrix(c.rowDimsA, c.colDimsA) + val mC = matrixMultiply(mA, mB.transpose) + + println("Multiplying") + println(printMatrix(mA)) + println("With") + println(printMatrix(mB.transpose)) + println("Expecting") + println(printMatrix(mC)) + + // Input data + for(ii <- 0 until c.colDimsA * c.rowDimsA){ + + val rowInputIdx = ii / c.colDimsA + val colInputIdx = ii % c.colDimsA + + poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx)) + poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx)) + expect(c.io.outputValid, false, "Valid output during initialization") + + step(1) + } + + // Perform calculation + for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ + for(kk <- 0 until c.colDimsA - 1){ + expect(c.io.outputValid, false, "Valid output mistimed") + step(1) + } + expect(c.io.outputValid, true, "Valid output timing is wrong") + expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") + step(1) + } + + val mX = genMatrix(c.rowDimsA, c.colDimsA) + val mY = genMatrix(c.rowDimsA, c.colDimsA) + val mZ = matrixMultiply(mX, mY.transpose) + + println("Starting new calculation") + println("Multiplying") + println(printMatrix(mX)) + println("With") + println(printMatrix(mY.transpose)) + println("Expecting") + println(printMatrix(mZ)) + + // Input data + for(ii <- 0 until c.colDimsA * c.rowDimsA){ + + val rowInputIdx = ii / c.colDimsA + val colInputIdx = ii % c.colDimsA + + poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx)) + poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx)) + expect(c.io.outputValid, false, "Valid output during initialization") + + step(1) + } + + // Perform calculation + for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ + for(kk <- 0 until c.colDimsA - 1){ + expect(c.io.outputValid, false, "Valid output mistimed") + step(1) + } + expect(c.io.outputValid, true, "Valid output timing is wrong") + expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") + step(1) + } + } }