未验证 提交 a7c63c2e 编写于 作者: S Skylot

fix: handle method override with several bases (#1234)

上级 081a0e21
......@@ -180,7 +180,7 @@ public class MethodGen {
if (overrideAttr == null) {
return;
}
if (!overrideAttr.isAtBaseMth()) {
if (!overrideAttr.getBaseMethods().contains(mth)) {
code.startLine("@Override");
if (mth.checkCommentsLevel(CommentsLevel.INFO)) {
code.add(" // ");
......
package jadx.core.dex.attributes.nodes;
import java.util.List;
import java.util.Set;
import java.util.SortedSet;
import org.jetbrains.annotations.Nullable;
import jadx.api.plugins.input.data.attributes.PinnedAttribute;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.Utils;
public class MethodOverrideAttr extends PinnedAttribute {
......@@ -23,32 +21,26 @@ public class MethodOverrideAttr extends PinnedAttribute {
*/
private SortedSet<MethodNode> relatedMthNodes;
public MethodOverrideAttr(List<IMethodDetails> overrideList, SortedSet<MethodNode> relatedMthNodes) {
private Set<IMethodDetails> baseMethods;
public MethodOverrideAttr(List<IMethodDetails> overrideList, SortedSet<MethodNode> relatedMthNodes, Set<IMethodDetails> baseMethods) {
this.overrideList = overrideList;
this.relatedMthNodes = relatedMthNodes;
}
public boolean isAtBaseMth() {
return overrideList.isEmpty();
}
@Nullable
public IMethodDetails getBaseMth() {
return Utils.last(overrideList);
this.baseMethods = baseMethods;
}
public List<IMethodDetails> getOverrideList() {
return overrideList;
}
public void setOverrideList(List<IMethodDetails> overrideList) {
this.overrideList = overrideList;
}
public SortedSet<MethodNode> getRelatedMthNodes() {
return relatedMthNodes;
}
public Set<IMethodDetails> getBaseMethods() {
return baseMethods;
}
public void setRelatedMthNodes(SortedSet<MethodNode> relatedMthNodes) {
this.relatedMthNodes = relatedMthNodes;
}
......@@ -60,6 +52,6 @@ public class MethodOverrideAttr extends PinnedAttribute {
@Override
public String toString() {
return "METHOD_OVERRIDE: " + overrideList;
return "METHOD_OVERRIDE: " + getBaseMethods();
}
}
......@@ -18,6 +18,7 @@ import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.Utils;
public class MethodUtils {
private final RootNode root;
......@@ -132,7 +133,7 @@ public class MethodUtils {
if (overrideAttr == null) {
return null;
}
return overrideAttr.getBaseMth();
return Utils.getOne(overrideAttr.getBaseMethods());
}
public ClassInfo getMethodOriginDeclClass(MethodNode mth) {
......
package jadx.core.dex.visitors;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
......@@ -33,6 +33,7 @@ import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.exceptions.JadxRuntimeException;
@JadxVisitor(
name = "OverrideMethodVisitor",
......@@ -46,70 +47,86 @@ public class OverrideMethodVisitor extends AbstractVisitor {
@Override
public boolean visit(ClassNode cls) throws JadxException {
processCls(cls);
return true;
}
private void processCls(ClassNode cls) {
List<ArgType> superTypes = collectSuperTypes(cls);
if (!superTypes.isEmpty()) {
SuperTypesData superData = collectSuperTypes(cls);
if (superData != null) {
for (MethodNode mth : cls.getMethods()) {
processMth(cls, superTypes, mth);
processMth(mth, superData);
}
}
return true;
}
private void processMth(ClassNode cls, List<ArgType> superTypes, MethodNode mth) {
private void processMth(MethodNode mth, SuperTypesData superData) {
if (mth.isConstructor() || mth.getAccessFlags().isStatic() || mth.getAccessFlags().isPrivate()) {
return;
}
MethodOverrideAttr attr = processOverrideMethods(cls, mth, superTypes);
MethodOverrideAttr attr = processOverrideMethods(mth, superData);
if (attr != null) {
if (attr.getBaseMethods().isEmpty()) {
throw new JadxRuntimeException("No base methods for override attribute: " + attr.getOverrideList());
}
mth.addAttr(attr);
IMethodDetails baseMth = attr.getBaseMth();
IMethodDetails baseMth = Utils.getOne(attr.getBaseMethods());
if (baseMth != null) {
boolean updated = fixMethodReturnType(mth, baseMth, superTypes);
updated |= fixMethodArgTypes(mth, baseMth, superTypes);
boolean updated = fixMethodReturnType(mth, baseMth, superData);
updated |= fixMethodArgTypes(mth, baseMth, superData);
if (updated) {
// check if new signature cause method collisions
checkMethodSignatureCollisions(mth, cls.root().getArgs().isRenameValid());
checkMethodSignatureCollisions(mth, mth.root().getArgs().isRenameValid());
}
}
}
}
private MethodOverrideAttr processOverrideMethods(ClassNode cls, MethodNode mth, List<ArgType> superTypes) {
private MethodOverrideAttr processOverrideMethods(MethodNode mth, SuperTypesData superData) {
MethodOverrideAttr result = mth.get(AType.METHOD_OVERRIDE);
if (result != null) {
return result;
}
ClassNode cls = mth.getParentClass();
String signature = mth.getMethodInfo().makeSignature(false);
List<IMethodDetails> overrideList = new ArrayList<>();
for (ArgType superType : superTypes) {
ClassNode classNode = cls.root().resolveClass(superType);
Set<IMethodDetails> baseMethods = new HashSet<>();
for (ArgType superType : superData.getSuperTypes()) {
ClassNode classNode = mth.root().resolveClass(superType);
if (classNode != null) {
MethodNode ovrdMth = searchOverriddenMethod(classNode, signature);
if (ovrdMth != null && isMethodVisibleInCls(ovrdMth, cls)) {
overrideList.add(ovrdMth);
MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE);
if (attr != null) {
return buildOverrideAttr(mth, overrideList, attr);
if (ovrdMth != null) {
if (isMethodVisibleInCls(ovrdMth, cls)) {
overrideList.add(ovrdMth);
MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE);
if (attr != null) {
addBaseMethod(superData, overrideList, baseMethods, superType);
return buildOverrideAttr(mth, overrideList, baseMethods, attr);
}
}
}
} else {
ClspClass clsDetails = cls.root().getClsp().getClsDetails(superType);
ClspClass clsDetails = mth.root().getClsp().getClsDetails(superType);
if (clsDetails != null) {
Map<String, ClspMethod> methodsMap = clsDetails.getMethodsMap();
for (Map.Entry<String, ClspMethod> entry : methodsMap.entrySet()) {
String mthShortId = entry.getKey();
if (mthShortId.startsWith(signature)) {
overrideList.add(entry.getValue());
break;
}
}
}
}
addBaseMethod(superData, overrideList, baseMethods, superType);
}
return buildOverrideAttr(mth, overrideList, baseMethods, null);
}
private void addBaseMethod(SuperTypesData superData, List<IMethodDetails> overrideList, Set<IMethodDetails> baseMethods,
ArgType superType) {
if (superData.getEndTypes().contains(superType.getObject())) {
IMethodDetails last = Utils.last(overrideList);
if (last != null) {
baseMethods.add(last);
}
}
return buildOverrideAttr(mth, overrideList, null);
}
@Nullable
......@@ -124,22 +141,24 @@ public class OverrideMethodVisitor extends AbstractVisitor {
@Nullable
private MethodOverrideAttr buildOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList,
@Nullable MethodOverrideAttr attr) {
Set<IMethodDetails> baseMethods, @Nullable MethodOverrideAttr attr) {
if (overrideList.isEmpty() && attr == null) {
return null;
}
if (attr == null) {
// traced to base method
List<IMethodDetails> cleanOverrideList = overrideList.stream().distinct().collect(Collectors.toList());
return applyOverrideAttr(mth, cleanOverrideList, false);
return applyOverrideAttr(mth, cleanOverrideList, baseMethods, false);
}
// trace stopped at already processed method -> start merging
List<IMethodDetails> mergedOverrideList = Utils.mergeLists(overrideList, attr.getOverrideList());
List<IMethodDetails> cleanOverrideList = mergedOverrideList.stream().distinct().collect(Collectors.toList());
return applyOverrideAttr(mth, cleanOverrideList, true);
Set<IMethodDetails> mergedBaseMethods = Utils.mergeSets(baseMethods, attr.getBaseMethods());
return applyOverrideAttr(mth, cleanOverrideList, mergedBaseMethods, true);
}
private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList, boolean update) {
private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList,
Set<IMethodDetails> baseMethods, boolean update) {
// don't rename method if override list contains not resolved method
boolean dontRename = overrideList.stream().anyMatch(m -> !(m instanceof MethodNode));
SortedSet<MethodNode> relatedMethods = null;
......@@ -189,10 +208,10 @@ public class OverrideMethodVisitor extends AbstractVisitor {
continue;
}
}
mthNode.addAttr(new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods));
mthNode.addAttr(new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods, baseMethods));
depth++;
}
return new MethodOverrideAttr(overrideList, relatedMethods);
return new MethodOverrideAttr(overrideList, relatedMethods, baseMethods);
}
@NotNull
......@@ -222,52 +241,92 @@ public class OverrideMethodVisitor extends AbstractVisitor {
return Objects.equals(superMth.getParentClass().getPackage(), cls.getPackage());
}
private List<ArgType> collectSuperTypes(ClassNode cls) {
Map<String, ArgType> superTypes = new LinkedHashMap<>();
collectSuperTypes(cls, superTypes);
private static final class SuperTypesData {
private final List<ArgType> superTypes;
private final Set<String> endTypes;
private SuperTypesData(List<ArgType> superTypes, Set<String> endTypes) {
this.superTypes = superTypes;
this.endTypes = endTypes;
}
public List<ArgType> getSuperTypes() {
return superTypes;
}
public Set<String> getEndTypes() {
return endTypes;
}
}
@Nullable
private SuperTypesData collectSuperTypes(ClassNode cls) {
List<ArgType> superTypes = new ArrayList<>();
Set<String> endTypes = new HashSet<>();
collectSuperTypes(cls, superTypes, endTypes);
if (superTypes.isEmpty()) {
return Collections.emptyList();
return null;
}
if (endTypes.isEmpty()) {
throw new JadxRuntimeException("No end types in class hierarchy: " + cls);
}
return new ArrayList<>(superTypes.values());
return new SuperTypesData(superTypes, endTypes);
}
private void collectSuperTypes(ClassNode cls, Map<String, ArgType> superTypes) {
private void collectSuperTypes(ClassNode cls, List<ArgType> superTypes, Set<String> endTypes) {
RootNode root = cls.root();
int k = 0;
ArgType superClass = cls.getSuperClass();
if (superClass != null && !Objects.equals(superClass, ArgType.OBJECT)) {
addSuperType(root, superTypes, superClass);
if (superClass != null) {
k += addSuperType(root, superTypes, endTypes, superClass);
}
for (ArgType iface : cls.getInterfaces()) {
addSuperType(root, superTypes, iface);
k += addSuperType(root, superTypes, endTypes, iface);
}
if (k == 0) {
endTypes.add(cls.getType().getObject());
}
}
private void addSuperType(RootNode root, Map<String, ArgType> superTypesMap, ArgType superType) {
superTypesMap.put(superType.getObject(), superType);
private int addSuperType(RootNode root, List<ArgType> superTypesMap, Set<String> endTypes, ArgType superType) {
if (Objects.equals(superType, ArgType.OBJECT)) {
return 0;
}
superTypesMap.add(superType);
ClassNode classNode = root.resolveClass(superType);
if (classNode == null) {
for (String superCls : root.getClsp().getSuperTypes(superType.getObject())) {
ArgType type = ArgType.object(superCls);
superTypesMap.put(type.getObject(), type);
if (classNode != null) {
collectSuperTypes(classNode, superTypesMap, endTypes);
return 1;
}
ClspClass clsDetails = root.getClsp().getClsDetails(superType);
if (clsDetails != null) {
int k = 0;
for (ArgType parentType : clsDetails.getParents()) {
k += addSuperType(root, superTypesMap, endTypes, parentType);
}
} else {
collectSuperTypes(classNode, superTypesMap);
if (k == 0) {
endTypes.add(superType.getObject());
}
return 1;
}
// no info found => treat as hierarchy end
endTypes.add(superType.getObject());
return 1;
}
private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
ArgType returnType = mth.getReturnType();
if (returnType == ArgType.VOID) {
return false;
}
boolean updated = updateReturnType(mth, baseMth, superTypes);
boolean updated = updateReturnType(mth, baseMth, superData);
if (updated) {
mth.addDebugComment("Return type fixed from '" + returnType + "' to match base method");
}
return updated;
}
private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
ArgType baseReturnType = baseMth.getReturnType();
if (mth.getReturnType().equals(baseReturnType)) {
return false;
......@@ -277,7 +336,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
}
TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
for (ArgType superType : superTypes) {
for (ArgType superType : superData.getSuperTypes()) {
TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
if (compareResult == TypeCompareEnum.NARROW_BY_GENERIC) {
ArgType targetRetType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseReturnType);
......@@ -292,7 +351,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
return false;
}
private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData) {
List<ArgType> mthArgTypes = mth.getArgTypes();
List<ArgType> baseArgTypes = baseMth.getArgTypes();
if (mthArgTypes.equals(baseArgTypes)) {
......@@ -305,7 +364,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
boolean changed = false;
List<ArgType> newArgTypes = new ArrayList<>(argCount);
for (int argNum = 0; argNum < argCount; argNum++) {
ArgType newType = updateArgType(mth, baseMth, superTypes, argNum);
ArgType newType = updateArgType(mth, baseMth, superData, argNum);
if (newType != null) {
changed = true;
newArgTypes.add(newType);
......@@ -319,7 +378,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
return changed;
}
private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes, int argNum) {
private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, SuperTypesData superData, int argNum) {
ArgType arg = mth.getArgTypes().get(argNum);
ArgType baseArg = baseMth.getArgTypes().get(argNum);
if (arg.equals(baseArg)) {
......@@ -330,7 +389,7 @@ public class OverrideMethodVisitor extends AbstractVisitor {
}
TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
for (ArgType superType : superTypes) {
for (ArgType superType : superData.getSuperTypes()) {
TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
if (compareResult == TypeCompareEnum.NARROW_BY_GENERIC) {
ArgType targetArgType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseArg);
......
......@@ -8,10 +8,12 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import org.jetbrains.annotations.Nullable;
......@@ -284,6 +286,19 @@ public class Utils {
return result;
}
public static <T> Set<T> mergeSets(Set<T> first, Set<T> second) {
if (isEmpty(first)) {
return second;
}
if (isEmpty(second)) {
return first;
}
Set<T> result = new HashSet<>(first.size() + second.size());
result.addAll(first);
result.addAll(second);
return result;
}
public static Map<String, String> newConstStringMap(String... parameters) {
int len = parameters.length;
if (len == 0) {
......@@ -338,6 +353,14 @@ public class Utils {
return list.get(0);
}
@Nullable
public static <T> T getOne(@Nullable Collection<T> collection) {
if (collection == null || collection.size() != 1) {
return null;
}
return collection.iterator().next();
}
@Nullable
public static <T> T first(List<T> list) {
if (list.isEmpty()) {
......
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestOverrideWithTwoBases extends IntegrationTest {
public static class TestCls {
public abstract static class BaseClass {
public abstract int a();
}
public interface I {
int a();
}
public static class Cls extends BaseClass implements I {
@Override
public int a() {
return 2;
}
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("@Override");
}
}
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestOverrideWithTwoBases2 extends IntegrationTest {
public static class TestCls {
public interface I {
int a();
}
public abstract static class BaseCls implements I {
}
public static class Cls extends BaseCls implements I {
@Override
public int a() {
return 2;
}
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("@Override");
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册