AbstractMacroProcessor.scala 16.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*
 * Copyright (c) 2021 jxnu-liguobin && contributors
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of
 * this software and associated documentation files (the "Software"), to deal in
 * the Software without restriction, including without limitation the rights to
 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
 * the Software, and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
 * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

梦境迷离's avatar
梦境迷离 已提交
22
package io.github.dreamylost.macros
23

24 25
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
26
import scala.reflect.macros.whitebox
梦境迷离's avatar
梦境迷离 已提交
27

28 29 30
/**
 *
 * @author 梦境迷离
梦境迷离's avatar
梦境迷离 已提交
31
 * @since 2021/7/24
32 33
 * @version 1.0
 */
梦境迷离's avatar
梦境迷离 已提交
34
abstract class AbstractMacroProcessor(val c: whitebox.Context) {
梦境迷离's avatar
梦境迷离 已提交
35

梦境迷离's avatar
梦境迷离 已提交
36 37
  import c.universe._

梦境迷离's avatar
梦境迷离 已提交
38 39
  protected lazy val SDKClasses = Set("java.lang.Object", "scala.AnyRef")

梦境迷离's avatar
梦境迷离 已提交
40 41 42 43 44 45 46 47 48 49
  /**
   * Subclasses should override the method and return the final result abstract syntax tree, or an abstract syntax tree close to the final result.
   * When the macro implementation is very simple, we don't need to use this method, so we don't need to implement it.
   * When there are many macro input parameters, we will not use this method temporarily because we need to pass parameters.
   *
   * @param classDecl
   * @param compDeclOpt
   * @return c.Expr[Any], Why use Any? The dependent type need aux-pattern in scala2. Now let's get around this.
   *
   */
50
  def createCustomExpr(classDecl: ClassDef, compDeclOpt: Option[ModuleDef] = None): Any = ???
梦境迷离's avatar
梦境迷离 已提交
51 52 53 54 55 56 57

  /**
   * Subclasses must override the method.
   *
   * @param annottees
   * @return Macro expanded final syntax tree.
   */
梦境迷离's avatar
梦境迷离 已提交
58
  def impl(annottees: Expr[Any]*): Expr[Any]
59

60 61 62 63 64 65 66
  /**
   * Eval tree.
   *
   * @param tree
   * @tparam T
   * @return
   */
梦境迷离's avatar
梦境迷离 已提交
67
  def evalTree[T: WeakTypeTag](tree: Tree): T = c.eval(c.Expr[T](c.untypecheck(tree.duplicate)))
68

梦境迷离's avatar
梦境迷离 已提交
69
  def extractArgumentsTuple1[T: WeakTypeTag](partialFunction: PartialFunction[Tree, Tuple1[T]]): Tuple1[T] = {
70 71 72
    partialFunction.apply(c.prefix.tree)
  }

梦境迷离's avatar
梦境迷离 已提交
73
  def extractArgumentsTuple2[T1: WeakTypeTag, T2: WeakTypeTag](partialFunction: PartialFunction[Tree, (T1, T2)]): (T1, T2) = {
74 75 76
    partialFunction.apply(c.prefix.tree)
  }

梦境迷离's avatar
梦境迷离 已提交
77
  def extractArgumentsTuple4[T1: WeakTypeTag, T2: WeakTypeTag, T3: WeakTypeTag, T4: WeakTypeTag](partialFunction: PartialFunction[Tree, (T1, T2, T3, T4)]): (T1, T2, T3, T4) = {
78 79 80
    partialFunction.apply(c.prefix.tree)
  }

梦境迷离's avatar
梦境迷离 已提交
81 82 83 84 85 86
  /**
   * Output ast result.
   *
   * @param force
   * @param resTree
   */
梦境迷离's avatar
梦境迷离 已提交
87
  def printTree(force: Boolean, resTree: Tree): Unit = {
88 89
    c.info(
      c.enclosingPosition,
90
      s"\n###### Time: ${ZonedDateTime.now().format(DateTimeFormatter.ISO_ZONED_DATE_TIME)} Expanded macro start ######\n" + resTree.toString() + "\n###### Expanded macro end ######\n",
91 92 93 94 95 96 97 98 99 100
      force = force
    )
  }

