提交 ed67f8e1 编写于 作者: S Skylot

core: make strict shrink code implementation

上级 45312560
......@@ -12,6 +12,7 @@ import jadx.core.dex.visitors.FallbackModeVisitor;
import jadx.core.dex.visitors.IDexTreeVisitor;
import jadx.core.dex.visitors.MethodInlinerVisitor;
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.dex.visitors.SimplifyVisitor;
import jadx.core.dex.visitors.regions.CheckRegions;
import jadx.core.dex.visitors.regions.CleanRegions;
import jadx.core.dex.visitors.regions.PostRegionVisitor;
......@@ -50,12 +51,13 @@ public class Jadx {
passes.add(new BlockMakerVisitor());
passes.add(new TypeResolver());
passes.add(new ConstInlinerVisitor());
passes.add(new FinishTypeResolver());
if (args.isRawCFGOutput())
passes.add(new DotGraphVisitor(outDir, false, true));
passes.add(new ConstInlinerVisitor());
passes.add(new FinishTypeResolver());
passes.add(new ModVisitor());
passes.add(new EnumVisitor());
......@@ -66,6 +68,7 @@ public class Jadx {
passes.add(new PostRegionVisitor());
passes.add(new CodeShrinker());
passes.add(new SimplifyVisitor());
passes.add(new ProcessVariables());
passes.add(new CheckRegions());
if (args.isCFGOutput())
......
......@@ -128,7 +128,9 @@ public class InsnGen {
private String ifield(FieldInfo field, InsnArg arg) throws CodegenException {
String name = field.getName();
if (arg.isThis()) {
// TODO: add jadx argument ""
// FIXME: check variable names in scope
if (false && arg.isThis()) {
boolean useShort = true;
List<RegisterArg> args = mth.getArguments(false);
for (RegisterArg param : args) {
......@@ -138,7 +140,7 @@ public class InsnGen {
}
}
if (useShort) {
return name; // FIXME: check variable names in scope
return name;
}
}
return arg(arg) + "." + name;
......@@ -672,7 +674,7 @@ public class InsnGen {
// "++" or "--"
if (insn.getArg(1).isLiteral() && (op == ArithOp.ADD || op == ArithOp.SUB)) {
LiteralArg lit = (LiteralArg) insn.getArg(1);
if (Math.abs(lit.getLiteral()) == 1 && lit.isInteger()) {
if (lit.isInteger() && lit.getLiteral() == 1) {
code.add(assignVar(insn)).add(op.getSymbol()).add(op.getSymbol());
return;
}
......
......@@ -16,6 +16,7 @@ import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.IBlock;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.IRegion;
......@@ -134,6 +135,19 @@ public class RegionGen extends InsnGen {
}
private CodeWriter makeLoop(LoopRegion region, CodeWriter code) throws CodegenException {
BlockNode header = region.getHeader();
if (header != null) {
List<InsnNode> headerInsns = header.getInstructions();
if (headerInsns.size() > 1) {
// write not inlined instructions from header
mth.getAttributes().add(AttributeFlag.INCONSISTENT_CODE);
for (int i = 0; i < headerInsns.size() - 1; i++) {
InsnNode insn = headerInsns.get(i);
makeInsn(insn, code);
}
}
}
IfCondition condition = region.getCondition();
if (condition == null) {
// infinite loop
......
......@@ -21,7 +21,7 @@ public class NameMapper {
"continue",
"default",
"do",
"double ",
"double",
"else",
"enum",
"extends",
......
......@@ -339,15 +339,16 @@ public abstract class ArgType {
}
public static ArgType merge(ArgType a, ArgType b) {
if (a.equals(b))
return a;
if (b == null || a == null)
if (b == null || a == null) {
return null;
}
if (a.equals(b)) {
return a;
}
ArgType res = mergeInternal(a, b);
if (res == null)
if (res == null) {
res = mergeInternal(b, a); // swap
}
return res;
}
......
......@@ -87,10 +87,6 @@ public abstract class InsnArg extends Typed {
case MOVE:
case CONST:
arg = insn.getArg(0);
String name = insn.getResult().getTypedVar().getName();
if (name != null) {
arg.getTypedVar().setName(name);
}
break;
case CONST_STR:
arg = wrap(insn);
......
......@@ -10,7 +10,6 @@ import jadx.core.dex.nodes.DexNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.parser.FieldValueAttr;
import jadx.core.dex.visitors.InstructionRemover;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -96,10 +95,6 @@ public class RegisterArg extends InsnArg {
if (ai != null && ai.getType() == InsnType.MOVE) {
InsnArg arg = ai.getArg(0);
if (arg != this && arg.isThis()) {
// actually we need to remove this instruction but we can't
// because of iterating on instructions list
// so unbind insn and rely on code shrinker
InstructionRemover.unbindInsn(ai);
return true;
}
}
......
......@@ -11,6 +11,8 @@ import jadx.core.dex.nodes.MethodNode;
public class ConstructorInsn extends InsnNode {
private final MethodInfo callMth;
private final CallType callType;
private final RegisterArg instanceArg;
private static enum CallType {
CONSTRUCTOR, // just new instance
......@@ -19,30 +21,28 @@ public class ConstructorInsn extends InsnNode {
SELF // call itself
}
private CallType callType;
public ConstructorInsn(MethodNode mth, InvokeNode invoke) {
super(InsnType.CONSTRUCTOR, invoke.getArgsCount() - 1);
this.callMth = invoke.getCallMth();
ClassInfo classType = callMth.getDeclClass();
instanceArg = (RegisterArg) invoke.getArg(0);
instanceArg.setParentInsn(this);
if (invoke.getArg(0).isThis()) {
if (instanceArg.isThis()) {
if (classType.equals(mth.getParentClass().getClassInfo())) {
if (callMth.getShortId().equals(mth.getMethodInfo().getShortId())) {
// self constructor
callType = CallType.SELF;
} else if (mth.getMethodInfo().isConstructor()) {
} else {
callType = CallType.THIS;
}
} else {
callType = CallType.SUPER;
}
}
if (callType == null) {
} else {
callType = CallType.CONSTRUCTOR;
setResult((RegisterArg) invoke.getArg(0));
setResult(instanceArg);
}
for (int i = 1; i < invoke.getArgsCount(); i++) {
addArg(invoke.getArg(i));
}
......@@ -53,6 +53,10 @@ public class ConstructorInsn extends InsnNode {
return callMth;
}
public RegisterArg getInstanceArg() {
return instanceArg;
}
public ClassInfo getClassType() {
return callMth.getDeclClass();
}
......
......@@ -133,6 +133,30 @@ public class InsnNode extends LineAttrNode {
}
}
public boolean canReorder() {
switch (getType()) {
case CONST:
case CONST_STR:
case CONST_CLASS:
case CAST:
case MOVE:
case ARITH:
case NEG:
case CMP_L:
case CMP_G:
case CHECK_CAST:
case INSTANCE_OF:
case FILL_ARRAY:
case FILLED_NEW_ARRAY:
case NEW_ARRAY:
case NEW_MULTIDIM_ARRAY:
case STR_CONCAT:
case MOVE_EXCEPTION:
return true;
}
return false;
}
@Override
public String toString() {
return InsnUtils.formatOffset(offset) + ": "
......
......@@ -66,7 +66,9 @@ public class RootNode {
for (ClassNode cls : inner) {
ClassNode parent = resolveClass(cls.getClassInfo().getParentClass());
if (parent == null) {
names.remove(cls.getFullName());
cls.getClassInfo().notInner();
names.put(cls.getFullName(), cls);
} else {
parent.addInnerClass(cls);
}
......
......@@ -40,10 +40,6 @@ public final class LoopRegion extends AbstractRegion {
return conditionBlock;
}
public void setHeader(BlockNode conditionBlock) {
this.conditionBlock = conditionBlock;
}
public IContainer getBody() {
return body;
}
......@@ -64,7 +60,7 @@ public final class LoopRegion extends AbstractRegion {
}
private IfNode getIfInsn() {
return (IfNode) conditionBlock.getInstructions().get(conditionBlock.getInstructions().size() - 1);
return (IfNode) conditionBlock.getInstructions().get(0);
}
/**
......@@ -108,10 +104,13 @@ public final class LoopRegion extends AbstractRegion {
*/
public void mergePreCondition() {
if (preCondition != null && conditionBlock != null) {
preCondition.getInstructions().addAll(conditionBlock.getInstructions());
conditionBlock.getInstructions().clear();
conditionBlock.getInstructions().addAll(preCondition.getInstructions());
preCondition.getInstructions().clear();
List<InsnNode> condInsns = conditionBlock.getInstructions();
List<InsnNode> preCondInsns = preCondition.getInstructions();
preCondInsns.addAll(condInsns);
condInsns.clear();
condInsns.addAll(preCondInsns);
preCondInsns.clear();
preCondition = null;
}
}
......
package jadx.core.dex.visitors;
import jadx.core.Consts;
import jadx.core.dex.attributes.AttributeFlag;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.exceptions.JadxRuntimeException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
......@@ -36,239 +27,219 @@ public class CodeShrinker extends AbstractVisitor {
@Override
public void visit(MethodNode mth) {
if (mth.isNoCode() || mth.getAttributes().contains(AttributeFlag.DONT_SHRINK))
if (mth.isNoCode() || mth.getAttributes().contains(AttributeFlag.DONT_SHRINK)) {
return;
shrink(mth);
prettify(mth);
}
for (BlockNode block : mth.getBasicBlocks()) {
shrinkBlock(mth, block);
}
}
private static void shrink(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
List<InsnNode> insnList = block.getInstructions();
InstructionRemover remover = new InstructionRemover(insnList);
for (InsnNode insn : insnList) {
// wrap instructions
RegisterArg result = insn.getResult();
if (result != null) {
List<InsnArg> useList = result.getTypedVar().getUseList();
if (useList.size() == 1) {
// variable is used only in this instruction
// TODO not correct sometimes :(
remover.add(insn);
} else if (useList.size() == 2) {
InsnArg useInsnArg = selectOther(useList, result);
InsnNode useInsn = useInsnArg.getParentInsn();
if (useInsn == null) {
LOG.debug("parent insn null: {}, mth: {}", insn, mth);
} else if (useInsn != insn) {
boolean wrap = false;
// TODO
if (false && result.getTypedVar().getName() != null) {
// don't wrap if result variable has name from debug info
wrap = false;
} else if (BlockUtils.blockContains(block, useInsn)) {
// TODO don't reorder methods invocations
// wrap insn from current block
wrap = true;
} else {
// TODO implement rules for shrink insn from different blocks
BlockNode useBlock = BlockUtils.getBlockByInsn(mth, useInsn);
if (useBlock != null
&& (useBlock.getPredecessors().contains(block)
|| insn.getType() == InsnType.MOVE_EXCEPTION)) {
wrap = true;
}
}
if (wrap) {
if (insn.getType() == InsnType.MOVE) {
for (int r = 0; r < useInsn.getArgsCount(); r++) {
if (useInsn.getArg(r).getTypedVar() == insn.getResult().getTypedVar()) {
useInsn.setArg(r, insn.getArg(0));
break;
}
}
} else {
useInsnArg.wrapInstruction(insn);
}
remover.add(insn);
}
}
}
private static final class ArgsInfo {
private final InsnNode insn;
private final List<ArgsInfo> argsList;
private final List<RegisterArg> args;
private final int pos;
private int inlineBorder;
private ArgsInfo inlinedInsn;
public ArgsInfo(InsnNode insn, List<ArgsInfo> argsList, int pos) {
this.insn = insn;
this.argsList = argsList;
this.pos = pos;
this.inlineBorder = pos;
this.args = new LinkedList<RegisterArg>();
addArgs(insn, args);
}
private static void addArgs(InsnNode insn, List<RegisterArg> args) {
if (insn.getType() == InsnType.CONSTRUCTOR) {
args.add(((ConstructorInsn) insn).getInstanceArg());
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isRegister()) {
args.add((RegisterArg) arg);
}
}
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
addArgs(((InsnWrapArg) arg).getWrapInsn(), args);
}
}
remover.perform();
}
}
private static void prettify(MethodNode mth) {
for (BlockNode block : mth.getBasicBlocks()) {
List<InsnNode> list = block.getInstructions();
for (int i = 0; i < list.size(); i++) {
InsnNode modInsn = pretifyInsn(mth, list.get(i));
if (modInsn != null) {
list.set(i, modInsn);
public InsnNode getInsn() {
return insn;
}
private List<RegisterArg> getArgs() {
return args;
}
public WrapInfo checkInline(int assignPos, RegisterArg arg) {
if (assignPos >= inlineBorder || !canMove(assignPos, inlineBorder)) {
return null;
}
inlineBorder = assignPos;
return inline(assignPos, arg);
}
private boolean canMove(int from, int to) {
from++;
if (from == to) {
// previous instruction or on edge of inline border
return true;
}
if (from > to) {
throw new JadxRuntimeException("Invalid inline insn positions: " + from + " - " + to);
}
for (int i = from; i < to - 1; i++) {
ArgsInfo argsInfo = argsList.get(i);
if (argsInfo.getInlinedInsn() == this) {
continue;
}
if (!argsInfo.insn.canReorder()) {
return false;
}
}
return true;
}
}
private static InsnNode pretifyInsn(MethodNode mth, InsnNode insn) {
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
InsnNode ni = pretifyInsn(mth, ((InsnWrapArg) arg).getWrapInsn());
if (ni != null)
arg.wrapInstruction(ni);
public WrapInfo inline(int assignInsnPos, RegisterArg arg) {
ArgsInfo argsInfo = argsList.get(assignInsnPos);
argsInfo.inlinedInsn = this;
return new WrapInfo(argsInfo.insn, arg);
}
public ArgsInfo getInlinedInsn() {
if (inlinedInsn != null) {
ArgsInfo parent = inlinedInsn.getInlinedInsn();
if (parent != null) {
inlinedInsn = parent;
}
}
return inlinedInsn;
}
switch (insn.getType()) {
case ARITH:
ArithNode arith = (ArithNode) insn;
if (arith.getArgsCount() == 2) {
InsnArg litArg = null;
if (arith.getArg(1).isInsnWrap()) {
InsnNode wr = ((InsnWrapArg) arith.getArg(1)).getWrapInsn();
if (wr.getType() == InsnType.CONST)
litArg = wr.getArg(0);
} else if (arith.getArg(1).isLiteral()) {
litArg = arith.getArg(1);
}
@Override
public String toString() {
return "ArgsInfo: |" + inlineBorder
+ " ->" + (inlinedInsn == null ? "-" : inlinedInsn.pos)
+ " " + args + " : " + insn;
}
}
if (litArg != null) {
long lit = ((LiteralArg) litArg).getLiteral();
boolean invert = false;
private static final class WrapInfo {
private final InsnNode insn;
private final RegisterArg arg;
if (arith.getOp() == ArithOp.ADD && lit < 0)
invert = true;
public WrapInfo(InsnNode assignInsn, RegisterArg arg) {
this.insn = assignInsn;
this.arg = arg;
}
// fix 'c + (-1)' => 'c - (1)'
if (invert) {
return new ArithNode(ArithOp.SUB,
arith.getResult(), insn.getArg(0),
InsnArg.lit(-lit, litArg.getType()));
}
}
}
break;
private InsnNode getInsn() {
return insn;
}
case IF:
// simplify 'cmp' instruction in if condition
IfNode ifb = (IfNode) insn;
InsnArg f = ifb.getArg(0);
if (f.isInsnWrap()) {
InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
if (ifb.isZeroCmp()
|| ((LiteralArg) ifb.getArg(1)).getLiteral() == 0) {
ifb.changeCondition(wi.getArg(0), wi.getArg(1), ifb.getOp());
} else {
LOG.warn("TODO: cmp" + ifb);
}
}
}
break;
private RegisterArg getArg() {
return arg;
}
case INVOKE:
MethodInfo callMth = ((InvokeNode) insn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getShortId().equals("toString()")
&& insn.getArg(0).isInsnWrap()) {
try {
List<InsnNode> chain = flattenInsnChain(insn);
if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) {
ConstructorInsn constr = (ConstructorInsn) chain.get(0);
if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& constr.getArgsCount() == 0) {
int len = chain.size();
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
for (int i = 1; i < len; i++) {
concatInsn.addArg(chain.get(i).getArg(1));
}
concatInsn.setResult(insn.getResult());
return concatInsn;
}
@Override
public String toString() {
return "WrapInfo: " + arg + " -> " + insn;
}
}
private void shrinkBlock(MethodNode mth, BlockNode block) {
InsnList insnList = new InsnList(block.getInstructions());
int insnCount = insnList.size();
List<ArgsInfo> argsList = new ArrayList<ArgsInfo>(insnCount);
for (int i = 0; i < insnCount; i++) {
argsList.add(new ArgsInfo(insnList.get(i), argsList, i));
}
List<WrapInfo> wrapList = new ArrayList<WrapInfo>();
for (ArgsInfo argsInfo : argsList) {
List<RegisterArg> args = argsInfo.getArgs();
for (ListIterator<RegisterArg> it = args.listIterator(args.size()); it.hasPrevious(); ) {
RegisterArg arg = it.previous();
List<InsnArg> useList = arg.getTypedVar().getUseList();
if (useList.size() != 2) {
continue;
}
InsnNode assignInsn = selectOther(useList, arg).getParentInsn();
if (assignInsn == null
|| assignInsn.getResult() == null
|| assignInsn.getResult().getRegNum() != arg.getRegNum()) {
continue;
}
int assignPos = insnList.getIndex(assignInsn);
if (assignPos != -1) {
if (assignInsn.canReorder()) {
wrapList.add(argsInfo.inline(assignPos, arg));
} else {
WrapInfo wrapInfo = argsInfo.checkInline(assignPos, arg);
if (wrapInfo != null) {
wrapList.add(wrapInfo);
}
} catch (Throwable e) {
LOG.debug("Can't convert string concatenation: {} insn: {}", mth, insn, e);
}
}
break;
case IPUT:
case SPUT:
// convert field arith operation to arith instruction
// (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
InsnArg arg = insn.getArg(0);
if (arg.isInsnWrap()) {
InsnNode wrap = ((InsnWrapArg) arg).getWrapInsn();
InsnType wrapType = wrap.getType();
if ((wrapType == InsnType.ARITH || wrapType == InsnType.STR_CONCAT) && wrap.getArg(0).isInsnWrap()) {
InsnNode get = ((InsnWrapArg) wrap.getArg(0)).getWrapInsn();
InsnType getType = get.getType();
if (getType == InsnType.IGET || getType == InsnType.SGET) {
FieldInfo field = (FieldInfo) ((IndexInsnNode) insn).getIndex();
FieldInfo innerField = (FieldInfo) ((IndexInsnNode) get).getIndex();
if (field.equals(innerField)) {
try {
RegisterArg reg = null;
if (getType == InsnType.IGET) {
reg = ((RegisterArg) get.getArg(0));
}
RegisterArg fArg = new FieldArg(field, reg != null ? reg.getRegNum() : -1);
if (reg != null) {
fArg.replaceTypedVar(get.getArg(0));
}
if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap;
return new ArithNode(ar.getOp(), fArg, fArg, ar.getArg(1));
} else {
int argsCount = wrap.getArgsCount();
InsnNode concat = new InsnNode(InsnType.STR_CONCAT, argsCount - 1);
for (int i = 1; i < argsCount; i++) {
concat.addArg(wrap.getArg(i));
}
return new ArithNode(ArithOp.ADD, fArg, fArg, InsnArg.wrap(concat));
}
} catch (Throwable e) {
LOG.debug("Can't convert field arith insn: {}, mth: {}", insn, mth, e);
}
}
} else {
// another block
if (block.getPredecessors().size() == 1) {
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
if (canMoveBetweenBlocks(assignInsn, assignBlock, block, argsInfo.getInsn())) {
arg.wrapInstruction(assignInsn);
InsnList.remove(assignBlock, assignInsn);
}
}
}
break;
case CHECK_CAST:
InsnArg castArg = insn.getArg(0);
ArgType castType = (ArgType) ((IndexInsnNode) insn).getIndex();
if (!ArgType.isCastNeeded(castArg.getType(), castType)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setResult(insn.getResult());
insnNode.addArg(castArg);
return insnNode;
}
break;
default:
break;
}
}
for (WrapInfo wrapInfo : wrapList) {
wrapInfo.getArg().wrapInstruction(wrapInfo.getInsn());
}
for (WrapInfo wrapInfo : wrapList) {
insnList.remove(wrapInfo.getInsn());
}
return null;
}
private static List<InsnNode> flattenInsnChain(InsnNode insn) {
List<InsnNode> chain = new ArrayList<InsnNode>();
InsnArg i = insn.getArg(0);
while (i.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getArgsCount() == 0)
break;
}
i = wrapInsn.getArg(0);
private boolean canMoveBetweenBlocks(InsnNode assignInsn, BlockNode assignBlock,
BlockNode useBlock, InsnNode useInsn) {
if (!useBlock.getPredecessors().contains(assignBlock)
&& !BlockUtils.isOnlyOnePathExists(assignBlock, useBlock)) {
return false;
}
boolean startCheck = false;
for (InsnNode insn : assignBlock.getInstructions()) {
if (startCheck) {
if (!insn.canReorder()) {
return false;
}
}
if (insn == assignInsn) {
startCheck = true;
}
}
BlockNode next = assignBlock.getCleanSuccessors().get(0);
while (next != useBlock) {
for (InsnNode insn : assignBlock.getInstructions()) {
if (!insn.canReorder()) {
return false;
}
}
next = next.getCleanSuccessors().get(0);
}
for (InsnNode insn : useBlock.getInstructions()) {
if (insn == useInsn) {
return true;
}
if (!insn.canReorder()) {
return false;
}
}
Collections.reverse(chain);
return chain;
throw new JadxRuntimeException("Can't process instruction move : " + assignBlock);
}
public static InsnArg inlineArgument(MethodNode mth, RegisterArg arg) {
......@@ -304,9 +275,6 @@ public class CodeShrinker extends AbstractVisitor {
private static InsnArg selectOther(List<InsnArg> list, RegisterArg insn) {
InsnArg first = list.get(0);
if (first == insn)
return list.get(1);
else
return first;
return insn == first ? list.get(1) : first;
}
}
package jadx.core.dex.visitors;
import jadx.core.Consts;
import jadx.core.dex.info.FieldInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.ArithOp;
import jadx.core.dex.instructions.IfNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.FieldArg;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.mods.ConstructorInsn;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class SimplifyVisitor extends AbstractVisitor {
private static final Logger LOG = LoggerFactory.getLogger(SimplifyVisitor.class);
@Override
public void visit(MethodNode mth) {
if (mth.isNoCode()) {
return;
}
for (BlockNode block : mth.getBasicBlocks()) {
List<InsnNode> list = block.getInstructions();
for (int i = 0; i < list.size(); i++) {
InsnNode modInsn = simplifyInsn(mth, list.get(i));
if (modInsn != null) {
list.set(i, modInsn);
}
}
}
}
private static InsnNode simplifyInsn(MethodNode mth, InsnNode insn) {
for (InsnArg arg : insn.getArguments()) {
if (arg.isInsnWrap()) {
InsnNode ni = simplifyInsn(mth, ((InsnWrapArg) arg).getWrapInsn());
if (ni != null) {
arg.wrapInstruction(ni);
}
}
}
switch (insn.getType()) {
case ARITH:
ArithNode arith = (ArithNode) insn;
if (arith.getArgsCount() == 2) {
InsnArg litArg = null;
if (arith.getArg(1).isInsnWrap()) {
InsnNode wr = ((InsnWrapArg) arith.getArg(1)).getWrapInsn();
if (wr.getType() == InsnType.CONST) {
litArg = wr.getArg(0);
}
} else if (arith.getArg(1).isLiteral()) {
litArg = arith.getArg(1);
}
if (litArg != null) {
long lit = ((LiteralArg) litArg).getLiteral();
boolean invert = false;
if (arith.getOp() == ArithOp.ADD && lit < 0) {
invert = true;
}
// fix 'c + (-1)' => 'c - (1)'
if (invert) {
return new ArithNode(ArithOp.SUB,
arith.getResult(), insn.getArg(0),
InsnArg.lit(-lit, litArg.getType()));
}
}
}
break;
case IF:
// simplify 'cmp' instruction in if condition
IfNode ifb = (IfNode) insn;
InsnArg f = ifb.getArg(0);
if (f.isInsnWrap()) {
InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
if (ifb.isZeroCmp()
|| ((LiteralArg) ifb.getArg(1)).getLiteral() == 0) {
ifb.changeCondition(wi.getArg(0), wi.getArg(1), ifb.getOp());
} else {
LOG.warn("TODO: cmp" + ifb);
}
}
}
break;
case INVOKE:
MethodInfo callMth = ((InvokeNode) insn).getCallMth();
if (callMth.getDeclClass().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& callMth.getShortId().equals("toString()")
&& insn.getArg(0).isInsnWrap()) {
try {
List<InsnNode> chain = flattenInsnChain(insn);
if (chain.size() > 1 && chain.get(0).getType() == InsnType.CONSTRUCTOR) {
ConstructorInsn constr = (ConstructorInsn) chain.get(0);
if (constr.getClassType().getFullName().equals(Consts.CLASS_STRING_BUILDER)
&& constr.getArgsCount() == 0) {
int len = chain.size();
InsnNode concatInsn = new InsnNode(InsnType.STR_CONCAT, len - 1);
for (int i = 1; i < len; i++) {
concatInsn.addArg(chain.get(i).getArg(1));
}
concatInsn.setResult(insn.getResult());
return concatInsn;
}
}
} catch (Throwable e) {
LOG.debug("Can't convert string concatenation: {} insn: {}", mth, insn, e);
}
}
break;
case IPUT:
case SPUT:
// convert field arith operation to arith instruction
// (IPUT = ARITH (IGET, lit) -> ARITH (fieldArg <op>= lit))
InsnArg arg = insn.getArg(0);
if (arg.isInsnWrap()) {
InsnNode wrap = ((InsnWrapArg) arg).getWrapInsn();
InsnType wrapType = wrap.getType();
if ((wrapType == InsnType.ARITH || wrapType == InsnType.STR_CONCAT) && wrap.getArg(0).isInsnWrap()) {
InsnNode get = ((InsnWrapArg) wrap.getArg(0)).getWrapInsn();
InsnType getType = get.getType();
if (getType == InsnType.IGET || getType == InsnType.SGET) {
FieldInfo field = (FieldInfo) ((IndexInsnNode) insn).getIndex();
FieldInfo innerField = (FieldInfo) ((IndexInsnNode) get).getIndex();
if (field.equals(innerField)) {
try {
RegisterArg reg = null;
if (getType == InsnType.IGET) {
reg = ((RegisterArg) get.getArg(0));
}
RegisterArg fArg = new FieldArg(field, reg != null ? reg.getRegNum() : -1);
if (reg != null) {
fArg.replaceTypedVar(get.getArg(0));
}
if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap;
return new ArithNode(ar.getOp(), fArg, fArg, ar.getArg(1));
} else {
int argsCount = wrap.getArgsCount();
InsnNode concat = new InsnNode(InsnType.STR_CONCAT, argsCount - 1);
for (int i = 1; i < argsCount; i++) {
concat.addArg(wrap.getArg(i));
}
return new ArithNode(ArithOp.ADD, fArg, fArg, InsnArg.wrap(concat));
}
} catch (Throwable e) {
LOG.debug("Can't convert field arith insn: {}, mth: {}", insn, mth, e);
}
}
}
}
}
break;
case CHECK_CAST:
InsnArg castArg = insn.getArg(0);
ArgType castType = (ArgType) ((IndexInsnNode) insn).getIndex();
if (!ArgType.isCastNeeded(castArg.getType(), castType)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setResult(insn.getResult());
insnNode.addArg(castArg);
return insnNode;
}
break;
default:
break;
}
return null;
}
private static List<InsnNode> flattenInsnChain(InsnNode insn) {
List<InsnNode> chain = new ArrayList<InsnNode>();
InsnArg i = insn.getArg(0);
while (i.isInsnWrap()) {
InsnNode wrapInsn = ((InsnWrapArg) i).getWrapInsn();
chain.add(wrapInsn);
if (wrapInsn.getArgsCount() == 0) {
break;
}
i = wrapInsn.getArg(0);
}
Collections.reverse(chain);
return chain;
}
}
......@@ -186,7 +186,7 @@ public class RegionMaker {
if (loopRegion.isConditionAtEnd()) {
// TODO: add some checks
} else {
if (condBlock != loop.getStart()) {
if (condBlock != loopStart) {
if (condBlock.getPredecessors().contains(loopStart)) {
loopRegion.setPreCondition(loopStart);
// if we can't merge pre-condition this is not correct header
......@@ -492,10 +492,8 @@ public class RegionMaker {
IfCondition nestedCondition = IfCondition.fromIfNode(nestedIfInsn);
if (isPathExists(bElse, nestedIfBlock)) {
// else branch
if (bThen != nbThen
&& !isEqualReturns(bThen, nbThen)) {
if (bThen != nbElse
&& !isEqualReturns(bThen, nbElse)) {
if (!isEqualReturns(bThen, nbThen)) {
if (!isEqualReturns(bThen, nbElse)) {
// not connected conditions
break;
}
......@@ -505,10 +503,8 @@ public class RegionMaker {
condition = IfCondition.merge(Mode.OR, condition, nestedCondition);
} else {
// then branch
if (bElse != nbElse
&& !isEqualReturns(bElse, nbElse)) {
if (bElse != nbThen
&& !isEqualReturns(bElse, nbThen)) {
if (!isEqualReturns(bElse, nbElse)) {
if (!isEqualReturns(bElse, nbThen)) {
// not connected conditions
break;
}
......@@ -693,21 +689,20 @@ public class RegionMaker {
if (b1 == b2) {
return true;
}
if (b1 == null
|| b2 == null) {
if (b1 == null || b2 == null) {
return false;
}
if (b1.getInstructions().size()
+ b2.getInstructions().size() != 2) {
List<InsnNode> b1Insns = b1.getInstructions();
List<InsnNode> b2Insns = b2.getInstructions();
if (b1Insns.size() != 1 || b2Insns.size() != 1) {
return false;
}
if (!b1.getAttributes().contains(AttributeFlag.RETURN)
|| !b2.getAttributes().contains(AttributeFlag.RETURN)) {
|| !b2.getAttributes().contains(AttributeFlag.RETURN)) {
return false;
}
if (b1.getInstructions().get(0).getArgsCount() > 0) {
if (b1.getInstructions().get(0).getArg(0)
!= b2.getInstructions().get(0).getArg(0)) {
if (b1Insns.get(0).getArgsCount() > 0) {
if (b1Insns.get(0).getArg(0) != b2Insns.get(0).getArg(0)) {
return false;
}
}
......
......@@ -87,18 +87,6 @@ public class BlockUtils {
return false;
}
/**
* Return instruction position in block (use == for comparison, not equals)
*/
public static int insnIndex(BlockNode block, InsnNode insn) {
int size = block.getInstructions().size();
for (int i = 0; i < size; i++) {
if (block.getInstructions().get(i) == insn)
return i;
}
return -1;
}
public static boolean lastInsnType(BlockNode block, InsnType type) {
List<InsnNode> insns = block.getInstructions();
if (insns.isEmpty())
......@@ -190,6 +178,22 @@ public class BlockUtils {
return traverseSuccessorsUntil(start, end, new HashSet<BlockNode>());
}
public static boolean isOnlyOnePathExists(BlockNode start, BlockNode end) {
if (start == end) {
return true;
}
if (!end.isDominator(start)) {
return false;
}
while (start.getCleanSuccessors().size() == 1) {
start = start.getCleanSuccessors().get(0);
if (start == end) {
return true;
}
}
return false;
}
/**
* Search for first node which not dominated by block, starting from child
*/
......
package jadx.core.utils;
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;
import java.util.Iterator;
import java.util.List;
public final class InsnList implements Iterable<InsnNode> {
public static void remove(List<InsnNode> list, InsnNode insn) {
for (Iterator<InsnNode> iterator = list.iterator(); iterator.hasNext(); ) {
InsnNode next = iterator.next();
if (next == insn) {
iterator.remove();
return;
}
}
}
public static void remove(BlockNode block, InsnNode insn) {
remove(block.getInstructions(), insn);
}
public static int getIndex(List<InsnNode> list, InsnNode insn) {
int i = 0;
for (InsnNode curObj : list) {
if (curObj == insn) {
return i;
}
i++;
}
return -1;
}
private final List<InsnNode> list;
public InsnList(List<InsnNode> list) {
this.list = list;
}
public int getIndex(InsnNode insn) {
return getIndex(list, insn);
}
public boolean contains(InsnNode insn) {
return getIndex(insn) != -1;
}
public void remove(InsnNode insn) {
remove(list, insn);
}
public Iterator<InsnNode> iterator() {
return list.iterator();
}
public InsnNode get(int index) {
return list.get(index);
}
public int size() {
return list.size();
}
}
......@@ -19,6 +19,7 @@ import java.util.jar.JarEntry;
import java.util.jar.JarOutputStream;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertNotNull;
import static junit.framework.Assert.fail;
public abstract class InternalJadxTest {
......@@ -32,14 +33,18 @@ public abstract class InternalJadxTest {
Decompiler d = new Decompiler();
try {
d.loadFile(temp);
assertEquals(d.getClasses().size(), 1);
} catch (Exception e) {
fail(e.getMessage());
} finally {
temp.delete();
}
List<ClassNode> classes = d.getRoot().getClasses(false);
ClassNode cls = classes.get(0);
String clsName = clazz.getName();
ClassNode cls = null;
for (ClassNode aClass : classes) {
if(aClass.getFullName().equals(clsName)) {
cls = aClass;
}
}
assertNotNull("Class not found: " + clsName, cls);
assertEquals(cls.getFullName(), clazz.getName());
......
......@@ -58,7 +58,7 @@ public class TestLoopCondition extends InternalJadxTest {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("i < f.length()"));
assertThat(code, containsString("i < this.f.length()"));
assertThat(code, containsString("while (a && i < 10) {"));
assertThat(code, containsString("list.set(i, \"ABC\")"));
assertThat(code, containsString("list.set(i, \"DEF\")"));
......
......@@ -3,8 +3,6 @@ package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
......@@ -29,7 +27,7 @@ public class TestRedundantThis extends InternalJadxTest {
}
}
@Test
// @Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
......
......@@ -3,8 +3,6 @@ package jadx.tests.internal;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
......@@ -15,7 +13,7 @@ public class TestStringBuilderElimination extends InternalJadxTest {
private static final long serialVersionUID = 4245254480662372757L;
public MyException(String str, Exception e) {
// super("msg:" + str, e);
super("msg:" + str, e);
}
public void method(int k) {
......@@ -23,10 +21,14 @@ public class TestStringBuilderElimination extends InternalJadxTest {
}
}
@Test
// @Test
public void test() {
ClassNode cls = getClassNode(MyException.class);
String code = cls.getCode().toString();
System.out.println(code);
assertThat(code, containsString("MyException(String str, Exception e) {"));
assertThat(code, containsString("super(\"msg:\" + str, e);"));
assertThat(code, not(containsString("new StringBuilder")));
assertThat(code, containsString("System.out.println(\"k=\" + k);"));
......
package jadx.tests.internal.inline;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertThat;
public class TestInline extends InternalJadxTest {
public static class TestCls extends Exception {
public static void main(String[] args) throws Exception {
System.out.println("Test: " + new TestCls().testRun());
}
private boolean testRun() {
return false;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("System.out.println(\"Test: \" + new TestInline$TestCls().testRun());"));
}
}
package jadx.tests.internal.inline;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.junit.Assert.assertThat;
public class TestInline2 extends InternalJadxTest {
public static class TestCls extends Exception {
public int simple_loops() throws InterruptedException {
int[] a = new int[]{1, 2, 4, 6, 8};
int b = 0;
for (int i = 0; i < a.length; i++) {
b += a[i];
}
for (long i = b; i > 0; i--) {
b += i;
}
return b;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("i < a.length"));
}
}
package jadx.tests.internal.inline;
import jadx.api.InternalJadxTest;
import jadx.core.dex.nodes.ClassNode;
import org.junit.Test;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.junit.Assert.assertThat;
public class TestInline3 extends InternalJadxTest {
public static class TestCls {
public TestCls(int b1, int b2) {
this(b1, b2, 0, 0, 0);
}
public TestCls(int a1, int a2, int a3, int a4, int a5) {
}
public class A extends TestCls{
public A(int a) {
super(a, a);
}
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsString("this(b1, b2, 0, 0, 0);"));
assertThat(code, containsString("super(a, a);"));
assertThat(code, not(containsString("super(a, a).this$0")));
assertThat(code, containsString("public class A extends TestInline3$TestCls {"));
}
}
......@@ -164,6 +164,32 @@ public class TestCF3 extends AbstractTest {
return i;
}
private void f() {
try {
Thread.sleep(50);
} catch (InterruptedException e) {
// ignore
}
}
public long testInline() {
long l = System.nanoTime();
f();
return System.nanoTime() - l;
}
private int f2 = 1;
public void func() {
this.f2++;
}
public boolean testInline2() {
int a = this.f2;
func();
return a != this.f2;
}
@Override
public boolean testRun() throws Exception {
setEnabled(false);
......@@ -200,6 +226,9 @@ public class TestCF3 extends AbstractTest {
assertEquals(testComplexIfInLoop3(2), 2);
assertEquals(testComplexIfInLoop3(6), 6);
assertEquals(testComplexIfInLoop3(8), 24);
assertTrue(testInline() > 20);
assertTrue(testInline2());
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册