equalsAndHashCodeMacro.scala 5.7 KB
Newer Older
1
/*
梦境迷离's avatar
梦境迷离 已提交
2
 * Copyright (c) 2021 org.bitlap
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy of
 * this software and associated documentation files (the "Software"), to deal in
 * the Software without restriction, including without limitation the rights to
 * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
 * the Software, and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
 * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
 * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

梦境迷离's avatar
梦境迷离 已提交
22
package org.bitlap.tools.macros
梦境迷离's avatar
梦境迷离 已提交
23 24 25 26 27 28 29 30 31

import scala.reflect.macros.whitebox

/**
 *
 * @author 梦境迷离
 * @since 2021/7/18
 * @version 1.0
 */
梦境迷离's avatar
梦境迷离 已提交
32 33 34
object equalsAndHashCodeMacro {

  class EqualsAndHashCodeProcessor(override val c: whitebox.Context) extends AbstractMacroProcessor(c) {
梦境迷离's avatar
梦境迷离 已提交
35 36 37

    import c.universe._

梦境迷离's avatar
梦境迷离 已提交
38 39 40 41 42 43 44
    private val extractArgumentsDetail: (Boolean, Nil.type) = extractArgumentsTuple2 {
      case q"new equalsAndHashCode(verbose=$verbose)" => (evalTree(verbose.asInstanceOf[Tree]), Nil)
      case q"new equalsAndHashCode(excludeFields=$excludeFields)" => (false, evalTree(excludeFields.asInstanceOf[Tree]))
      case q"new equalsAndHashCode(verbose=$verbose, excludeFields=$excludeFields)" => (evalTree(verbose.asInstanceOf[Tree]), evalTree(excludeFields.asInstanceOf[Tree]))
      case q"new equalsAndHashCode()" => (false, Nil)
      case _ => c.abort(c.enclosingPosition, ErrorMessage.UNEXPECTED_PATTERN)
    }
梦境迷离's avatar
梦境迷离 已提交
45

梦境迷离's avatar
梦境迷离 已提交
46 47
    override def checkAnnottees(annottees: Seq[c.universe.Expr[Any]]): Unit = {
      super.checkAnnottees(annottees)
48 49 50
      val annotateeClass: ClassDef = checkGetClassDef(annottees)
      if (isCaseClass(annotateeClass)) {
        c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
梦境迷离's avatar
梦境迷离 已提交
51
      }
梦境迷离's avatar
梦境迷离 已提交
52
    }
梦境迷离's avatar
梦境迷离 已提交
53

梦境迷离's avatar
梦境迷离 已提交
54 55
    override val verbose: Boolean = extractArgumentsDetail._1

梦境迷离's avatar
梦境迷离 已提交
56 57 58
    /**
     * Extract the internal fields of members belonging to the class.
     */
59 60 61 62 63 64
    private def getInternalFieldsTermNameExcludeLocal(annotteeClassDefinitions: Seq[Tree]): Seq[TermName] = {
      if (annotteeClassDefinitions.exists(f => isNotLocalClassMember(f))) {
        c.info(c.enclosingPosition, s"There is a non private class definition inside the class", extractArgumentsDetail._1)
      }
      getClassMemberValDefs(annotteeClassDefinitions).filter(p => isNotLocalClassMember(p) &&
        !extractArgumentsDetail._2.contains(p.name.decodedName.toString)).map(_.name.toTermName)
梦境迷离's avatar
梦境迷离 已提交
65
    }
梦境迷离's avatar
梦境迷离 已提交
66

梦境迷离's avatar
梦境迷离 已提交
67
    // equals method
68
    private def getEqualsMethod(className: TypeName, termNames: Seq[TermName], superClasses: Seq[Tree], annotteeClassDefinitions: Seq[Tree]): List[Tree] = {
69
      val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions).exists {
70 71 72
        case defDef: DefDef if defDef.name.decodedName.toString == "canEqual" && defDef.vparamss.nonEmpty =>
          val safeValDefs = valDefAccessors(defDef.vparamss.flatten)
          safeValDefs.exists(_.paramType.toString == "Any") && safeValDefs.exists(_.name.decodedName.toString == "that")
梦境迷离's avatar
梦境迷离 已提交
73 74
        case _ => false
      }
75
      val equalsExprs = termNames.map(termName => q"this.$termName.equals(t.$termName)")
梦境迷离's avatar
梦境迷离 已提交
76 77
      // Make a rough judgment on whether override is needed.
      val modifiers = if (existsSuperClassExcludeSdkClass(superClasses)) Modifiers(Flag.OVERRIDE, typeNames.EMPTY, List()) else Modifiers(NoFlags, typeNames.EMPTY, List())
梦境迷离's avatar
梦境迷离 已提交
78
      val canEqual = if (existsCanEqual) q"" else q"$modifiers def canEqual(that: Any) = that.isInstanceOf[$className]"
79 80
      val equalsMethod =
        q"""
梦境迷离's avatar
梦境迷离 已提交
81 82 83 84 85 86
          override def equals(that: Any): Boolean =
            that match {
              case t: $className => t.canEqual(this) && Seq(..$equalsExprs).forall(f => f) && ${if (existsSuperClassExcludeSdkClass(superClasses)) q"super.equals(that)" else q"true"}
              case _ => false
          }
         """
87
      List(canEqual, equalsMethod)
梦境迷离's avatar
梦境迷离 已提交
88
    }
梦境迷离's avatar
梦境迷离 已提交
89

梦境迷离's avatar
梦境迷离 已提交
90 91 92
    private def getHashcodeMethod(termNames: Seq[TermName], superClasses: Seq[Tree]): Tree = {
      // we append super.hashCode by `+`
      // the algorithm see https://alvinalexander.com/scala/how-to-define-equals-hashcode-methods-in-scala-object-equality/
93 94
      val superTree = q"super.hashCode"
      q"""
梦境迷离's avatar
梦境迷离 已提交
95 96
         override def hashCode(): Int = {
            val state = Seq(..$termNames)
97
            state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + ${if (existsSuperClassExcludeSdkClass(superClasses)) superTree else q"0"}
梦境迷离's avatar
梦境迷离 已提交
98
          }
99
       """
梦境迷离's avatar
梦境迷离 已提交
100
    }
梦境迷离's avatar
梦境迷离 已提交
101

102 103 104 105 106 107
    override def createCustomExpr(classDecl: ClassDef, compDeclOpt: Option[ModuleDef]): Any = {
      lazy val map = (classDefinition: ClassDefinition) => {
        getClassConstructorValDefsFlatten(classDefinition.classParamss).
          filter(cf => isNotLocalClassMember(cf)).
          map(_.name.toTermName) ++
          getInternalFieldsTermNameExcludeLocal(classDefinition.body)
梦境迷离's avatar
梦境迷离 已提交
108
      }
109 110 111 112 113 114
      val classDefinition = mapToClassDeclInfo(classDecl)
      val res = appendClassBody(classDecl, classInfo =>
        getEqualsMethod(classDefinition.className, map(classInfo), classDefinition.superClasses, classDefinition.body) ++
          List(getHashcodeMethod(map(classInfo), classDefinition.superClasses))
      )

梦境迷离's avatar
梦境迷离 已提交
115 116
      c.Expr(
        q"""
117 118
          ${compDeclOpt.fold(EmptyTree)(x => x)}
          $res
梦境迷离's avatar
梦境迷离 已提交
119
         """)
梦境迷离's avatar
梦境迷离 已提交
120
    }
梦境迷离's avatar
梦境迷离 已提交
121
  }
梦境迷离's avatar
梦境迷离 已提交
122

梦境迷离's avatar
梦境迷离 已提交
123
}