| @@ -17,11 +17,16 @@ class DotProd(val elements: Int) extends Module { | |||
| val counter = Counter(elements) | |||
| // The sum of the result from each column | |||
| val accumulator = RegInit(0.U(32.W)) | |||
| val result = accumulator + io.dataInA * io.dataInB | |||
| // Reset the accumulator when we roll over, ready for | |||
| // a new calculation | |||
| accumulator := Mux(counter.inc(), 0.U, result) | |||
| io.dataOut := result | |||
| io.outputValid := counter.inc() | |||
| } | |||
| @@ -22,18 +22,37 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { | |||
| val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io | |||
| val dotProdCalc = Module(new DotProd(colDimsA)).io | |||
| val calculating = RegInit(false.B) | |||
| val resultCol = Counter(rowDimsA) | |||
| val col = Counter(colDimsA) | |||
| val row = Counter(rowDimsA) | |||
| // The number of elements in the matrices | |||
| val matSize = colDimsA * rowDimsA | |||
| matrixA.rowIdx := row.value | |||
| matrixA.colIdx := col.value | |||
| // We use a single counter that is incremented each tick. | |||
| // It rolls over when the calculation is finished, and we're | |||
| // ready for a new matrix. | |||
| val counter = Counter(matSize * (1 + rowDimsA)) | |||
| counter.inc() | |||
| // We first go through all cells in the matrices and insert the values. | |||
| // We then do the multiplication by multiplying each row of matrix A with | |||
| // matrix B. This variable says if we are currently inserting or calculating. | |||
| val calculating = counter.value >= matSize.U | |||
| val calcOffset = counter.value - matSize.U | |||
| // We go through row for row, so the column is always incremented | |||
| val col = counter.value % colDimsA.U | |||
| // While inserting, we use the same position in both matrices. | |||
| // While calculating, matrix A stays at the same row while we go through | |||
| // all of matrix B. | |||
| val rowA = Mux(calculating, calcOffset / matSize.U, counter.value / colDimsA.U) | |||
| val rowB = Mux(calculating, (calcOffset / colDimsA.U) % rowDimsA.U, rowA) | |||
| matrixA.rowIdx := rowA | |||
| matrixA.colIdx := col | |||
| matrixA.dataIn := io.dataInA | |||
| matrixA.writeEnable := ~calculating | |||
| matrixB.rowIdx := Mux(calculating, resultCol.value, row.value) | |||
| matrixB.colIdx := col.value | |||
| matrixB.rowIdx := rowB | |||
| matrixB.colIdx := col | |||
| matrixB.dataIn := io.dataInB | |||
| matrixB.writeEnable := ~calculating | |||
| @@ -42,18 +61,4 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { | |||
| io.dataOut := dotProdCalc.dataOut | |||
| io.outputValid := dotProdCalc.outputValid & calculating | |||
| when (col.inc()) { | |||
| when (calculating) { | |||
| when (resultCol.inc()) { | |||
| when (row.inc()) { | |||
| calculating := false.B | |||
| } | |||
| } | |||
| }.otherwise { | |||
| when (row.inc()) { | |||
| calculating := true.B | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -22,14 +22,14 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module { | |||
| val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io)) | |||
| for(ii <- 0 until rowsDim){ | |||
| rows(ii).idx := 0.U | |||
| rows(ii).dataIn := 0.U | |||
| rows(ii).writeEnable := false.B | |||
| } | |||
| // It doesn't matter what we use to drive idx and dataIn, as long as | |||
| // writeEnable is low, so we always use the input values. | |||
| rows(ii).idx := io.colIdx | |||
| rows(ii).dataIn := io.dataIn | |||
| rows(io.rowIdx).writeEnable := io.writeEnable | |||
| rows(io.rowIdx).idx := io.colIdx | |||
| rows(io.rowIdx).dataIn := io.dataIn | |||
| // Enable writing when the current row is selected | |||
| rows(ii).writeEnable := io.writeEnable && (io.rowIdx === ii.U) | |||
| } | |||
| io.dataOut := rows(io.rowIdx).dataOut | |||
| } | |||
| @@ -22,6 +22,14 @@ class MatMulSpec extends FlatSpec with Matchers { | |||
| ) | |||
| } | |||
| it should "Do a complete matrix multiplication multiple times" in { | |||
| wrapTester( | |||
| chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => | |||
| new MultipleFullMatMul(c) | |||
| } should be(true) | |||
| ) | |||
| } | |||
| it should "Signal at the end" in { | |||
| wrapTester( | |||
| chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => | |||
| @@ -92,4 +100,77 @@ object MatMulTests { | |||
| step(1) | |||
| } | |||
| } | |||
| class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) { | |||
| val mA = genMatrix(c.rowDimsA, c.colDimsA) | |||
| val mB = genMatrix(c.rowDimsA, c.colDimsA) | |||
| val mC = matrixMultiply(mA, mB.transpose) | |||
| println("Multiplying") | |||
| println(printMatrix(mA)) | |||
| println("With") | |||
| println(printMatrix(mB.transpose)) | |||
| println("Expecting") | |||
| println(printMatrix(mC)) | |||
| // Input data | |||
| for(ii <- 0 until c.colDimsA * c.rowDimsA){ | |||
| val rowInputIdx = ii / c.colDimsA | |||
| val colInputIdx = ii % c.colDimsA | |||
| poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx)) | |||
| poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx)) | |||
| expect(c.io.outputValid, false, "Valid output during initialization") | |||
| step(1) | |||
| } | |||
| // Perform calculation | |||
| for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ | |||
| for(kk <- 0 until c.colDimsA - 1){ | |||
| expect(c.io.outputValid, false, "Valid output mistimed") | |||
| step(1) | |||
| } | |||
| expect(c.io.outputValid, true, "Valid output timing is wrong") | |||
| expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") | |||
| step(1) | |||
| } | |||
| val mX = genMatrix(c.rowDimsA, c.colDimsA) | |||
| val mY = genMatrix(c.rowDimsA, c.colDimsA) | |||
| val mZ = matrixMultiply(mX, mY.transpose) | |||
| println("Starting new calculation") | |||
| println("Multiplying") | |||
| println(printMatrix(mX)) | |||
| println("With") | |||
| println(printMatrix(mY.transpose)) | |||
| println("Expecting") | |||
| println(printMatrix(mZ)) | |||
| // Input data | |||
| for(ii <- 0 until c.colDimsA * c.rowDimsA){ | |||
| val rowInputIdx = ii / c.colDimsA | |||
| val colInputIdx = ii % c.colDimsA | |||
| poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx)) | |||
| poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx)) | |||
| expect(c.io.outputValid, false, "Valid output during initialization") | |||
| step(1) | |||
| } | |||
| // Perform calculation | |||
| for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ | |||
| for(kk <- 0 until c.colDimsA - 1){ | |||
| expect(c.io.outputValid, false, "Valid output mistimed") | |||
| step(1) | |||
| } | |||
| expect(c.io.outputValid, true, "Valid output timing is wrong") | |||
| expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") | |||
| step(1) | |||
| } | |||
| } | |||
| } | |||