未验证 提交 cb741db6 编写于 作者: S Skylot

fix: improve usage search, refactor java nodes creation (#1489)

上级 1df217c4
......@@ -99,7 +99,7 @@ public final class JadxDecompiler implements Closeable {
private final Map<MethodNode, JavaMethod> methodsMap = new ConcurrentHashMap<>();
private final Map<FieldNode, JavaField> fieldsMap = new ConcurrentHashMap<>();
private final IDecompileScheduler decompileScheduler = new DecompilerScheduler(this);
private final IDecompileScheduler decompileScheduler = new DecompilerScheduler();
private final List<ILoadResult> customLoads = new ArrayList<>();
......@@ -202,6 +202,7 @@ public final class JadxDecompiler implements Closeable {
}
}
@SuppressWarnings("unused")
public void registerPlugin(JadxPlugin plugin) {
pluginManager.register(plugin);
}
......@@ -467,23 +468,10 @@ public final class JadxDecompiler implements Closeable {
return protoXmlParser;
}
private void loadJavaClass(JavaClass javaClass) {
javaClass.getMethods().forEach(mth -> methodsMap.put(mth.getMethodNode(), mth));
javaClass.getFields().forEach(fld -> fieldsMap.put(fld.getFieldNode(), fld));
for (JavaClass innerCls : javaClass.getInnerClasses()) {
classesMap.put(innerCls.getClassNode(), innerCls);
loadJavaClass(innerCls);
}
for (JavaClass inlinedCls : javaClass.getInlinedClasses()) {
classesMap.put(inlinedCls.getClassNode(), inlinedCls);
loadJavaClass(inlinedCls);
}
}
/**
* Get JavaClass by ClassNode without loading and decompilation
*/
@ApiStatus.Internal
JavaClass convertClassNode(ClassNode cls) {
return classesMap.compute(cls, (node, prevJavaCls) -> {
if (prevJavaCls != null && prevJavaCls.getClassNode() == cls) {
......@@ -497,66 +485,23 @@ public final class JadxDecompiler implements Closeable {
});
}
@Nullable("For not generated classes")
@ApiStatus.Internal
public JavaClass getJavaClassByNode(ClassNode cls) {
JavaClass javaClass = classesMap.get(cls);
if (javaClass != null && javaClass.getClassNode() == cls) {
return javaClass;
}
// load parent class if inner
ClassNode parentClass = cls.getTopParentClass();
if (parentClass.contains(AFlag.DONT_GENERATE)) {
return null;
}
JavaClass parentJavaClass = classesMap.get(parentClass);
if (parentJavaClass == null) {
getClasses();
parentJavaClass = classesMap.get(parentClass);
}
if (parentJavaClass != null) {
loadJavaClass(parentJavaClass);
javaClass = classesMap.get(cls);
if (javaClass != null) {
return javaClass;
}
}
// class or parent classes can be excluded from generation
if (cls.hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaClass not found by ClassNode: " + cls);
JavaField convertFieldNode(FieldNode field) {
return fieldsMap.computeIfAbsent(field, fldNode -> {
JavaClass parentCls = convertClassNode(fldNode.getParentClass());
return new JavaField(parentCls, fldNode);
});
}
@ApiStatus.Internal
@Nullable
public JavaMethod getJavaMethodByNode(MethodNode mth) {
JavaMethod javaMethod = methodsMap.get(mth);
if (javaMethod != null && javaMethod.getMethodNode() == mth) {
return javaMethod;
}
if (mth.contains(AFlag.DONT_GENERATE)) {
return null;
}
// parent class not loaded yet
ClassNode parentClass = mth.getParentClass();
ClassNode codeCls = getCodeParentClass(parentClass);
JavaClass javaClass = getJavaClassByNode(codeCls);
if (javaClass == null) {
return null;
}
loadJavaClass(javaClass);
javaMethod = methodsMap.get(mth);
if (javaMethod != null) {
return javaMethod;
}
if (parentClass.hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaMethod not found by MethodNode: " + mth);
JavaMethod convertMethodNode(MethodNode method) {
return methodsMap.computeIfAbsent(method, mthNode -> {
ClassNode codeCls = getCodeParentClass(mthNode.getParentClass());
return new JavaMethod(convertClassNode(codeCls), mthNode);
});
}
private ClassNode getCodeParentClass(ClassNode cls) {
private static ClassNode getCodeParentClass(ClassNode cls) {
ClassNode codeCls;
InlinedAttr inlinedAttr = cls.get(AType.INLINED);
if (inlinedAttr != null) {
......@@ -570,35 +515,12 @@ public final class JadxDecompiler implements Closeable {
return getCodeParentClass(codeCls);
}
@ApiStatus.Internal
@Nullable
public JavaField getJavaFieldByNode(FieldNode fld) {
JavaField javaField = fieldsMap.get(fld);
if (javaField != null && javaField.getFieldNode() == fld) {
return javaField;
}
// parent class not loaded yet
JavaClass javaClass = getJavaClassByNode(fld.getParentClass().getTopParentClass());
if (javaClass == null) {
return null;
}
loadJavaClass(javaClass);
javaField = fieldsMap.get(fld);
if (javaField != null) {
return javaField;
}
if (fld.getParentClass().hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaField not found by FieldNode: " + fld);
}
@Nullable
public JavaClass searchJavaClassByOrigFullName(String fullName) {
return getRoot().getClasses().stream()
.filter(cls -> cls.getClassInfo().getFullName().equals(fullName))
.findFirst()
.map(this::getJavaClassByNode)
.map(this::convertClassNode)
.orElse(null);
}
......@@ -619,9 +541,9 @@ public final class JadxDecompiler implements Closeable {
.orElse(null);
if (node != null) {
if (node.contains(AFlag.DONT_GENERATE)) {
return getJavaClassByNode(node.getTopParentClass());
return convertClassNode(node.getTopParentClass());
} else {
return getJavaClassByNode(node);
return convertClassNode(node);
}
}
return null;
......@@ -632,7 +554,7 @@ public final class JadxDecompiler implements Closeable {
return getRoot().getClasses().stream()
.filter(cls -> cls.getClassInfo().getAliasFullName().equals(fullName))
.findFirst()
.map(this::getJavaClassByNode)
.map(this::convertClassNode)
.orElse(null);
}
......@@ -650,9 +572,9 @@ public final class JadxDecompiler implements Closeable {
case CLASS:
return convertClassNode((ClassNode) ann);
case METHOD:
return getJavaMethodByNode((MethodNode) ann);
return convertMethodNode((MethodNode) ann);
case FIELD:
return getJavaFieldByNode((FieldNode) ann);
return convertFieldNode((FieldNode) ann);
case DECLARATION:
return getJavaNodeByCodeAnnotation(codeInfo, ((NodeDeclareRef) ann).getNode());
case VAR:
......@@ -670,7 +592,7 @@ public final class JadxDecompiler implements Closeable {
@Nullable
private JavaVariable resolveVarNode(VarNode varNode) {
MethodNode mthNode = varNode.getMth();
JavaMethod mth = getJavaMethodByNode(mthNode);
JavaMethod mth = convertMethodNode(mthNode);
if (mth == null) {
return null;
}
......
......@@ -6,7 +6,6 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
......@@ -15,7 +14,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.metadata.ICodeAnnotation;
import jadx.api.metadata.ICodeAnnotation.AnnType;
import jadx.api.metadata.ICodeNodeRef;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
......@@ -24,6 +22,7 @@ import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.ListUtils;
public final class JavaClass implements JavaNode {
private static final Logger LOG = LoggerFactory.getLogger(JavaClass.class);
......@@ -88,6 +87,14 @@ public final class JavaClass implements JavaNode {
return cls.getDisassembledCode();
}
@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.CLASS) {
return ann.equals(cls);
}
return false;
}
/**
* Internal API. Not Stable!
*/
......@@ -99,7 +106,6 @@ public final class JavaClass implements JavaNode {
/**
* Decompile class and loads internal lists of fields, methods, etc.
* Do nothing if already loaded.
* Return not null on first call only (for actual loading)
*/
@Nullable
private synchronized void load() {
......@@ -140,10 +146,9 @@ public final class JavaClass implements JavaNode {
if (fieldsCount != 0) {
List<JavaField> flds = new ArrayList<>(fieldsCount);
for (FieldNode f : cls.getFields()) {
// if (!f.contains(AFlag.DONT_GENERATE)) {
JavaField javaField = new JavaField(this, f);
flds.add(javaField);
// }
if (!f.contains(AFlag.DONT_GENERATE)) {
flds.add(rootDecompiler.convertFieldNode(f));
}
}
this.fields = Collections.unmodifiableList(flds);
}
......@@ -153,8 +158,7 @@ public final class JavaClass implements JavaNode {
List<JavaMethod> mths = new ArrayList<>(methodsCount);
for (MethodNode m : cls.getMethods()) {
if (!m.contains(AFlag.DONT_GENERATE)) {
JavaMethod javaMethod = new JavaMethod(this, m);
mths.add(javaMethod);
mths.add(rootDecompiler.convertMethodNode(m));
}
}
mths.sort(Comparator.comparing(JavaMethod::getName));
......@@ -162,7 +166,7 @@ public final class JavaClass implements JavaNode {
}
}
protected JadxDecompiler getRootDecompiler() {
JadxDecompiler getRootDecompiler() {
if (parent != null) {
return parent.getRootDecompiler();
}
......@@ -193,27 +197,16 @@ public final class JavaClass implements JavaNode {
}
public List<Integer> getUsePlacesFor(ICodeInfo codeInfo, JavaNode javaNode) {
Map<Integer, ICodeAnnotation> map = codeInfo.getCodeMetadata().getAsMap();
if (map.isEmpty() || decompiler == null) {
if (!codeInfo.hasMetadata()) {
return Collections.emptyList();
}
JadxDecompiler rootDec = getRootDecompiler();
List<Integer> result = new ArrayList<>();
for (Map.Entry<Integer, ICodeAnnotation> entry : map.entrySet()) {
ICodeAnnotation ann = entry.getValue();
AnnType annType = ann.getAnnType();
if (annType == AnnType.DECLARATION || annType == AnnType.OFFSET) {
// ignore declarations and offset annotations
continue;
}
JavaNode annNode = rootDec.getJavaNodeByCodeAnnotation(codeInfo, ann);
if (annNode == null && LOG.isDebugEnabled()) {
LOG.debug("Failed to resolve code annotation, cls: {}, pos: {}, ann: {}", this, entry.getKey(), ann);
codeInfo.getCodeMetadata().searchDown(0, (pos, ann) -> {
if (javaNode.isOwnCodeAnnotation(ann)) {
result.add(pos);
}
if (Objects.equals(annNode, javaNode)) {
result.add(entry.getKey());
}
}
return null;
});
return result;
}
......@@ -294,7 +287,16 @@ public final class JavaClass implements JavaNode {
if (methodNode == null) {
return null;
}
return new JavaMethod(this, methodNode);
return getRootDecompiler().convertMethodNode(methodNode);
}
public List<JavaClass> getDependencies() {
JadxDecompiler d = getRootDecompiler();
return ListUtils.map(cls.getDependencies(), d::convertClassNode);
}
public int getTotalDepsCount() {
return cls.getTotalDepsCount();
}
@Override
......
......@@ -4,6 +4,7 @@ import java.util.List;
import org.jetbrains.annotations.ApiStatus;
import jadx.api.metadata.ICodeAnnotation;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.FieldNode;
......@@ -65,6 +66,14 @@ public final class JavaField implements JavaNode {
this.field.getFieldInfo().removeAlias();
}
@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.FIELD) {
return ann.equals(field);
}
return false;
}
/**
* Internal API. Not Stable!
*/
......
......@@ -9,6 +9,7 @@ import org.jetbrains.annotations.ApiStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.metadata.ICodeAnnotation;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.info.AccessInfo;
......@@ -18,6 +19,7 @@ import jadx.core.utils.Utils;
public final class JavaMethod implements JavaNode {
private static final Logger LOG = LoggerFactory.getLogger(JavaMethod.class);
private final MethodNode mth;
private final JavaClass parent;
......@@ -78,7 +80,7 @@ public final class JavaMethod implements JavaNode {
JadxDecompiler decompiler = getDeclaringClass().getRootDecompiler();
return ovrdAttr.getRelatedMthNodes().stream()
.map(m -> {
JavaMethod javaMth = decompiler.getJavaMethodByNode(m);
JavaMethod javaMth = decompiler.convertMethodNode(m);
if (javaMth == null) {
LOG.warn("Failed convert to java method: {}", m);
}
......@@ -106,6 +108,14 @@ public final class JavaMethod implements JavaNode {
this.mth.getMethodInfo().removeAlias();
}
@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.METHOD) {
return ann.equals(mth);
}
return false;
}
/**
* Internal API. Not Stable!
*/
......
......@@ -2,6 +2,8 @@ package jadx.api;
import java.util.List;
import jadx.api.metadata.ICodeAnnotation;
public interface JavaNode {
String getName();
......@@ -18,4 +20,6 @@ public interface JavaNode {
default void removeAlias() {
}
boolean isOwnCodeAnnotation(ICodeAnnotation ann);
}
......@@ -5,6 +5,8 @@ import java.util.List;
import org.jetbrains.annotations.NotNull;
import jadx.api.metadata.ICodeAnnotation;
public final class JavaPackage implements JavaNode, Comparable<JavaPackage> {
private final String name;
private final List<JavaClass> classes;
......@@ -49,6 +51,11 @@ public final class JavaPackage implements JavaNode, Comparable<JavaPackage> {
return Collections.emptyList();
}
@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
return false;
}
@Override
public int compareTo(@NotNull JavaPackage o) {
return name.compareTo(o.name);
......
......@@ -5,7 +5,9 @@ import java.util.List;
import org.jetbrains.annotations.ApiStatus;
import jadx.api.metadata.ICodeAnnotation;
import jadx.api.metadata.annotations.VarNode;
import jadx.api.metadata.annotations.VarRef;
import jadx.core.dex.instructions.args.ArgType;
public class JavaVariable implements JavaNode {
......@@ -68,6 +70,15 @@ public class JavaVariable implements JavaNode {
return Collections.singletonList(mth);
}
@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.VAR_REF) {
VarRef varRef = (VarRef) ann;
return varRef.getRefPos() == getDefPos();
}
return false;
}
@Override
public int hashCode() {
return varNode.hashCode();
......
......@@ -13,9 +13,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.api.IDecompileScheduler;
import jadx.api.JadxDecompiler;
import jadx.api.JavaClass;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class DecompilerScheduler implements IDecompileScheduler {
......@@ -24,18 +22,11 @@ public class DecompilerScheduler implements IDecompileScheduler {
private static final int MERGED_BATCH_SIZE = 16;
private static final boolean DEBUG_BATCHES = false;
private final JadxDecompiler decompiler;
public DecompilerScheduler(JadxDecompiler decompiler) {
this.decompiler = decompiler;
}
@Override
public List<List<JavaClass>> buildBatches(List<JavaClass> classes) {
try {
long start = System.currentTimeMillis();
List<List<ClassNode>> batches = internalBatches(Utils.collectionMap(classes, JavaClass::getClassNode));
List<List<JavaClass>> result = Utils.collectionMap(batches, l -> Utils.collectionMapNoNull(l, decompiler::getJavaClassByNode));
List<List<JavaClass>> result = internalBatches(classes);
if (LOG.isDebugEnabled()) {
LOG.debug("Build decompilation batches in {}ms", System.currentTimeMillis() - start);
}
......@@ -53,14 +44,14 @@ public class DecompilerScheduler implements IDecompileScheduler {
* Put classes with many dependencies at the end.
* Build batches for dependencies of single class to avoid locking from another thread.
*/
public List<List<ClassNode>> internalBatches(List<ClassNode> classes) {
public List<List<JavaClass>> internalBatches(List<JavaClass> classes) {
List<DepInfo> deps = sumDependencies(classes);
Set<ClassNode> added = new HashSet<>(classes.size());
Comparator<ClassNode> cmpDepSize = Comparator.comparingInt(ClassNode::getTotalDepsCount);
List<List<ClassNode>> result = new ArrayList<>();
List<ClassNode> mergedBatch = new ArrayList<>(MERGED_BATCH_SIZE);
Set<JavaClass> added = new HashSet<>(classes.size());
Comparator<JavaClass> cmpDepSize = Comparator.comparingInt(JavaClass::getTotalDepsCount);
List<List<JavaClass>> result = new ArrayList<>();
List<JavaClass> mergedBatch = new ArrayList<>(MERGED_BATCH_SIZE);
for (DepInfo depInfo : deps) {
ClassNode cls = depInfo.getCls();
JavaClass cls = depInfo.getCls();
if (!added.add(cls)) {
continue;
}
......@@ -73,9 +64,9 @@ public class DecompilerScheduler implements IDecompileScheduler {
mergedBatch = new ArrayList<>(MERGED_BATCH_SIZE);
}
} else {
List<ClassNode> batch = new ArrayList<>(depsSize + 1);
for (ClassNode dep : cls.getDependencies()) {
ClassNode topDep = dep.getTopParentClass();
List<JavaClass> batch = new ArrayList<>(depsSize + 1);
for (JavaClass dep : cls.getDependencies()) {
JavaClass topDep = dep.getTopParentClass();
if (!added.contains(topDep)) {
batch.add(topDep);
added.add(topDep);
......@@ -95,11 +86,11 @@ public class DecompilerScheduler implements IDecompileScheduler {
return result;
}
private static List<DepInfo> sumDependencies(List<ClassNode> classes) {
private static List<DepInfo> sumDependencies(List<JavaClass> classes) {
List<DepInfo> deps = new ArrayList<>(classes.size());
for (ClassNode cls : classes) {
for (JavaClass cls : classes) {
int count = 0;
for (ClassNode dep : cls.getDependencies()) {
for (JavaClass dep : cls.getDependencies()) {
count += 1 + dep.getTotalDepsCount();
}
deps.add(new DepInfo(cls, count));
......@@ -109,15 +100,15 @@ public class DecompilerScheduler implements IDecompileScheduler {
}
private static final class DepInfo implements Comparable<DepInfo> {
private final ClassNode cls;
private final JavaClass cls;
private final int depsCount;
private DepInfo(ClassNode cls, int depsCount) {
private DepInfo(JavaClass cls, int depsCount) {
this.cls = cls;
this.depsCount = depsCount;
}
public ClassNode getCls() {
public JavaClass getCls() {
return cls;
}
......@@ -129,7 +120,7 @@ public class DecompilerScheduler implements IDecompileScheduler {
public int compareTo(@NotNull DecompilerScheduler.DepInfo o) {
int deps = Integer.compare(depsCount, o.depsCount);
if (deps == 0) {
return cls.compareTo(o.cls);
return cls.getClassNode().compareTo(o.cls.getClassNode());
}
return deps;
}
......@@ -147,9 +138,9 @@ public class DecompilerScheduler implements IDecompileScheduler {
.collect(Collectors.toList());
}
private void dumpBatchesStats(List<ClassNode> classes, List<List<ClassNode>> result, List<DepInfo> deps) {
private void dumpBatchesStats(List<JavaClass> classes, List<List<JavaClass>> result, List<DepInfo> deps) {
double avg = result.stream().mapToInt(List::size).average().orElse(-1);
int maxSingleDeps = classes.stream().mapToInt(ClassNode::getTotalDepsCount).max().orElse(-1);
int maxSingleDeps = classes.stream().mapToInt(JavaClass::getTotalDepsCount).max().orElse(-1);
int maxSubDeps = deps.stream().mapToInt(DepInfo::getDepsCount).max().orElse(-1);
LOG.info("Batches stats:"
+ "\n input classes: " + classes.size()
......
package jadx.api;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
public class JadxInternalAccess {
......@@ -7,4 +10,16 @@ public class JadxInternalAccess {
public static RootNode getRoot(JadxDecompiler d) {
return d.getRoot();
}
public static JavaClass convertClassNode(JadxDecompiler d, ClassNode clsNode) {
return d.convertClassNode(clsNode);
}
public static JavaMethod convertMethodNode(JadxDecompiler d, MethodNode mthNode) {
return d.convertMethodNode(mthNode);
}
public static JavaField convertFieldNode(JadxDecompiler d, FieldNode fldNode) {
return d.convertFieldNode(fldNode);
}
}
package jadx.tests.integration.others;
import java.util.List;
import java.util.Objects;
import org.junit.jupiter.api.Test;
import jadx.api.JadxInternalAccess;
import jadx.api.JavaClass;
import jadx.api.JavaMethod;
import jadx.api.metadata.ICodeAnnotation;
......@@ -46,8 +46,8 @@ public class TestCodeMetadata extends IntegrationTest {
int callDefPos = callMth.getDefPosition();
assertThat(callDefPos).isNotZero();
JavaClass javaClass = Objects.requireNonNull(jadxDecompiler.getJavaClassByNode(cls));
JavaMethod callJavaMethod = Objects.requireNonNull(jadxDecompiler.getJavaMethodByNode(callMth));
JavaClass javaClass = JadxInternalAccess.convertClassNode(jadxDecompiler, cls);
JavaMethod callJavaMethod = JadxInternalAccess.convertMethodNode(jadxDecompiler, callMth);
List<Integer> callUsePlaces = javaClass.getUsePlacesFor(javaClass.getCodeInfo(), callJavaMethod);
assertThat(callUsePlaces).hasSize(1);
int callUse = callUsePlaces.get(0);
......
......@@ -20,7 +20,6 @@ import jadx.api.JadxDecompiler;
import jadx.api.JavaClass;
import jadx.api.JavaNode;
import jadx.api.metadata.ICodeAnnotation;
import jadx.core.dex.nodes.ClassNode;
import jadx.gui.settings.JadxProject;
import jadx.gui.treemodel.JClass;
import jadx.gui.treemodel.JNode;
......@@ -244,8 +243,8 @@ public final class CodeArea extends AbstractCodeArea {
ICodeInfo codeInfo = getCodeInfo();
if (codeInfo.hasMetadata()) {
ICodeAnnotation ann = codeInfo.getCodeMetadata().getAt(pos);
if (ann instanceof ClassNode) {
return getDecompiler().getJavaClassByNode(((ClassNode) ann));
if (ann != null && ann.getAnnType() == ICodeAnnotation.AnnType.CLASS) {
return (JavaClass) getDecompiler().getJavaNodeByCodeAnnotation(codeInfo, ann);
}
}
} catch (Exception e) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册