提交 2b7d7ce2 编写于 作者: S Skylot

fix: additional casts at use place to help type inference (#1002)

上级 a22efc2e
......@@ -67,7 +67,7 @@ public enum AFlag {
*/
EXPLICIT_PRIMITIVE_TYPE,
EXPLICIT_CAST,
SOFT_CAST, // synthetic cast to help type inference
SOFT_CAST, // synthetic cast to help type inference (allow unchecked casts for generics)
INCONSISTENT_CODE, // warning about incorrect decompilation
......
......@@ -206,6 +206,10 @@ public abstract class InsnArg extends Typed {
return arg;
}
public boolean isZeroLiteral() {
return isLiteral() && (((LiteralArg) this)).getLiteral() == 0;
}
public boolean isThis() {
return contains(AFlag.THIS);
}
......
......@@ -29,7 +29,6 @@ import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn;
......@@ -209,7 +208,7 @@ public class EnumVisitor extends AbstractVisitor {
case NEW_ARRAY:
InsnArg arg = wrappedInsn.getArg(0);
if (arg.isLiteral() && ((LiteralArg) arg).getLiteral() == 0) {
if (arg.isZeroLiteral()) {
// empty enum
return Collections.emptyList();
}
......
......@@ -260,8 +260,7 @@ public class SimplifyVisitor extends AbstractVisitor {
if (f.isInsnWrap()) {
InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
if (insn.getArg(1).isLiteral()
&& ((LiteralArg) insn.getArg(1)).getLiteral() == 0) {
if (insn.getArg(1).isZeroLiteral()) {
insn.changeCondition(insn.getOp(), wi.getArg(0), wi.getArg(1));
} else {
LOG.warn("TODO: cmp {}", insn);
......
......@@ -10,6 +10,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -332,6 +333,10 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
return invokeUseBound;
}
}
if (insn.getType() == InsnType.CHECK_CAST && insn.contains(AFlag.SOFT_CAST)) {
// ignore
return null;
}
return new TypeBoundConst(BoundEnum.USE, regArg.getInitType(), regArg);
}
......@@ -499,20 +504,30 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
if (insertAssignCast(mth, var, boundType)) {
return 1;
}
// TODO: check if use casts are needed
return 0;
return insertUseCasts(mth, var);
}
}
return 0;
}
private int insertUseCasts(MethodNode mth, SSAVar var) {
List<RegisterArg> useList = var.getUseList();
if (useList.isEmpty()) {
return 0;
}
int useCasts = 0;
for (RegisterArg useReg : new ArrayList<>(useList)) {
if (insertSoftUseCast(mth, useReg)) {
useCasts++;
}
}
return useCasts;
}
private boolean insertAssignCast(MethodNode mth, SSAVar var, ArgType castType) {
RegisterArg assignArg = var.getAssign();
InsnNode assignInsn = assignArg.getParentInsn();
if (assignInsn == null) {
return false;
}
if (assignInsn.getType() == InsnType.PHI) {
if (assignInsn == null || assignInsn.getType() == InsnType.PHI) {
return false;
}
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
......@@ -521,14 +536,38 @@ public final class TypeInferenceVisitor extends AbstractVisitor {
}
RegisterArg newAssignArg = assignArg.duplicateWithNewSSAVar(mth);
assignInsn.setResult(newAssignArg);
IndexInsnNode castInsn = makeSoftCastInsn(assignArg, newAssignArg, castType);
return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn);
}
private boolean insertSoftUseCast(MethodNode mth, RegisterArg useArg) {
InsnNode useInsn = useArg.getParentInsn();
if (useInsn == null || useInsn.getType() == InsnType.PHI) {
return false;
}
if (useInsn.getType() == InsnType.IF && useInsn.getArg(1).isZeroLiteral()) {
// cast not needed if compare with null
return false;
}
BlockNode useBlock = BlockUtils.getBlockByInsn(mth, useInsn);
if (useBlock == null) {
return false;
}
RegisterArg newUseArg = useArg.duplicateWithNewSSAVar(mth);
useInsn.replaceArg(useArg, newUseArg);
IndexInsnNode castInsn = makeSoftCastInsn(newUseArg, useArg, useArg.getInitType());
return BlockUtils.insertBeforeInsn(useBlock, useInsn, castInsn);
}
@NotNull
private IndexInsnNode makeSoftCastInsn(RegisterArg result, RegisterArg arg, ArgType castType) {
IndexInsnNode castInsn = new IndexInsnNode(InsnType.CHECK_CAST, castType, 1);
castInsn.setResult(assignArg.duplicate());
castInsn.addArg(newAssignArg.duplicate());
castInsn.setResult(result.duplicate());
castInsn.addArg(arg.duplicate());
castInsn.add(AFlag.SOFT_CAST);
castInsn.add(AFlag.SYNTHETIC);
return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn);
return castInsn;
}
private boolean trySplitConstInsns(MethodNode mth) {
......
......@@ -670,17 +670,33 @@ public class BlockUtils {
return false;
}
public static boolean insertBeforeInsn(BlockNode block, InsnNode insn, InsnNode newInsn) {
int index = getInsnIndexInBlock(block, insn);
if (index == -1) {
return false;
}
block.getInstructions().add(index, newInsn);
return true;
}
public static boolean insertAfterInsn(BlockNode block, InsnNode insn, InsnNode newInsn) {
int index = getInsnIndexInBlock(block, insn);
if (index == -1) {
return false;
}
block.getInstructions().add(index + 1, newInsn);
return true;
}
public static int getInsnIndexInBlock(BlockNode block, InsnNode insn) {
List<InsnNode> instructions = block.getInstructions();
int size = instructions.size();
for (int i = 0; i < size; i++) {
InsnNode instruction = instructions.get(i);
if (instruction == insn) {
instructions.add(i + 1, newInsn);
return true;
if (instructions.get(i) == insn) {
return i;
}
}
return false;
return -1;
}
public static boolean replaceInsn(MethodNode mth, InsnNode oldInsn, InsnNode newInsn) {
......
......@@ -169,4 +169,8 @@ public class DebugUtils {
LOG.debug(" {}: {}", entry.getKey(), entry.getValue());
}
}
public static void printStackTrace(String label) {
LOG.debug("StackTrace: {}\n{}", label, Utils.getStackTrace(new Exception()));
}
}
package jadx.tests.integration.types;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
/**
* Issue 1002
* Insertion of additional cast (at use place) needed for successful type inference
*/
public class TestTypeResolver16 extends SmaliTest {
// @formatter:off
/*
public final <T, K> List<T> test(List<? extends T> list, Set<? extends T> set, Function<? super T, ? extends K> function) {
checkParameterIsNotNull(function, "distinctBy");
if (set != null) {
List<? extends T> union = list != null ? union(list, set, function) : null;
if (union != null) {
list = union;
}
}
return list != null ? (List<T>) list : emptyList();
}
*/
// @formatter:on
@Test
public void test() {
assertThat(getClassNodeFromSmali())
.code()
.containsOne("(List<T>) list");
}
}
.class public Ltypes/TestTypeResolver16;
.super Ljava/lang/Object;
.method public final test(Ljava/util/List;Ljava/util/Set;Ljava/util/function/Function;)Ljava/util/List;
.locals 1
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
"K:",
"Ljava/lang/Object;",
">(",
"Ljava/util/List<",
"+TT;>;",
"Ljava/util/Set<",
"+TT;>;",
"Ljava/util/function/Function<",
"-TT;+TK;>;)",
"Ljava/util/List<",
"TT;>;"
}
.end annotation
const-string v0, "distinctBy"
invoke-static {p3, v0}, Ltypes/TestTypeResolver16;->checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V
if-eqz p2, :cond_1
if-eqz p1, :cond_0
.line 85
move-object v0, p1
check-cast v0, Ljava/util/Collection;
check-cast p2, Ljava/lang/Iterable;
invoke-static {v0, p2, p3}, Ltypes/TestTypeResolver16;->union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List;
move-result-object p2
goto :goto_0
:cond_0
const/4 p2, 0x0
:goto_0
if-eqz p2, :cond_1
move-object p1, p2
:cond_1
if-eqz p1, :cond_2
goto :goto_1
:cond_2
invoke-static {}, Ltypes/TestTypeResolver16;->emptyList()Ljava/util/List;
move-result-object p1
:goto_1
return-object p1
.end method
.method public static final union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List;
.locals 4
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
"K:",
"Ljava/lang/Object;",
">(",
"Ljava/util/Collection<",
"+TT;>;",
"Ljava/lang/Iterable<",
"+TT;>;",
"Ljava/util/function/Function<",
"-TT;+TK;>;)",
"Ljava/util/List<",
"TT;>;"
}
.end annotation
const/4 v0, 0x0
return-object v0
.end method
.method public static checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V
.locals 0
return-void
.end method
.method public static final emptyList()Ljava/util/List;
.locals 1
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
">()",
"Ljava/util/List<",
"TT;>;"
}
.end annotation
const/4 v0, 0x0
return-object v0
.end method
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册