提交 d1a6841c 编写于 作者: S Skylot

fix: inline assign in complex conditions (#699)

上级 600842a1
......@@ -104,8 +104,7 @@ public class InsnGen {
} else if (arg.isLiteral()) {
code.add(lit((LiteralArg) arg));
} else if (arg.isInsnWrap()) {
Flags flag = wrap ? Flags.BODY_ONLY : Flags.BODY_ONLY_NOWRAP;
makeInsn(((InsnWrapArg) arg).getWrapInsn(), code, flag);
addWrappedArg(code, (InsnWrapArg) arg, wrap);
} else if (arg.isNamed()) {
code.add(((Named) arg).getName());
} else {
......@@ -113,6 +112,18 @@ public class InsnGen {
}
}
private void addWrappedArg(CodeWriter code, InsnWrapArg arg, boolean wrap) throws CodegenException {
InsnNode wrapInsn = arg.getWrapInsn();
if (wrapInsn.contains(AFlag.FORCE_ASSIGN_INLINE)) {
code.add('(');
makeInsn(wrapInsn, code, Flags.INLINE);
code.add(')');
} else {
Flags flags = wrap ? Flags.BODY_ONLY : Flags.BODY_ONLY_NOWRAP;
makeInsn(wrapInsn, code, flags);
}
}
public void assignVar(CodeWriter code, InsnNode insn) throws CodegenException {
RegisterArg arg = insn.getResult();
if (insn.contains(AFlag.DECLARE_VAR)) {
......@@ -922,10 +933,7 @@ public class InsnGen {
if (parentInsn.contains(AFlag.WRAPPED)) {
return false;
}
if (callMthNode.getReturnType().equals(ArgType.VOID)) {
return false;
}
return true;
return !callMthNode.getReturnType().equals(ArgType.VOID);
}
private void makeTernary(TernaryInsn insn, CodeWriter code, Set<Flags> state) throws CodegenException {
......
......@@ -44,6 +44,11 @@ public enum AFlag {
*/
IMMUTABLE_TYPE,
/**
* Force inline instruction with inline assign
*/
FORCE_ASSIGN_INLINE,
CUSTOM_DECLARE, // variable for this register don't need declaration
DECLARE_VAR,
......
package jadx.core.dex.regions.conditions;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
public final class IfInfo {
private final IfCondition condition;
......@@ -11,32 +14,35 @@ public final class IfInfo {
private final BlockNode thenBlock;
private final BlockNode elseBlock;
private final Set<BlockNode> skipBlocks;
private final List<InsnNode> forceInlineInsns;
private BlockNode outBlock;
@Deprecated
private BlockNode ifBlock;
public IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock) {
this(condition, thenBlock, elseBlock, new HashSet<>(), new HashSet<>());
this(condition, thenBlock, elseBlock, new HashSet<>(), new HashSet<>(), new ArrayList<>());
}
public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) {
this(info.getCondition(), thenBlock, elseBlock, info.getMergedBlocks(), info.getSkipBlocks());
this(info.getCondition(), thenBlock, elseBlock,
info.getMergedBlocks(), info.getSkipBlocks(), info.getForceInlineInsns());
}
private IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock,
Set<BlockNode> mergedBlocks, Set<BlockNode> skipBlocks) {
Set<BlockNode> mergedBlocks, Set<BlockNode> skipBlocks, List<InsnNode> forceInlineInsns) {
this.condition = condition;
this.thenBlock = thenBlock;
this.elseBlock = elseBlock;
this.mergedBlocks = mergedBlocks;
this.skipBlocks = skipBlocks;
this.forceInlineInsns = forceInlineInsns;
}
public static IfInfo invert(IfInfo info) {
IfCondition invertedCondition = IfCondition.invert(info.getCondition());
IfInfo tmpIf = new IfInfo(invertedCondition,
info.getElseBlock(), info.getThenBlock(),
info.getMergedBlocks(), info.getSkipBlocks());
info.getMergedBlocks(), info.getSkipBlocks(), info.getForceInlineInsns());
tmpIf.setIfBlock(info.getIfBlock());
return tmpIf;
}
......@@ -45,6 +51,7 @@ public final class IfInfo {
for (IfInfo info : arr) {
mergedBlocks.addAll(info.getMergedBlocks());
skipBlocks.addAll(info.getSkipBlocks());
addInsnsForForcedInline(info.getForceInlineInsns());
}
}
......@@ -84,6 +91,18 @@ public final class IfInfo {
this.ifBlock = ifBlock;
}
public List<InsnNode> getForceInlineInsns() {
return forceInlineInsns;
}
public void resetForceInlineInsns() {
forceInlineInsns.clear();
}
public void addInsnsForForcedInline(List<InsnNode> insns) {
forceInlineInsns.addAll(insns);
}
@Override
public String toString() {
return "IfInfo: then: " + thenBlock + ", else: " + elseBlock;
......
package jadx.core.dex.visitors.regions;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -35,11 +37,13 @@ public class IfMakerHelper {
private IfMakerHelper() {
}
@Nullable
static IfInfo makeIfInfo(BlockNode ifBlock) {
IfNode ifNode = (IfNode) BlockUtils.getLastInsn(ifBlock);
if (ifNode == null) {
throw new JadxRuntimeException("Empty IF block: " + ifBlock);
InsnNode lastInsn = BlockUtils.getLastInsn(ifBlock);
if (lastInsn == null || lastInsn.getType() != InsnType.IF) {
return null;
}
IfNode ifNode = (IfNode) lastInsn;
IfCondition condition = IfCondition.fromIfNode(ifNode);
IfInfo info = new IfInfo(condition, ifNode.getThenBlock(), ifNode.getElseBlock());
info.setIfBlock(ifBlock);
......@@ -48,8 +52,11 @@ public class IfMakerHelper {
}
static IfInfo searchNestedIf(IfInfo info) {
IfInfo tmp = mergeNestedIfNodes(info);
return tmp != null ? tmp : info;
IfInfo next = mergeNestedIfNodes(info);
if (next != null) {
return next;
}
return info;
}
static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) {
......@@ -160,12 +167,24 @@ public class IfMakerHelper {
return null;
}
}
boolean assignInlineNeeded = !nextIf.getForceInlineInsns().isEmpty();
if (assignInlineNeeded) {
for (BlockNode mergedBlock : currentIf.getMergedBlocks()) {
if (mergedBlock.contains(AFlag.LOOP_START)) {
// don't inline assigns into loop condition
return currentIf;
}
}
}
if (isInversionNeeded(currentIf, nextIf)) {
// invert current node for match pattern
nextIf = IfInfo.invert(nextIf);
}
if (!isEqualPaths(curThen, nextIf.getThenBlock())
&& !isEqualPaths(curElse, nextIf.getElseBlock())) {
boolean thenPathSame = isEqualPaths(curThen, nextIf.getThenBlock());
boolean elsePathSame = isEqualPaths(curElse, nextIf.getElseBlock());
if (!thenPathSame && !elsePathSame) {
// complex condition, run additional checks
if (checkConditionBranches(curThen, curElse)
|| checkConditionBranches(curElse, curThen)) {
......@@ -191,6 +210,15 @@ public class IfMakerHelper {
} else {
return currentIf;
}
} else {
if (assignInlineNeeded) {
boolean sameOuts = (thenPathSame && !followThenBranch) || (elsePathSame && followThenBranch);
if (!sameOuts) {
// don't inline assigns inside simple condition
currentIf.resetForceInlineInsns();
return currentIf;
}
}
}
IfInfo result = mergeIfInfo(currentIf, nextIf, followThenBranch);
......@@ -315,36 +343,32 @@ public class IfMakerHelper {
}
info.getSkipBlocks().clear();
}
for (InsnNode forceInlineInsn : info.getForceInlineInsns()) {
forceInlineInsn.add(AFlag.FORCE_ASSIGN_INLINE);
}
}
private static IfInfo getNextIf(IfInfo info, BlockNode block) {
if (!canSelectNext(info, block)) {
return null;
}
BlockNode nestedIfBlock = getNextIfNode(block);
if (nestedIfBlock != null) {
return makeIfInfo(nestedIfBlock);
}
return null;
return getNextIfNodeInfo(info, block);
}
private static boolean canSelectNext(IfInfo info, BlockNode block) {
if (block.getPredecessors().size() == 1) {
return true;
}
if (info.getMergedBlocks().containsAll(block.getPredecessors())) {
return true;
}
return false;
return info.getMergedBlocks().containsAll(block.getPredecessors());
}
private static BlockNode getNextIfNode(BlockNode block) {
private static IfInfo getNextIfNodeInfo(IfInfo info, BlockNode block) {
if (block == null || block.contains(AType.LOOP) || block.contains(AFlag.ADDED_TO_REGION)) {
return null;
}
InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn != null && lastInsn.getType() == InsnType.IF) {
return block;
return makeIfInfo(block);
}
// skip this block and search in successors chain
List<BlockNode> successors = block.getSuccessors();
......@@ -358,6 +382,7 @@ public class IfMakerHelper {
}
List<InsnNode> insns = block.getInstructions();
boolean pass = true;
List<InsnNode> forceInlineInsns = new ArrayList<>();
if (!insns.isEmpty()) {
// check that all instructions can be inlined
for (InsnNode insn : insns) {
......@@ -367,7 +392,9 @@ public class IfMakerHelper {
break;
}
List<RegisterArg> useList = res.getSVar().getUseList();
if (useList.size() != 1) {
int useCount = useList.size();
if (useCount == 0) {
// TODO?
pass = false;
break;
}
......@@ -378,12 +405,20 @@ public class IfMakerHelper {
pass = false;
break;
}
if (useCount > 1) {
forceInlineInsns.add(insn);
}
}
}
if (pass) {
return getNextIfNode(next);
if (!pass) {
return null;
}
return null;
IfInfo nextInfo = makeIfInfo(next);
if (nextInfo == null) {
return getNextIfNodeInfo(info, next);
}
nextInfo.addInsnsForForcedInline(forceInlineInsns);
return nextInfo;
}
private static void skipSimplePath(BlockNode block, Set<BlockNode> skipped) {
......
......@@ -655,6 +655,9 @@ public class RegionMaker {
}
IfInfo currentIf = makeIfInfo(block);
if (currentIf == null) {
return null;
}
IfInfo mergedIf = mergeNestedIfNodes(currentIf);
if (mergedIf != null) {
currentIf = mergedIf;
......
......@@ -20,6 +20,7 @@ import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.exceptions.JadxRuntimeException;
@JadxVisitor(
......@@ -78,12 +79,15 @@ public class CodeShrinkVisitor extends AbstractVisitor {
if (sVar == null || sVar.getAssign().contains(AFlag.DONT_INLINE)) {
return;
}
// allow inline only one use arg
if (sVar.getVariableUseCount() != 1) {
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null
|| assignInsn.contains(AFlag.DONT_INLINE)
|| assignInsn.contains(AFlag.WRAPPED)) {
return;
}
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null || assignInsn.contains(AFlag.DONT_INLINE)) {
// allow inline only one use arg
boolean assignInline = assignInsn.contains(AFlag.FORCE_ASSIGN_INLINE);
if (!assignInline && sVar.getVariableUseCount() != 1) {
return;
}
List<RegisterArg> useList = sVar.getUseList();
......@@ -96,6 +100,10 @@ public class CodeShrinkVisitor extends AbstractVisitor {
int assignPos = insnList.getIndex(assignInsn);
if (assignPos != -1) {
if (assignInline) {
// TODO?
return;
}
WrapInfo wrapInfo = argsInfo.checkInline(assignPos, arg);
if (wrapInfo != null) {
wrapList.add(wrapInfo);
......@@ -106,11 +114,30 @@ public class CodeShrinkVisitor extends AbstractVisitor {
if (assignBlock != null
&& assignInsn != arg.getParentInsn()
&& canMoveBetweenBlocks(assignInsn, assignBlock, block, argsInfo.getInsn())) {
inline(mth, arg, assignInsn, assignBlock);
if (assignInline) {
assignInline(mth, arg, assignInsn, assignBlock);
} else {
inline(mth, arg, assignInsn, assignBlock);
}
}
}
}
private static void assignInline(MethodNode mth, RegisterArg arg, InsnNode assignInsn, BlockNode assignBlock) {
RegisterArg useArg = arg.getSVar().getUseList().get(0);
InsnNode useInsn = useArg.getParentInsn();
if (useInsn == null || useInsn.contains(AFlag.DONT_GENERATE)) {
return;
}
InsnArg replaceArg = InsnArg.wrapArg(assignInsn.copy());
useInsn.replaceArg(useArg, replaceArg);
assignInsn.add(AFlag.REMOVE);
assignInsn.add(AFlag.DONT_GENERATE);
InsnRemover.remove(mth, assignBlock, assignInsn);
}
private static boolean inline(MethodNode mth, RegisterArg arg, InsnNode insn, BlockNode block) {
InsnNode parentInsn = arg.getParentInsn();
if (parentInsn != null && parentInsn.getType() == InsnType.RETURN) {
......
......@@ -2,7 +2,6 @@ package jadx.tests.integration.conditions;
import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
......@@ -41,13 +40,13 @@ public class TestConditions19 extends IntegrationTest {
}
@Test
@NotYetImplemented("Inner assignment or labeled block with break")
public void test() {
noDebugInfo();
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("str.length()"));
assertThat(code, containsOne("System.out.println(\"done\");"));
}
}
......@@ -11,27 +11,7 @@ import static org.hamcrest.MatcherAssert.assertThat;
public class TestLoopCondition extends IntegrationTest {
public static class TestCls {
public String f;
private void setEnabled(boolean r1z) {
}
public void testIfInLoop() {
int j = 0;
for (int i = 0; i < f.length(); i++) {
char ch = f.charAt(i);
if (ch == '/') {
j++;
if (j == 2) {
setEnabled(true);
return;
}
}
}
setEnabled(false);
}
public void testMoreComplexIfInLoop(java.util.ArrayList<String> list) throws Exception {
public void test(java.util.ArrayList<String> list) {
for (int i = 0; i != 16 && i < 255; i++) {
list.set(i, "ABC");
if (i == 128) {
......@@ -47,12 +27,7 @@ public class TestLoopCondition extends IntegrationTest {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("i < this.f.length()"));
assertThat(code, containsOne("list.set(i, \"ABC\")"));
assertThat(code, containsOne("list.set(i, \"DEF\")"));
assertThat(code, containsOne("if (j == 2) {"));
assertThat(code, containsOne("setEnabled(true);"));
assertThat(code, containsOne("setEnabled(false);"));
}
}
......@@ -18,12 +18,12 @@ public class TestVariables5 extends IntegrationTest {
private boolean enabled;
private void testIfInLoop() {
int j = 0;
for (int i = 0; i < f.length(); i++) {
char ch = f.charAt(i);
int i = 0;
for (int i2 = 0; i2 < f.length(); i2++) {
char ch = f.charAt(i2);
if (ch == '/') {
j++;
if (j == 2) {
i++;
if (i == 2) {
setEnabled(true);
return;
}
......@@ -51,6 +51,8 @@ public class TestVariables5 extends IntegrationTest {
assertThat(code, not(containsString("int i2++;")));
assertThat(code, containsOne("int i = 0;"));
assertThat(code, containsOne("i++;"));
assertThat(code, containsOne("&& (i = i + 1) == 2"));
// assertThat(code, containsOne("i++;"));
// assertThat(code, containsOne("if (i == 2) {"));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册