  /**
   * Check the class and its companion object, and return the class definition.
   *
   * @param annottees
   * @return Return ClassDef
   */
101
  def checkGetClassDef(annottees: Seq[Expr[Any]]): ClassDef = {
梦境迷离's avatar
梦境迷离 已提交
102
    annottees.map(_.tree).toList match {
103
      case (classDecl: ClassDef) :: Nil => classDecl
104
      case (classDecl: ClassDef) :: (_: ModuleDef) :: Nil => classDecl
梦境迷离's avatar
梦境迷离 已提交
105
      case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN)
106
    }
梦境迷离's avatar
梦境迷离 已提交
107 108
  }

109
  /**
110
   * Get object if it exists.
111 112 113 114
   *
   * @param annottees
   * @return
   */
115
  def getModuleDefOption(annottees: Seq[Expr[Any]]): Option[ModuleDef] = {
116
    annottees.map(_.tree).toList match {
117 118 119 120 121
      case (moduleDef: ModuleDef) :: Nil => Some(moduleDef)
      case (_: ClassDef) :: Nil => None
      case (_: ClassDef) :: (compDecl: ModuleDef) :: Nil => Some(compDecl)
      case (moduleDef: ModuleDef) :: (_: ClassDef) :: Nil => Some(moduleDef)
      case _ => None
122 123 124 125 126 127 128
    }
  }

  /**
   * Modify the associated object itself according to whether there is an associated object.
   *
   * @param annottees
梦境迷离's avatar
梦境迷离 已提交
129
   * @param modifyAction The actual processing function
130 131
   * @return Return the result of modifyAction
   */
132
  def collectCustomExpr(annottees: Seq[Expr[Any]])
梦境迷离's avatar
梦境迷离 已提交
133
    (modifyAction: (ClassDef, Option[ModuleDef]) => Any): Expr[Nothing] = {
134 135 136
    val classDef = checkGetClassDef(annottees)
    val compDecl = getModuleDefOption(annottees)
    modifyAction(classDef, compDecl).asInstanceOf[Expr[Nothing]]
137 138 139
  }

  /**
140
   * Check whether the class is a case class.
141 142 143 144
   *
   * @param annotateeClass classDef
   * @return Return true if it is a case class
   */
梦境迷离's avatar
梦境迷离 已提交
145
  def isCaseClass(annotateeClass: ClassDef): Boolean = {
146
    annotateeClass.mods.hasFlag(Flag.CASE)
梦境迷离's avatar
梦境迷离 已提交
147 148 149
  }

  /**
150
   * Check whether the mods of the fields has a `private[this]` or `protected[this]`, because it cannot be used out of class.
梦境迷离's avatar
梦境迷离 已提交
151
   *
152 153
   * @param tree Tree is a field or method?
   * @return false if mods exists private[this] or protected[this]
梦境迷离's avatar
梦境迷离 已提交
154
   */
155 156 157 158 159 160 161 162 163 164
  def isNotLocalClassMember(tree: Tree): Boolean = {
    lazy val modifierNotLocal = (mods: Modifiers) => {
      !(
        mods.hasFlag(Flag.PRIVATE | Flag.LOCAL) | mods.hasFlag(Flag.PROTECTED | Flag.LOCAL)
      )
    }
    tree match {
      case v: ValDef => modifierNotLocal(v.mods)
      case d: DefDef => modifierNotLocal(d.mods)
      case _         => true
165 166 167
    }
  }

梦境迷离's avatar
梦境迷离 已提交
168
  /**
169
   * Get the field TermName with type.
梦境迷离's avatar
梦境迷离 已提交
170
   *
梦境迷离's avatar
梦境迷离 已提交
171
   * @param annotteeClassParams
172
   * @return {{ i: Int}}
梦境迷离's avatar
梦境迷离 已提交
173
   */
