From c2aa3202bafc21ae8623df820cd8f723671fc0f7 Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Sun, 19 Apr 2020 16:13:04 +0200 Subject: [PATCH] Implement rudimentary method weaving --- src/dev/w1zzrd/asm/AsmAnnotation.java | 18 +- src/dev/w1zzrd/asm/Inject.java | 2 + src/dev/w1zzrd/asm/Merger.java | 334 +++++++++++++++++++------- 3 files changed, 263 insertions(+), 91 deletions(-) diff --git a/src/dev/w1zzrd/asm/AsmAnnotation.java b/src/dev/w1zzrd/asm/AsmAnnotation.java index 990ad74..83f422b 100644 --- a/src/dev/w1zzrd/asm/AsmAnnotation.java +++ b/src/dev/w1zzrd/asm/AsmAnnotation.java @@ -1,25 +1,25 @@ package dev.w1zzrd.asm; -import com.sun.istack.internal.Nullable; - +import java.lang.annotation.Annotation; import java.util.Map; -public final class AsmAnnotation { - private final String annotationType; +public final class AsmAnnotation { + private final Class annotationType; private final Map entries; - public AsmAnnotation(String annotationType, Map entries) { + public AsmAnnotation(Class annotationType, Map entries) { this.annotationType = annotationType; this.entries = entries; } - public String getAnnotationType() { + public Class getAnnotationType() { return annotationType; } - @Nullable - public T getEntry(String name, Class type) { - return hasEntry(name) ? (T)entries.get(name) : null; + public T getEntry(String name) { + if (!hasEntry(name)) + throw new IllegalArgumentException(String.format("No entry \"%s\" in asm annotation!", name)); + return (T)entries.get(name); } public boolean hasEntry(String name) { diff --git a/src/dev/w1zzrd/asm/Inject.java b/src/dev/w1zzrd/asm/Inject.java index 73ad5c2..4311d33 100644 --- a/src/dev/w1zzrd/asm/Inject.java +++ b/src/dev/w1zzrd/asm/Inject.java @@ -11,4 +11,6 @@ import static dev.w1zzrd.asm.InPlaceInjection.REPLACE; @Target({ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.FIELD}) public @interface Inject { InPlaceInjection value() default REPLACE; + String target() default ""; + boolean acceptOriginalReturn() default false; } diff --git a/src/dev/w1zzrd/asm/Merger.java b/src/dev/w1zzrd/asm/Merger.java index 0ad2fe9..d2a5aff 100644 --- a/src/dev/w1zzrd/asm/Merger.java +++ b/src/dev/w1zzrd/asm/Merger.java @@ -1,26 +1,31 @@ package dev.w1zzrd.asm; import com.sun.istack.internal.Nullable; +import com.sun.org.apache.bcel.internal.generic.GETSTATIC; import jdk.internal.org.objectweb.asm.*; import jdk.internal.org.objectweb.asm.tree.*; import java.io.*; +import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URL; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; - import static dev.w1zzrd.asm.Merger.SpecialCall.FIELD; -import static dev.w1zzrd.asm.Merger.SpecialCall.METHOD; import static dev.w1zzrd.asm.Merger.SpecialCall.SUPER; import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS; public class Merger { + private static final Pattern re_methodSignature = Pattern.compile("((?:[a-zA-Z_$][a-zA-Z\\d_$]+)|(?:))\\(((?:(?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D)*)\\)((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D|V)"); + private static final Pattern re_types = Pattern.compile("((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D)"); + private static final Pattern re_retTypes = Pattern.compile("((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D|V)"); + protected final ClassNode targetNode; - //protected final List injectMethods = new ArrayList<>(); - //protected final List injectFields = new ArrayList<>(); + protected final ArrayList extended = new ArrayList<>(); public Merger(String targetClass) throws IOException { @@ -84,20 +89,21 @@ public class Merger { if (inject.visibleAnnotations != null && inject.interfaces != null) { - AsmAnnotation annot = getAnnotation("Ldev/w1zzrd/asm/InjectClass;", inject); + AsmAnnotation injectAnnotation = getAnnotation(InjectClass.class, inject); // If there is not injectMethods annotation or there is an // explicit request to not injectMethods interfaces, just return - if (annot == null || (annot.hasEntry("injectInterfaces") && !annot.getEntry("injectInterfaces", Boolean.class))) + if (injectAnnotation == null || + (injectAnnotation.hasEntry("injectInterfaces") && + !(Boolean)injectAnnotation.getEntry("injectInterfaces"))) return; if (targetNode.interfaces == null) targetNode.interfaces = new ArrayList<>(); - for (String iface : inject.interfaces) - if (!targetNode.interfaces.contains(iface)) - targetNode.interfaces.add(iface); + + inject.interfaces.stream().filter(it -> !targetNode.interfaces.contains(it)).forEach(targetNode.interfaces::add); } } @@ -175,10 +181,11 @@ public class Merger { } - AsmAnnotation annotation = getAnnotation("Ldev/w1zzrd/asm/Inject;", inject); + MethodSig signature = getSignature(inject); + AsmAnnotation annotation = getAnnotation(Inject.class, inject); if (annotation != null && annotation.hasEntry("value")) { - InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value", InPlaceInjection.class)); + InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value")); Optional adapt = targetNode.methods.stream().filter(it -> methodNodeEquals(it, inject)).findFirst(); if (injection != InPlaceInjection.REPLACE && adapt.isPresent()) { @@ -198,7 +205,7 @@ public class Merger { } // If no goto instructions were added, just remove the added label - if (removeReturn(instr, next) && created) + if (removeReturn(instr, next, true, null) && created) toAdapt.remove(next); else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame toAdapt.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null)); @@ -218,13 +225,59 @@ public class Merger { created = true; } + boolean isStaticMethod = (inject.access & Opcodes.ACC_STATIC) != 0; + boolean keepReturn = annotation.hasEntry("acceptOriginalReturn") && (Boolean)annotation.getEntry("acceptOriginalReturn"); + boolean hasReturn = !signature.ret.equals("V"); + + LocalVariableNode retNode = keepReturn && hasReturn ? inject.localVariables.get((isStaticMethod ? 0 : 1) + signature.args.length) : null; + + // If no goto instructions were added, just remove the added label - if (removeReturn(toAdapt, next) && created) - instr.remove(next); + boolean noGoto = removeReturn(toAdapt, next, !keepReturn && hasReturn, retNode); + if(noGoto) { + if(created) + instr.remove(next); + } else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null)); + /* + if(keepReturn && hasReturn) { + if (!extended.contains(signature)) { + Object[] locals = new Object[inject.localVariables.size()]; + + locals[0] = getTargetName();//resolveFrameType(inject.localVariables.get(0).desc); + locals[1] = Opcodes.TOP; + locals[2] = resolveFrameType(inject.localVariables.get(2).desc); + + //for (int i = 0; i < inject.localVariables.size(); ++i) + // locals[i] = resolveFrameType(inject.localVariables.get(i).desc); + + //instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null)); + instr.add(1, new FrameNode(Opcodes.F_FULL, 3, locals, 1, new Object[]{ resolveFrameType(inject.localVariables.get(inject.localVariables.size() - 1).desc) })); + instr.add(2, new VarInsnNode(resolveStoreInstr(signature.ret), locals.length - 1)); + + extended.add(signature); + } + } + */ + instr.addAll(0, toAdapt); + + if (keepReturn && hasReturn) { + // A little bit overkill, but I'm lazy + LabelNode first; + if (instr.get(0) instanceof LabelNode) + first = (LabelNode) instr.get(0); + else { + first = new LabelNode(); + instr.add(0, first); + } + + // Make the scope of received retVal span the entire method + retNode.start = first; + } + break; } } @@ -241,20 +294,14 @@ public class Merger { if (var.desc.equals("L"+injectOwner+";")) var.desc = "L"+getTargetName()+";"; }); + + inject.desc = '(' + signature.args_literal + ')' + signature.ret; } public boolean shouldInject(ClassNode inject) { - if (inject.visibleAnnotations != null) { - for (AnnotationNode aNode : inject.visibleAnnotations) - if ( - aNode.desc.equals("Ldev/w1zzrd/asm/InjectClass;") && - aNode.values.indexOf("value") != -1 && - ((Type) aNode.values.get(aNode.values.indexOf("value") + 1)).getClassName().equals(getTargetName()) - ) - return true; - } - - return false; + AsmAnnotation injectAnnotation = getAnnotation(InjectClass.class, inject); + return injectAnnotation != null && + ((Type)injectAnnotation.getEntry("value")).getClassName().equals(getTargetName()); } public byte[] toByteArray() { @@ -263,26 +310,8 @@ public class Merger { public byte[] toByteArray(int writerFlags) { ClassWriter writer = new ClassWriter(writerFlags); - - // Adapt nodes as necessary - //List originalMethods = targetNode.methods; - //targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList()); - - //List originalFields = targetNode.fields; - //targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList()); - - - // Accept writer targetNode.accept(writer); - // Restore originals - //targetNode.methods = originalMethods; - //targetNode.fields = originalFields; - - // Inject methods and fields - //injectMethods.forEach(node -> node.accept(writer)); - //injectFields.forEach(node -> node.accept(writer)); - return writer.toByteArray(); } @@ -311,25 +340,6 @@ public class Merger { return null; } - - /* - protected boolean isNotInjected(MethodNode node) { - for (MethodNode mNode : injectMethods) - if (methodNodeEquals(node, mNode)) - return false; - - return true; - } - - protected boolean isNotInjected(FieldNode node) { - for (FieldNode mNode : injectFields) - if (fieldNodeEquals(node, mNode)) - return false; - - return true; - } - */ - // To be used instead of referencing object constructs public static Object field(String name) { throw new RuntimeException("Field not injected"); @@ -341,7 +351,7 @@ public class Merger { enum SpecialCall { - FIELD, METHOD, SUPER + FIELD, SUPER } @@ -353,9 +363,6 @@ public class Merger { case "field": return FIELD; - case "method": - return METHOD; - case "superCall": return SUPER; @@ -366,9 +373,14 @@ public class Merger { @Nullable - protected static AsmAnnotation getAnnotation(String annotationType, ClassNode cNode) { + protected static AsmAnnotation getAnnotation(Class annotationType, ClassNode cNode) { + if(cNode.visibleAnnotations == null) + return null; + + String targetAnnot = 'L' + annotationType.getTypeName().replace('.', '/') + ';'; + for (AnnotationNode aNode : cNode.visibleAnnotations) - if (aNode.desc.equals(annotationType)) { + if (aNode.desc.equals(targetAnnot)) { HashMap map = new HashMap<>(); // Collect annotation values @@ -376,16 +388,21 @@ public class Merger { for (int i = 1; i < aNode.values.size(); i+=2) map.put((String)aNode.values.get(i - 1), aNode.values.get(i)); - return new AsmAnnotation(annotationType, map); + return new AsmAnnotation<>(annotationType, map); } return null; } @Nullable - protected static AsmAnnotation getAnnotation(String annotationType, MethodNode cNode) { + protected static AsmAnnotation getAnnotation(Class annotationType, MethodNode cNode) { + if(cNode.visibleAnnotations == null) + return null; + + String targetAnnot = 'L' + annotationType.getTypeName().replace('.', '/') + ';'; + for (AnnotationNode aNode : cNode.visibleAnnotations) - if (aNode.desc.equals(annotationType)) { + if (aNode.desc.equals(targetAnnot)) { HashMap map = new HashMap<>(); // Collect annotation values @@ -419,24 +436,142 @@ public class Merger { map.put(key, toPut); } - return new AsmAnnotation(annotationType, map); + return new AsmAnnotation<>(annotationType, map); } return null; } protected static boolean methodNodeEquals(MethodNode a, MethodNode b) { - return a.name.equals(b.name) && Objects.equals(a.desc, b.desc); + return getSignature(a).equals(getSignature(b)); } + protected static MethodSig getSignature(MethodNode node) { + AsmAnnotation annotation = getAnnotation(Inject.class, node); + + // Attempt to parse a declared signature + if (annotation != null && annotation.hasEntry("target")) { + MethodSig sig = parseMethodSignature(annotation.getEntry("target")); + + if (sig != null) + return sig; + } + + // Parse implicit signature + return Objects.requireNonNull(parseMethodSignature(node.name+node.desc)); + } + + @Nullable + protected static MethodSig parseMethodSignature(String sig) { + Matcher signatureMatcher = re_methodSignature.matcher(sig); + + if (sig.length() > 0 && signatureMatcher.matches()) { + String name = signatureMatcher.group(1); + String ret = signatureMatcher.group(3); + + Matcher argMatcher = re_types.matcher(signatureMatcher.group(2)); + ArrayList args = new ArrayList<>(); + while (argMatcher.find()) + args.add(argMatcher.group(1)); + + return new MethodSig(name, ret, args.toArray(new String[args.size()])); + } + + return null; + } + + /** + * Data class for storing method signature and name + */ + protected static final class MethodSig { + public final String name; + public final String ret; + public final String[] args; + public final String args_literal; + + public MethodSig(String name, String ret, String[] args) { + this.name = name; + this.ret = ret; + this.args = args; + + StringBuilder builder = new StringBuilder(); + for (String s : args) + builder.append(s); + + args_literal = builder.toString(); + } + + @Override + public String toString() { + return name+'('+args_literal+')'+ret; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof MethodSig && toString().equals(obj.toString()); + } + } + + + + protected static Object resolveFrameType(String typeString) { + Type sigType = Type.getType(typeString); + switch (sigType.getSort()) { + case 1: + case 2: + case 3: + case 4: + case 5: + return Opcodes.INTEGER; + case 6: + return Opcodes.FLOAT; + case 7: + return Opcodes.LONG; + case 8: + return Opcodes.DOUBLE; + case 9: + return sigType.getDescriptor(); + default: + return sigType.getInternalName(); + } + } + + protected static int resolveStoreInstr(String typeString) { + switch (typeString) { + case "Z": + case "I": + case "B": + case "C": + case "S": + return Opcodes.ISTORE; + case "J": + return Opcodes.LSTORE; + case "F": + return Opcodes.FSTORE; + case "D": + return Opcodes.DSTORE; + + // Void has no store type + case "V": + return -1; + + default: + return Opcodes.ASTORE; + } + } + + protected static boolean fieldNodeEquals(FieldNode a, FieldNode b) { return a.name.equals(b.name) && Objects.equals(a.signature, b.signature); } protected static boolean shouldInject(MethodNode node) { if (node.visibleAnnotations == null) return false; + + String targetDesc = 'L' + Inject.class.getTypeName().replace('.', '/') + ';'; + for (AnnotationNode aNode : node.visibleAnnotations) - if (aNode.desc.equals("Ldev/w1zzrd/asm/Inject;")) + if (aNode.desc.equals(targetDesc)) return true; return false; @@ -444,14 +579,17 @@ public class Merger { protected static boolean shouldInject(FieldNode node) { if (node.visibleAnnotations == null) return false; + + String targetDesc = 'L' + Inject.class.getTypeName().replace('.', '/') + ';'; + for (AnnotationNode aNode : node.visibleAnnotations) - if (aNode.desc.equals("Ldev/w1zzrd/asm/Inject;")) + if (aNode.desc.equals(targetDesc)) return true; return false; } - protected static boolean removeReturn(List instr, LabelNode jumpReplace) { + protected static boolean removeReturn(List instr, LabelNode jumpReplace, boolean popReturn, LocalVariableNode storeNode) { ListIterator iter = instr.listIterator(); JumpInsnNode finalJump = null; int keepLabel = 0; @@ -460,12 +598,16 @@ public class Merger { if (node instanceof InsnNode && node.getOpcode() >= Opcodes.IRETURN && node.getOpcode() <= Opcodes.RETURN) { iter.remove(); - // Make sure to properly pop values from the stack - // TODO: Optimize LDC's and field load calls here - if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN) - iter.add(new InsnNode(Opcodes.POP2)); - else if (node.getOpcode() != Opcodes.RETURN) - iter.add(new InsnNode(Opcodes.POP)); + // If we're not keeping the return value and the return + // value is gotten from a method call, just pop the result of the call + if(popReturn && !removeRedundantLoad(iter)) { + if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN) + iter.add(new InsnNode(Opcodes.POP2)); + else if (node.getOpcode() != Opcodes.RETURN) + iter.add(new InsnNode(Opcodes.POP)); + } else { + iter.add(new VarInsnNode(resolveStoreInstr(storeNode.desc), storeNode.index)); + } iter.add(finalJump = new JumpInsnNode(Opcodes.GOTO, jumpReplace)); ++keepLabel; @@ -478,6 +620,34 @@ public class Merger { return keepLabel <= 1; } + protected static boolean removeRedundantLoad(ListIterator iter) { + boolean hasEffects = false; + int iterCount = 0; + while (iter.hasPrevious()) { + AbstractInsnNode node = iter.previous(); + ++iterCount; + + if (node instanceof MethodInsnNode) { + hasEffects = true; + break; + } + + if ((node instanceof FieldInsnNode && node.getOpcode() == Opcodes.GETSTATIC) || + (node instanceof InsnNode && (node.getOpcode() == Opcodes.LDC || + (node.getOpcode() >= Opcodes.ILOAD && node.getOpcode() <= Opcodes.ALOAD) || + (node.getOpcode() >= Opcodes.IALOAD && node.getOpcode() <= Opcodes.SALOAD)))) + break; + } + + for(int i = 0; i < iterCount; ++i) { + iter.next(); + if (!hasEffects) + iter.remove(); + } + + return hasEffects; + } + public static ClassNode getClassNode(URL url) throws IOException { return readClass(getClassBytes(url)); }