diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java index 7a0e652ba64277f6d6d54c5a3292c96686295966..65adbdda78be9f238d09bb54a8c7b54a4c232b4c 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/RegionMaker.java @@ -663,7 +663,8 @@ public class RegionMaker { Map fallThroughCases = new LinkedHashMap(); - BitSet outs = new BitSet(mth.getBasicBlocks().size()); + List basicBlocks = mth.getBasicBlocks(); + BitSet outs = new BitSet(basicBlocks.size()); outs.or(block.getDomFrontier()); for (BlockNode s : block.getCleanSuccessors()) { BitSet df = s.getDomFrontier(); @@ -672,8 +673,8 @@ public class RegionMaker { if (df.cardinality() > 2) { LOG.debug("Unexpected case pattern, block: {}, mth: {}", s, mth); } else { - BlockNode first = mth.getBasicBlocks().get(df.nextSetBit(0)); - BlockNode second = mth.getBasicBlocks().get(df.nextSetBit(first.getId() + 1)); + BlockNode first = basicBlocks.get(df.nextSetBit(0)); + BlockNode second = basicBlocks.get(df.nextSetBit(first.getId() + 1)); if (second.getDomFrontier().get(first.getId())) { fallThroughCases.put(s, second); df = new BitSet(df.size()); @@ -687,6 +688,11 @@ public class RegionMaker { } outs.or(df); } + outs.clear(block.getId()); + if (loop != null) { + outs.clear(loop.getStart().getId()); + } + stack.push(sw); stack.addExits(BlockUtils.bitSetToBlocks(mth, outs)); @@ -709,9 +715,8 @@ public class RegionMaker { } if (outs.cardinality() > 1) { // filter loop start and successors of other blocks - List blocks = mth.getBasicBlocks(); for (int i = outs.nextSetBit(0); i >= 0; i = outs.nextSetBit(i + 1)) { - BlockNode b = blocks.get(i); + BlockNode b = basicBlocks.get(i); outs.andNot(b.getDomFrontier()); if (b.contains(AFlag.LOOP_START)) { outs.clear(b.getId()); @@ -745,7 +750,7 @@ public class RegionMaker { } BlockNode out = null; if (outs.cardinality() == 1) { - out = mth.getBasicBlocks().get(outs.nextSetBit(0)); + out = basicBlocks.get(outs.nextSetBit(0)); stack.addExit(out); } else if (loop == null && outs.cardinality() > 1) { LOG.warn("Can't detect out node for switch block: {} in {}", block, mth); diff --git a/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchInLoop.java b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchInLoop.java new file mode 100644 index 0000000000000000000000000000000000000000..c0a53abb4ee9520cdd47f6e235881e1950961dbf --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/switches/TestSwitchInLoop.java @@ -0,0 +1,44 @@ +package jadx.tests.integration.switches; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; + +public class TestSwitchInLoop extends IntegrationTest { + public static class TestCls { + public int test(int k) { + int a = 0; + while (true) { + switch (k) { + case 0: + return a; + default: + a++; + k >>= 1; + } + } + } + + public void check() { + assertEquals(1, test(1)); + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("switch (k) {")); + assertThat(code, containsOne("case 0:")); + assertThat(code, containsOne("return a;")); + assertThat(code, containsOne("default:")); + assertThat(code, containsOne("a++;")); + assertThat(code, containsOne("k >>= 1;")); + } +}