174
  def getConstructorParamsNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] = {
175
    annotteeClassParams.map(_.asInstanceOf[ValDef]).map(v => q"${v.name}: ${v.tpt}")
梦境迷离's avatar
梦境迷离 已提交
176 177 178
  }

  /**
179
   * Modify companion object or object.
梦境迷离's avatar
梦境迷离 已提交
180 181
   *
   * @param compDeclOpt
182
   * @param codeBlocks
梦境迷离's avatar
梦境迷离 已提交
183 184 185
   * @param className
   * @return
   */
186
  def appendModuleBody(
梦境迷离's avatar
梦境迷离 已提交
187
    compDeclOpt: Option[ModuleDef],
188 189 190 191 192 193 194 195
    codeBlocks:  List[Tree], className: TypeName): Tree = {
    compDeclOpt.fold(q"object ${className.toTermName} { ..$codeBlocks }") {
      compDecl =>
        c.info(c.enclosingPosition, s"appendModuleBody className: $className, exists obj: $compDecl", force = true)
        val ModuleDef(mods, name, impl) = compDecl
        val Template(parents, self, body) = impl
        val newImpl = Template(parents, self, body ++ codeBlocks)
        ModuleDef(mods, name, newImpl)
196 197
    }
  }
梦境迷离's avatar
梦境迷离 已提交
198 199 200 201

  /**
   * Extract the internal fields of members belonging to the class, but not in primary constructor.
   *
梦境迷离's avatar
梦境迷离 已提交
202
   * @param annotteeClassDefinitions
梦境迷离's avatar
梦境迷离 已提交
203
   */
204 205
  def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[ValDef] = {
    annotteeClassDefinitions.filter(_ match {
梦境迷离's avatar
梦境迷离 已提交
206
      case _: ValDef => true
梦境迷离's avatar
梦境迷离 已提交
207
      case _         => false
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
    }).map(_.asInstanceOf[ValDef])
  }

  /**
   * Extract the constructor params ValDef and flatten for currying.
   *
   * @param annotteeClassParams
   * @return {{ Seq(ValDef) }}
   */
  def getClassConstructorValDefsFlatten(annotteeClassParams: List[List[Tree]]): Seq[ValDef] = {
    annotteeClassParams.flatten.map(_.asInstanceOf[ValDef])
  }

  /**
   * Extract the constructor params ValDef not flatten.
   *
   * @param annotteeClassParams
   * @return {{ Seq(Seq(ValDef)) }}
   */
  def getClassConstructorValDefsNotFlatten(annotteeClassParams: List[List[Tree]]): Seq[Seq[ValDef]] = {
    annotteeClassParams.map(_.map(_.asInstanceOf[ValDef]))
梦境迷离's avatar
梦境迷离 已提交
229 230 231 232 233 234 235
  }

  /**
   * Extract the methods belonging to the class, contains Secondary Constructor.
   *
   * @param annotteeClassDefinitions
   */
236 237
  def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[DefDef] = {
    annotteeClassDefinitions.filter(_ match {
梦境迷离's avatar
梦境迷离 已提交
238 239
      case _: DefDef => true
      case _         => false
240
    }).map(_.asInstanceOf[DefDef])
梦境迷离's avatar
梦境迷离 已提交
241
  }
梦境迷离's avatar
梦境迷离 已提交
242 243 244 245 246 247 248 249

  /**
   * We generate constructor with currying, and we have to deal with the first layer of currying alone.
   *
   * @param typeName
   * @param fieldss
   * @param isCase
   * @return A constructor with currying, it not contains tpt, provide for calling method.
梦境迷离's avatar
梦境迷离 已提交
250
   * @example {{ new TestClass12(i)(j)(k)(t) }}
梦境迷离's avatar
梦境迷离 已提交
251
   */
