選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

177 行
4.8KB

  1. package Ex0
  2. import chisel3._
  3. import chisel3.iotesters.PeekPokeTester
  4. import org.scalatest.{Matchers, FlatSpec}
  5. import TestUtils._
  6. class MatMulSpec extends FlatSpec with Matchers {
  7. import MatMulTests._
  8. val rowDims = 3
  9. val colDims = 7
  10. behavior of "MatMul"
  11. it should "Do a complete matrix multiplication" in {
  12. wrapTester(
  13. chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
  14. new FullMatMul(c)
  15. } should be(true)
  16. )
  17. }
  18. it should "Do a complete matrix multiplication multiple times" in {
  19. wrapTester(
  20. chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
  21. new MultipleFullMatMul(c)
  22. } should be(true)
  23. )
  24. }
  25. it should "Signal at the end" in {
  26. wrapTester(
  27. chisel3.iotesters.Driver(() => new MatMul(rowDims, colDims)) { c =>
  28. new SignalWhenDone(c)
  29. } should be(true)
  30. )
  31. }
  32. }
  33. object MatMulTests {
  34. val rand = new scala.util.Random(100)
  35. class SignalWhenDone(c: MatMul) extends PeekPokeTester(c) {
  36. val mA = genMatrix(c.rowDimsA, c.colDimsA)
  37. val mB = genMatrix(c.rowDimsA, c.colDimsA)
  38. val mC = matrixMultiply(mA, mB.transpose)
  39. for(ii <- 0 until c.colDimsA * c.rowDimsA){
  40. expect(c.io.outputValid, false, "Valid output during initialization")
  41. step(1)
  42. }
  43. for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
  44. for(kk <- 0 until c.colDimsA - 1){
  45. expect(c.io.outputValid, false, "Valid output mistimed")
  46. step(1)
  47. }
  48. expect(c.io.outputValid, true, "Valid output timing is wrong")
  49. step(1)
  50. }
  51. }
  52. class FullMatMul(c: MatMul) extends PeekPokeTester(c) {
  53. val mA = genMatrix(c.rowDimsA, c.colDimsA)
  54. val mB = genMatrix(c.rowDimsA, c.colDimsA)
  55. val mC = matrixMultiply(mA, mB.transpose)
  56. println("Multiplying")
  57. println(printMatrix(mA))
  58. println("With")
  59. println(printMatrix(mB.transpose))
  60. println("Expecting")
  61. println(printMatrix(mC))
  62. // Input data
  63. for(ii <- 0 until c.colDimsA * c.rowDimsA){
  64. val rowInputIdx = ii / c.colDimsA
  65. val colInputIdx = ii % c.colDimsA
  66. poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx))
  67. poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx))
  68. expect(c.io.outputValid, false, "Valid output during initialization")
  69. step(1)
  70. }
  71. // Perform calculation
  72. for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
  73. for(kk <- 0 until c.colDimsA - 1){
  74. expect(c.io.outputValid, false, "Valid output mistimed")
  75. step(1)
  76. }
  77. expect(c.io.outputValid, true, "Valid output timing is wrong")
  78. expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
  79. step(1)
  80. }
  81. }
  82. class MultipleFullMatMul(c: MatMul) extends PeekPokeTester(c) {
  83. val mA = genMatrix(c.rowDimsA, c.colDimsA)
  84. val mB = genMatrix(c.rowDimsA, c.colDimsA)
  85. val mC = matrixMultiply(mA, mB.transpose)
  86. println("Multiplying")
  87. println(printMatrix(mA))
  88. println("With")
  89. println(printMatrix(mB.transpose))
  90. println("Expecting")
  91. println(printMatrix(mC))
  92. // Input data
  93. for(ii <- 0 until c.colDimsA * c.rowDimsA){
  94. val rowInputIdx = ii / c.colDimsA
  95. val colInputIdx = ii % c.colDimsA
  96. poke(c.io.dataInA, mA(rowInputIdx)(colInputIdx))
  97. poke(c.io.dataInB, mB(rowInputIdx)(colInputIdx))
  98. expect(c.io.outputValid, false, "Valid output during initialization")
  99. step(1)
  100. }
  101. // Perform calculation
  102. for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
  103. for(kk <- 0 until c.colDimsA - 1){
  104. expect(c.io.outputValid, false, "Valid output mistimed")
  105. step(1)
  106. }
  107. expect(c.io.outputValid, true, "Valid output timing is wrong")
  108. expect(c.io.dataOut, mC(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
  109. step(1)
  110. }
  111. val mX = genMatrix(c.rowDimsA, c.colDimsA)
  112. val mY = genMatrix(c.rowDimsA, c.colDimsA)
  113. val mZ = matrixMultiply(mX, mY.transpose)
  114. println("Starting new calculation")
  115. println("Multiplying")
  116. println(printMatrix(mX))
  117. println("With")
  118. println(printMatrix(mY.transpose))
  119. println("Expecting")
  120. println(printMatrix(mZ))
  121. // Input data
  122. for(ii <- 0 until c.colDimsA * c.rowDimsA){
  123. val rowInputIdx = ii / c.colDimsA
  124. val colInputIdx = ii % c.colDimsA
  125. poke(c.io.dataInA, mX(rowInputIdx)(colInputIdx))
  126. poke(c.io.dataInB, mY(rowInputIdx)(colInputIdx))
  127. expect(c.io.outputValid, false, "Valid output during initialization")
  128. step(1)
  129. }
  130. // Perform calculation
  131. for(ii <- 0 until (c.rowDimsA * c.rowDimsA)){
  132. for(kk <- 0 until c.colDimsA - 1){
  133. expect(c.io.outputValid, false, "Valid output mistimed")
  134. step(1)
  135. }
  136. expect(c.io.outputValid, true, "Valid output timing is wrong")
  137. expect(c.io.dataOut, mZ(ii / c.rowDimsA)(ii % c.rowDimsA), "Wrong value calculated")
  138. step(1)
  139. }
  140. }
  141. }