From a7c63c2eb3c8b2ac3e920651c22cd9dd4b9d12f6 Mon Sep 17 00:00:00 2001 From: Skylot Date: Tue, 18 Jan 2022 18:27:09 +0000 Subject: [PATCH] fix: handle method override with several bases (#1234) --- .../java/jadx/core/codegen/MethodGen.java | 2 +- .../attributes/nodes/MethodOverrideAttr.java | 28 +-- .../core/dex/nodes/utils/MethodUtils.java | 3 +- .../dex/visitors/OverrideMethodVisitor.java | 173 ++++++++++++------ .../src/main/java/jadx/core/utils/Utils.java | 23 +++ .../others/TestOverrideWithTwoBases.java | 34 ++++ .../others/TestOverrideWithTwoBases2.java | 33 ++++ 7 files changed, 219 insertions(+), 77 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases2.java diff --git a/jadx-core/src/main/java/jadx/core/codegen/MethodGen.java b/jadx-core/src/main/java/jadx/core/codegen/MethodGen.java index 10aa04f5..e38c8ce7 100644 --- a/jadx-core/src/main/java/jadx/core/codegen/MethodGen.java +++ b/jadx-core/src/main/java/jadx/core/codegen/MethodGen.java @@ -180,7 +180,7 @@ public class MethodGen { if (overrideAttr == null) { return; } - if (!overrideAttr.isAtBaseMth()) { + if (!overrideAttr.getBaseMethods().contains(mth)) { code.startLine("@Override"); if (mth.checkCommentsLevel(CommentsLevel.INFO)) { code.add(" // "); diff --git a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodOverrideAttr.java b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodOverrideAttr.java index 8491fc96..82b840f1 100644 --- a/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodOverrideAttr.java +++ b/jadx-core/src/main/java/jadx/core/dex/attributes/nodes/MethodOverrideAttr.java @@ -1,15 +1,13 @@ package jadx.core.dex.attributes.nodes; import java.util.List; +import java.util.Set; import java.util.SortedSet; -import org.jetbrains.annotations.Nullable; - import jadx.api.plugins.input.data.attributes.PinnedAttribute; import jadx.core.dex.attributes.AType; import jadx.core.dex.nodes.IMethodDetails; import jadx.core.dex.nodes.MethodNode; -import jadx.core.utils.Utils; public class MethodOverrideAttr extends PinnedAttribute { @@ -23,32 +21,26 @@ public class MethodOverrideAttr extends PinnedAttribute { */ private SortedSet relatedMthNodes; - public MethodOverrideAttr(List overrideList, SortedSet relatedMthNodes) { + private Set baseMethods; + + public MethodOverrideAttr(List overrideList, SortedSet relatedMthNodes, Set baseMethods) { this.overrideList = overrideList; this.relatedMthNodes = relatedMthNodes; - } - - public boolean isAtBaseMth() { - return overrideList.isEmpty(); - } - - @Nullable - public IMethodDetails getBaseMth() { - return Utils.last(overrideList); + this.baseMethods = baseMethods; } public List getOverrideList() { return overrideList; } - public void setOverrideList(List overrideList) { - this.overrideList = overrideList; - } - public SortedSet getRelatedMthNodes() { return relatedMthNodes; } + public Set getBaseMethods() { + return baseMethods; + } + public void setRelatedMthNodes(SortedSet relatedMthNodes) { this.relatedMthNodes = relatedMthNodes; } @@ -60,6 +52,6 @@ public class MethodOverrideAttr extends PinnedAttribute { @Override public String toString() { - return "METHOD_OVERRIDE: " + overrideList; + return "METHOD_OVERRIDE: " + getBaseMethods(); } } diff --git a/jadx-core/src/main/java/jadx/core/dex/nodes/utils/MethodUtils.java b/jadx-core/src/main/java/jadx/core/dex/nodes/utils/MethodUtils.java index 3d1132ec..e18f97ed 100644 --- a/jadx-core/src/main/java/jadx/core/dex/nodes/utils/MethodUtils.java +++ b/jadx-core/src/main/java/jadx/core/dex/nodes/utils/MethodUtils.java @@ -18,6 +18,7 @@ import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.IMethodDetails; import jadx.core.dex.nodes.MethodNode; import jadx.core.dex.nodes.RootNode; +import jadx.core.utils.Utils; public class MethodUtils { private final RootNode root; @@ -132,7 +133,7 @@ public class MethodUtils { if (overrideAttr == null) { return null; } - return overrideAttr.getBaseMth(); + return Utils.getOne(overrideAttr.getBaseMethods()); } public ClassInfo getMethodOriginDeclClass(MethodNode mth) { diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/OverrideMethodVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/OverrideMethodVisitor.java index 3f261afd..d55998e5 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/OverrideMethodVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/OverrideMethodVisitor.java @@ -1,11 +1,11 @@ package jadx.core.dex.visitors; import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.SortedSet; import java.util.TreeSet; import java.util.stream.Collectors; @@ -33,6 +33,7 @@ import jadx.core.dex.visitors.typeinference.TypeCompareEnum; import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor; import jadx.core.utils.Utils; import jadx.core.utils.exceptions.JadxException; +import jadx.core.utils.exceptions.JadxRuntimeException; @JadxVisitor( name = "OverrideMethodVisitor", @@ -46,70 +47,86 @@ public class OverrideMethodVisitor extends AbstractVisitor { @Override public boolean visit(ClassNode cls) throws JadxException { - processCls(cls); - return true; - } - - private void processCls(ClassNode cls) { - List superTypes = collectSuperTypes(cls); - if (!superTypes.isEmpty()) { + SuperTypesData superData = collectSuperTypes(cls); + if (superData != null) { for (MethodNode mth : cls.getMethods()) { - processMth(cls, superTypes, mth); + processMth(mth, superData); } } + return true; } - private void processMth(ClassNode cls, List superTypes, MethodNode mth) { + private void processMth(MethodNode mth, SuperTypesData superData) { if (mth.isConstructor() || mth.getAccessFlags().isStatic() || mth.getAccessFlags().isPrivate()) { return; } - MethodOverrideAttr attr = processOverrideMethods(cls, mth, superTypes); + MethodOverrideAttr attr = processOverrideMethods(mth, superData); if (attr != null) { + if (attr.getBaseMethods().isEmpty()) { + throw new JadxRuntimeException("No base methods for override attribute: " + attr.getOverrideList()); + } mth.addAttr(attr); - IMethodDetails baseMth = attr.getBaseMth(); + IMethodDetails baseMth = Utils.getOne(attr.getBaseMethods()); if (baseMth != null) { - boolean updated = fixMethodReturnType(mth, baseMth, superTypes); - updated |= fixMethodArgTypes(mth, baseMth, superTypes); + boolean updated = fixMethodReturnType(mth, baseMth, superData); + updated |= fixMethodArgTypes(mth, baseMth, superData); if (updated) { // check if new signature cause method collisions - checkMethodSignatureCollisions(mth, cls.root().getArgs().isRenameValid()); + checkMethodSignatureCollisions(mth, mth.root().getArgs().isRenameValid()); } } } } - private MethodOverrideAttr processOverrideMethods(ClassNode cls, MethodNode mth, List superTypes) { + private MethodOverrideAttr processOverrideMethods(MethodNode mth, SuperTypesData superData) { MethodOverrideAttr result = mth.get(AType.METHOD_OVERRIDE); if (result != null) { return result; } + ClassNode cls = mth.getParentClass(); String signature = mth.getMethodInfo().makeSignature(false); List overrideList = new ArrayList<>(); - for (ArgType superType : superTypes) { - ClassNode classNode = cls.root().resolveClass(superType); + Set baseMethods = new HashSet<>(); + for (ArgType superType : superData.getSuperTypes()) { + ClassNode classNode = mth.root().resolveClass(superType); if (classNode != null) { MethodNode ovrdMth = searchOverriddenMethod(classNode, signature); - if (ovrdMth != null && isMethodVisibleInCls(ovrdMth, cls)) { - overrideList.add(ovrdMth); - MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE); - if (attr != null) { - return buildOverrideAttr(mth, overrideList, attr); + if (ovrdMth != null) { + if (isMethodVisibleInCls(ovrdMth, cls)) { + overrideList.add(ovrdMth); + MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE); + if (attr != null) { + addBaseMethod(superData, overrideList, baseMethods, superType); + return buildOverrideAttr(mth, overrideList, baseMethods, attr); + } } } } else { - ClspClass clsDetails = cls.root().getClsp().getClsDetails(superType); + ClspClass clsDetails = mth.root().getClsp().getClsDetails(superType); if (clsDetails != null) { Map methodsMap = clsDetails.getMethodsMap(); for (Map.Entry entry : methodsMap.entrySet()) { String mthShortId = entry.getKey(); if (mthShortId.startsWith(signature)) { overrideList.add(entry.getValue()); + break; } } } } + addBaseMethod(superData, overrideList, baseMethods, superType); + } + return buildOverrideAttr(mth, overrideList, baseMethods, null); + } + + private void addBaseMethod(SuperTypesData superData, List overrideList, Set baseMethods, + ArgType superType) { + if (superData.getEndTypes().contains(superType.getObject())) { + IMethodDetails last = Utils.last(overrideList); + if (last != null) { + baseMethods.add(last); + } } - return buildOverrideAttr(mth, overrideList, null); } @Nullable @@ -124,22 +141,24 @@ public class OverrideMethodVisitor extends AbstractVisitor { @Nullable private MethodOverrideAttr buildOverrideAttr(MethodNode mth, List overrideList, - @Nullable MethodOverrideAttr attr) { + Set baseMethods, @Nullable MethodOverrideAttr attr) { if (overrideList.isEmpty() && attr == null) { return null; } if (attr == null) { // traced to base method List cleanOverrideList = overrideList.stream().distinct().collect(Collectors.toList()); - return applyOverrideAttr(mth, cleanOverrideList, false); + return applyOverrideAttr(mth, cleanOverrideList, baseMethods, false); } // trace stopped at already processed method -> start merging List mergedOverrideList = Utils.mergeLists(overrideList, attr.getOverrideList()); List cleanOverrideList = mergedOverrideList.stream().distinct().collect(Collectors.toList()); - return applyOverrideAttr(mth, cleanOverrideList, true); + Set mergedBaseMethods = Utils.mergeSets(baseMethods, attr.getBaseMethods()); + return applyOverrideAttr(mth, cleanOverrideList, mergedBaseMethods, true); } - private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List overrideList, boolean update) { + private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List overrideList, + Set baseMethods, boolean update) { // don't rename method if override list contains not resolved method boolean dontRename = overrideList.stream().anyMatch(m -> !(m instanceof MethodNode)); SortedSet relatedMethods = null; @@ -189,10 +208,10 @@ public class OverrideMethodVisitor extends AbstractVisitor { continue; } } - mthNode.addAttr(new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods)); + mthNode.addAttr(new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods, baseMethods)); depth++; } - return new MethodOverrideAttr(overrideList, relatedMethods); + return new MethodOverrideAttr(overrideList, relatedMethods, baseMethods); } @NotNull @@ -222,52 +241,92 @@ public class OverrideMethodVisitor extends AbstractVisitor { return Objects.equals(superMth.getParentClass().getPackage(), cls.getPackage()); } - private List collectSuperTypes(ClassNode cls) { - Map superTypes = new LinkedHashMap<>(); - collectSuperTypes(cls, superTypes); + private static final class SuperTypesData { + private final List superTypes; + private final Set endTypes; + + private SuperTypesData(List superTypes, Set endTypes) { + this.superTypes = superTypes; + this.endTypes = endTypes; + } + + public List getSuperTypes() { + return superTypes; + } + + public Set getEndTypes() { + return endTypes; + } + } + + @Nullable + private SuperTypesData collectSuperTypes(ClassNode cls) { + List superTypes = new ArrayList<>(); + Set endTypes = new HashSet<>(); + collectSuperTypes(cls, superTypes, endTypes); if (superTypes.isEmpty()) { - return Collections.emptyList(); + return null; + } + if (endTypes.isEmpty()) { + throw new JadxRuntimeException("No end types in class hierarchy: " + cls); } - return new ArrayList<>(superTypes.values()); + return new SuperTypesData(superTypes, endTypes); } - private void collectSuperTypes(ClassNode cls, Map superTypes) { + private void collectSuperTypes(ClassNode cls, List superTypes, Set endTypes) { RootNode root = cls.root(); + int k = 0; ArgType superClass = cls.getSuperClass(); - if (superClass != null && !Objects.equals(superClass, ArgType.OBJECT)) { - addSuperType(root, superTypes, superClass); + if (superClass != null) { + k += addSuperType(root, superTypes, endTypes, superClass); } for (ArgType iface : cls.getInterfaces()) { - addSuperType(root, superTypes, iface); + k += addSuperType(root, superTypes, endTypes, iface); + } + if (k == 0) { + endTypes.add(cls.getType().getObject()); } } - private void addSuperType(RootNode root, Map superTypesMap, ArgType superType) { - superTypesMap.put(superType.getObject(), superType); + private int addSuperType(RootNode root, List superTypesMap, Set endTypes, ArgType superType) { + if (Objects.equals(superType, ArgType.OBJECT)) { + return 0; + } + superTypesMap.add(superType); ClassNode classNode = root.resolveClass(superType); - if (classNode == null) { - for (String superCls : root.getClsp().getSuperTypes(superType.getObject())) { - ArgType type = ArgType.object(superCls); - superTypesMap.put(type.getObject(), type); + if (classNode != null) { + collectSuperTypes(classNode, superTypesMap, endTypes); + return 1; + } + ClspClass clsDetails = root.getClsp().getClsDetails(superType); + if (clsDetails != null) { + int k = 0; + for (ArgType parentType : clsDetails.getParents()) { + k += addSuperType(root, superTypesMap, endTypes, parentType); } - } else { - collectSuperTypes(classNode, superTypesMap); + if (k == 0) { + endTypes.add(superType.getObject()); + } + return 1; } + // no info found => treat as hierarchy end + endTypes.add(superType.getObject()); + return 1; } - private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, List superTypes) { + private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) { ArgType returnType = mth.getReturnType(); if (returnType == ArgType.VOID) { return false; } - boolean updated = updateReturnType(mth, baseMth, superTypes); + boolean updated = updateReturnType(mth, baseMth, superData); if (updated) { mth.addDebugComment("Return type fixed from '" + returnType + "' to match base method"); } return updated; } - private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, List superTypes) { + private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) { ArgType baseReturnType = baseMth.getReturnType(); if (mth.getReturnType().equals(baseReturnType)) { return false; @@ -277,7 +336,7 @@ public class OverrideMethodVisitor extends AbstractVisitor { } TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare(); ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType(); - for (ArgType superType : superTypes) { + for (ArgType superType : superData.getSuperTypes()) { TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls); if (compareResult == TypeCompareEnum.NARROW_BY_GENERIC) { ArgType targetRetType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseReturnType); @@ -292,7 +351,7 @@ public class OverrideMethodVisitor extends AbstractVisitor { return false; } - private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, List superTypes) { + private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) { List mthArgTypes = mth.getArgTypes(); List baseArgTypes = baseMth.getArgTypes(); if (mthArgTypes.equals(baseArgTypes)) { @@ -305,7 +364,7 @@ public class OverrideMethodVisitor extends AbstractVisitor { boolean changed = false; List newArgTypes = new ArrayList<>(argCount); for (int argNum = 0; argNum < argCount; argNum++) { - ArgType newType = updateArgType(mth, baseMth, superTypes, argNum); + ArgType newType = updateArgType(mth, baseMth, superData, argNum); if (newType != null) { changed = true; newArgTypes.add(newType); @@ -319,7 +378,7 @@ public class OverrideMethodVisitor extends AbstractVisitor { return changed; } - private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, List superTypes, int argNum) { + private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData, int argNum) { ArgType arg = mth.getArgTypes().get(argNum); ArgType baseArg = baseMth.getArgTypes().get(argNum); if (arg.equals(baseArg)) { @@ -330,7 +389,7 @@ public class OverrideMethodVisitor extends AbstractVisitor { } TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare(); ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType(); - for (ArgType superType : superTypes) { + for (ArgType superType : superData.getSuperTypes()) { TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls); if (compareResult == TypeCompareEnum.NARROW_BY_GENERIC) { ArgType targetArgType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseArg); diff --git a/jadx-core/src/main/java/jadx/core/utils/Utils.java b/jadx-core/src/main/java/jadx/core/utils/Utils.java index 38576a90..040373e0 100644 --- a/jadx-core/src/main/java/jadx/core/utils/Utils.java +++ b/jadx-core/src/main/java/jadx/core/utils/Utils.java @@ -8,10 +8,12 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.Function; import org.jetbrains.annotations.Nullable; @@ -284,6 +286,19 @@ public class Utils { return result; } + public static Set mergeSets(Set first, Set second) { + if (isEmpty(first)) { + return second; + } + if (isEmpty(second)) { + return first; + } + Set result = new HashSet<>(first.size() + second.size()); + result.addAll(first); + result.addAll(second); + return result; + } + public static Map newConstStringMap(String... parameters) { int len = parameters.length; if (len == 0) { @@ -338,6 +353,14 @@ public class Utils { return list.get(0); } + @Nullable + public static T getOne(@Nullable Collection collection) { + if (collection == null || collection.size() != 1) { + return null; + } + return collection.iterator().next(); + } + @Nullable public static T first(List list) { if (list.isEmpty()) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases.java new file mode 100644 index 00000000..ee42a7d2 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases.java @@ -0,0 +1,34 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestOverrideWithTwoBases extends IntegrationTest { + + public static class TestCls { + public abstract static class BaseClass { + public abstract int a(); + } + + public interface I { + int a(); + } + + public static class Cls extends BaseClass implements I { + @Override + public int a() { + return 2; + } + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .containsOne("@Override"); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases2.java b/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases2.java new file mode 100644 index 00000000..acd4a9cd --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/others/TestOverrideWithTwoBases2.java @@ -0,0 +1,33 @@ +package jadx.tests.integration.others; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestOverrideWithTwoBases2 extends IntegrationTest { + + public static class TestCls { + public interface I { + int a(); + } + + public abstract static class BaseCls implements I { + } + + public static class Cls extends BaseCls implements I { + @Override + public int a() { + return 2; + } + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .containsOne("@Override"); + } +} -- GitLab