提交 43af7786 编写于 作者: 梦境迷离's avatar 梦境迷离

refactor

上级 6a85f75c
......@@ -153,32 +153,13 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
}
/**
* Expand the class and check whether the class is a case class.
* Check whether the class is a case class.
*
* @param annotateeClass classDef
* @return Return true if it is a case class
*/
def isCaseClass(annotateeClass: ClassDef): Boolean = {
annotateeClass match {
case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" =>
mods.asInstanceOf[Modifiers].hasFlag(Flag.CASE)
case _ => c.abort(c.enclosingPosition, ErrorMessage.ONLY_CLASS)
}
}
/**
* Expand the constructor and get the field TermName.
*
* @param field
* @return
*/
def getFieldTermName(field: Tree): TermName = {
field match {
case q"$mods val $tname: $tpt = $expr" => tname.asInstanceOf[TermName]
case q"$mods var $tname: $tpt = $expr" => tname.asInstanceOf[TermName]
// case q"$mods val $pat = $expr" => pat.asInstanceOf[TermName]
// case q"$mods var $pat = $expr" => pat.asInstanceOf[TermName]
}
annotateeClass.mods.hasFlag(Flag.CASE)
}
/**
......@@ -198,31 +179,28 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param tree a field or method
* @return
*/
def classMemberIsNotLocal(tree: Tree): Boolean = {
def isNotLocalClassMember(tree: Tree): Boolean = {
lazy val modifierNotLocal = (mods: Modifiers) => {
!(
mods.hasFlag(Flag.PRIVATE | Flag.LOCAL) | mods.hasFlag(Flag.PROTECTED | Flag.LOCAL)
)
}
tree match {
case q"$mods val $tname: $tpt = $expr" => modifierNotLocal(mods.asInstanceOf[Modifiers])
case q"$mods var $tname: $tpt = $expr" => modifierNotLocal(mods.asInstanceOf[Modifiers])
case _ => true
// case q"$mods val $pat = $expr" => modifierNotLocal(mods.asInstanceOf[Modifiers])
// case q"$mods var $pat = $expr" => modifierNotLocal(mods.asInstanceOf[Modifiers])
case v: ValDef => modifierNotLocal(v.mods)
case d: DefDef => modifierNotLocal(d.mods)
case _ => true
}
}
/**
* Expand the constructor and get the field with assign.
* Get the field TermName with type.
*
* @param annotteeClassParams
* @return
* @return {{ i: Int}}
*/
def getConstructorFieldAssignExprs(annotteeClassParams: Seq[Tree]): Seq[Tree] = {
def getConstructorFieldNameWithType(annotteeClassParams: Seq[Tree]): Seq[Tree] = {
annotteeClassParams.map {
case q"$mods var $tname: $tpt = $expr" => q"$tname: $tpt" //Ignore expr
case q"$mods val $tname: $tpt = $expr" => q"$tname: $tpt"
case v: ValDef => q"${v.name}: ${v.tpt}"
}
}
......@@ -262,7 +240,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param annotteeClassDefinitions
*/
def getClassMemberValDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
annotteeClassDefinitions.filter(p => p match {
annotteeClassDefinitions.filter(_ match {
case _: ValDef => true
case _ => false
})
......@@ -274,7 +252,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @param annotteeClassDefinitions
*/
def getClassMemberDefDefs(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
annotteeClassDefinitions.filter(p => p match {
annotteeClassDefinitions.filter(_ match {
case _: DefDef => true
case _ => false
})
......@@ -290,7 +268,9 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @example {{ new TestClass12(i)(j)(k)(t) }}
*/
def getConstructorWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], isCase: Boolean): Tree = {
val allFieldsTermName = fieldss.map(f => f.map(ff => getFieldTermName(ff)))
val allFieldsTermName = fieldss.map(f => f.map {
case v: ValDef => v.name.toTermName
})
// not currying
val constructor = if (fieldss.isEmpty || fieldss.size == 1) {
q"${if (isCase) q"${typeName.toTermName}(..${allFieldsTermName.flatten})" else q"new $typeName(..${allFieldsTermName.flatten})"}"
......@@ -313,7 +293,7 @@ abstract class AbstractMacroProcessor(val c: whitebox.Context) {
* @example {{ def apply(int: Int)(j: Int)(k: Option[String])(t: Option[Long]): B3 = new B3(int)(j)(k)(t) }}
*/
def getApplyMethodWithCurrying(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree]): Tree = {
val allFieldsTermName = fieldss.map(f => getConstructorFieldAssignExprs(f))
val allFieldsTermName = fieldss.map(f => getConstructorFieldNameWithType(f))
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
// not currying
val applyMethod = if (fieldss.isEmpty || fieldss.size == 1) {
......
......@@ -41,30 +41,24 @@ object builderMacro {
private def getFieldDefinition(field: Tree): Tree = {
field match {
case q"$mods val $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr"""
case q"$mods var $tname: $tpt = $expr" => q"""private var $tname: $tpt = $expr"""
case v: ValDef => q"private var ${v.name}: ${v.tpt} = ${v.rhs}"
}
}
private def getFieldSetMethod(typeName: TypeName, field: Tree, classTypeParams: List[Tree]): Tree = {
val builderClassName = getBuilderClassName(typeName)
val returnTypeParams = extractClassTypeParamsTypeName(classTypeParams)
field match {
case q"$mods var $tname: $tpt = $expr" =>
q"""
def $tname($tname: $tpt): $builderClassName[..$returnTypeParams] = {
this.$tname = $tname
this
}
"""
case q"$mods val $tname: $tpt = $expr" =>
q"""
def $tname($tname: $tpt): $builderClassName[..$returnTypeParams] = {
this.$tname = $tname
val valDefMapTo = (v: ValDef) => {
q"""
def ${v.name}(${v.name}: ${v.tpt}): $builderClassName[..$returnTypeParams] = {
this.${v.name} = ${v.name}
this
}
"""
}
field match {
case v: ValDef => valDefMapTo(v)
}
}
private def getBuilderClassAndMethod(typeName: TypeName, fieldss: List[List[Tree]], classTypeParams: List[Tree], isCase: Boolean): Tree = {
......@@ -91,7 +85,6 @@ object builderMacro {
val (className, fieldss, classTypeParams) = classDecl match {
// @see https://scala-lang.org/files/archive/spec/2.13/05-classes-and-objects.html
case q"$mods class $tpname[..$tparams](...$paramss) extends ..$bases { ..$body }" =>
c.info(c.enclosingPosition, s"modifiedDeclaration className: $tpname, paramss: $paramss", force = true)
(tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], tparams.asInstanceOf[List[Tree]])
case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl")
}
......
......@@ -50,8 +50,8 @@ object constructorMacro {
*/
private def getClassMemberVarDefOnlyAssignExpr(annotteeClassDefinitions: Seq[Tree]): Seq[Tree] = {
getClassMemberValDefs(annotteeClassDefinitions).filter(_ match {
case q"$mods var $tname: $tpt = $expr" if !extractArgumentsDetail._2.contains(tname.asInstanceOf[TermName].decodedName.toString) => true
case _ => false
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
}).map {
case q"$mods var $pat = $expr" =>
// TODO getClass RETURN a java type, maybe we can try use class reflect to get the fields type name.
......@@ -73,17 +73,19 @@ object constructorMacro {
val classFieldDefinitions = getClassMemberValDefs(annotteeClassDefinitions)
val annotteeClassFieldNames = classFieldDefinitions.filter(_ match {
case q"$mods var $tname: $tpt = $expr" if !extractArgumentsDetail._2.contains(tname.asInstanceOf[TermName].decodedName.toString) => true
case _ => false
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
}).map {
case q"$mods var $tname: $tpt = $expr" => tname.asInstanceOf[TermName]
case v: ValDef if v.mods.hasFlag(Flag.MUTABLE) => v.name
}
// Extract the field of the primary constructor.
val allFieldsTermName = annotteeClassParams.map(f => f.map(ff => getFieldTermName(ff)))
val allFieldsTermName = annotteeClassParams.map(f => f.map {
case v: ValDef => v.name.toTermName
})
// Extract the field of the primary constructor.
val classParamsAssignExpr = getConstructorFieldAssignExprs(annotteeClassParams.flatten)
val classParamsAssignExpr = getConstructorFieldNameWithType(annotteeClassParams.flatten)
val applyMethod = if (annotteeClassParams.isEmpty || annotteeClassParams.size == 1) {
q"""
def this(..${classParamsAssignExpr ++ classFieldDefinitionsOnlyAssignExpr}) = {
......@@ -93,7 +95,7 @@ object constructorMacro {
"""
} else {
// NOTE: currying constructor overload must be placed in the first bracket block.
val allClassParamsAssignExpr = annotteeClassParams.map(cc => getConstructorFieldAssignExprs(cc))
val allClassParamsAssignExpr = annotteeClassParams.map(cc => getConstructorFieldNameWithType(cc))
q"""
def this(..${allClassParamsAssignExpr.head ++ classFieldDefinitionsOnlyAssignExpr})(...${allClassParamsAssignExpr.tail}) = {
this(..${allFieldsTermName.head})(...${allFieldsTermName.tail})
......
......@@ -68,20 +68,21 @@ object equalsAndHashCodeMacro {
* Extract the internal fields of members belonging to the class.
*/
private def getInternalFieldTermNameExcludeLocal(annotteeClassDefinitions: Seq[Tree]): Seq[TermName] = {
getClassMemberValDefs(annotteeClassDefinitions).filter(p => classMemberIsNotLocal(p) && (p match {
case q"$mods var $tname: $tpt = $expr" =>
!extractArgumentsDetail._2.contains(tname.asInstanceOf[TermName].decodedName.toString)
//`val i = 1` will match `q"$mods val $tname: $tpt = $expr"` and tpt is `<type ?>`, not `q"$mods val $pat = $expr"`
case q"$mods val $tname: $tpt = $expr" =>
!extractArgumentsDetail._2.contains(tname.asInstanceOf[TermName].decodedName.toString)
case _ => false
})).map(f => getFieldTermName(f))
getClassMemberValDefs(annotteeClassDefinitions).filter(p => isNotLocalClassMember(p) && (p match {
case v: ValDef => !extractArgumentsDetail._2.contains(v.name.decodedName.toString)
case _ => false
})).map {
case v: ValDef => v.name.toTermName
}
}
// equals method
private def getEqualsMethod(className: TypeName, termNames: Seq[TermName], superClasses: Seq[Tree], annotteeClassDefinitions: Seq[Tree]): Tree = {
val existsCanEqual = getClassMemberDefDefs(annotteeClassDefinitions) exists {
case q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.asInstanceOf[TermName].decodedName.toString == "canEqual" && paramss.nonEmpty =>
case tree @ q"$mods def $tname[..$tparams](...$paramss): $tpt = $expr" if tname.asInstanceOf[TermName].decodedName.toString == "canEqual" && paramss.nonEmpty =>
if (!isNotLocalClassMember(tree)) {
c.info(c.enclosingPosition, "The canEqual method has been found in class, and method mods exists private[this] or protected[this]", extractArgumentsDetail._1)
}
val params = paramss.asInstanceOf[List[List[Tree]]].flatten.map(pp => getMethodParamName(pp))
params.exists(p => p.decodedName.toString == "Any")
case _ => false
......@@ -130,8 +131,10 @@ object equalsAndHashCodeMacro {
(tpname.asInstanceOf[TypeName], paramss.asInstanceOf[List[List[Tree]]], stats.asInstanceOf[Seq[Tree]], parents.asInstanceOf[Seq[Tree]])
case _ => c.abort(c.enclosingPosition, s"${ErrorMessage.ONLY_CLASS} classDef: $classDecl")
}
val ctorFieldNames = annotteeClassParams.flatten.filter(cf => classMemberIsNotLocal(cf))
val allFieldsTermName = ctorFieldNames.map(f => getFieldTermName(f))
val ctorFieldNames = annotteeClassParams.flatten.filter(cf => isNotLocalClassMember(cf))
val allFieldsTermName = ctorFieldNames.map {
case v: ValDef => v.name.toTermName
}
val allTernNames = allFieldsTermName ++ getInternalFieldTermNameExcludeLocal(annotteeClassDefinitions)
val hash = getHashcodeMethod(allTernNames, superClasses)
val equals = getEqualsMethod(className, allTernNames, superClasses, annotteeClassDefinitions)
......
......@@ -89,19 +89,17 @@ object toStringMacro {
if (argument.includeFieldNames) {
lastParam.fold(q"$field") { lp =>
field match {
case q"$mods var $tname: $tpt = $expr" =>
if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname"""
case q"$mods val $tname: $tpt = $expr" =>
if (tname.toString() != lp) q"""${tname.toString()}+${"="}+this.$tname+${", "}""" else q"""${tname.toString()}+${"="}+this.$tname"""
case v: ValDef =>
if (v.name.toTermName.decodedName.toString != lp) q"""${v.name.toTermName.decodedName.toString}+${"="}+this.${v.name}+${", "}"""
else q"""${v.name.toTermName.decodedName.toString}+${"="}+this.${v.name}"""
case _ => q"$field"
}
}
} else {
lastParam.fold(q"$field") { lp =>
field match {
case q"$mods var $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname"""
case q"$mods val $tname: $tpt = $expr" => if (tname.toString() != lp) q"""$tname+${", "}""" else q"""$tname"""
case _ => if (field.toString() != lp) q"""$field+${", "}""" else q"""$field"""
case v: ValDef => if (v.name.toTermName.decodedName.toString != lp) q"""${v.name}+${", "}""" else q"""${v.name}"""
case _ => if (field.toString() != lp) q"""$field+${", "}""" else q"""$field"""
}
}
}
......@@ -119,18 +117,14 @@ object toStringMacro {
val annotteeClassFieldDefinitions = annotteeClassDefinitions.filter(p => p match {
case _: ValDef => true
case mem: MemberDef =>
if (mem.toString().startsWith("override def toString")) { // TODO better way
if (mem.name.decodedName.toString.startsWith("toString")) { // TODO better way
c.abort(mem.pos, "'toString' method has already defined, please remove it or not use'@toString'")
}
false
case _ => false
})
// For the parameters of a given constructor, separate the parameter components and extract the constructor parameters containing val and var
val ctorParams = annotteeClassParams.flatten.map {
case tree @ q"$mods val $tname: $tpt = $expr" => tree
case tree @ q"$mods var $tname: $tpt = $expr" => tree
}
val ctorParams = annotteeClassParams.flatten
val member = if (argument.includeInternalFields) ctorParams ++ annotteeClassFieldDefinitions else ctorParams
val lastParam = member.lastOption.map {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册