| @@ -17,11 +17,16 @@ class DotProd(val elements: Int) extends Module { | |||||
| val counter = Counter(elements) | val counter = Counter(elements) | ||||
| // The sum of the result from each column | |||||
| val accumulator = RegInit(0.U(32.W)) | val accumulator = RegInit(0.U(32.W)) | ||||
| val result = accumulator + io.dataInA * io.dataInB | val result = accumulator + io.dataInA * io.dataInB | ||||
| // Reset the accumulator when we roll over, ready for | |||||
| // a new calculation | |||||
| accumulator := Mux(counter.inc(), 0.U, result) | accumulator := Mux(counter.inc(), 0.U, result) | ||||
| io.dataOut := result | io.dataOut := result | ||||
| io.outputValid := counter.inc() | io.outputValid := counter.inc() | ||||
| } | } | ||||
| @@ -22,18 +22,37 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { | |||||
| val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io | val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io | ||||
| val dotProdCalc = Module(new DotProd(colDimsA)).io | val dotProdCalc = Module(new DotProd(colDimsA)).io | ||||
| val calculating = RegInit(false.B) | |||||
| val resultCol = Counter(rowDimsA) | |||||
| val col = Counter(colDimsA) | |||||
| val row = Counter(rowDimsA) | |||||
| // The number of elements in the matrices | |||||
| val matSize = colDimsA * rowDimsA | |||||
| matrixA.rowIdx := row.value | |||||
| matrixA.colIdx := col.value | |||||
| // We use a single counter that is incremented each tick. | |||||
| // It rolls over when the calculation is finished, and we're | |||||
| // ready for a new matrix. | |||||
| val counter = Counter(matSize * (1 + rowDimsA)) | |||||
| counter.inc() | |||||
| // We first go through all cells in the matrices and insert the values. | |||||
| // We then do the multiplication by multiplying each row of matrix A with | |||||
| // matrix B. This variable says if we are currently inserting or calculating. | |||||
| val calculating = counter.value >= matSize.U | |||||
| val calcOffset = counter.value - matSize.U | |||||
| // We go through row for row, so the column is always incremented | |||||
| val col = counter.value % colDimsA.U | |||||
| // While inserting, we use the same position in both matrices. | |||||
| // While calculating, matrix A stays at the same row while we go through | |||||
| // all of matrix B. | |||||
| val rowA = Mux(calculating, calcOffset / matSize.U, counter.value / colDimsA.U) | |||||
| val rowB = Mux(calculating, (calcOffset / colDimsA.U) % rowDimsA.U, rowA) | |||||
| matrixA.rowIdx := rowA | |||||
| matrixA.colIdx := col | |||||
| matrixA.dataIn := io.dataInA | matrixA.dataIn := io.dataInA | ||||
| matrixA.writeEnable := ~calculating | matrixA.writeEnable := ~calculating | ||||
| matrixB.rowIdx := Mux(calculating, resultCol.value, row.value) | |||||
| matrixB.colIdx := col.value | |||||
| matrixB.rowIdx := rowB | |||||
| matrixB.colIdx := col | |||||
| matrixB.dataIn := io.dataInB | matrixB.dataIn := io.dataInB | ||||
| matrixB.writeEnable := ~calculating | matrixB.writeEnable := ~calculating | ||||
| @@ -42,18 +61,4 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { | |||||
| io.dataOut := dotProdCalc.dataOut | io.dataOut := dotProdCalc.dataOut | ||||
| io.outputValid := dotProdCalc.outputValid & calculating | io.outputValid := dotProdCalc.outputValid & calculating | ||||
| when (col.inc()) { | |||||
| when (calculating) { | |||||
| when (resultCol.inc()) { | |||||
| when (row.inc()) { | |||||
| calculating := false.B | |||||
| } | |||||
| } | |||||
| }.otherwise { | |||||
| when (row.inc()) { | |||||
| calculating := true.B | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| @@ -22,14 +22,14 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module { | |||||
| val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io)) | val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io)) | ||||
| for(ii <- 0 until rowsDim){ | for(ii <- 0 until rowsDim){ | ||||
| rows(ii).idx := 0.U | |||||
| rows(ii).dataIn := 0.U | |||||
| rows(ii).writeEnable := false.B | |||||
| } | |||||
| // It doesn't matter what we use to drive idx and dataIn, as long as | |||||
| // writeEnable is low, so we always use the input values. | |||||
| rows(ii).idx := io.colIdx | |||||
| rows(ii).dataIn := io.dataIn | |||||
| rows(io.rowIdx).writeEnable := io.writeEnable | |||||
| rows(io.rowIdx).idx := io.colIdx | |||||
| rows(io.rowIdx).dataIn := io.dataIn | |||||
| // Enable writing when the current row is selected | |||||
| rows(ii).writeEnable := io.writeEnable && (io.rowIdx === ii.U) | |||||
| } | |||||
| io.dataOut := rows(io.rowIdx).dataOut | io.dataOut := rows(io.rowIdx).dataOut | ||||
| } | } | ||||
| @@ -22,6 +22,14 @@ class MatMulSpec extends FlatSpec with Matchers { | |||||
| ) | ) | ||||
| } | } | ||||
| it should "Do a complete matrix multiplication multiple times" in { | |||||
| wrapTester( | |||||
| chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => | |||||
| new MultipleFullMatMul(c) | |||||
| } should be(true) | |||||
| ) | |||||
| } | |||||
| it should "Signal at the end" in { | it should "Signal at the end" in { | ||||
| wrapTester( | wrapTester( | ||||
| chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => | chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => | ||||
| @@ -92,4 +100,77 @@ object MatMulTests { | |||||
| step(1) | step(1) | ||||
| } | } | ||||
| } | } | ||||
| class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) { | |||||
| val mA = genMatrix(c.rowDimsA, c.colDimsA) | |||||
| val mB = genMatrix(c.rowDimsA, c.colDimsA) | |||||
| val mC = matrixMultiply(mA, mB.transpose) | |||||
| println("Multiplying") | |||||
| println(printMatrix(mA)) | |||||
| println("With") | |||||
| println(printMatrix(mB.transpose)) | |||||
| println("Expecting") | |||||
| println(printMatrix(mC)) | |||||
| // Input data | |||||
| for(ii <- 0 until c.colDimsA * c.rowDimsA){ | |||||
| val rowInputIdx = ii / c.colDimsA | |||||
| val colInputIdx = ii % c.colDimsA | |||||
| poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx)) | |||||
| poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx)) | |||||
| expect(c.io.outputValid, false, "Valid output during initialization") | |||||
| step(1) | |||||
| } | |||||
| // Perform calculation | |||||
| for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ | |||||
| for(kk <- 0 until c.colDimsA - 1){ | |||||
| expect(c.io.outputValid, false, "Valid output mistimed") | |||||
| step(1) | |||||
| } | |||||
| expect(c.io.outputValid, true, "Valid output timing is wrong") | |||||
| expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") | |||||
| step(1) | |||||
| } | |||||
| val mX = genMatrix(c.rowDimsA, c.colDimsA) | |||||
| val mY = genMatrix(c.rowDimsA, c.colDimsA) | |||||
| val mZ = matrixMultiply(mX, mY.transpose) | |||||
| println("Starting new calculation") | |||||
| println("Multiplying") | |||||
| println(printMatrix(mX)) | |||||
| println("With") | |||||
| println(printMatrix(mY.transpose)) | |||||
| println("Expecting") | |||||
| println(printMatrix(mZ)) | |||||
| // Input data | |||||
| for(ii <- 0 until c.colDimsA * c.rowDimsA){ | |||||
| val rowInputIdx = ii / c.colDimsA | |||||
| val colInputIdx = ii % c.colDimsA | |||||
| poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx)) | |||||
| poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx)) | |||||
| expect(c.io.outputValid, false, "Valid output during initialization") | |||||
| step(1) | |||||
| } | |||||
| // Perform calculation | |||||
| for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ | |||||
| for(kk <- 0 until c.colDimsA - 1){ | |||||
| expect(c.io.outputValid, false, "Valid output mistimed") | |||||
| step(1) | |||||
| } | |||||
| expect(c.io.outputValid, true, "Valid output timing is wrong") | |||||
| expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") | |||||
| step(1) | |||||
| } | |||||
| } | |||||
| } | } | ||||