梦境迷离's avatar
梦境迷离 已提交
252
  def getConstructorWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = {
253 254
    val fieldssValDefNotFlatten = getClassConstructorValDefsNotFlatten(fieldss)
    val allFieldsTermName = fieldssValDefNotFlatten.map(_.map(_.name.toTermName))
梦境迷离's avatar
梦境迷离 已提交
255 256 257 258 259 260 261
    // not currying
    val constructor = if (fieldss.isEmpty || fieldss.size == 1) {
      q"${if (isCase) q"${typeName.toTermName}(..${allFieldsTermName.flatten})" else q"new $typeName(..${allFieldsTermName.flatten})"}"
    } else {
      // currying
      val first = allFieldsTermName.head
      if (isCase) q"${typeName.toTermName}(...$first)(...${allFieldsTermName.tail})"
梦境迷离's avatar
梦境迷离 已提交
262
      else q"new $typeName(..$first)(...${allFieldsTermName.tail})"
梦境迷离's avatar
梦境迷离 已提交
263 264 265 266 267 268 269 270 271 272
    }
    constructor
  }

  /**
   * We generate apply method with currying, and we have to deal with the first layer of currying alone.
   *
   * @param typeName
   * @param fieldss
   * @return A apply method with currying.
梦境迷离's avatar
梦境迷离 已提交
273
   * @example {{ def apply(int: Int)(j: Int)(k: Option[String])(t: Option[Long]): B3 = new B3(int)(j)(k)(t) }}
梦境迷离's avatar
梦境迷离 已提交
274
   */
梦境迷离's avatar
梦境迷离 已提交
275
  def getApplyMethodWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree]): Tree = {
276
    val allFieldsTermName = fieldss.map(f => getConstructorParamsNameWithType(f))
梦境迷离's avatar
梦境迷离 已提交
277
    val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
梦境迷离's avatar
梦境迷离 已提交
278 279
    // not currying
    val applyMethod = if (fieldss.isEmpty || fieldss.size == 1) {
梦境迷离's avatar
梦境迷离 已提交
280
      q"def apply[..$classTypeParams](..${allFieldsTermName.flatten}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}"
梦境迷离's avatar
梦境迷离 已提交
281 282 283
    } else {
      // currying
      val first = allFieldsTermName.head
梦境迷离's avatar
梦境迷离 已提交
284
      q"def apply[..$classTypeParams](..$first)(...${allFieldsTermName.tail}): $typeName[..$returnTypeParams] = ${getConstructorWithCurrying(typeName, fieldss, isCase = false)}"
梦境迷离's avatar
梦境迷离 已提交
285 286 287
    }
    applyMethod
  }
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316

  /**
   * Only for primitive types, we can get type and map to scala type.
   *
   * @param jType java type name
   * @return Scala type name
   */
  def toScalaType(jType: String): String = {
    val types = Map(
      "java.lang.Integer" -> "Int",
      "java.lang.Long" -> "Long",
      "java.lang.Double" -> "Double",
      "java.lang.Float" -> "Float",
      "java.lang.Short" -> "Short",
      "java.lang.Byte" -> "Byte",
      "java.lang.Boolean" -> "Boolean",
      "java.lang.Character" -> "Char",
      "java.lang.String" -> "String"
    )
    types.getOrElse(jType, jType)
  }

  /**
   * Gets a list of generic parameters.
   * This is because the generic parameters of a class cannot be used directly in the return type, and need to be converted.
   *
   * @param tpParams
   * @return
   */
梦境迷离's avatar
梦境迷离 已提交
317
  def extractClassTypeParamsTypeName(tpParams: List[Tree]): List[TypeName] = {
梦境迷离's avatar
梦境迷离 已提交
318 319 320 321 322 323 324 325 326 327 328
    tpParams.map(_.asInstanceOf[TypeDef].name)
  }

  /**
   * Is there a parent class? Does not contains sdk class, such as AnyRef Object
   *
   * @param superClasses
   * @return
   */
  def existsSuperClassExcludeSdkClass(superClasses: Seq[Tree]): Boolean = {
    superClasses.nonEmpty && !superClasses.forall(sc => SDKClasses.contains(sc.toString()))
329
  }
梦境迷离's avatar
梦境迷离 已提交
330

