ConstructorProcessor.scala 2.5 KB
Newer Older
I
IceMimosa 已提交
1 2 3 4
package io.github.dreamylost.plugin.processor.clazz

import io.github.dreamylost.plugin.ScalaMacroNames
import io.github.dreamylost.plugin.processor.ProcessType.ProcessType
I
IceMimosa 已提交
5
import io.github.dreamylost.plugin.processor.{ AbsProcessor, ProcessType }
I
IceMimosa 已提交
6 7
import org.jetbrains.plugins.scala.lang.psi.api.expr.ScMethodCall
import org.jetbrains.plugins.scala.lang.psi.api.statements.ScVariableDefinition
I
IceMimosa 已提交
8
import org.jetbrains.plugins.scala.lang.psi.api.toplevel.typedef.{ ScClass, ScTypeDefinition }
I
IceMimosa 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
import org.jetbrains.plugins.scala.lang.psi.types.ScLiteralType

/**
 * Desc: Processor for annotation constructor
 *
 * Mail: chk19940609@gmail.com
 * Created by IceMimosa
 * Date: 2021/7/8
 */
class ConstructorProcessor extends AbsProcessor {

  private val excludeFieldsName = "excludeFields"

  override def needCompanion: Boolean = true

  override def process(source: ScTypeDefinition, typ: ProcessType): Seq[String] = {
    typ match {
      case ProcessType.Method =>
        source match {
          case clazz: ScClass =>
            val consFields = getConstructorParameters(clazz, withSecond = false)
            val excludeFields = clazz.annotations(ScalaMacroNames.CONSTRUCTOR).lastOption match {
              case Some(an) =>
                // get excludeFields function call
                an.getParameterList.getAttributes.findLast(_.getAttributeName == excludeFieldsName).map(_.getDetachedValue)
                  .collect {
                    case call: ScMethodCall =>
                      // get call parameters
                      call.argumentExpressions.flatMap(_.`type`().toOption)
                        .collect {
                          case str: ScLiteralType => str.value.value.toString
                        }
                        .mkString(", ")
                  }.getOrElse("")
              case None => ""
            }
            val varFields = clazz.extendsBlock.members
              .collect {
                // var, others: ScPatternDefinition, ScFunctionDefinition
                case `var`: ScVariableDefinition => `var`
              }
              .flatMap { v =>
                v.declaredNames.map(n => (n, v.`type`().toOption.map(_.toString).getOrElse("Unit")))
              }
              .filter(v => !excludeFields.contains(v._1))

            val consFieldsStr = consFields.map(_._1).mkString(", ")
            val allFieldsStr = (consFields ++ varFields).map(f => s"${f._1}: ${f._2}").mkString(", ")

            Seq(s"def this($allFieldsStr) = this($consFieldsStr)")
          case _ => Nil
        }
      case _ => Nil
    }
  }
}