diff --git a/refactor/base/JavaRefactorListener.go b/refactor/base/JavaRefactorListener.go index a3db85a2a6edba1f8f41b6f2ed51f09dce99cdc8..243ba79e53b22e1bc2c88c643aa006bedf58fab4 100644 --- a/refactor/base/JavaRefactorListener.go +++ b/refactor/base/JavaRefactorListener.go @@ -4,6 +4,7 @@ import ( . "../../language/java" . "./models" "github.com/antlr/antlr4/runtime/Go/antlr" + "strings" ) var node *JFullIdentifier; @@ -35,6 +36,25 @@ func (s *JavaRefactorListener) EnterImportDeclaration(ctx *ImportDeclarationCont func (s *JavaRefactorListener) EnterClassDeclaration(ctx *ClassDeclarationContext) { node.Type = "Class" node.Name = ctx.IDENTIFIER().GetText() + + if ctx.IMPLEMENTS() != nil { + context := ctx.TypeList() + startLine := ctx.TypeList().GetStart().GetLine() + stopLine := ctx.TypeList().GetStart().GetLine() + + split := strings.Split(context.GetText(), ",") + for _, imp := range split { + field := &JField{imp, node.Pkg, startLine, stopLine} + node.AddField(*field) + } + } + + if ctx.EXTENDS() != nil { + startLine := ctx.TypeType().GetStart().GetLine() + stopLine := ctx.TypeType().GetStart().GetLine() + field := &JField{ctx.TypeType().GetText(), node.Pkg, startLine, stopLine} + node.AddField(*field) + } } func (s *JavaRefactorListener) EnterInterfaceMethodDeclaration(ctx *InterfaceMethodDeclarationContext) { diff --git a/refactor/unused/remove_unused_import.go b/refactor/unused/remove_unused_import.go index b26d31f8aa61c4dda19a0576362a00f314a6d7cb..2a1d1c1b2397148458e430e74f905e129b6cc4ef 100644 --- a/refactor/unused/remove_unused_import.go +++ b/refactor/unused/remove_unused_import.go @@ -56,20 +56,14 @@ func (j *RemoveUnusedImportApp) Analysis() { func handleNode(node *JFullIdentifier) { var fields = node.GetFields() - var imports []JImport = node.GetImports() + var imports = node.GetImports() - if len(fields) == 0 { - removeAllImports(imports) - return - } - - var errorCount = 0 + var errorLines []int for index := range imports { imp := imports[index] ss := strings.Split(imp.Name, ".") lastField := ss[len(ss)-1] - var isOk = false for _, field := range fields { if field.Name == lastField { @@ -78,16 +72,19 @@ func handleNode(node *JFullIdentifier) { } if !isOk { - removeImportByLineNum(imp, imp.StartLine-1 - errorCount) - errorCount++ + errorLines = append(errorLines, imp.StartLine) } } + + removeImportByLines(currentFile, errorLines) } -func removeAllImports(imports []JImport) { - for index := range imports { - imp := imports[index] - removeImportByLineNum(imp, imp.StartLine) +func removeImportByLines(file string, errorLines []int) { + removedErrorCount := 1 + for _, line := range errorLines { + newStart := line - removedErrorCount + removeLine(file, newStart) + removedErrorCount++ } }