package Ex0 import chisel3._ import chisel3.iotesters.PeekPokeTester import org.scalatest.{Matchers, FlatSpec} import TestUtils._ class MatMulSpec extends FlatSpec with Matchers { import MatMulTests._ val rowDims = 3 val colDims = 7 behavior of "MatMul" it should "Do a complete matrix multiplication" in { wrapTester( chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => new FullMatMul(c) } should be(true) ) } it should "Do a complete matrix multiplication multiple times" in { wrapTester( chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => new MultipleFullMatMul(c) } should be(true) ) } it should "Signal at the end" in { wrapTester( chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => new SignalWhenDone(c) } should be(true) ) } } object MatMulTests { val rand = new scala.util.Random(100) class SignalWhenDone(c: MatMul) extends PeekPokeTester(c) { val mA = genMatrix(c.rowDimsA, c.colDimsA) val mB = genMatrix(c.rowDimsA, c.colDimsA) val mC = matrixMultiply(mA, mB.transpose) for(ii <- 0 until c.colDimsA * c.rowDimsA){ expect(c.io.outputValid, false, "Valid output during initialization") step(1) } for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ for(kk <- 0 until c.colDimsA - 1){ expect(c.io.outputValid, false, "Valid output mistimed") step(1) } expect(c.io.outputValid, true, "Valid output timing is wrong") step(1) } } class FullMatMul(c: MatMul) extends PeekPokeTester(c) { val mA = genMatrix(c.rowDimsA, c.colDimsA) val mB = genMatrix(c.rowDimsA, c.colDimsA) val mC = matrixMultiply(mA, mB.transpose) println("Multiplying") println(printMatrix(mA)) println("With") println(printMatrix(mB.transpose)) println("Expecting") println(printMatrix(mC)) // Input data for(ii <- 0 until c.colDimsA * c.rowDimsA){ val rowInputIdx = ii / c.colDimsA val colInputIdx = ii % c.colDimsA poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx)) poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx)) expect(c.io.outputValid, false, "Valid output during initialization") step(1) } // Perform calculation for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ for(kk <- 0 until c.colDimsA - 1){ expect(c.io.outputValid, false, "Valid output mistimed") step(1) } expect(c.io.outputValid, true, "Valid output timing is wrong") expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") step(1) } } class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) { val mA = genMatrix(c.rowDimsA, c.colDimsA) val mB = genMatrix(c.rowDimsA, c.colDimsA) val mC = matrixMultiply(mA, mB.transpose) println("Multiplying") println(printMatrix(mA)) println("With") println(printMatrix(mB.transpose)) println("Expecting") println(printMatrix(mC)) // Input data for(ii <- 0 until c.colDimsA * c.rowDimsA){ val rowInputIdx = ii / c.colDimsA val colInputIdx = ii % c.colDimsA poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx)) poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx)) expect(c.io.outputValid, false, "Valid output during initialization") step(1) } // Perform calculation for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ for(kk <- 0 until c.colDimsA - 1){ expect(c.io.outputValid, false, "Valid output mistimed") step(1) } expect(c.io.outputValid, true, "Valid output timing is wrong") expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") step(1) } val mX = genMatrix(c.rowDimsA, c.colDimsA) val mY = genMatrix(c.rowDimsA, c.colDimsA) val mZ = matrixMultiply(mX, mY.transpose) println("Starting new calculation") println("Multiplying") println(printMatrix(mX)) println("With") println(printMatrix(mY.transpose)) println("Expecting") println(printMatrix(mZ)) // Input data for(ii <- 0 until c.colDimsA * c.rowDimsA){ val rowInputIdx = ii / c.colDimsA val colInputIdx = ii % c.colDimsA poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx)) poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx)) expect(c.io.outputValid, false, "Valid output during initialization") step(1) } // Perform calculation for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ for(kk <- 0 until c.colDimsA - 1){ expect(c.io.outputValid, false, "Valid output mistimed") step(1) } expect(c.io.outputValid, true, "Valid output timing is wrong") expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated") step(1) } } }