瀏覽代碼

First solution to EX0

sindre-ex0
Sindre Stephansen 6 年之前
父節點
當前提交
d235c5697b
共有 8 個文件被更改,包括 123 次插入68 次删除
  1. +5
    -9
      src/main/scala/DotProd.scala
  2. +37
    -29
      src/main/scala/MatMul.scala
  3. +8
    -8
      src/main/scala/Matrix.scala
  4. +2
    -12
      src/main/scala/Vector.scala
  5. +5
    -6
      src/main/scala/main.scala
  6. +43
    -0
      src/test/scala/DotProdSpec.scala
  7. +22
    -2
      src/test/scala/MatMulSpec.scala
  8. +1
    -2
      src/test/scala/MatrixSpec.scala

+ 5
- 9
src/main/scala/DotProd.scala 查看文件

@@ -16,16 +16,12 @@ class DotProd(val elements: Int) extends Module {
)


/**
* Your code here
*/
val counter = Counter(elements)
val accumulator = RegInit(UInt(32.W), 0.U)
val accumulator = RegInit(0.U(32.W))

// Please don't manually implement product!
val product = io.dataInA * io.dataInB
val result = accumulator + io.dataInA * io.dataInB

// placeholder
io.dataOut := 0.U
io.outputValid := false.B
accumulator := Mux(counter.inc(), 0.U, result)
io.dataOut := result
io.outputValid := counter.inc()
}

+ 37
- 29
src/main/scala/MatMul.scala 查看文件

@@ -1,10 +1,11 @@
package Ex0

import scala.math.max

import chisel3._
import chisel3.util.Counter
import chisel3.experimental.MultiIOModule

class MatMul(val rowDimsA: Int, val colDimsA: Int) extends MultiIOModule {
class MatMul(val rowDimsA: Int, val colDimsA: Int) extends Module {

val io = IO(
new Bundle {
@@ -16,36 +17,43 @@ class MatMul(val rowDimsA: Int, val colDimsA: Int) extends MultiIOModule {
}
)

val debug = IO(
new Bundle {
val myDebugSignal = Output(Bool())
}
)


/**
* Your code here
*/
val matrixA = Module(new Matrix(rowDimsA, colDimsA)).io
val matrixB = Module(new Matrix(rowDimsA, colDimsA)).io
val dotProdCalc = Module(new DotProd(colDimsA)).io

matrixA.dataIn := 0.U
matrixA.rowIdx := 0.U
matrixA.colIdx := 0.U
matrixA.writeEnable := false.B

matrixB.rowIdx := 0.U
matrixB.colIdx := 0.U
matrixB.dataIn := 0.U
matrixB.writeEnable := false.B

dotProdCalc.dataInA := 0.U
dotProdCalc.dataInB := 0.U

io.dataOut := 0.U
io.outputValid := false.B


debug.myDebugSignal := false.B
val calculating = RegInit(false.B)
val resultCol = Counter(rowDimsA)
val col = Counter(colDimsA)
val row = Counter(rowDimsA)

matrixA.rowIdx := row.value
matrixA.colIdx := col.value
matrixA.dataIn := io.dataInA
matrixA.writeEnable := ~calculating

matrixB.rowIdx := Mux(calculating, resultCol.value, row.value)
matrixB.colIdx := col.value
matrixB.dataIn := io.dataInB
matrixB.writeEnable := ~calculating

dotProdCalc.dataInA := matrixA.dataOut
dotProdCalc.dataInB := matrixB.dataOut

io.dataOut := dotProdCalc.dataOut
io.outputValid := dotProdCalc.outputValid & calculating

when (col.inc()) {
when (calculating) {
when (resultCol.inc()) {
when (row.inc()) {
calculating := false.B
}
}
}.otherwise {
when (row.inc()) {
calculating := true.B
}
}
}
}

+ 8
- 8
src/main/scala/Matrix.scala 查看文件

@@ -18,18 +18,18 @@ class Matrix(val rowsDim: Int, val colsDim: Int) extends Module {
}
)

/**
* Your code here
*/

// Creates a vector of zero-initialized registers
val rows = Vec.fill(rowsDim)(Module(new Vector(colsDim)).io)
val rows = VecInit(List.fill(rowsDim)(Module(new Vector(colsDim)).io))

// placeholders
io.dataOut := 0.U
for(ii <- 0 until rowsDim){
rows(ii).idx := 0.U
rows(ii).dataIn := 0.U
rows(ii).writeEnable := false.B
rows(ii).idx := 0.U
}

rows(io.rowIdx).writeEnable := io.writeEnable
rows(io.rowIdx).idx := io.colIdx
rows(io.rowIdx).dataIn := io.dataIn

io.dataOut := rows(io.rowIdx).dataOut
}

+ 2
- 12
src/main/scala/Vector.scala 查看文件

@@ -15,21 +15,11 @@ class Vector(val elements: Int) extends Module {
}
)

// Creates a vector of zero-initialized registers
val internalVector = RegInit(VecInit(List.fill(elements)(0.U(32.W))))


when(io.writeEnable){
// TODO:
// When writeEnable is true the content of internalVector at the index specified
// by idx should be set to the value of io.dataIn
internalVector(io.idx) := io.dataIn
}
// In this case we don't want an otherwise block, in writeEnable is low we don't change
// anything