331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
  private[macros] case class ValDefAccessor(
      mods: Modifiers,
      name: TermName,
      tpt:  Tree,
      rhs:  Tree
  ) {

    def typeName: TypeName = symbol.name.toTypeName

    def symbol: c.universe.Symbol = paramType.typeSymbol

    def paramType = c.typecheck(tq"$tpt", c.TYPEmode).tpe
  }

  /**
   * Retrieves the accessor fields on a class and returns a Seq of ValDefAccessor.
   *
   * @param params The list of params retrieved from the class
   * @return An Sequence of tuples where each tuple encodes the string name and string type of a field
   */
  def valDefAccessors(params: Seq[Tree]): Seq[ValDefAccessor] = {
    params.map {
      case ValDef(mods, name: TermName, tpt: Tree, rhs) =>
        ValDefAccessor(mods, name, tpt, rhs)
    }
  }

  /**
   * Extract the necessary structure information of the class for macro programming.
   *
   * @param classDecl
   */
  def mapToClassDeclInfo(classDecl: ClassDef): ClassDefinition = {
    val q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = classDecl
    val (className, classParamss, classTypeParams) = (tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]])
    ClassDefinition(self.asInstanceOf[ValDef], mods.asInstanceOf[Modifiers], className, classParamss, classTypeParams, stats.asInstanceOf[List[Tree]], parents.asInstanceOf[List[Tree]])
  }

  /**
   * Extract the necessary structure information of the moduleDef for macro programming.
   *
   * @param moduleDef
   */
  def mapToModuleDeclInfo(moduleDef: ModuleDef): ClassDefinition = {
    val q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = moduleDef
    ClassDefinition(self.asInstanceOf[ValDef], mods.asInstanceOf[Modifiers], tpname.asInstanceOf[TermName].toTypeName, Nil, Nil, stats.asInstanceOf[List[Tree]], parents.asInstanceOf[List[Tree]])
  }

  /**
   * Generate the specified syntax tree and assign it to the tree definition itself.
   * Used only when you modify the definition of the class itself. Such as add method/add field.
   *
   * @param classDecl
   * @param classInfoAction Content body added in class definition
   * @return
   */
  def appendClassBody(classDecl: ClassDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.ClassDef = {
    val classInfo = mapToClassDeclInfo(classDecl)
    val ClassDef(mods, name, tparams, impl) = classDecl
    val Template(parents, self, body) = impl
    ClassDef(mods, name, tparams, Template(parents, self, body ++ classInfoAction(classInfo)))
  }

  // TODO fix, why cannot use ClassDef apply
  def prependImplDefBody(implDef: ImplDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.Tree = {
    implDef match {
      case classDecl: ClassDef =>
        val classInfo = mapToClassDeclInfo(classDecl)
        val q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = classDecl
        q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..${classInfoAction(classInfo) ++ stats} }"
      case moduleDef: ModuleDef =>
        val classInfo = mapToModuleDeclInfo(moduleDef)
        val q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = moduleDef
        q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..${classInfoAction(classInfo) ++ stats.toList} }"
    }
  }

  def appendImplDefSuper(implDef: ImplDef, classInfoAction: ClassDefinition => List[Tree]): c.universe.Tree = {
    implDef match {
      case classDecl: ClassDef =>
        val classInfo = mapToClassDeclInfo(classDecl)
        val q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = classDecl
        q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..${parents ++ classInfoAction(classInfo)} { $self => ..$stats }"
      case moduleDef: ModuleDef =>
        val classInfo = mapToModuleDeclInfo(moduleDef)
        val q"$mods object $tpname extends { ..$earlydefns } with ..$parents { $self => ..$stats }" = moduleDef
        q"$mods object $tpname extends { ..$earlydefns } with ..${parents.toList ++ classInfoAction(classInfo)} { $self => ..$stats }"
    }
  }

  /**
   * Modify the method body of the method tree.
   *
   * @param defDef
   * @param defBodyAction Method body of final result
   * @return
   */
  def mapToMethodDef(defDef: DefDef, defBodyAction: => Tree): c.universe.DefDef = {
    val DefDef(mods, name, tparams, vparamss, tpt, rhs) = defDef
    DefDef(mods, name, tparams, vparamss, tpt, defBodyAction)
  }

  private[macros] case class ClassDefinition(
      self:            ValDef,
      mods:            Modifiers,
      className:       TypeName,
      classParamss:    List[List[Tree]],
      classTypeParams: List[Tree],
      body:            List[Tree],
      superClasses:    List[Tree],
      earlydefns:      List[Tree]       = Nil
  )

444
}