Implement basic method instruction merging

This commit is contained in:
Gabriel Tofvesson 2020-04-18 21:24:24 +02:00
parent 2745f391e9
commit a53d94e407
3 changed files with 255 additions and 26 deletions

View File

@ -0,0 +1,5 @@
package dev.w1zzrd.asm;
public enum InPlaceInjection {
BEFORE, AFTER, REPLACE
}

View File

@ -5,6 +5,10 @@ import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import static dev.w1zzrd.asm.InPlaceInjection.REPLACE;
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.FIELD})
public @interface Inject { }
public @interface Inject {
InPlaceInjection value() default REPLACE;
}

View File

@ -11,11 +11,16 @@ import java.net.URL;
import java.util.*;
import java.util.stream.Collectors;
import static dev.w1zzrd.asm.Merger.SpecialCall.FIELD;
import static dev.w1zzrd.asm.Merger.SpecialCall.METHOD;
import static dev.w1zzrd.asm.Merger.SpecialCall.SUPER;
import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS;
public class Merger {
protected final ClassNode targetNode;
protected final List<MethodNode> injectMethods = new ArrayList<>();
protected final List<FieldNode> injectFields = new ArrayList<>();
//protected final List<MethodNode> injectMethods = new ArrayList<>();
//protected final List<FieldNode> injectFields = new ArrayList<>();
public Merger(String targetClass) throws IOException {
@ -39,13 +44,38 @@ public class Merger {
return targetNode.name;
}
public String getTargetSuperName() { return targetNode.superName; }
public void inject(MethodNode inject, String injectOwner) {
transformInjection(inject, injectOwner);
injectMethods.add(inject);
targetNode
.methods
.stream()
.filter(it -> methodNodeEquals(it, inject))
.findFirst()
.ifPresent(targetNode.methods::remove);
targetNode.methods.add(inject);
}
public void inject(FieldNode inject) {
injectFields.add(inject);
targetNode
.fields
.stream()
.filter(it -> fieldNodeEquals(it, inject))
.findFirst()
.ifPresent(targetNode.fields::remove);
targetNode.fields.add(inject);
}
public void inject(String className, ClassLoader loader) throws IOException {
inject(getClassNode(loader.getResource(className.replace('.', '/')+".class")));
}
public void inject(String className) throws IOException {
inject(className, ClassLoader.getSystemClassLoader());
}
public void inject(ClassNode inject) {
@ -71,6 +101,10 @@ public class Merger {
}
}
public void inject(Class<?> inject) throws IOException {
inject(getClassNode(inject.getResource(inject.getSimpleName()+".class")));
}
protected String resolveField(String fieldName) {
for(FieldNode fNode : targetNode.fields)
if (fNode.name.equals(fieldName))
@ -85,16 +119,36 @@ public class Merger {
for (int i = 0; i < inject.instructions.size(); ++i) {
AbstractInsnNode node = inject.instructions.get(i);
if (!(node instanceof LineNumberNode)) {
if (node instanceof MethodInsnNode && ((MethodInsnNode) node).owner.equals("dev/w1zzrd/asm/Merger") && ((MethodInsnNode) node).name.equals("field")) {
// field access
AbstractInsnNode loadNode = instr.get(instr.size() - 1);
if(loadNode instanceof LdcInsnNode) {
instr.remove(instr.size() - 1);
SpecialCall call = node instanceof MethodInsnNode ? getSpecialCall((MethodInsnNode) node) : null;
if (call != null) {
switch (call) {
case FIELD: {
// field access
AbstractInsnNode loadNode = instr.remove(instr.size() - 1);
String constant = (String) ((LdcInsnNode) loadNode).cst;
String constant = (String) ((LdcInsnNode) loadNode).cst;
instr.add(new VarInsnNode(Opcodes.ALOAD, 0));
instr.add(new FieldInsnNode(Opcodes.GETFIELD, getTargetName(), constant, resolveField(constant)));
instr.add(new VarInsnNode(Opcodes.ALOAD, 0));
instr.add(new FieldInsnNode(Opcodes.GETFIELD, getTargetName(), constant, resolveField(constant)));
break;
}
case SUPER: {
// super call
AbstractInsnNode loadNode = instr.remove(instr.size() - 1);
do {
node = inject.instructions.get(++i);
if (!(node instanceof MethodInsnNode && ((MethodInsnNode) node).name.equals(((LdcInsnNode)loadNode).cst) && ((MethodInsnNode) node).owner.equals(getTargetName())))
instr.add(node);
else break;
} while(true);
((MethodInsnNode) node).owner = getTargetSuperName();
instr.add(node);
break;
}
}
} else {
// Attempt to fix injector ownership
@ -107,16 +161,86 @@ public class Merger {
}
}
if (node instanceof FrameNode) {
if (((FrameNode) node).local != null)
((FrameNode) node).local = ((FrameNode) node).local.stream().map(it -> Objects.equals(it, injectOwner) ? getTargetName() : it).collect(Collectors.toList());
if (((FrameNode) node).stack != null)
((FrameNode) node).stack = ((FrameNode) node).stack.stream().map(it -> Objects.equals(it, injectOwner) ? getTargetName() : it).collect(Collectors.toList());
}
instr.add(node);
}
}
}
AsmAnnotation annotation = getAnnotation("Ldev/w1zzrd/asm/Inject;", inject);
if (annotation != null && annotation.hasEntry("value")) {
InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value", InPlaceInjection.class));
Optional<MethodNode> adapt = targetNode.methods.stream().filter(it -> methodNodeEquals(it, inject)).findFirst();
if (injection != InPlaceInjection.REPLACE && adapt.isPresent()) {
ArrayList<AbstractInsnNode> toAdapt = new ArrayList<>();
adapt.get().instructions.iterator().forEachRemaining(toAdapt::add);
switch (injection) {
case BEFORE: {
LabelNode next;
boolean created = false;
if (toAdapt.size() > 0 && toAdapt.get(0) instanceof LabelNode)
next = (LabelNode)toAdapt.get(0);
else {
next = new LabelNode();
toAdapt.add(0, next);
created = true;
}
// If no goto instructions were added, just remove the added label
if (removeReturn(instr, next) && created)
toAdapt.remove(next);
else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame
toAdapt.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null));
instr.addAll(toAdapt);
break;
}
case AFTER: {
LabelNode next;
boolean created = false;
if (toAdapt.size() > 0 && instr.get(0) instanceof LabelNode)
next = (LabelNode)instr.get(0);
else {
next = new LabelNode();
instr.add(0, next);
created = true;
}
// If no goto instructions were added, just remove the added label
if (removeReturn(toAdapt, next) && created)
instr.remove(next);
else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame
instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null));
instr.addAll(0, toAdapt);
break;
}
}
}
}
InsnList collect = new InsnList();
for(AbstractInsnNode node : instr)
collect.add(node);
inject.instructions = collect;
inject.localVariables.forEach(var -> {
if (var.desc.equals("L"+injectOwner+";"))
var.desc = "L"+getTargetName()+";";
});
}
public boolean shouldInject(ClassNode inject) {
@ -134,26 +258,30 @@ public class Merger {
}
public byte[] toByteArray() {
ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
return toByteArray(COMPUTE_MAXS);
}
public byte[] toByteArray(int writerFlags) {
ClassWriter writer = new ClassWriter(writerFlags);
// Adapt nodes as necessary
List<MethodNode> originalMethods = targetNode.methods;
targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList());
//List<MethodNode> originalMethods = targetNode.methods;
//targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList());
List<FieldNode> originalFields = targetNode.fields;
targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList());
//List<FieldNode> originalFields = targetNode.fields;
//targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList());
// Accept writer
targetNode.accept(writer);
// Restore originals
targetNode.methods = originalMethods;
targetNode.fields = originalFields;
//targetNode.methods = originalMethods;
//targetNode.fields = originalFields;
// Inject methods and fields
injectMethods.forEach(node -> node.accept(writer));
injectFields.forEach(node -> node.accept(writer));
//injectMethods.forEach(node -> node.accept(writer));
//injectFields.forEach(node -> node.accept(writer));
return writer.toByteArray();
}
@ -184,6 +312,7 @@ public class Merger {
}
/*
protected boolean isNotInjected(MethodNode node) {
for (MethodNode mNode : injectMethods)
if (methodNodeEquals(node, mNode))
@ -199,21 +328,42 @@ public class Merger {
return true;
}
*/
// To be used instead of referencing object constructs
public static Object field(String name) {
throw new RuntimeException("Field not injected");
}
// TODO: Implement
public static Object method(String name, Object... args) {
throw new RuntimeException("Method not injected");
public static void superCall(String superMethodName){
throw new RuntimeException("Super call not injected");
}
enum SpecialCall {
FIELD, METHOD, SUPER
}
@Nullable
protected static SpecialCall getSpecialCall(MethodInsnNode node) {
if (!node.owner.equals("dev/w1zzrd/asm/Merger")) return null;
switch (node.name) {
case "field":
return FIELD;
case "method":
return METHOD;
case "superCall":
return SUPER;
default:
return null;
}
}
@Nullable
protected static AsmAnnotation getAnnotation(String annotationType, ClassNode cNode) {
@ -232,6 +382,49 @@ public class Merger {
return null;
}
@Nullable
protected static AsmAnnotation getAnnotation(String annotationType, MethodNode cNode) {
for (AnnotationNode aNode : cNode.visibleAnnotations)
if (aNode.desc.equals(annotationType)) {
HashMap<String, Object> map = new HashMap<>();
// Collect annotation values
if (aNode.values != null)
NODE_LOOP:
for (int i = 1; i < aNode.values.size(); i+=2) {
String key = (String) aNode.values.get(i - 1);
Object toPut = aNode.values.get(i);
if (toPut instanceof String[] && ((String[]) toPut).length == 2) {
String enumType = ((String[])toPut)[0];
String enumName = ((String[])toPut)[1];
if (enumType.startsWith("L") && enumType.endsWith(";"))
try{
Class<?> type = Class.forName(enumType.substring(1, enumType.length()-1).replace('/', '.'));
Method m = Enum.class.getDeclaredMethod("name");
Object[] values = (Object[]) type.getDeclaredMethod("values").invoke(null);
for (Object value : values)
if (m.invoke(value).equals(enumName)) {
map.put(key, value);
continue NODE_LOOP;
}
} catch (Throwable e) {
/* Just ignore */
}
}
// Default insertion policy
map.put(key, toPut);
}
return new AsmAnnotation(annotationType, map);
}
return null;
}
protected static boolean methodNodeEquals(MethodNode a, MethodNode b) {
return a.name.equals(b.name) && Objects.equals(a.desc, b.desc);
}
@ -258,6 +451,33 @@ public class Merger {
return false;
}
protected static boolean removeReturn(List<AbstractInsnNode> instr, LabelNode jumpReplace) {
ListIterator<AbstractInsnNode> iter = instr.listIterator();
JumpInsnNode finalJump = null;
int keepLabel = 0;
while (iter.hasNext()) {
AbstractInsnNode node = iter.next();
if (node instanceof InsnNode && node.getOpcode() >= Opcodes.IRETURN && node.getOpcode() <= Opcodes.RETURN) {
iter.remove();
// Make sure to properly pop values from the stack
// TODO: Optimize LDC's and field load calls here
if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN)
iter.add(new InsnNode(Opcodes.POP2));
else if (node.getOpcode() != Opcodes.RETURN)
iter.add(new InsnNode(Opcodes.POP));
iter.add(finalJump = new JumpInsnNode(Opcodes.GOTO, jumpReplace));
++keepLabel;
}
}
if (finalJump != null) // This *should* always be true
instr.remove(finalJump);
return keepLabel <= 1;
}
public static ClassNode getClassNode(URL url) throws IOException {
return readClass(getClassBytes(url));
}