// TODO:
// io.dataOut should be driven by the contents of internalVector at the index specified
// by idx
io.dataOut := 0.U
io.dataOut := internalVector(io.idx)
}

+ 5
- 6
src/main/scala/main.scala 查看文件

@@ -7,16 +7,15 @@ object main {
val s = """
| Attempting to "run" a chisel program is rather meaningless.
| Instead, try running the tests, for instance with "test" or "testOnly Examples.MyIncrementTest
|
| If you want to create chisel graphs, simply remove this message and comment in the code underneath
|
| If you want to create chisel graphs, simply remove this message and comment in the code underneath
| to generate the modules you're interested in.
""".stripMargin
println(s)
//println(s)
}

// Uncomment to dump .fir file
// val f = new File("MatMul.fir")
// chisel3.Driver.dumpFirrtl(chisel3.Driver.elaborate(() => new MatMul(5, 4)), Option(f))
val f = new File("MatMul.fir")
chisel3.Driver.dumpFirrtl(chisel3.Driver.elaborate(() => new MatMul(5, 4)), Option(f))

}


+ 43
- 0
src/test/scala/DotProdSpec.scala 查看文件

@@ -31,6 +31,15 @@ class DotProdSpec extends FlatSpec with Matchers {
}


it should "Calculate the correct output multiple times" in {
wrapTester(
chisel3.iotesters.Driver(() => new DotProd(elements)) { c =>
new CalculatesMultiple(c)
} should be(true)
)
}


it should "Calculate the correct output and signal when appropriate" in {
wrapTester(
chisel3.iotesters.Driver(() => new DotProd(elements)) { c =>
@@ -81,6 +90,40 @@ object DotProdTests {
}


class CalculatesMultiple(c: DotProd) extends PeekPokeTester(c) {

val inputsA = List.fill(c.elements)(rand.nextInt(10))
val inputsB = List.fill(c.elements)(rand.nextInt(10))
val inputsC = List.fill(c.elements)(rand.nextInt(10))
val inputsD = List.fill(c.elements)(rand.nextInt(10))

println("runnign dot prod calc with multiple inputs:")
println(inputsA.mkString("[", "] [", "]"))
println(inputsB.mkString("[", "] [", "]"))
println()
println(inputsC.mkString("[", "] [", "]"))
println(inputsD.mkString("[", "] [", "]"))
val expectedOutput1 = (for ((a, b) <- inputsA zip inputsB) yield a * b) sum
val expectedOutput2 = (for ((a, b) <- inputsC zip inputsD) yield a * b) sum

for(ii <- 0 until c.elements){
poke(c.io.dataInA, inputsA(ii))
poke(c.io.dataInB, inputsB(ii))
if(ii == c.elements - 1)
expect(c.io.dataOut, expectedOutput1)
step(1)
}

for(ii <- 0 until c.elements){
poke(c.io.dataInA, inputsC(ii))
poke(c.io.dataInB, inputsD(ii))
if(ii == c.elements - 1)
expect(c.io.dataOut, expectedOutput2)
step(1)
}
}


class CalculatesCorrectResultAndSignals(c: DotProd) extends PeekPokeTester(c) {

val inputsA = List.fill(c.elements)(rand.nextInt(10))


+ 22
- 2
src/test/scala/MatMulSpec.scala 查看文件

@@ -14,25 +14,45 @@ class MatMulSpec extends FlatSpec with Matchers {

behavior of "MatMul"

it should "Do shit" in {
it should "Do a complete matrix multiplication" in {
wrapTester(
chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
new FullMatMul(c)
} should be(true)
)
}

it should "Signal at the end" in {
wrapTester(
chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
new SignalWhenDone(c)
} should be(true)
)
}
}

object MatMulTests {

val rand = new scala.util.Random(100)

class TestExample(c: MatMul) extends PeekPokeTester(c) {
class SignalWhenDone(c: MatMul) extends PeekPokeTester(c) {
val mA = genMatrix(c.rowDimsA, c.colDimsA)
val mB = genMatrix(c.rowDimsA, c.colDimsA)
val mC = matrixMultiply(mA, mB.transpose)

for(ii <- 0 until c.colDimsA * c.rowDimsA){
expect(c.io.outputValid, false, "Valid output during initialization")
step(1)
}

for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
for(kk <- 0 until c.colDimsA - 1){
expect(c.io.outputValid, false, "Valid output mistimed")
step(1)
}
expect(c.io.outputValid, true, "Valid output timing is wrong")
step(1)
}
}

class FullMatMul(c: MatMul) extends PeekPokeTester(c) {


+ 1
- 2
src/test/scala/MatrixSpec.scala 查看文件

@@ -12,7 +12,6 @@ class MatrixSpec extends FlatSpec with Matchers {

behavior of "Matrix"

val rand = new scala.util.Random(100)
val rowDims = 5
val colDims = 3

@@ -37,7 +36,7 @@ class MatrixSpec extends FlatSpec with Matchers {
it should "Retain its contents when writeEnable is low" in {
wrapTester(
chisel3.iotesters.Driver(() => new Matrix(rowDims, colDims)) { c =>
new UpdatesData(c)
new RetainsData(c)
} should be(true)
)
}


Loading…
取消
儲存