提交 37857e88 编写于 作者: S Skylot

core: fix switch statement processing (issue #9 case 2)

上级 6fbcf46a
...@@ -199,7 +199,7 @@ public class RegionGen extends InsnGen { ...@@ -199,7 +199,7 @@ public class RegionGen extends InsnGen {
SwitchNode insn = (SwitchNode) sw.getHeader().getInstructions().get(0); SwitchNode insn = (SwitchNode) sw.getHeader().getInstructions().get(0);
InsnArg arg = insn.getArg(0); InsnArg arg = insn.getArg(0);
code.startLine("switch ("); code.startLine("switch (");
addArg(code, arg); addArg(code, arg, false);
code.add(") {"); code.add(") {");
code.incIndent(); code.incIndent();
......
...@@ -627,7 +627,7 @@ public class RegionMaker { ...@@ -627,7 +627,7 @@ public class RegionMaker {
BitSet succ = BlockUtils.blocksToBitSet(mth, block.getSuccessors()); BitSet succ = BlockUtils.blocksToBitSet(mth, block.getSuccessors());
BitSet domsOn = BlockUtils.blocksToBitSet(mth, block.getDominatesOn()); BitSet domsOn = BlockUtils.blocksToBitSet(mth, block.getDominatesOn());
domsOn.and(succ); // filter 'out' block domsOn.xor(succ); // filter 'out' block
BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors()); BlockNode defCase = getBlockByOffset(insn.getDefaultCaseOffset(), block.getSuccessors());
if (defCase != null) { if (defCase != null) {
...@@ -642,13 +642,11 @@ public class RegionMaker { ...@@ -642,13 +642,11 @@ public class RegionMaker {
} }
if (outCount > 1) { if (outCount > 1) {
// filter successors of other blocks // filter successors of other blocks
List<BlockNode> blocks = mth.getBasicBlocks();
for (int i = domsOn.nextSetBit(0); i >= 0; i = domsOn.nextSetBit(i + 1)) { for (int i = domsOn.nextSetBit(0); i >= 0; i = domsOn.nextSetBit(i + 1)) {
BlockNode b = mth.getBasicBlocks().get(i); BlockNode b = blocks.get(i);
for (BlockNode s : b.getCleanSuccessors()) { for (BlockNode s : b.getCleanSuccessors()) {
int id = s.getId(); domsOn.clear(s.getId());
if (domsOn.get(id)) {
domsOn.clear(id);
}
} }
} }
outCount = domsOn.cardinality(); outCount = domsOn.cardinality();
...@@ -658,19 +656,27 @@ public class RegionMaker { ...@@ -658,19 +656,27 @@ public class RegionMaker {
if (outCount == 1) { if (outCount == 1) {
out = mth.getBasicBlocks().get(domsOn.nextSetBit(0)); out = mth.getBasicBlocks().get(domsOn.nextSetBit(0));
} else if (outCount == 0) { } else if (outCount == 0) {
// default and out blocks are same // one or several case blocks are empty,
out = defCase; // run expensive algorithm for find 'out' block
for (BlockNode maybeOut : block.getSuccessors()) {
boolean allReached = true;
for (BlockNode s : block.getSuccessors()) {
if (!BlockUtils.isPathExists(s, maybeOut)) {
allReached = false;
break;
}
}
if (allReached) {
out = maybeOut;
break;
}
}
} }
stack.push(sw); stack.push(sw);
if (out != null) { if (out != null) {
stack.addExit(out); stack.addExit(out);
} }
// else {
// for (BlockNode e : BlockUtils.bitSetToBlocks(mth, domsOn)) {
// stack.addExit(e);
// }
// }
if (!stack.containsExit(defCase)) { if (!stack.containsExit(defCase)) {
sw.setDefaultCase(makeRegion(defCase, stack)); sw.setDefaultCase(makeRegion(defCase, stack));
......
...@@ -203,10 +203,7 @@ public class BlockUtils { ...@@ -203,10 +203,7 @@ public class BlockUtils {
} }
public static boolean isPathExists(BlockNode start, BlockNode end) { public static boolean isPathExists(BlockNode start, BlockNode end) {
if (start == end) { if (start == end || end.isDominator(start)) {
return true;
}
if (end.isDominator(start)) {
return true; return true;
} }
return traverseSuccessorsUntil(start, end, new BitSet()); return traverseSuccessorsUntil(start, end, new BitSet());
......
package jadx.tests.internal; package jadx.tests.internal.switches;
import jadx.api.InternalJadxTest; import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
...@@ -6,6 +6,7 @@ import jadx.core.dex.nodes.ClassNode; ...@@ -6,6 +6,7 @@ import jadx.core.dex.nodes.ClassNode;
import org.junit.Test; import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat; import static org.junit.Assert.assertThat;
public class TestSwitch extends InternalJadxTest { public class TestSwitch extends InternalJadxTest {
...@@ -45,5 +46,9 @@ public class TestSwitch extends InternalJadxTest { ...@@ -45,5 +46,9 @@ public class TestSwitch extends InternalJadxTest {
assertThat(code, containsString("case '/':")); assertThat(code, containsString("case '/':"));
assertThat(code, containsString(indent(5) + "break;")); assertThat(code, containsString(indent(5) + "break;"));
assertThat(code, containsString(indent(4) + "default:"));
assertEquals(1, count(code, "i++"));
assertEquals(4, count(code, "break;"));
} }
} }
package jadx.tests.internal; package jadx.tests.internal.switches;
import jadx.api.InternalJadxTest; import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode; import jadx.core.dex.nodes.ClassNode;
......
package jadx.tests.internal.switches;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
public class TestSwitchNoDefault extends InternalJadxTest {
public static class TestCls {
public void test(int a) {
String s = null;
switch (a) {
case 1:
s = "1";
break;
case 2:
s = "2";
break;
case 3:
s = "3";
break;
case 4:
s = "4";
break;
}
System.out.println(s);
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertEquals(4, count(code, "break;"));
assertEquals(1, count(code, "System.out.println(s);"));
}
}
package jadx.tests.internal.switches;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
public class TestSwitchSimple extends InternalJadxTest {
public static class TestCls {
public void test(int a) {
String s = null;
switch (a % 4) {
case 1:
s = "1";
break;
case 2:
s = "2";
break;
case 3:
s = "3";
break;
case 4:
s = "4";
break;
default:
System.out.println("Not Reach");
break;
}
System.out.println(s);
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
System.out.println(code);
assertEquals(5, count(code, "break;"));
assertEquals(1, count(code, "System.out.println(s);"));
assertEquals(1, count(code, "System.out.println(\"Not Reach\");"));
assertThat(code, not(containsString("switch ((a % 4)) {")));
assertThat(code, containsString("switch (a % 4) {"));
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册