Implement rudimentary method weaving

This commit is contained in:
Gabriel Tofvesson 2020-04-19 16:13:04 +02:00
parent 4f694565c3
commit c2aa3202ba
3 changed files with 263 additions and 91 deletions

View File

@ -1,25 +1,25 @@
package dev.w1zzrd.asm;
import com.sun.istack.internal.Nullable;
import java.lang.annotation.Annotation;
import java.util.Map;
public final class AsmAnnotation {
private final String annotationType;
public final class AsmAnnotation<A extends Annotation> {
private final Class<A> annotationType;
private final Map<String, Object> entries;
public AsmAnnotation(String annotationType, Map<String, Object> entries) {
public AsmAnnotation(Class<A> annotationType, Map<String, Object> entries) {
this.annotationType = annotationType;
this.entries = entries;
}
public String getAnnotationType() {
public Class<A> getAnnotationType() {
return annotationType;
}
@Nullable
public <T> T getEntry(String name, Class<T> type) {
return hasEntry(name) ? (T)entries.get(name) : null;
public <T> T getEntry(String name) {
if (!hasEntry(name))
throw new IllegalArgumentException(String.format("No entry \"%s\" in asm annotation!", name));
return (T)entries.get(name);
}
public boolean hasEntry(String name) {

View File

@ -11,4 +11,6 @@ import static dev.w1zzrd.asm.InPlaceInjection.REPLACE;
@Target({ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.FIELD})
public @interface Inject {
InPlaceInjection value() default REPLACE;
String target() default "";
boolean acceptOriginalReturn() default false;
}

View File

@ -1,26 +1,31 @@
package dev.w1zzrd.asm;
import com.sun.istack.internal.Nullable;
import com.sun.org.apache.bcel.internal.generic.GETSTATIC;
import jdk.internal.org.objectweb.asm.*;
import jdk.internal.org.objectweb.asm.tree.*;
import java.io.*;
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static dev.w1zzrd.asm.Merger.SpecialCall.FIELD;
import static dev.w1zzrd.asm.Merger.SpecialCall.METHOD;
import static dev.w1zzrd.asm.Merger.SpecialCall.SUPER;
import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS;
public class Merger {
private static final Pattern re_methodSignature = Pattern.compile("((?:[a-zA-Z_$][a-zA-Z\\d_$]+)|(?:<init>))\\(((?:(?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D)*)\\)((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D|V)");
private static final Pattern re_types = Pattern.compile("((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D)");
private static final Pattern re_retTypes = Pattern.compile("((?:\\[*L(?:[a-zA-Z_$][a-zA-Z\\d_$]*/)*[a-zA-Z_$][a-zA-Z\\d_$]*;)|Z|B|C|S|I|J|F|D|V)");
protected final ClassNode targetNode;
//protected final List<MethodNode> injectMethods = new ArrayList<>();
//protected final List<FieldNode> injectFields = new ArrayList<>();
protected final ArrayList<MethodSig> extended = new ArrayList<>();
public Merger(String targetClass) throws IOException {
@ -84,20 +89,21 @@ public class Merger {
if (inject.visibleAnnotations != null && inject.interfaces != null) {
AsmAnnotation annot = getAnnotation("Ldev/w1zzrd/asm/InjectClass;", inject);
AsmAnnotation<InjectClass> injectAnnotation = getAnnotation(InjectClass.class, inject);
// If there is not injectMethods annotation or there is an
// explicit request to not injectMethods interfaces, just return
if (annot == null || (annot.hasEntry("injectInterfaces") && !annot.getEntry("injectInterfaces", Boolean.class)))
if (injectAnnotation == null ||
(injectAnnotation.hasEntry("injectInterfaces") &&
!(Boolean)injectAnnotation.getEntry("injectInterfaces")))
return;
if (targetNode.interfaces == null)
targetNode.interfaces = new ArrayList<>();
for (String iface : inject.interfaces)
if (!targetNode.interfaces.contains(iface))
targetNode.interfaces.add(iface);
inject.interfaces.stream().filter(it -> !targetNode.interfaces.contains(it)).forEach(targetNode.interfaces::add);
}
}
@ -175,10 +181,11 @@ public class Merger {
}
AsmAnnotation annotation = getAnnotation("Ldev/w1zzrd/asm/Inject;", inject);
MethodSig signature = getSignature(inject);
AsmAnnotation<Inject> annotation = getAnnotation(Inject.class, inject);
if (annotation != null && annotation.hasEntry("value")) {
InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value", InPlaceInjection.class));
InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value"));
Optional<MethodNode> adapt = targetNode.methods.stream().filter(it -> methodNodeEquals(it, inject)).findFirst();
if (injection != InPlaceInjection.REPLACE && adapt.isPresent()) {
@ -198,7 +205,7 @@ public class Merger {
}
// If no goto instructions were added, just remove the added label
if (removeReturn(instr, next) && created)
if (removeReturn(instr, next, true, null) && created)
toAdapt.remove(next);
else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame
toAdapt.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null));
@ -218,13 +225,59 @@ public class Merger {
created = true;
}
boolean isStaticMethod = (inject.access & Opcodes.ACC_STATIC) != 0;
boolean keepReturn = annotation.hasEntry("acceptOriginalReturn") && (Boolean)annotation.getEntry("acceptOriginalReturn");
boolean hasReturn = !signature.ret.equals("V");
LocalVariableNode retNode = keepReturn && hasReturn ? inject.localVariables.get((isStaticMethod ? 0 : 1) + signature.args.length) : null;
// If no goto instructions were added, just remove the added label
if (removeReturn(toAdapt, next) && created)
instr.remove(next);
boolean noGoto = removeReturn(toAdapt, next, !keepReturn && hasReturn, retNode);
if(noGoto) {
if(created)
instr.remove(next);
}
else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame
instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null));
/*
if(keepReturn && hasReturn) {
if (!extended.contains(signature)) {
Object[] locals = new Object[inject.localVariables.size()];
locals[0] = getTargetName();//resolveFrameType(inject.localVariables.get(0).desc);
locals[1] = Opcodes.TOP;
locals[2] = resolveFrameType(inject.localVariables.get(2).desc);
//for (int i = 0; i < inject.localVariables.size(); ++i)
// locals[i] = resolveFrameType(inject.localVariables.get(i).desc);
//instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null));
instr.add(1, new FrameNode(Opcodes.F_FULL, 3, locals, 1, new Object[]{ resolveFrameType(inject.localVariables.get(inject.localVariables.size() - 1).desc) }));
instr.add(2, new VarInsnNode(resolveStoreInstr(signature.ret), locals.length - 1));
extended.add(signature);
}
}
*/
instr.addAll(0, toAdapt);
if (keepReturn && hasReturn) {
// A little bit overkill, but I'm lazy
LabelNode first;
if (instr.get(0) instanceof LabelNode)
first = (LabelNode) instr.get(0);
else {
first = new LabelNode();
instr.add(0, first);
}
// Make the scope of received retVal span the entire method
retNode.start = first;
}
break;
}
}
@ -241,20 +294,14 @@ public class Merger {
if (var.desc.equals("L"+injectOwner+";"))
var.desc = "L"+getTargetName()+";";
});
inject.desc = '(' + signature.args_literal + ')' + signature.ret;
}
public boolean shouldInject(ClassNode inject) {
if (inject.visibleAnnotations != null) {
for (AnnotationNode aNode : inject.visibleAnnotations)
if (
aNode.desc.equals("Ldev/w1zzrd/asm/InjectClass;") &&
aNode.values.indexOf("value") != -1 &&
((Type) aNode.values.get(aNode.values.indexOf("value") + 1)).getClassName().equals(getTargetName())
)
return true;
}
return false;
AsmAnnotation<InjectClass> injectAnnotation = getAnnotation(InjectClass.class, inject);
return injectAnnotation != null &&
((Type)injectAnnotation.getEntry("value")).getClassName().equals(getTargetName());
}
public byte[] toByteArray() {
@ -263,26 +310,8 @@ public class Merger {
public byte[] toByteArray(int writerFlags) {
ClassWriter writer = new ClassWriter(writerFlags);
// Adapt nodes as necessary
//List<MethodNode> originalMethods = targetNode.methods;
//targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList());
//List<FieldNode> originalFields = targetNode.fields;
//targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList());
// Accept writer
targetNode.accept(writer);
// Restore originals
//targetNode.methods = originalMethods;
//targetNode.fields = originalFields;
// Inject methods and fields
//injectMethods.forEach(node -> node.accept(writer));
//injectFields.forEach(node -> node.accept(writer));
return writer.toByteArray();
}
@ -311,25 +340,6 @@ public class Merger {
return null;
}
/*
protected boolean isNotInjected(MethodNode node) {
for (MethodNode mNode : injectMethods)
if (methodNodeEquals(node, mNode))
return false;
return true;
}
protected boolean isNotInjected(FieldNode node) {
for (FieldNode mNode : injectFields)
if (fieldNodeEquals(node, mNode))
return false;
return true;
}
*/
// To be used instead of referencing object constructs
public static Object field(String name) {
throw new RuntimeException("Field not injected");
@ -341,7 +351,7 @@ public class Merger {
enum SpecialCall {
FIELD, METHOD, SUPER
FIELD, SUPER
}
@ -353,9 +363,6 @@ public class Merger {
case "field":
return FIELD;
case "method":
return METHOD;
case "superCall":
return SUPER;
@ -366,9 +373,14 @@ public class Merger {
@Nullable
protected static AsmAnnotation getAnnotation(String annotationType, ClassNode cNode) {
protected static <T extends Annotation> AsmAnnotation<T> getAnnotation(Class<T> annotationType, ClassNode cNode) {
if(cNode.visibleAnnotations == null)
return null;
String targetAnnot = 'L' + annotationType.getTypeName().replace('.', '/') + ';';
for (AnnotationNode aNode : cNode.visibleAnnotations)
if (aNode.desc.equals(annotationType)) {
if (aNode.desc.equals(targetAnnot)) {
HashMap<String, Object> map = new HashMap<>();
// Collect annotation values
@ -376,16 +388,21 @@ public class Merger {
for (int i = 1; i < aNode.values.size(); i+=2)
map.put((String)aNode.values.get(i - 1), aNode.values.get(i));
return new AsmAnnotation(annotationType, map);
return new AsmAnnotation<>(annotationType, map);
}
return null;
}
@Nullable
protected static AsmAnnotation getAnnotation(String annotationType, MethodNode cNode) {
protected static <T extends Annotation> AsmAnnotation<T> getAnnotation(Class<T> annotationType, MethodNode cNode) {
if(cNode.visibleAnnotations == null)
return null;
String targetAnnot = 'L' + annotationType.getTypeName().replace('.', '/') + ';';
for (AnnotationNode aNode : cNode.visibleAnnotations)
if (aNode.desc.equals(annotationType)) {
if (aNode.desc.equals(targetAnnot)) {
HashMap<String, Object> map = new HashMap<>();
// Collect annotation values
@ -419,24 +436,142 @@ public class Merger {
map.put(key, toPut);
}
return new AsmAnnotation(annotationType, map);
return new AsmAnnotation<>(annotationType, map);
}
return null;
}
protected static boolean methodNodeEquals(MethodNode a, MethodNode b) {
return a.name.equals(b.name) && Objects.equals(a.desc, b.desc);
return getSignature(a).equals(getSignature(b));
}
protected static MethodSig getSignature(MethodNode node) {
AsmAnnotation<Inject> annotation = getAnnotation(Inject.class, node);
// Attempt to parse a declared signature
if (annotation != null && annotation.hasEntry("target")) {
MethodSig sig = parseMethodSignature(annotation.getEntry("target"));
if (sig != null)
return sig;
}
// Parse implicit signature
return Objects.requireNonNull(parseMethodSignature(node.name+node.desc));
}
@Nullable
protected static MethodSig parseMethodSignature(String sig) {
Matcher signatureMatcher = re_methodSignature.matcher(sig);
if (sig.length() > 0 && signatureMatcher.matches()) {
String name = signatureMatcher.group(1);
String ret = signatureMatcher.group(3);
Matcher argMatcher = re_types.matcher(signatureMatcher.group(2));
ArrayList<String> args = new ArrayList<>();
while (argMatcher.find())
args.add(argMatcher.group(1));
return new MethodSig(name, ret, args.toArray(new String[args.size()]));
}
return null;
}
/**
* Data class for storing method signature and name
*/
protected static final class MethodSig {
public final String name;
public final String ret;
public final String[] args;
public final String args_literal;
public MethodSig(String name, String ret, String[] args) {
this.name = name;
this.ret = ret;
this.args = args;
StringBuilder builder = new StringBuilder();
for (String s : args)
builder.append(s);
args_literal = builder.toString();
}
@Override
public String toString() {
return name+'('+args_literal+')'+ret;
}
@Override
public boolean equals(Object obj) {
return obj instanceof MethodSig && toString().equals(obj.toString());
}
}
protected static Object resolveFrameType(String typeString) {
Type sigType = Type.getType(typeString);
switch (sigType.getSort()) {
case 1:
case 2:
case 3:
case 4:
case 5:
return Opcodes.INTEGER;
case 6:
return Opcodes.FLOAT;
case 7:
return Opcodes.LONG;
case 8:
return Opcodes.DOUBLE;
case 9:
return sigType.getDescriptor();
default:
return sigType.getInternalName();
}
}
protected static int resolveStoreInstr(String typeString) {
switch (typeString) {
case "Z":
case "I":
case "B":
case "C":
case "S":
return Opcodes.ISTORE;
case "J":
return Opcodes.LSTORE;
case "F":
return Opcodes.FSTORE;
case "D":
return Opcodes.DSTORE;
// Void has no store type
case "V":
return -1;
default:
return Opcodes.ASTORE;
}
}
protected static boolean fieldNodeEquals(FieldNode a, FieldNode b) {
return a.name.equals(b.name) && Objects.equals(a.signature, b.signature);
}
protected static boolean shouldInject(MethodNode node) {
if (node.visibleAnnotations == null) return false;
String targetDesc = 'L' + Inject.class.getTypeName().replace('.', '/') + ';';
for (AnnotationNode aNode : node.visibleAnnotations)
if (aNode.desc.equals("Ldev/w1zzrd/asm/Inject;"))
if (aNode.desc.equals(targetDesc))
return true;
return false;
@ -444,14 +579,17 @@ public class Merger {
protected static boolean shouldInject(FieldNode node) {
if (node.visibleAnnotations == null) return false;
String targetDesc = 'L' + Inject.class.getTypeName().replace('.', '/') + ';';
for (AnnotationNode aNode : node.visibleAnnotations)
if (aNode.desc.equals("Ldev/w1zzrd/asm/Inject;"))
if (aNode.desc.equals(targetDesc))
return true;
return false;
}
protected static boolean removeReturn(List<AbstractInsnNode> instr, LabelNode jumpReplace) {
protected static boolean removeReturn(List<AbstractInsnNode> instr, LabelNode jumpReplace, boolean popReturn, LocalVariableNode storeNode) {
ListIterator<AbstractInsnNode> iter = instr.listIterator();
JumpInsnNode finalJump = null;
int keepLabel = 0;
@ -460,12 +598,16 @@ public class Merger {
if (node instanceof InsnNode && node.getOpcode() >= Opcodes.IRETURN && node.getOpcode() <= Opcodes.RETURN) {
iter.remove();
// Make sure to properly pop values from the stack
// TODO: Optimize LDC's and field load calls here
if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN)
iter.add(new InsnNode(Opcodes.POP2));
else if (node.getOpcode() != Opcodes.RETURN)
iter.add(new InsnNode(Opcodes.POP));
// If we're not keeping the return value and the return
// value is gotten from a method call, just pop the result of the call
if(popReturn && !removeRedundantLoad(iter)) {
if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN)
iter.add(new InsnNode(Opcodes.POP2));
else if (node.getOpcode() != Opcodes.RETURN)
iter.add(new InsnNode(Opcodes.POP));
} else {
iter.add(new VarInsnNode(resolveStoreInstr(storeNode.desc), storeNode.index));
}
iter.add(finalJump = new JumpInsnNode(Opcodes.GOTO, jumpReplace));
++keepLabel;
@ -478,6 +620,34 @@ public class Merger {
return keepLabel <= 1;
}
protected static boolean removeRedundantLoad(ListIterator<AbstractInsnNode> iter) {
boolean hasEffects = false;
int iterCount = 0;
while (iter.hasPrevious()) {
AbstractInsnNode node = iter.previous();
++iterCount;
if (node instanceof MethodInsnNode) {
hasEffects = true;
break;
}
if ((node instanceof FieldInsnNode && node.getOpcode() == Opcodes.GETSTATIC) ||
(node instanceof InsnNode && (node.getOpcode() == Opcodes.LDC ||
(node.getOpcode() >= Opcodes.ILOAD && node.getOpcode() <= Opcodes.ALOAD) ||
(node.getOpcode() >= Opcodes.IALOAD && node.getOpcode() <= Opcodes.SALOAD))))
break;
}
for(int i = 0; i < iterCount; ++i) {
iter.next();
if (!hasEffects)
iter.remove();
}
return hasEffects;
}
public static ClassNode getClassNode(URL url) throws IOException {
return readClass(getClassBytes(url));
}