我加入了艙單contextbounds在我的計劃達成了一些工作代碼:
import scala.virtualization.lms.common._
trait LinearAlgebra extends Base {
type Vector[T]
type Matrix[T]
def vector_scale[T:Manifest:Numeric](v: Rep[Vector[T]], k: Rep[T]): Rep[Vector[T]]
// Tensor product between 2 matrices
def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]]
// Concrete syntax
implicit class VectorOps[T:Numeric:Manifest](v: Rep[Vector[T]]) {
def *(k: Rep[T]):Rep[Vector[T]] = vector_scale[T](v, k)
}
implicit class MatrixOps[T:Numeric:Manifest](A:Rep[Matrix[T]]) {
def |*(B:Rep[Matrix[T]]):Rep[Matrix[T]] = tensor_prod(A,B)
}
implicit def any2rep[T:Manifest](t:T) = unit(t)
}
trait Interpreter extends Base {
override type Rep[+A] = A
override protected def unit[A: Manifest](a: A) = a
}
trait LinearAlgebraInterpreter extends LinearAlgebra with Interpreter {
override type Vector[T] = Array[T]
override type Matrix[T] = Array[Array[T]]
override def vector_scale[T:Manifest](v: Vector[T], k: T)(implicit num:Numeric[T]):Rep[Vector[T]] = v map {x => num.times(x,k)}
def tensor_prod[T:Manifest](A:Matrix[T],B:Matrix[T])(implicit num:Numeric[T]):Matrix[T] = {
def smm(s:T,m:Matrix[T]) = m.map(_.map(x => num.times(x,s)))
def concat(A:Matrix[T],B:Matrix[T]) = (A,B).zipped.map(_++_)
A flatMap (row => row map (s => smm(s,B)) reduce concat)
}
}
trait LinearAlgebraExp extends LinearAlgebra with BaseExp {
// Here we say how a Rep[Vector] will be bound to a Array[Scalar] in regular Scala code
override type Vector[T] = Array[T]
type Matrix[T] = Array[Array[T]]
// Reification of the concept of scaling a vector `v` by a factor `k`
case class VectorScale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) extends Def[Vector[T]]
override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = toAtom(VectorScale(v, k))
def tensor_prod[T:Manifest:Numeric](A:Rep[Matrix[T]],B:Rep[Matrix[T]]):Rep[Matrix[T]] = ???
}
trait ScalaGenLinearAlgebra extends ScalaGenBase {
// This code generator works with IR nodes defined by the LinearAlgebraExp trait
val IR: LinearAlgebraExp
import IR._
override def emitNode(sym: Sym[Any], node: Def[Any]): Unit = node match {
case VectorScale(v, k) => {
emitValDef(sym, quote(v) + ".map(x => x * " + quote(k) + ")")
}
case _ => super.emitNode(sym, node)
}
}
trait LinearAlgebraExpOpt extends LinearAlgebraExp {
override def vector_scale[T:Manifest:Numeric](v: Exp[Vector[T]], k: Exp[T]) = k match {
case Const(1.0) => v
case _ => super.vector_scale(v, k)
}
}
trait Prog extends LinearAlgebra {
def f[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(3))
def g[T:Manifest](v: Rep[Vector[T]])(implicit num:Numeric[T]): Rep[Vector[T]] = v * unit(num.fromInt(1))
//def h(A:Rep[Matrix],B:Rep[Matrix]):Rep[Matrix] = A |* B
}
object TestLinAlg extends App {
val interpretedProg = new Prog with LinearAlgebraInterpreter {
println(g(Array(1.0, 2.0)).mkString(","))
}
val optProg = new Prog with LinearAlgebraExpOpt with EffectExp with CompileScala { self =>
override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
codegen.emitSource(g[Double], "optimizedG", new java.io.PrintWriter(System.out))
}
val nonOptProg = new Prog with LinearAlgebraExp with EffectExp with CompileScala { self =>
override val codegen = new ScalaGenEffect with ScalaGenLinearAlgebra { val IR: self.type = self }
codegen.emitSource(g[Double], "nonOptimizedG", new java.io.PrintWriter(System.out))
}
def compareInterpCompiled = {
val optcomp = optProg.compile(optProg.g[Double])
val nonOptComp = nonOptProg.compile(nonOptProg.g[Double])
val a = Array(1.0,2.0)
optcomp(a).toList == nonOptComp(a).toList
}
println(compareInterpCompiled)
}
我的目標是使用例如在https://github.com/julienrf/lms-tutorial/wiki,然後修改它使用數字類型,而不是。我想發現傳遞隱式類型的開銷被完全剝奪了。上述程序的(成功)輸出是here。
我們看到num.times的調用確實被剝離了