diff --git a/src/dev/w1zzrd/asm/InPlaceInjection.java b/src/dev/w1zzrd/asm/InPlaceInjection.java new file mode 100644 index 0000000..87d76a8 --- /dev/null +++ b/src/dev/w1zzrd/asm/InPlaceInjection.java @@ -0,0 +1,5 @@ +package dev.w1zzrd.asm; + +public enum InPlaceInjection { + BEFORE, AFTER, REPLACE +} diff --git a/src/dev/w1zzrd/asm/Inject.java b/src/dev/w1zzrd/asm/Inject.java index bf440bc..73ad5c2 100644 --- a/src/dev/w1zzrd/asm/Inject.java +++ b/src/dev/w1zzrd/asm/Inject.java @@ -5,6 +5,10 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import static dev.w1zzrd.asm.InPlaceInjection.REPLACE; + @Retention(RetentionPolicy.RUNTIME) @Target({ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.FIELD}) -public @interface Inject { } +public @interface Inject { + InPlaceInjection value() default REPLACE; +} diff --git a/src/dev/w1zzrd/asm/Merger.java b/src/dev/w1zzrd/asm/Merger.java index 2c5d4b3..0ad2fe9 100644 --- a/src/dev/w1zzrd/asm/Merger.java +++ b/src/dev/w1zzrd/asm/Merger.java @@ -11,11 +11,16 @@ import java.net.URL; import java.util.*; import java.util.stream.Collectors; +import static dev.w1zzrd.asm.Merger.SpecialCall.FIELD; +import static dev.w1zzrd.asm.Merger.SpecialCall.METHOD; +import static dev.w1zzrd.asm.Merger.SpecialCall.SUPER; +import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS; + public class Merger { protected final ClassNode targetNode; - protected final List injectMethods = new ArrayList<>(); - protected final List injectFields = new ArrayList<>(); + //protected final List injectMethods = new ArrayList<>(); + //protected final List injectFields = new ArrayList<>(); public Merger(String targetClass) throws IOException { @@ -39,13 +44,38 @@ public class Merger { return targetNode.name; } + public String getTargetSuperName() { return targetNode.superName; } + public void inject(MethodNode inject, String injectOwner) { transformInjection(inject, injectOwner); - injectMethods.add(inject); + + targetNode + .methods + .stream() + .filter(it -> methodNodeEquals(it, inject)) + .findFirst() + .ifPresent(targetNode.methods::remove); + + targetNode.methods.add(inject); } public void inject(FieldNode inject) { - injectFields.add(inject); + targetNode + .fields + .stream() + .filter(it -> fieldNodeEquals(it, inject)) + .findFirst() + .ifPresent(targetNode.fields::remove); + + targetNode.fields.add(inject); + } + + public void inject(String className, ClassLoader loader) throws IOException { + inject(getClassNode(loader.getResource(className.replace('.', '/')+".class"))); + } + + public void inject(String className) throws IOException { + inject(className, ClassLoader.getSystemClassLoader()); } public void inject(ClassNode inject) { @@ -71,6 +101,10 @@ public class Merger { } } + public void inject(Class inject) throws IOException { + inject(getClassNode(inject.getResource(inject.getSimpleName()+".class"))); + } + protected String resolveField(String fieldName) { for(FieldNode fNode : targetNode.fields) if (fNode.name.equals(fieldName)) @@ -85,16 +119,36 @@ public class Merger { for (int i = 0; i < inject.instructions.size(); ++i) { AbstractInsnNode node = inject.instructions.get(i); if (!(node instanceof LineNumberNode)) { - if (node instanceof MethodInsnNode && ((MethodInsnNode) node).owner.equals("dev/w1zzrd/asm/Merger") && ((MethodInsnNode) node).name.equals("field")) { - // field access - AbstractInsnNode loadNode = instr.get(instr.size() - 1); - if(loadNode instanceof LdcInsnNode) { - instr.remove(instr.size() - 1); + SpecialCall call = node instanceof MethodInsnNode ? getSpecialCall((MethodInsnNode) node) : null; + if (call != null) { + switch (call) { + case FIELD: { + // field access + AbstractInsnNode loadNode = instr.remove(instr.size() - 1); - String constant = (String) ((LdcInsnNode) loadNode).cst; + String constant = (String) ((LdcInsnNode) loadNode).cst; - instr.add(new VarInsnNode(Opcodes.ALOAD, 0)); - instr.add(new FieldInsnNode(Opcodes.GETFIELD, getTargetName(), constant, resolveField(constant))); + instr.add(new VarInsnNode(Opcodes.ALOAD, 0)); + instr.add(new FieldInsnNode(Opcodes.GETFIELD, getTargetName(), constant, resolveField(constant))); + break; + } + + case SUPER: { + // super call + AbstractInsnNode loadNode = instr.remove(instr.size() - 1); + + do { + node = inject.instructions.get(++i); + if (!(node instanceof MethodInsnNode && ((MethodInsnNode) node).name.equals(((LdcInsnNode)loadNode).cst) && ((MethodInsnNode) node).owner.equals(getTargetName()))) + instr.add(node); + else break; + } while(true); + + ((MethodInsnNode) node).owner = getTargetSuperName(); + instr.add(node); + + break; + } } } else { // Attempt to fix injector ownership @@ -107,16 +161,86 @@ public class Merger { } } + if (node instanceof FrameNode) { + if (((FrameNode) node).local != null) + ((FrameNode) node).local = ((FrameNode) node).local.stream().map(it -> Objects.equals(it, injectOwner) ? getTargetName() : it).collect(Collectors.toList()); + + if (((FrameNode) node).stack != null) + ((FrameNode) node).stack = ((FrameNode) node).stack.stream().map(it -> Objects.equals(it, injectOwner) ? getTargetName() : it).collect(Collectors.toList()); + } + instr.add(node); } } } + + AsmAnnotation annotation = getAnnotation("Ldev/w1zzrd/asm/Inject;", inject); + if (annotation != null && annotation.hasEntry("value")) { + + InPlaceInjection injection = Objects.requireNonNull(annotation.getEntry("value", InPlaceInjection.class)); + Optional adapt = targetNode.methods.stream().filter(it -> methodNodeEquals(it, inject)).findFirst(); + + if (injection != InPlaceInjection.REPLACE && adapt.isPresent()) { + ArrayList toAdapt = new ArrayList<>(); + adapt.get().instructions.iterator().forEachRemaining(toAdapt::add); + + switch (injection) { + case BEFORE: { + LabelNode next; + boolean created = false; + if (toAdapt.size() > 0 && toAdapt.get(0) instanceof LabelNode) + next = (LabelNode)toAdapt.get(0); + else { + next = new LabelNode(); + toAdapt.add(0, next); + created = true; + } + + // If no goto instructions were added, just remove the added label + if (removeReturn(instr, next) && created) + toAdapt.remove(next); + else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame + toAdapt.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null)); + + instr.addAll(toAdapt); + break; + } + + case AFTER: { + LabelNode next; + boolean created = false; + if (toAdapt.size() > 0 && instr.get(0) instanceof LabelNode) + next = (LabelNode)instr.get(0); + else { + next = new LabelNode(); + instr.add(0, next); + created = true; + } + + // If no goto instructions were added, just remove the added label + if (removeReturn(toAdapt, next) && created) + instr.remove(next); + else // A goto call was added. Make sure we inform the JVM of stack and locals with a frame + instr.add(1, new FrameNode(Opcodes.F_SAME, -1, null, -1, null)); + + instr.addAll(0, toAdapt); + break; + } + } + } + } + InsnList collect = new InsnList(); for(AbstractInsnNode node : instr) collect.add(node); inject.instructions = collect; + + inject.localVariables.forEach(var -> { + if (var.desc.equals("L"+injectOwner+";")) + var.desc = "L"+getTargetName()+";"; + }); } public boolean shouldInject(ClassNode inject) { @@ -134,26 +258,30 @@ public class Merger { } public byte[] toByteArray() { - ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS); + return toByteArray(COMPUTE_MAXS); + } + + public byte[] toByteArray(int writerFlags) { + ClassWriter writer = new ClassWriter(writerFlags); // Adapt nodes as necessary - List originalMethods = targetNode.methods; - targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList()); + //List originalMethods = targetNode.methods; + //targetNode.methods = targetNode.methods.stream().filter(this::isNotInjected).collect(Collectors.toList()); - List originalFields = targetNode.fields; - targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList()); + //List originalFields = targetNode.fields; + //targetNode.fields = targetNode.fields.stream().filter(this::isNotInjected).collect(Collectors.toList()); // Accept writer targetNode.accept(writer); // Restore originals - targetNode.methods = originalMethods; - targetNode.fields = originalFields; + //targetNode.methods = originalMethods; + //targetNode.fields = originalFields; // Inject methods and fields - injectMethods.forEach(node -> node.accept(writer)); - injectFields.forEach(node -> node.accept(writer)); + //injectMethods.forEach(node -> node.accept(writer)); + //injectFields.forEach(node -> node.accept(writer)); return writer.toByteArray(); } @@ -184,6 +312,7 @@ public class Merger { } + /* protected boolean isNotInjected(MethodNode node) { for (MethodNode mNode : injectMethods) if (methodNodeEquals(node, mNode)) @@ -199,21 +328,42 @@ public class Merger { return true; } - + */ // To be used instead of referencing object constructs public static Object field(String name) { throw new RuntimeException("Field not injected"); } - // TODO: Implement - public static Object method(String name, Object... args) { - throw new RuntimeException("Method not injected"); + public static void superCall(String superMethodName){ + throw new RuntimeException("Super call not injected"); } + enum SpecialCall { + FIELD, METHOD, SUPER + } + @Nullable + protected static SpecialCall getSpecialCall(MethodInsnNode node) { + if (!node.owner.equals("dev/w1zzrd/asm/Merger")) return null; + + switch (node.name) { + case "field": + return FIELD; + + case "method": + return METHOD; + + case "superCall": + return SUPER; + + default: + return null; + } + } + @Nullable protected static AsmAnnotation getAnnotation(String annotationType, ClassNode cNode) { @@ -232,6 +382,49 @@ public class Merger { return null; } + @Nullable + protected static AsmAnnotation getAnnotation(String annotationType, MethodNode cNode) { + for (AnnotationNode aNode : cNode.visibleAnnotations) + if (aNode.desc.equals(annotationType)) { + HashMap map = new HashMap<>(); + + // Collect annotation values + if (aNode.values != null) + NODE_LOOP: + for (int i = 1; i < aNode.values.size(); i+=2) { + String key = (String) aNode.values.get(i - 1); + Object toPut = aNode.values.get(i); + + if (toPut instanceof String[] && ((String[]) toPut).length == 2) { + String enumType = ((String[])toPut)[0]; + String enumName = ((String[])toPut)[1]; + if (enumType.startsWith("L") && enumType.endsWith(";")) + try{ + Class type = Class.forName(enumType.substring(1, enumType.length()-1).replace('/', '.')); + Method m = Enum.class.getDeclaredMethod("name"); + Object[] values = (Object[]) type.getDeclaredMethod("values").invoke(null); + + for (Object value : values) + if (m.invoke(value).equals(enumName)) { + map.put(key, value); + continue NODE_LOOP; + } + + } catch (Throwable e) { + /* Just ignore */ + } + } + + // Default insertion policy + map.put(key, toPut); + } + + return new AsmAnnotation(annotationType, map); + } + + return null; + } + protected static boolean methodNodeEquals(MethodNode a, MethodNode b) { return a.name.equals(b.name) && Objects.equals(a.desc, b.desc); } @@ -258,6 +451,33 @@ public class Merger { return false; } + protected static boolean removeReturn(List instr, LabelNode jumpReplace) { + ListIterator iter = instr.listIterator(); + JumpInsnNode finalJump = null; + int keepLabel = 0; + while (iter.hasNext()) { + AbstractInsnNode node = iter.next(); + if (node instanceof InsnNode && node.getOpcode() >= Opcodes.IRETURN && node.getOpcode() <= Opcodes.RETURN) { + iter.remove(); + + // Make sure to properly pop values from the stack + // TODO: Optimize LDC's and field load calls here + if (node.getOpcode() == Opcodes.LRETURN || node.getOpcode() == Opcodes.DRETURN) + iter.add(new InsnNode(Opcodes.POP2)); + else if (node.getOpcode() != Opcodes.RETURN) + iter.add(new InsnNode(Opcodes.POP)); + + iter.add(finalJump = new JumpInsnNode(Opcodes.GOTO, jumpReplace)); + ++keepLabel; + } + } + + if (finalJump != null) // This *should* always be true + instr.remove(finalJump); + + return keepLabel <= 1; + } + public static ClassNode getClassNode(URL url) throws IOException { return readClass(getClassBytes(url)); }