|
- package Ex0
-
- import chisel3._
- import chisel3.iotesters.PeekPokeTester
- import org.scalatest.{Matchers, FlatSpec}
- import TestUtils._
-
- class MatMulSpec extends FlatSpec with Matchers {
- import MatMulTests._
-
- val rowDims = 3
- val colDims = 7
-
-
- behavior of "MatMul"
-
- it should "Do a complete matrix multiplication" in {
- wrapTester(
- chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
- new FullMatMul(c)
- } should be(true)
- )
- }
-
- it should "Do a complete matrix multiplication multiple times" in {
- wrapTester(
- chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
- new MultipleFullMatMul(c)
- } should be(true)
- )
- }
-
- it should "Signal at the end" in {
- wrapTester(
- chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
- new SignalWhenDone(c)
- } should be(true)
- )
- }
- }
-
- object MatMulTests {
-
- val rand = new scala.util.Random(100)
-
- class SignalWhenDone(c: MatMul) extends PeekPokeTester(c) {
- val mA = genMatrix(c.rowDimsA, c.colDimsA)
- val mB = genMatrix(c.rowDimsA, c.colDimsA)
- val mC = matrixMultiply(mA, mB.transpose)
-
- for(ii <- 0 until c.colDimsA * c.rowDimsA){
- expect(c.io.outputValid, false, "Valid output during initialization")
- step(1)
- }
-
- for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
- for(kk <- 0 until c.colDimsA - 1){
- expect(c.io.outputValid, false, "Valid output mistimed")
- step(1)
- }
- expect(c.io.outputValid, true, "Valid output timing is wrong")
- step(1)
- }
- }
-
- class FullMatMul(c: MatMul) extends PeekPokeTester(c) {
-
- val mA = genMatrix(c.rowDimsA, c.colDimsA)
- val mB = genMatrix(c.rowDimsA, c.colDimsA)
- val mC = matrixMultiply(mA, mB.transpose)
-
- println("Multiplying")
- println(printMatrix(mA))
- println("With")
- println(printMatrix(mB.transpose))
- println("Expecting")
- println(printMatrix(mC))
-
- // Input data
- for(ii <- 0 until c.colDimsA * c.rowDimsA){
-
- val rowInputIdx = ii / c.colDimsA
- val colInputIdx = ii % c.colDimsA
-
- poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx))
- poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx))
- expect(c.io.outputValid, false, "Valid output during initialization")
-
- step(1)
- }
-
- // Perform calculation
- for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
- for(kk <- 0 until c.colDimsA - 1){
- expect(c.io.outputValid, false, "Valid output mistimed")
- step(1)
- }
- expect(c.io.outputValid, true, "Valid output timing is wrong")
- expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
- step(1)
- }
- }
-
- class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) {
- val mA = genMatrix(c.rowDimsA, c.colDimsA)
- val mB = genMatrix(c.rowDimsA, c.colDimsA)
- val mC = matrixMultiply(mA, mB.transpose)
-
- println("Multiplying")
- println(printMatrix(mA))
- println("With")
- println(printMatrix(mB.transpose))
- println("Expecting")
- println(printMatrix(mC))
-
- // Input data
- for(ii <- 0 until c.colDimsA * c.rowDimsA){
-
- val rowInputIdx = ii / c.colDimsA
- val colInputIdx = ii % c.colDimsA
-
- poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx))
- poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx))
- expect(c.io.outputValid, false, "Valid output during initialization")
-
- step(1)
- }
-
- // Perform calculation
- for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
- for(kk <- 0 until c.colDimsA - 1){
- expect(c.io.outputValid, false, "Valid output mistimed")
- step(1)
- }
- expect(c.io.outputValid, true, "Valid output timing is wrong")
- expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
- step(1)
- }
-
- val mX = genMatrix(c.rowDimsA, c.colDimsA)
- val mY = genMatrix(c.rowDimsA, c.colDimsA)
- val mZ = matrixMultiply(mX, mY.transpose)
-
- println("Starting new calculation")
- println("Multiplying")
- println(printMatrix(mX))
- println("With")
- println(printMatrix(mY.transpose))
- println("Expecting")
- println(printMatrix(mZ))
-
- // Input data
- for(ii <- 0 until c.colDimsA * c.rowDimsA){
-
- val rowInputIdx = ii / c.colDimsA
- val colInputIdx = ii % c.colDimsA
-
- poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx))
- poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx))
- expect(c.io.outputValid, false, "Valid output during initialization")
-
- step(1)
- }
-
- // Perform calculation
- for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
- for(kk <- 0 until c.colDimsA - 1){
- expect(c.io.outputValid, false, "Valid output mistimed")
- step(1)
- }
- expect(c.io.outputValid, true, "Valid output timing is wrong")
- expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
- step(1)
- }
- }
- }
|