2013-08-24 73 views
2

我想編寫一個scala宏,它可以基於映射條目用簡單類型檢查覆蓋case類的字段值。 如果原始字段類型和覆蓋值類型兼容,則設置新值,否則保持原始值。根據映射條目更改scala宏中的簡單類型檢查的映射條目字段

到目前爲止,我下面的代碼:

import language.experimental.macros 
    import scala.reflect.macros.Context 

    object ProductUtils { 

     def withOverrides[T](entity: T, overrides: Map[String, Any]): T = 
      macro withOverridesImpl[T] 

     def withOverridesImpl[T: c.WeakTypeTag](c: Context) 
               (entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = { 
      import c.universe._ 

      val originalEntityTree = reify(entity.splice).tree 
      val originalEntityCopy = entity.actualType.member(newTermName("copy")) 

      val originalEntity = 
       weakTypeOf[T].declarations.collect { 
        case m: MethodSymbol if m.isCaseAccessor => 
         (m.name, c.Expr[T](Select(originalEntityTree, m.name)), m.returnType) 
       } 

      val values = 
       originalEntity.map { 
        case (name, value, ctype) => 
         AssignOrNamedArg(
          Ident(name), 
          { 
           def reifyWithType[K: WeakTypeTag] = reify { 
            overrides 
             .splice 
             .asInstanceOf[Map[String, Any]] 
             .get(c.literal(name.decoded).splice) match { 
              case Some(newValue : K) => newValue 
              case _     => value.splice 
             } 
           } 

           reifyWithType(c.WeakTypeTag(ctype)).tree 
          } 
         ) 
       }.toList 

      originalEntityCopy match { 
       case s: MethodSymbol => 
        c.Expr[T](
         Apply(Select(originalEntityTree, originalEntityCopy), values)) 
       case _ => c.abort(c.enclosingPosition, "No eligible copy method!") 
      } 

     } 

    } 

執行的是這樣的:

import macros.ProductUtils 

    case class Example(field1: String, field2: Int, filed3: String) 

    object MacrosTest { 
     def main(args: Array[String]) { 
      val overrides = Map("field1" -> "new value", "field2" -> "wrong type") 
      println(ProductUtils.withOverrides(Example("", 0, ""), overrides)) // Example("new value", 0, "") 
     } 
    } 

正如你所看到的,我已經成功地獲得原始字段的類型,現在要模式匹配在它上面reifyWithType

不幸的是在當前實現我真的得到編譯時警告:

warning: abstract type pattern K is unchecked since it is eliminated by erasure case Some(newValue : K) => newValue 

和編譯器崩潰的IntelliJ:

Exception in thread "main" java.lang.NullPointerException 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseAsInstanceOf$1(Erasure.scala:1032) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseNormalApply(Erasure.scala:1083) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preEraseApply(Erasure.scala:1187) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.preErase(Erasure.scala:1193) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1268) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018) 
    at scala.reflect.internal.Trees$class.itransform(Trees.scala:1217) 
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13) 
    at scala.reflect.internal.SymbolTable.itransform(SymbolTable.scala:13) 
    at scala.reflect.api.Trees$Transformer.transform(Trees.scala:2897) 
    at scala.tools.nsc.transform.TypingTransformers$TypingTransformer.transform(TypingTransformers.scala:48) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1280) 
    at scala.tools.nsc.transform.Erasure$ErasureTransformer$$anon$1.transform(Erasure.scala:1018) 

所以問題是:
*是否有可能將宏中接收的類型的類型比較爲值運行時類型?
*或者有沒有更好的方法來解決這個任務?

+0

我不確定宏是否有助於y ou在這裏,因爲你將不得不有宏生成運行時反射代碼(這是絕對可能的,但有點不愉快)。 –

回答

0

所有我結束了以下解決方案後:

import language.experimental.macros 
import scala.reflect.macros.Context 

object ProductUtils { 

    def withOverrides[T](entity: T, overrides: Map[String, Any]): T = 
     macro withOverridesImpl[T] 

