提交 bae36f97 编写于 作者: S Skylot

fix: merge const block before return (#699)

上级 11db454b
...@@ -151,6 +151,13 @@ public class RegisterArg extends InsnArg implements Named { ...@@ -151,6 +151,13 @@ public class RegisterArg extends InsnArg implements Named {
&& Objects.equals(sVar, reg.getSVar()); && Objects.equals(sVar, reg.getSVar());
} }
public boolean sameReg(InsnArg arg) {
if (!arg.isRegister()) {
return false;
}
return regNum == ((RegisterArg) arg).getRegNum();
}
public boolean sameCodeVar(RegisterArg arg) { public boolean sameCodeVar(RegisterArg arg) {
return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar(); return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar();
} }
......
...@@ -16,6 +16,7 @@ import jadx.core.dex.attributes.AFlag; ...@@ -16,6 +16,7 @@ import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType; import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo; import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.InsnType; import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg; import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg; import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.instructions.args.RegisterArg;
...@@ -29,6 +30,7 @@ import jadx.core.dex.trycatch.ExceptionHandler; ...@@ -29,6 +30,7 @@ import jadx.core.dex.trycatch.ExceptionHandler;
import jadx.core.dex.trycatch.TryCatchBlock; import jadx.core.dex.trycatch.TryCatchBlock;
import jadx.core.dex.visitors.AbstractVisitor; import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.BlockUtils; import jadx.core.utils.BlockUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException; import jadx.core.utils.exceptions.JadxRuntimeException;
import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect; import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect;
...@@ -413,7 +415,48 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -413,7 +415,48 @@ public class BlockProcessor extends AbstractVisitor {
return true; return true;
} }
} }
return splitReturn(mth); if (mergeConstReturn(mth)) {
return true;
}
return splitReturnBlocks(mth);
}
private static boolean mergeConstReturn(MethodNode mth) {
if (mth.getReturnType() == ArgType.VOID) {
return false;
}
boolean changed = false;
for (BlockNode exitBlock : new ArrayList<>(mth.getExitBlocks())) {
BlockNode pred = Utils.getOne(exitBlock.getPredecessors());
if (pred != null) {
InsnNode constInsn = Utils.getOne(pred.getInstructions());
if (constInsn != null && constInsn.isConstInsn()) {
RegisterArg constArg = constInsn.getResult();
InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock);
if (returnInsn != null) {
InsnArg retArg = returnInsn.getArg(0);
if (constArg.sameReg(retArg)) {
mergeConstAndReturnBlocks(mth, exitBlock, pred);
changed = true;
}
}
}
}
}
if (changed) {
removeMarkedBlocks(mth);
cleanExitNodes(mth);
}
return changed;
}
private static void mergeConstAndReturnBlocks(MethodNode mth, BlockNode exitBlock, BlockNode pred) {
pred.getInstructions().addAll(exitBlock.getInstructions());
pred.copyAttributesFrom(exitBlock);
BlockSplitter.removeConnection(pred, exitBlock);
exitBlock.getInstructions().clear();
exitBlock.add(AFlag.REMOVE);
} }
private static boolean independentBlockTreeMod(MethodNode mth) { private static boolean independentBlockTreeMod(MethodNode mth) {
...@@ -604,16 +647,25 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -604,16 +647,25 @@ public class BlockProcessor extends AbstractVisitor {
return true; return true;
} }
private static boolean splitReturnBlocks(MethodNode mth) {
boolean changed = false;
for (BlockNode exitBlock : mth.getExitBlocks()) {
if (splitReturn(mth, exitBlock)) {
changed = true;
}
}
if (changed) {
cleanExitNodes(mth);
}
return changed;
}
/** /**
* Splice return block if several predecessors presents * Splice return block if several predecessors presents
*/ */
private static boolean splitReturn(MethodNode mth) { private static boolean splitReturn(MethodNode mth, BlockNode exitBlock) {
if (mth.getExitBlocks().size() != 1) { if (exitBlock.contains(AFlag.SYNTHETIC)
return false; || exitBlock.contains(AFlag.ORIG_RETURN)
}
BlockNode exitBlock = mth.getExitBlocks().get(0);
if (exitBlock.getInstructions().size() != 1
|| exitBlock.contains(AFlag.SYNTHETIC)
|| exitBlock.contains(AType.SPLITTER_BLOCK)) { || exitBlock.contains(AType.SPLITTER_BLOCK)) {
return false; return false;
} }
...@@ -625,31 +677,38 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -625,31 +677,38 @@ public class BlockProcessor extends AbstractVisitor {
if (preds.size() < 2) { if (preds.size() < 2) {
return false; return false;
} }
InsnNode returnInsn = exitBlock.getInstructions().get(0); InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock);
if (returnInsn.getArgsCount() != 0 && !isReturnArgAssignInPred(preds, returnInsn)) { if (returnInsn == null) {
return false;
}
if (returnInsn.getArgsCount() == 1
&& exitBlock.getInstructions().size() == 1
&& !isReturnArgAssignInPred(preds, returnInsn)) {
return false; return false;
} }
boolean first = true; boolean first = true;
for (BlockNode pred : preds) { for (BlockNode pred : preds) {
BlockNode newRetBlock = BlockSplitter.startNewBlock(mth, -1); BlockNode newRetBlock = BlockSplitter.startNewBlock(mth, -1);
newRetBlock.add(AFlag.SYNTHETIC); newRetBlock.add(AFlag.SYNTHETIC);
InsnNode newRetInsn;
if (first) { if (first) {
newRetInsn = returnInsn;
newRetBlock.add(AFlag.ORIG_RETURN); newRetBlock.add(AFlag.ORIG_RETURN);
newRetBlock.getInstructions().addAll(exitBlock.getInstructions());
first = false; first = false;
} else { } else {
newRetInsn = duplicateReturnInsn(returnInsn); for (InsnNode oldInsn : exitBlock.getInstructions()) {
newRetBlock.getInstructions().add(oldInsn.copy());
}
} }
newRetBlock.getInstructions().add(newRetInsn);
BlockSplitter.replaceConnection(pred, exitBlock, newRetBlock); BlockSplitter.replaceConnection(pred, exitBlock, newRetBlock);
} }
cleanExitNodes(mth);
return true; return true;
} }
private static boolean isReturnArgAssignInPred(List<BlockNode> preds, InsnNode returnInsn) { private static boolean isReturnArgAssignInPred(List<BlockNode> preds, InsnNode returnInsn) {
RegisterArg arg = (RegisterArg) returnInsn.getArg(0); InsnArg retArg = returnInsn.getArg(0);
if (retArg.isRegister()) {
RegisterArg arg = (RegisterArg) retArg;
int regNum = arg.getRegNum(); int regNum = arg.getRegNum();
for (BlockNode pred : preds) { for (BlockNode pred : preds) {
for (InsnNode insnNode : pred.getInstructions()) { for (InsnNode insnNode : pred.getInstructions()) {
...@@ -659,6 +718,7 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -659,6 +718,7 @@ public class BlockProcessor extends AbstractVisitor {
} }
} }
} }
}
return false; return false;
} }
...@@ -673,18 +733,6 @@ public class BlockProcessor extends AbstractVisitor { ...@@ -673,18 +733,6 @@ public class BlockProcessor extends AbstractVisitor {
} }
} }
private static InsnNode duplicateReturnInsn(InsnNode returnInsn) {
InsnNode insn = new InsnNode(returnInsn.getType(), returnInsn.getArgsCount());
if (returnInsn.getArgsCount() == 1) {
RegisterArg arg = (RegisterArg) returnInsn.getArg(0);
insn.addArg(arg.duplicate());
}
insn.copyAttributesFrom(returnInsn);
insn.setOffset(returnInsn.getOffset());
insn.setSourceLine(returnInsn.getSourceLine());
return insn;
}
private static void removeMarkedBlocks(MethodNode mth) { private static void removeMarkedBlocks(MethodNode mth) {
mth.getBasicBlocks().removeIf(block -> { mth.getBasicBlocks().removeIf(block -> {
if (block.contains(AFlag.REMOVE)) { if (block.contains(AFlag.REMOVE)) {
......
...@@ -204,6 +204,14 @@ public class Utils { ...@@ -204,6 +204,14 @@ public class Utils {
return Collections.unmodifiableMap(result); return Collections.unmodifiableMap(result);
} }
@Nullable
public static <T> T getOne(@Nullable List<T> list) {
if (list == null || list.size() != 1) {
return null;
}
return list.get(0);
}
@Nullable @Nullable
public static <T> T last(List<T> list) { public static <T> T last(List<T> list) {
if (list.isEmpty()) { if (list.isEmpty()) {
......
...@@ -2,9 +2,11 @@ package jadx.tests.integration.conditions; ...@@ -2,9 +2,11 @@ package jadx.tests.integration.conditions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest; import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsLines;
import static jadx.tests.api.utils.JadxMatchers.containsOne; import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
...@@ -31,7 +33,20 @@ public class TestConditions18 extends SmaliTest { ...@@ -31,7 +33,20 @@ public class TestConditions18 extends SmaliTest {
ClassNode cls = getClassNodeFromSmali(); ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, containsOne("return this == obj" assertThat(code, containsLines(2,
+ " || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));")); "if (this != obj) {",
indent() + "return (obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map);",
"}",
"return true;"));
}
@Test
@NotYetImplemented
public void testNYI() {
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();
assertThat(code,
containsOne("return this == obj || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));"));
} }
} }
package jadx.tests.integration.conditions;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
public class TestConditions21 extends SmaliTest {
// @formatter:off
/*
public boolean check(Object obj) {
if (this == obj) {
return true;
}
if (obj instanceof List) {
List list = (List) obj;
if (!list.isEmpty() && list.contains(this)) {
return true;
}
}
return false;
}
*/
// @formatter:on
@Test
public void test() {
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();
assertThat(code, containsOne("!list.isEmpty() && list.contains(this)"));
}
}
...@@ -2,17 +2,19 @@ package jadx.tests.integration.conditions; ...@@ -2,17 +2,19 @@ package jadx.tests.integration.conditions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest; import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsLines; import static jadx.tests.api.utils.JadxMatchers.containsLines;
import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
public class TestTernaryInIf2 extends SmaliTest { public class TestTernaryInIf2 extends SmaliTest {
public static class TestCls { public static class TestCls {
private String a; private String a = "a";
private String b; private String b = "b";
public boolean equals(TestCls other) { public boolean equals(TestCls other) {
if (this.a == null ? other.a == null : this.a.equals(other.a)) { if (this.a == null ? other.a == null : this.a.equals(other.a)) {
...@@ -22,6 +24,22 @@ public class TestTernaryInIf2 extends SmaliTest { ...@@ -22,6 +24,22 @@ public class TestTernaryInIf2 extends SmaliTest {
} }
return false; return false;
} }
public void check() {
TestCls other = new TestCls();
other.a = "a";
other.b = "b";
assertThat(this.equals(other), is(true));
other.b = "not-b";
assertThat(this.equals(other), is(false));
other.b = null;
assertThat(this.equals(other), is(false));
this.b = null;
assertThat(this.equals(other), is(true));
}
} }
@Test @Test
...@@ -30,9 +48,20 @@ public class TestTernaryInIf2 extends SmaliTest { ...@@ -30,9 +48,20 @@ public class TestTernaryInIf2 extends SmaliTest {
String code = cls.getCode().toString(); String code = cls.getCode().toString();
assertThat(code, containsLines(2, "if (this.a != null ? this.a.equals(other.a) : other.a == null) {")); assertThat(code, containsLines(2, "if (this.a != null ? this.a.equals(other.a) : other.a == null) {"));
assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null) {")); // assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null)
assertThat(code, containsLines(4, "return true;")); // {"));
assertThat(code, containsLines(2, "return false;")); // assertThat(code, containsLines(4, "return true;"));
// assertThat(code, containsLines(2, "return false;"));
}
@Test
@NotYetImplemented
public void testNYI() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsLines(2, "return (this.a != null ? this.a.equals(other.a) : other.a == null) "
+ "&& (this.b == null ? other.b == null : this.b.equals(other.b));"));
} }
@Test @Test
......
.class public final Lconditions/TestConditions21;
.super Ljava/lang/Object;
.method public check(Ljava/lang/Object;)Z
.locals 2
if-eq p0, p1, :ret_true
instance-of v0, p1, Ljava/util/List;
if-eqz v0, :ret_false
check-cast p1, Ljava/util/List;
invoke-interface {p1}, Ljava/util/List;->isEmpty()Z
move-result v0
if-nez v0, :ret_false
invoke-interface {p1, p0}, Ljava/util/List;->contains(Ljava/lang/Object;)Z
move-result v0
if-eqz v0, :ret_false
goto :ret_true
:ret_false
const/4 p1, 0x0
return p1
:ret_true
const/4 p1, 0x1
return p1
.end method
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册