forked from scala/scala
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathspec-matrix-new.scala
80 lines (67 loc) · 1.91 KB
/
spec-matrix-new.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
/** Test matrix multiplication with specialization.
*/
class Matrix[@specialized A: ArrayTag](val rows: Int, val cols: Int) {
private val arr: Array[Array[A]] = Array.ofDim[A](rows, cols)
def apply(i: Int, j: Int): A = {
if (i < 0 || i >= rows || j < 0 || j >= cols)
throw new NoSuchElementException("Indexes out of bounds: " + (i, j))
arr(i)(j)
}
def update(i: Int, j: Int, e: A) {
arr(i)(j) = e
}
def rowsIterator: Iterator[Array[A]] = new Iterator[Array[A]] {
var idx = 0;
def hasNext = idx < rows
def next = {
idx += 1
arr(idx - 1)
}
}
}
object Test {
def main(args: Array[String]) {
val m = randomMatrix(200, 100)
val n = randomMatrix(100, 200)
val p = mult(m, n)
println(p(0, 0))
println("Boxed doubles: " + runtime.BoxesRunTime.doubleBoxCount)
// println("Boxed integers: " + runtime.BoxesRunTime.integerBoxCount)
}
def randomMatrix(n: Int, m: Int) = {
val r = new util.Random(10)
val x = new Matrix[Double](n, m)
for (i <- 0 until n; j <- 0 until m)
x(i, j) = (r.nextInt % 1000).toDouble
x
}
def printMatrix[Double](m: Matrix[Double]) {
for (i <- 0 until m.rows) {
for (j <- 0 until m.cols)
print("%5.3f ".format(m(i, j)))
println
}
}
def multTag[@specialized(Int) T](m: Matrix[T], n: Matrix[T])(implicit at: ArrayTag[T], num: Numeric[T]) {
val p = new Matrix[T](m.rows, n.cols)
import num._
for (i <- 0 until m.rows)
for (j <- 0 until n.cols) {
var sum = num.zero
for (k <- 0 until n.rows)
sum += m(i, k) * n(k, j)
p(i, j) = sum
}
}
def mult(m: Matrix[Double], n: Matrix[Double]) = {
val p = new Matrix[Double](m.rows, n.cols)
for (i <- 0 until m.rows)
for (j <- 0 until n.cols) {
var sum = 0.0
for (k <- 0 until n.rows)
sum += m(i, k) * n(k, j)
p(i, j) = sum
}
p
}
}