    def withOverridesImpl[T: c.WeakTypeTag](c: Context)(entity: c.Expr[T], overrides: c.Expr[Map[String, Any]]): c.Expr[T] = { 
     import c.universe._ 

     val originalEntityTree = reify(entity.splice).tree 
     val originalEntityCopy = entity.actualType.member(newTermName("copy")) 

     val originalEntity = 
      weakTypeOf[T].declarations.collect { 
       case m: MethodSymbol if m.isCaseAccessor => 
        (m.name, c.Expr[T](Select(c.resetAllAttrs(originalEntityTree), m.name)), m.returnType) 
      } 

     val values = 
      originalEntity.map { 
       case (name, value, ctype) => 
        AssignOrNamedArg(
         Ident(name), 
         { 

          val ruClass = c.reifyRuntimeClass(ctype) 
          val mtag = c.reifyType(treeBuild.mkRuntimeUniverseRef, Select(treeBuild.mkRuntimeUniverseRef, newTermName("rootMirror")), ctype) 
          val mtree = Select(mtag, newTermName("tpe")) 

          def reifyWithType[K: c.WeakTypeTag] = reify { 

           def tryNewValue[A: scala.reflect.runtime.universe.TypeTag](candidate: Option[A]): Option[K] = 
            if (candidate.isEmpty) { 
             None 
            } else { 
             val cc = c.Expr[Class[_]](ruClass).splice 
             val candidateValue = candidate.get 
             val candidateType = scala.reflect.runtime.universe.typeOf[A] 
             val expectedType = c.Expr[scala.reflect.runtime.universe.Type](mtree).splice 

             val ok = (cc.isPrimitive, candidateValue) match { 
              case (true, _: java.lang.Integer) => cc == java.lang.Integer.TYPE 
              case (true, _: java.lang.Long)  => cc == java.lang.Long.TYPE 
              case (true, _: java.lang.Double) => cc == java.lang.Double.TYPE 
              case (true, _: java.lang.Character) => cc == java.lang.Character.TYPE 
              case (true, _: java.lang.Float)  => cc == java.lang.Float.TYPE 
              case (true, _: java.lang.Byte)  => cc == java.lang.Byte.TYPE 
              case (true, _: java.lang.Short)  => cc == java.lang.Short.TYPE 
              case (true, _: java.lang.Boolean) => cc == java.lang.Boolean.TYPE 
              case (true, _: Unit)    => cc == java.lang.Void.TYPE 
              case _        => 
               val args = candidateType.asInstanceOf[scala.reflect.runtime.universe.TypeRefApi].args 
               if (!args.contains(scala.reflect.runtime.universe.typeOf[Any]) 
                 && !(candidateType =:= scala.reflect.runtime.universe.typeOf[Any])) 
                candidateType =:= expectedType 
               else cc.isInstance(candidateValue) 
             } 

             if (ok) 
              Some(candidateValue.asInstanceOf[K]) 
             else None 
           } 

           tryNewValue(overrides.splice.get(c.literal(name.decoded).splice)).getOrElse(value.splice) 
          } 

          reifyWithType(c.WeakTypeTag(ctype)).tree 
         } 
        ) 
      }.toList 

     originalEntityCopy match { 
      case s: MethodSymbol => 
       c.Expr[T](
        Apply(Select(originalEntityTree, originalEntityCopy), values)) 
      case _ => c.abort(c.enclosingPosition, "No eligible copy method!") 
     } 

    } 

} 

它種滿足原始需求:

class ProductUtilsTest extends FunSuite { 

    case class A(a: String, b: String) 
    case class B(a: String, b: Int) 
    case class C(a: List[Int], b: List[String]) 
    case class D(a: Map[Int, String], b: Double) 
    case class E(a: A, b: B) 

    test("simple overrides works"){ 
     val overrides = Map("a" -> "A", "b" -> "B") 
     assert(ProductUtils.withOverrides(A("", ""), overrides) === A("A", "B")) 
    } 

    test("simple overrides works 1"){ 
     val overrides = Map("a" -> "A", "b" -> 1) 
     assert(ProductUtils.withOverrides(B("", 0), overrides) === B("A", 1)) 
    } 

    test("do not override if types do not match"){ 
     val overrides = Map("a" -> 0, "b" -> List("B")) 
     assert(ProductUtils.withOverrides(B("", 0), overrides) === B("", 0)) 
    } 

    test("complex types also works"){ 
     val overrides = Map("a" -> List(1), "b" -> List("A")) 
     assert(ProductUtils.withOverrides(C(List(0), List("")), overrides) === C(List(1), List("A"))) 
    } 

    test("complex types also works 1"){ 
     val overrides = Map("a" -> List(new Date()), "b" -> 2.0d) 
     assert(ProductUtils.withOverrides(D(Map(), 1.0), overrides) === D(Map(), 2.0)) 
    } 

    test("complex types also works 2"){ 
     val overrides = Map("a" -> A("AA", "BB"), "b" -> 2.0d) 
     assert(ProductUtils.withOverrides(E(A("", ""), B("", 0)), overrides) === E(A("AA", "BB"), B("", 0))) 
    } 

} 

可惜的是因爲在Java/Scala的類型擦除它是苦逼型的平等在將值更改爲新值之前,您可以這樣做:

scala> case class C(a: List[Int], b: List[String]) 
defined class C 

scala> val overrides = Map("a" -> List(new Date()), "b" -> List(1.0)) 
overrides: scala.collection.immutable.Map[String,List[Any]] = Map(a -> List(Mon Aug 26 15:52:27 CEST 2013), b -> List(1.0)) 

scala> ProductUtils.withOverrides(C(List(0), List("")), overrides) 
res0: C = C(List(Mon Aug 26 15:52:27 CEST 2013),List(1.0)) 

scala> res0.a.head + 1 
java.lang.ClassCastException: java.util.Date cannot be cast to java.lang.Integer 
    at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106) 
    at .<init>(<console>:14) 
    at .<clinit>(<console>) 
    at .<init>(<console>:7) 
    at .<clinit>(<console>) 
    at $print(<console>) 
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) 
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:57) 
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) 
    at java.lang.reflect.Method.invoke(Method.java:606) 
    at scala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:734) 
    at scala.tools.nsc.interpreter.IMain$Request.loadAndRun(IMain.scala:983) 
    at scala.tools.nsc.interpreter.IMain.loadAndRunReq$1(IMain.scala:573) 
    at scala.tools.nsc.interpreter.IMain.interpret(IMain.scala:604)