From d235c5697bc4a7d8b41ffb481318773251c71df6 Mon Sep 17 00:00:00 2001 From: Sindre Stephansen Date: Wed, 28 Aug 2019 08:11:52 +0200 Subject: [PATCH] First solution to EX0 --- src/main/scala/DotProd.scala | 14 +++---- src/main/scala/MatMul.scala | 66 ++++++++++++++++++-------------- src/main/scala/Matrix.scala | 16 ++++---- src/main/scala/Vector.scala | 14 +------ src/main/scala/main.scala | 11 +++--- src/test/scala/DotProdSpec.scala | 43 +++++++++++++++++++++ src/test/scala/MatMulSpec.scala | 24 +++++++++++- src/test/scala/MatrixSpec.scala | 3 +- 8 files changed, 123 insertions(+), 68 deletions(-) diff --git a/src/main/scala/DotProd.scala b/src/main/scala/DotProd.scala index 9a65cc2..08ee57f 100644 --- a/src/main/scala/DotProd.scala +++ b/src/main/scala/DotProd.scala @@ -16,16 +16,12 @@ class DotProd(val elements: Int) extends Module { ) - /** - * Your code here - */ val counter = Counter(elements) - val accumulator = RegInit(UInt(32.W), 0.U) + val accumulator = RegInit(0.U(32.W)) - // Please don't manually implement product! - val product = io.dataInA * io.dataInB + val result = accumulator + io.dataInA * io.dataInB - // placeholder - io.dataOut := 0.U - io.outputValid := false.B + accumulator := Mux(counter.inc(), 0.U, result) + io.dataOut := result + io.outputValid := counter.inc() } diff --git a/src/main/scala/MatMul.scala b/src/main/scala/MatMul.scala index d69b41b..617c251 100644 --- a/src/main/scala/MatMul.scala +++ b/src/main/scala/MatMul.scala @@ -1,10 +1,11 @@ package Ex0 +import scala.math.max + import chisel3._ import chisel3.util.Counter -import chisel3.experimental.MultiIOModule -class MatMul(val rowDimsA: Int, val colDimsA: Int) extends MultiIOModule { +class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module { val io = IO( new Bundle { @@ -16,36 +17,43 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends MultiIOModule { } ) - val debug = IO( - new Bundle { - val myDebugSignal = Output(Bool()) - } - ) - - /** - * Your code here - */ val matrixA = Module(new Matrix(rowDimsA, colDimsA)).io val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io val dotProdCalc = Module(new DotProd(colDimsA)).io - matrixA.dataIn := 0.U - matrixA.rowIdx := 0.U - matrixA.colIdx := 0.U - matrixA.writeEnable := false.B - - matrixB.rowIdx := 0.U - matrixB.colIdx := 0.U - matrixB.dataIn := 0.U - matrixB.writeEnable := false.B - - dotProdCalc.dataInA := 0.U - dotProdCalc.dataInB := 0.U - - io.dataOut := 0.U - io.outputValid := false.B - - - debug.myDebugSignal := false.B + val calculating = RegInit(false.B) + val resultCol = Counter(rowDimsA) + val col = Counter(colDimsA) + val row = Counter(rowDimsA) + + matrixA.rowIdx := row.value + matrixA.colIdx := col.value + matrixA.dataIn := io.dataInA + matrixA.writeEnable := ~calculating + + matrixB.rowIdx := Mux(calculating, resultCol.value, row.value) + matrixB.colIdx := col.value + matrixB.dataIn := io.dataInB + matrixB.writeEnable := ~calculating + + dotProdCalc.dataInA := matrixA.dataOut + dotProdCalc.dataInB := matrixB.dataOut + + io.dataOut := dotProdCalc.dataOut + io.outputValid := dotProdCalc.outputValid & calculating + + when (col.inc()) { + when (calculating) { + when (resultCol.inc()) { + when (row.inc()) { + calculating := false.B + } + } + }.otherwise { + when (row.inc()) { + calculating := true.B + } + } + } } diff --git a/src/main/scala/Matrix.scala b/src/main/scala/Matrix.scala index 126f436..c1658b4 100644 --- a/src/main/scala/Matrix.scala +++ b/src/main/scala/Matrix.scala @@ -18,18 +18,18 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module { } ) - /** - * Your code here - */ - // Creates a vector of zero-initialized registers - val rows = Vec.fill(rowsDim)(Module(new Vector(colsDim)).io) + val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io)) - // placeholders - io.dataOut := 0.U for(ii <- 0 until rowsDim){ + rows(ii).idx := 0.U rows(ii).dataIn := 0.U rows(ii).writeEnable := false.B - rows(ii).idx := 0.U } + + rows(io.rowIdx).writeEnable := io.writeEnable + rows(io.rowIdx).idx := io.colIdx + rows(io.rowIdx).dataIn := io.dataIn + + io.dataOut := rows(io.rowIdx).dataOut } diff --git a/src/main/scala/Vector.scala b/src/main/scala/Vector.scala index 10afbc7..6533ee1 100644 --- a/src/main/scala/Vector.scala +++ b/src/main/scala/Vector.scala @@ -15,21 +15,11 @@ class Vector(val elements: Int) extends Module { } ) - // Creates a vector of zero-initialized registers val internalVector = RegInit(VecInit(List.fill(elements)(0.U(32.W)))) - when(io.writeEnable){ - // TODO: - // When writeEnable is true the content of internalVector at the index specified - // by idx should be set to the value of io.dataIn + internalVector(io.idx) := io.dataIn } - // In this case we don't want an otherwise block, in writeEnable is low we don't change - // anything - - // TODO: - // io.dataOut should be driven by the contents of internalVector at the index specified - // by idx - io.dataOut := 0.U + io.dataOut := internalVector(io.idx) } diff --git a/src/main/scala/main.scala b/src/main/scala/main.scala index 9b0a7c9..11cda3d 100644 --- a/src/main/scala/main.scala +++ b/src/main/scala/main.scala @@ -7,16 +7,15 @@ object main { val s = """ | Attempting to "run" a chisel program is rather meaningless. | Instead, try running the tests, for instance with "test" or "testOnly Examples.MyIncrementTest - | - | If you want to create chisel graphs, simply remove this message and comment in the code underneath + | + | If you want to create chisel graphs, simply remove this message and comment in the code underneath | to generate the modules you're interested in. """.stripMargin - println(s) + //println(s) } // Uncomment to dump .fir file - // val f = new File("MatMul.fir") - // chisel3.Driver.dumpFirrtl(chisel3.Driver.elaborate(() => new MatMul(5, 4)), Option(f)) + val f = new File("MatMul.fir") + chisel3.Driver.dumpFirrtl(chisel3.Driver.elaborate(() => new MatMul(5, 4)), Option(f)) } - diff --git a/src/test/scala/DotProdSpec.scala b/src/test/scala/DotProdSpec.scala index d19d1ee..c0b09dc 100644 --- a/src/test/scala/DotProdSpec.scala +++ b/src/test/scala/DotProdSpec.scala @@ -31,6 +31,15 @@ class DotProdSpec extends FlatSpec with Matchers { } + it should "Calculate the correct output multiple times" in { + wrapTester( + chisel3.iotesters.Driver(() => new DotProd(elements)) { c => + new CalculatesMultiple(c) + } should be(true) + ) + } + + it should "Calculate the correct output and signal when appropriate" in { wrapTester( chisel3.iotesters.Driver(() => new DotProd(elements)) { c => @@ -81,6 +90,40 @@ object DotProdTests { } + class CalculatesMultiple(c: DotProd) extends PeekPokeTester(c) { + + val inputsA = List.fill(c.elements)(rand.nextInt(10)) + val inputsB = List.fill(c.elements)(rand.nextInt(10)) + val inputsC = List.fill(c.elements)(rand.nextInt(10)) + val inputsD = List.fill(c.elements)(rand.nextInt(10)) + + println("runnign dot prod calc with multiple inputs:") + println(inputsA.mkString("[", "] [", "]")) + println(inputsB.mkString("[", "] [", "]")) + println() + println(inputsC.mkString("[", "] [", "]")) + println(inputsD.mkString("[", "] [", "]")) + val expectedOutput1 = (for ((a, b) <- inputsA zip inputsB) yield a * b) sum + val expectedOutput2 = (for ((a, b) <- inputsC zip inputsD) yield a * b) sum + + for(ii <- 0 until c.elements){ + poke(c.io.dataInA, inputsA(ii)) + poke(c.io.dataInB, inputsB(ii)) + if(ii == c.elements - 1) + expect(c.io.dataOut, expectedOutput1) + step(1) + } + + for(ii <- 0 until c.elements){ + poke(c.io.dataInA, inputsC(ii)) + poke(c.io.dataInB, inputsD(ii)) + if(ii == c.elements - 1) + expect(c.io.dataOut, expectedOutput2) + step(1) + } + } + + class CalculatesCorrectResultAndSignals(c: DotProd) extends PeekPokeTester(c) { val inputsA = List.fill(c.elements)(rand.nextInt(10)) diff --git a/src/test/scala/MatMulSpec.scala b/src/test/scala/MatMulSpec.scala index 9e1c0a5..17f97b8 100644 --- a/src/test/scala/MatMulSpec.scala +++ b/src/test/scala/MatMulSpec.scala @@ -14,25 +14,45 @@ class MatMulSpec extends FlatSpec with Matchers { behavior of "MatMul" - it should "Do shit" in { + it should "Do a complete matrix multiplication" in { wrapTester( chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => new FullMatMul(c) } should be(true) ) } + + it should "Signal at the end" in { + wrapTester( + chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c => + new SignalWhenDone(c) + } should be(true) + ) + } } object MatMulTests { val rand = new scala.util.Random(100) - class TestExample(c: MatMul) extends PeekPokeTester(c) { + class SignalWhenDone(c: MatMul) extends PeekPokeTester(c) { val mA = genMatrix(c.rowDimsA, c.colDimsA) val mB = genMatrix(c.rowDimsA, c.colDimsA) val mC = matrixMultiply(mA, mB.transpose) + for(ii <- 0 until c.colDimsA * c.rowDimsA){ + expect(c.io.outputValid, false, "Valid output during initialization") + step(1) + } + for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){ + for(kk <- 0 until c.colDimsA - 1){ + expect(c.io.outputValid, false, "Valid output mistimed") + step(1) + } + expect(c.io.outputValid, true, "Valid output timing is wrong") + step(1) + } } class FullMatMul(c: MatMul) extends PeekPokeTester(c) { diff --git a/src/test/scala/MatrixSpec.scala b/src/test/scala/MatrixSpec.scala index 580881e..01839f6 100644 --- a/src/test/scala/MatrixSpec.scala +++ b/src/test/scala/MatrixSpec.scala @@ -12,7 +12,6 @@ class MatrixSpec extends FlatSpec with Matchers { behavior of "Matrix" - val rand = new scala.util.Random(100) val rowDims = 5 val colDims = 3 @@ -37,7 +36,7 @@ class MatrixSpec extends FlatSpec with Matchers { it should "Retain its contents when writeEnable is low" in { wrapTester( chisel3.iotesters.Driver(() => new Matrix(rowDims, colDims)) { c => - new UpdatesData(c) + new RetainsData(c) } should be(true) ) }