未验证 提交 8969d11a 编写于 作者: S Skylot

fix: restore fields order on init code move (#678)

上级 eedf32d1
......@@ -6,6 +6,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jetbrains.annotations.Nullable;
......@@ -304,6 +305,28 @@ public class InsnNode extends LineAttrNode {
}
}
/**
* Visit this instruction and all inner (wrapped) instructions
* To terminate visiting return non-null value
*/
@Nullable
public <R> R visitInsns(Function<InsnNode, R> visitor) {
R result = visitor.apply(this);
if (result != null) {
return result;
}
for (InsnArg arg : this.getArguments()) {
if (arg.isInsnWrap()) {
InsnNode innerInsn = ((InsnWrapArg) arg).getWrapInsn();
R res = innerInsn.visitInsns(visitor);
if (res != null) {
return res;
}
}
}
return null;
}
/**
* 'Soft' equals, don't compare arguments, only instruction specific parameters.
*/
......
......@@ -4,7 +4,9 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import jadx.api.plugins.input.data.attributes.JadxAttrType;
import jadx.core.dex.attributes.AFlag;
......@@ -25,6 +27,7 @@ import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.visitors.shrink.CodeShrinkVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
@JadxVisitor(
......@@ -40,193 +43,191 @@ public class ExtractFieldInit extends AbstractVisitor {
for (ClassNode inner : cls.getInnerClasses()) {
visit(inner);
}
checkStaticFieldsInit(cls);
moveStaticFieldsInit(cls);
moveCommonFieldsInit(cls);
return false;
}
private static void checkStaticFieldsInit(ClassNode cls) {
MethodNode clinit = cls.getClassInitMth();
if (clinit == null
|| !clinit.getAccessFlags().isStatic()
|| clinit.isNoCode()
|| clinit.getBasicBlocks() == null) {
return;
}
private static final class FieldInitInfo {
final FieldNode fieldNode;
final IndexInsnNode putInsn;
final boolean singlePath;
for (BlockNode block : clinit.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == InsnType.SPUT) {
processStaticFieldAssign(cls, (IndexInsnNode) insn);
}
}
public FieldInitInfo(FieldNode fieldNode, IndexInsnNode putInsn, boolean singlePath) {
this.fieldNode = fieldNode;
this.putInsn = putInsn;
this.singlePath = singlePath;
}
}
/**
* Remove a final field in place initialization if it an assign found in class init method
*/
private static void processStaticFieldAssign(ClassNode cls, IndexInsnNode insn) {
FieldInfo field = (FieldInfo) insn.getIndex();
if (field.getDeclClass().equals(cls.getClassInfo())) {
FieldNode fn = cls.searchField(field);
if (fn != null && fn.getAccessFlags().isFinal()) {
fn.remove(JadxAttrType.CONSTANT_VALUE);
}
private static final class ConstructorInitInfo {
final MethodNode constructorMth;
final List<FieldInitInfo> fieldInits;
private ConstructorInitInfo(MethodNode constructorMth, List<FieldInitInfo> fieldInits) {
this.constructorMth = constructorMth;
this.fieldInits = fieldInits;
}
}
private static void moveStaticFieldsInit(ClassNode cls) {
MethodNode classInitMth = cls.getClassInitMth();
if (classInitMth == null) {
if (classInitMth == null
|| !classInitMth.getAccessFlags().isStatic()
|| classInitMth.isNoCode()
|| classInitMth.getBasicBlocks() == null) {
return;
}
while (processFields(cls, classInitMth)) {
while (processStaticFields(cls, classInitMth)) {
// sometimes instructions moved to field init prevent from vars inline -> inline and try again
CodeShrinkVisitor.shrinkMethod(classInitMth);
}
}
private static boolean processFields(ClassNode cls, MethodNode classInitMth) {
boolean changed = false;
for (FieldNode field : cls.getFields()) {
if (field.contains(AFlag.DONT_GENERATE) || field.contains(AType.FIELD_INIT_INSN)) {
continue;
}
if (field.getAccessFlags().isStatic()) {
List<InsnNode> initInsns = getFieldAssigns(classInitMth, field, InsnType.SPUT);
if (initInsns.size() == 1) {
InsnNode insn = initInsns.get(0);
if (checkInsn(cls, insn)) {
InsnArg arg = insn.getArg(0);
if (arg instanceof InsnWrapArg) {
((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR);
}
InsnRemover.remove(classInitMth, insn);
addFieldInitAttr(classInitMth, field, insn);
changed = true;
}
}
}
private static boolean processStaticFields(ClassNode cls, MethodNode classInitMth) {
List<FieldInitInfo> inits = collectFieldsInit(cls, classInitMth, InsnType.SPUT);
if (inits.isEmpty()) {
return false;
}
return changed;
}
private static class InitInfo {
private final MethodNode constrMth;
private final List<InsnNode> putInsns = new ArrayList<>();
private InitInfo(MethodNode constrMth) {
this.constrMth = constrMth;
// ignore field init constant if field initialized in class init method
for (FieldInitInfo fieldInit : inits) {
FieldNode field = fieldInit.fieldNode;
if (field.getAccessFlags().isFinal()) {
field.remove(JadxAttrType.CONSTANT_VALUE);
}
}
public MethodNode getConstrMth() {
return constrMth;
filterFieldsInit(inits);
if (inits.isEmpty()) {
return false;
}
public List<InsnNode> getPutInsns() {
return putInsns;
for (FieldInitInfo fieldInit : inits) {
IndexInsnNode insn = fieldInit.putInsn;
InsnArg arg = insn.getArg(0);
if (arg instanceof InsnWrapArg) {
((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR);
}
InsnRemover.remove(classInitMth, insn);
addFieldInitAttr(classInitMth, fieldInit.fieldNode, insn);
}
fixFieldsOrder(cls, inits);
return true;
}
private static void moveCommonFieldsInit(ClassNode cls) {
List<MethodNode> constrList = getConstructorsList(cls);
if (constrList.isEmpty()) {
List<MethodNode> constructors = getConstructorsList(cls);
if (constructors.isEmpty()) {
return;
}
List<InitInfo> infoList = new ArrayList<>(constrList.size());
for (MethodNode constrMth : constrList) {
if (constrMth.isNoCode()) {
List<ConstructorInitInfo> infoList = new ArrayList<>(constructors.size());
for (MethodNode constructorMth : constructors) {
if (constructorMth.isNoCode()) {
return;
}
List<BlockNode> enterBlocks = constrMth.getEnterBlock().getCleanSuccessors();
if (enterBlocks.isEmpty()) {
List<FieldInitInfo> inits = collectFieldsInit(cls, constructorMth, InsnType.IPUT);
filterFieldsInit(inits);
if (inits.isEmpty()) {
return;
}
InitInfo info = new InitInfo(constrMth);
infoList.add(info);
// TODO: check not only first block
BlockNode blockNode = enterBlocks.get(0);
for (InsnNode insn : blockNode.getInstructions()) {
if (insn.getType() == InsnType.IPUT && checkInsn(cls, insn)) {
info.getPutInsns().add(insn);
} else if (!info.getPutInsns().isEmpty()) {
break;
}
}
infoList.add(new ConstructorInitInfo(constructorMth, inits));
}
// compare collected instructions
InitInfo common = null;
for (InitInfo info : infoList) {
ConstructorInitInfo common = null;
for (ConstructorInitInfo info : infoList) {
if (common == null) {
common = info;
} else if (!compareInsns(common.getPutInsns(), info.getPutInsns())) {
continue;
}
if (!compareFieldInits(common.fieldInits, info.fieldInits)) {
return;
}
}
if (common == null) {
return;
}
Set<FieldInfo> fields = new HashSet<>();
for (InsnNode insn : common.getPutInsns()) {
FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex();
FieldNode field = cls.root().resolveField(fieldInfo);
if (field == null) {
return;
}
if (!fields.add(fieldInfo)) {
return;
}
}
// all checks passed
for (InitInfo info : infoList) {
for (InsnNode putInsn : info.getPutInsns()) {
for (ConstructorInitInfo info : infoList) {
for (FieldInitInfo fieldInit : info.fieldInits) {
IndexInsnNode putInsn = fieldInit.putInsn;
InsnArg arg = putInsn.getArg(0);
if (arg instanceof InsnWrapArg) {
((InsnWrapArg) arg).getWrapInsn().add(AFlag.DECLARE_VAR);
}
InsnRemover.remove(info.getConstrMth(), putInsn);
InsnRemover.remove(info.constructorMth, putInsn);
}
}
for (InsnNode insn : common.getPutInsns()) {
FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex();
FieldNode field = cls.root().resolveField(fieldInfo);
addFieldInitAttr(common.getConstrMth(), field, insn);
for (FieldInitInfo fieldInit : common.fieldInits) {
addFieldInitAttr(common.constructorMth, fieldInit.fieldNode, fieldInit.putInsn);
}
fixFieldsOrder(cls, common.fieldInits);
}
private static boolean compareInsns(List<InsnNode> base, List<InsnNode> other) {
if (base.size() != other.size()) {
return false;
}
int count = base.size();
for (int i = 0; i < count; i++) {
InsnNode baseInsn = base.get(i);
InsnNode otherInsn = other.get(i);
if (!baseInsn.isSame(otherInsn)) {
return false;
private static List<FieldInitInfo> collectFieldsInit(ClassNode cls, MethodNode mth, InsnType putType) {
List<FieldInitInfo> fieldsInit = new ArrayList<>();
Set<BlockNode> singlePathBlocks = new HashSet<>();
BlockUtils.visitSinglePath(mth.getEnterBlock(), singlePathBlocks::add);
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == putType) {
IndexInsnNode putInsn = (IndexInsnNode) insn;
FieldInfo field = (FieldInfo) putInsn.getIndex();
if (field.getDeclClass().equals(cls.getClassInfo())) {
FieldNode fn = cls.searchField(field);
if (fn != null) {
boolean singlePath = singlePathBlocks.contains(block);
fieldsInit.add(new FieldInitInfo(fn, putInsn, singlePath));
}
}
}
}
}
return true;
return fieldsInit;
}
private static boolean checkInsn(ClassNode cls, InsnNode insn) {
if (insn instanceof IndexInsnNode) {
FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) insn).getIndex();
if (!fieldInfo.getDeclClass().equals(cls.getClassInfo())) {
// exclude fields from super classes
return false;
}
FieldNode fieldNode = cls.root().resolveField(fieldInfo);
if (fieldNode == null) {
// exclude inherited fields (not declared in this class)
return false;
private static void filterFieldsInit(List<FieldInitInfo> inits) {
// exclude fields initialized several times
Set<FieldInfo> excludedFields = inits
.stream()
.collect(Collectors.toMap(fi -> fi.fieldNode, fi -> 1, Integer::sum))
.entrySet()
.stream()
.filter(v -> v.getValue() > 1)
.map(v -> v.getKey().getFieldInfo())
.collect(Collectors.toSet());
for (FieldInitInfo initInfo : inits) {
if (!checkInsn(initInfo)) {
excludedFields.add(initInfo.fieldNode.getFieldInfo());
}
} else {
return false;
}
if (!excludedFields.isEmpty()) {
boolean changed;
do {
changed = false;
for (FieldInitInfo initInfo : inits) {
FieldInfo fieldInfo = initInfo.fieldNode.getFieldInfo();
if (excludedFields.contains(fieldInfo)) {
continue;
}
if (insnUseExcludedField(initInfo, excludedFields)) {
excludedFields.add(fieldInfo);
changed = true;
}
}
} while (changed);
}
// apply
if (!excludedFields.isEmpty()) {
inits.removeIf(fi -> excludedFields.contains(fi.fieldNode.getFieldInfo()));
}
}
private static boolean checkInsn(FieldInitInfo initInfo) {
if (!initInfo.singlePath) {
return false;
}
IndexInsnNode insn = initInfo.putInsn;
InsnArg arg = insn.getArg(0);
if (arg.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
......@@ -248,6 +249,52 @@ public class ExtractFieldInit extends AbstractVisitor {
return true;
}
private static boolean insnUseExcludedField(FieldInitInfo initInfo, Set<FieldInfo> excludedFields) {
if (excludedFields.isEmpty()) {
return false;
}
IndexInsnNode insn = initInfo.putInsn;
boolean staticField = insn.getType() == InsnType.SPUT;
InsnType useType = staticField ? InsnType.SGET : InsnType.IGET;
// exclude if init code use any excluded field
Boolean exclude = insn.visitInsns(innerInsn -> {
if (innerInsn.getType() == useType) {
FieldInfo fieldInfo = (FieldInfo) ((IndexInsnNode) innerInsn).getIndex();
if (excludedFields.contains(fieldInfo)) {
return true;
}
}
return null;
});
return Objects.equals(exclude, Boolean.TRUE);
}
private static void fixFieldsOrder(ClassNode cls, List<FieldInitInfo> fieldsInit) {
List<FieldNode> clsFields = cls.getFields();
List<FieldNode> orderedFields = Utils.collectionMap(fieldsInit, v -> v.fieldNode);
// check if already ordered
boolean ordered = Collections.indexOfSubList(clsFields, orderedFields) != -1;
if (!ordered) {
clsFields.removeAll(orderedFields);
clsFields.addAll(orderedFields);
}
}
private static boolean compareFieldInits(List<FieldInitInfo> base, List<FieldInitInfo> other) {
if (base.size() != other.size()) {
return false;
}
int count = base.size();
for (int i = 0; i < count; i++) {
InsnNode baseInsn = base.get(i).putInsn;
InsnNode otherInsn = other.get(i).putInsn;
if (!baseInsn.isSame(otherInsn)) {
return false;
}
}
return true;
}
private static List<MethodNode> getConstructorsList(ClassNode cls) {
List<MethodNode> list = new ArrayList<>();
for (MethodNode mth : cls.getMethods()) {
......@@ -262,26 +309,8 @@ public class ExtractFieldInit extends AbstractVisitor {
return list;
}
private static List<InsnNode> getFieldAssigns(MethodNode mth, FieldNode field, InsnType putInsn) {
if (mth.isNoCode() || mth.getBasicBlocks() == null) {
return Collections.emptyList();
}
List<InsnNode> assignInsns = new ArrayList<>();
for (BlockNode block : mth.getBasicBlocks()) {
for (InsnNode insn : block.getInstructions()) {
if (insn.getType() == putInsn) {
FieldInfo putNode = (FieldInfo) ((IndexInsnNode) insn).getIndex();
if (putNode.equals(field.getFieldInfo())) {
assignInsns.add(insn);
}
}
}
}
return assignInsns;
}
private static void addFieldInitAttr(MethodNode classInitMth, FieldNode field, InsnNode insn) {
private static void addFieldInitAttr(MethodNode mth, FieldNode field, InsnNode insn) {
InsnNode assignInsn = InsnNode.wrapArg(insn.getArg(0));
field.addAttr(new FieldInitInsnAttr(classInitMth, assignInsn));
field.addAttr(new FieldInitInsnAttr(mth, assignInsn));
}
}
......@@ -761,6 +761,30 @@ public class BlockUtils {
}
}
/**
* Visit blocks on path without branching or merging paths.
*/
public static void visitSinglePath(BlockNode startBlock, Consumer<BlockNode> visitor) {
if (startBlock == null) {
return;
}
visitor.accept(startBlock);
BlockNode next = getNextSinglePathBlock(startBlock);
while (next != null) {
visitor.accept(next);
next = getNextSinglePathBlock(next);
}
}
@Nullable
public static BlockNode getNextSinglePathBlock(BlockNode block) {
if (block == null || block.getPredecessors().size() > 1) {
return null;
}
List<BlockNode> successors = block.getSuccessors();
return successors.size() == 1 ? successors.get(0) : null;
}
public static List<BlockNode> buildSimplePath(BlockNode block) {
if (block == null) {
return Collections.emptyList();
......
......@@ -24,6 +24,9 @@ public class TestUtils {
}
public static int count(String string, String substring) {
if (substring == null || substring.isEmpty()) {
throw new IllegalArgumentException("Substring can't be null or empty");
}
int count = 0;
int idx = 0;
while ((idx = string.indexOf(substring, idx)) != -1) {
......
......@@ -82,6 +82,8 @@ public class TestFieldInitInTryCatch extends IntegrationTest {
ClassNode cls = getClassNode(TestCls3.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("public static final String[] A = {\"a\"};"));
// don't move code from try/catch
assertThat(code, containsOne("public static final String[] A;"));
assertThat(code, containsOne("A = new String[]{\"a\"};"));
}
}
package jadx.tests.integration.others;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestFieldInitOrder extends IntegrationTest {
public static class TestCls {
private final StringBuilder sb = new StringBuilder();
private final String a = sb.append("a").toString();
private final String b = sb.append("b").toString();
private final String c = sb.append("c").toString();
private final String result = sb.toString();
public void check() {
assertThat(result).isEqualTo("abc");
assertThat(a).isEqualTo("a");
assertThat(b).isEqualTo("ab");
assertThat(c).isEqualTo("abc");
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.doesNotContain("TestCls() {") // constructor removed
.doesNotContain("String result;")
.containsOne("String result = this.sb.toString();");
}
}
package jadx.tests.integration.others;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestFieldInit extends IntegrationTest {
public class TestFieldInitOrderStatic extends IntegrationTest {
@SuppressWarnings("ConstantName")
public static class TestCls {
public class A {
}
public static List<String> s = new ArrayList<>();
public A a = new A();
public int i = 1 + Random.class.getSimpleName().length();
public int n = 0;
public TestCls(int z) {
this.n = z;
this.n = 0;
private static final StringBuilder sb = new StringBuilder();
private static final String a = sb.append("a").toString();
private static final String b = sb.append("b").toString();
private static final String c = sb.append("c").toString();
private static final String result = sb.toString();
public void check() {
assertThat(result).isEqualTo("abc");
assertThat(a).isEqualTo("a");
assertThat(b).isEqualTo("ab");
assertThat(c).isEqualTo("abc");
}
}
......@@ -33,14 +28,8 @@ public class TestFieldInit extends IntegrationTest {
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("List<String> s = new ArrayList")
.containsOne("A a = new A();")
.containsOneOf(
"int i = (Random.class.getSimpleName().length() + 1);",
"int i = (1 + Random.class.getSimpleName().length());")
.containsOne("int n = 0;")
.doesNotContain("static {")
.containsOne("this.n = z;")
.containsOne("this.n = 0;");
.doesNotContain("String result;")
.containsOne("String result = sb.toString();");
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册