From 60540d21c066d4c6dc01aab386477dc74a20231f Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Thu, 28 Jan 2021 05:32:25 +0100 Subject: [PATCH] Implement method prepending --- src/dev/w1zzrd/asm/Combine.java | 194 +++++++++++++----- src/dev/w1zzrd/asm/DirectiveTarget.java | 17 ++ src/dev/w1zzrd/asm/Directives.java | 29 +++ .../DirectiveNotImplementedException.java | 7 + 4 files changed, 197 insertions(+), 50 deletions(-) create mode 100644 src/dev/w1zzrd/asm/DirectiveTarget.java create mode 100644 src/dev/w1zzrd/asm/Directives.java create mode 100644 src/dev/w1zzrd/asm/exception/DirectiveNotImplementedException.java diff --git a/src/dev/w1zzrd/asm/Combine.java b/src/dev/w1zzrd/asm/Combine.java index a64bcee..c5b51c9 100644 --- a/src/dev/w1zzrd/asm/Combine.java +++ b/src/dev/w1zzrd/asm/Combine.java @@ -85,8 +85,6 @@ public class Combine { final int graftArgCount = xsig.getArgCount() + (isStatic(extension) ? 0 : 1); final int targetArgCount = msig.getArgCount() + (isStatic(target) ? 0 : 1); - List targetInsns; - // If graft method cares about the return value of the original method, i.e. accepts it as an extra "argument" if (acceptReturn && !msig.getRet().isVoidType()) { //noinspection OptionalGetWithoutIsPresent @@ -102,21 +100,16 @@ public class Combine { // Handle retvar specially extension.localVariables.remove(retVar); - // Convert instructions into a more modifiable format - targetInsns = decomposeToList(target.instructions); - // Make space in the original frames for the return var // This isn't an optimal solution, but it works for now - adjustFramesForRetVar(targetInsns, targetArgCount); + adjustFramesForRetVar(target.instructions, targetArgCount); // Replace return instructions with GOTOs to the last instruction in the list // Return values are stored in retVar - storeAndGotoFromReturn(target, targetInsns, retVar.index, xsig); + storeAndGotoFromReturn(target, target.instructions, retVar.index, xsig); } else { - targetInsns = decomposeToList(target.instructions); - // If we don't care about the return value from the original, we can replace returns with pops - popAndGotoFromReturn(targetInsns, xsig); + popAndGotoFromReturn(target, target.instructions, xsig); } List extVars = getVarsOver(extension.localVariables, xsig.getArgCount()); @@ -125,13 +118,7 @@ public class Combine { target.localVariables.addAll(extVars); // Add extension instructions to instruction list - targetInsns.addAll(decomposeToList(extension.instructions)); - - // Some weird TOP delimiter between "true" locals and args - //insertTopDelimiter(targetArgCount, targetInsns, target.localVariables, true); - - // Convert instructions back to a InsnList - target.instructions = coalesceInstructions(targetInsns); + target.instructions.add(extension.instructions); // Make sure we extend the scope of the original method arguments for (int i = 0; i < targetArgCount; ++i) @@ -165,6 +152,16 @@ public class Combine { ); adaptMethod(extension, source); + MethodSignature sig = new MethodSignature(extension.desc); + + target.localVariables.addAll(getVarsOver(extension.localVariables, sig.getArgCount())); + extension.instructions.add(target.instructions); + + target.instructions = extension.instructions; + + // Extend argument scope to cover prepended code + for (int i = 0; i < sig.getArgCount(); ++i) + adjustArgument(target, getVarAt(target.localVariables, i), true, false); finishGrafting(extension, source); } @@ -285,7 +282,7 @@ public class Combine { protected void adaptMethod(MethodNode node, GraftSource source) { // Adapt instructions for (AbstractInsnNode insn = node.instructions.getFirst(); insn != null; insn = insn.getNext()) { - if (insn instanceof MethodInsnNode) adaptMethodInsn((MethodInsnNode) insn, source); + if (insn instanceof MethodInsnNode) adaptMethodInsn((MethodInsnNode) insn, source, node); else if (insn instanceof LdcInsnNode) adaptLdcInsn((LdcInsnNode) insn, source.getTypeName()); else if (insn instanceof FrameNode) adaptFrameNode((FrameNode) insn, source); else if (insn instanceof FieldInsnNode) adaptFieldInsn((FieldInsnNode) insn, source); @@ -313,83 +310,102 @@ public class Combine { return label; } - private void storeAndGotoFromReturn(MethodNode source, List nodes, int storeIndex, MethodSignature sig) { + protected static LabelNode findOrMakeEndLabel(InsnList nodes) { + AbstractInsnNode last = nodes.getLast(); + + while (last instanceof FrameNode) last = last.getPrevious(); + + if (last instanceof LabelNode) + return (LabelNode) last; + + LabelNode label = new LabelNode(); + + nodes.add(label); + return label; + } + + protected static boolean hasEndJumpFrame(InsnList nodes) { + return nodes.getLast() instanceof FrameNode && nodes.getLast().getPrevious() instanceof LabelNode; + } + + protected LabelNode makeEndJumpFrame(InsnList nodes, MethodSignature sig, MethodNode source) { LabelNode endLabel = findOrMakeEndLabel(nodes); - - int frameInsert = nodes.indexOf(endLabel); List local = makeFrameLocals(sig.getArgs()); if (!isStatic(source)) local.add(0, target.name); + //nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, 0, null, 0, null)); - nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_FULL, local.size(), local.toArray(), 0, new Object[0])); + nodes.insert(endLabel, new FrameNode(Opcodes.F_FULL, local.size(), local.toArray(), 0, new Object[0])); + + return endLabel; + } + + private void storeAndGotoFromReturn(MethodNode source, InsnList nodes, int storeIndex, MethodSignature sig) { + // If we already have a final frame, there's no need to add one + LabelNode endLabel = hasEndJumpFrame(nodes) ? findOrMakeEndLabel(nodes) : makeEndJumpFrame(nodes, sig, source); INSTRUCTION_LOOP: - for (int i = 0; i < nodes.size(); ++i) { - switch (nodes.get(i).getOpcode()) { + for (AbstractInsnNode current = nodes.getFirst(); current != null; current = current.getNext()) { + switch (current.getOpcode()) { case Opcodes.IRETURN: - nodes.add(i, new IntInsnNode(Opcodes.ISTORE, storeIndex)); + nodes.set(current, current = new IntInsnNode(Opcodes.ISTORE, storeIndex)); break; case Opcodes.FRETURN: - nodes.add(i, new IntInsnNode(Opcodes.FSTORE, storeIndex)); + nodes.set(current, current = new IntInsnNode(Opcodes.FSTORE, storeIndex)); break; case Opcodes.ARETURN: - nodes.add(i, new IntInsnNode(Opcodes.ASTORE, storeIndex)); + nodes.set(current, current = new IntInsnNode(Opcodes.ASTORE, storeIndex)); break; case Opcodes.LRETURN: - nodes.add(i, new IntInsnNode(Opcodes.LSTORE, storeIndex)); + nodes.set(current, current = new IntInsnNode(Opcodes.LSTORE, storeIndex)); break; case Opcodes.DRETURN: - nodes.add(i, new IntInsnNode(Opcodes.DSTORE, storeIndex)); + nodes.set(current, current = new IntInsnNode(Opcodes.DSTORE, storeIndex)); break; case Opcodes.RETURN: - --i; - break; + nodes.set(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel)); + // Fallthrough default: continue INSTRUCTION_LOOP; } - - nodes.set(i + 1, new JumpInsnNode(Opcodes.GOTO, endLabel)); + nodes.insert(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel)); } } - private static void popAndGotoFromReturn(List nodes, MethodSignature sig) { - LabelNode endLabel = findOrMakeEndLabel(nodes); - - int frameInsert = nodes.indexOf(endLabel); - List local = makeFrameLocals(sig.getArgs()); - //nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, 0, null, 0, null)); - nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, local.size(), local.toArray(), 0, new Object[0])); + private void popAndGotoFromReturn(MethodNode source, InsnList nodes, MethodSignature sig) { + // If we already have a final frame, there's no need to add one + LabelNode endLabel = hasEndJumpFrame(nodes) ? findOrMakeEndLabel(nodes) : makeEndJumpFrame(nodes, sig, source); INSTRUCTION_LOOP: - for (int i = 0; i < nodes.size(); ++i) { - switch (nodes.get(i).getOpcode()) { + for (AbstractInsnNode current = nodes.getFirst(); current != null; current = current.getNext()) { + switch (current.getOpcode()) { case Opcodes.IRETURN: case Opcodes.FRETURN: case Opcodes.ARETURN: - nodes.add(i, new InsnNode(Opcodes.POP)); + nodes.set(current, current = new InsnNode(Opcodes.POP)); break; case Opcodes.LRETURN: case Opcodes.DRETURN: - nodes.add(i, new InsnNode(Opcodes.POP2)); + nodes.set(current, current = new InsnNode(Opcodes.POP2)); break; case Opcodes.RETURN: - --i; - break; + nodes.set(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel)); + // Fallthrough default: continue INSTRUCTION_LOOP; } - nodes.set(i + 1, new JumpInsnNode(Opcodes.GOTO, endLabel)); + nodes.insert(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel)); } } @@ -454,9 +470,9 @@ public class Combine { } } - protected static void adjustFramesForRetVar(List nodes, int argc) { + protected static void adjustFramesForRetVar(InsnList nodes, int argc) { boolean isFirst = true; - for (AbstractInsnNode node : nodes) + for (AbstractInsnNode node = nodes.getFirst(); node != null; node = node.getNext()) if (node instanceof FrameNode) { if (isFirst) { isFirst = false; @@ -482,7 +498,7 @@ public class Combine { * @param node Grafted method instruction node * @param source The {@link GraftSource} from which the instruction node will be adapted */ - protected void adaptMethodInsn(MethodInsnNode node, GraftSource source) { + protected void adaptMethodInsn(MethodInsnNode node, GraftSource source, MethodNode sourceMethod) { if (node.owner.equals(source.getTypeName())) { final MethodNode injected = source.getInjectedMethod(node.name, node.desc); if (injected != null) { @@ -490,6 +506,65 @@ public class Combine { node.name = source.getMethodTargetName(injected); node.desc = adaptMethodSignature(node.desc, source); } + } else if (node.owner.equals("dev/w1zzrd/asm/Directives")) { // ASM target directives + if (node.name.equals(Directives.directiveNameByTarget(DirectiveTarget.TargetType.CALL_SUPER))) { + // We're attempting to redirect a call to a superclass + for (AbstractInsnNode prev = node.getPrevious(); prev != null; prev = prev.getPrevious()) { + + if (prev instanceof MethodInsnNode && + (((MethodInsnNode) prev).owner.equals(target.name) || + ((MethodInsnNode) prev).owner.equals(source.getTypeName()))) { + // Point method owner to superclass + ((MethodInsnNode) prev).owner = target.superName; + + // Since we're calling super, we want to make it a special call + if (prev.getOpcode() == Opcodes.INVOKEVIRTUAL) + ((MethodInsnNode) prev).setOpcode(Opcodes.INVOKESPECIAL); + } else if (prev instanceof FieldInsnNode && + (((FieldInsnNode) prev).owner.equals(target.name) || + ((FieldInsnNode) prev).owner.equals(source.getTypeName()))) { + // Just change the field we're accessing to the targets superclass' field + ((FieldInsnNode) prev).owner = target.superName; + } else { + continue; + } + + return; + } + + throw new RuntimeException(String.format("Could not locate a target for directive %s", node.name)); + } else if (node.name.equals(Directives.directiveNameByTarget(DirectiveTarget.TargetType.CALL_ORIGINAL))) { + // We want to redirect execution to the original method code + // The callOriginal method returns void, so the stack should be empty at this point + + InsnList insnList = sourceMethod.instructions; + + // If we already have a final frame, there's no need to add one + LabelNode endLabel = hasEndJumpFrame(insnList) ? + findOrMakeEndLabel(insnList) : + makeEndJumpFrame(insnList, new MethodSignature(sourceMethod.desc), sourceMethod); + + AbstractInsnNode jumpInsn = new JumpInsnNode(Opcodes.GOTO, endLabel); + insnList.set(node, jumpInsn); + + MethodSignature sig = new MethodSignature(sourceMethod.desc); + final Class[] ignoredNodes = {LineNumberNode.class, LabelNode.class, FrameNode.class}; + AbstractInsnNode afterJump = getNextNode(jumpInsn, ignoredNodes); + + if (!sig.getRet().isVoidType()) { + // Now we want to remove extraneous (unreachable) return instructions + afterJump = getNextNode(afterJump, ignoredNodes); + + // This should remove extraneous return instructions, along with any constants pushed to the stack + if (afterJump.getOpcode() >= Opcodes.IRETURN && afterJump.getOpcode() < Opcodes.RETURN) { + insnList.remove(afterJump); + insnList.remove(getNextNode(jumpInsn, ignoredNodes)); + } + } else if (afterJump.getOpcode() == Opcodes.RETURN) { + // This should just remove the extraneous RETURN instruction + insnList.remove(afterJump); + } + } } } @@ -712,6 +787,25 @@ public class Combine { inject.access ^= (inject.access & flags) ^ (target.access & flags); } + protected static AbstractInsnNode getNextNode(AbstractInsnNode node, Class... skipTypes) { + return traverseNode(node, AbstractInsnNode::getNext, skipTypes); + } + + protected static AbstractInsnNode getPreviousNode(AbstractInsnNode node, Class... skipTypes) { + return traverseNode(node, AbstractInsnNode::getPrevious, skipTypes); + } + + private static AbstractInsnNode traverseNode(AbstractInsnNode node, INodeTraversal traversal, Class[] skipTypes) { + TRAVERSAL: + for (AbstractInsnNode trav = traversal.traverse(node); trav != null; trav = traversal.traverse(trav)) { + for (Class cls : skipTypes) + if (trav.getClass().equals(cls)) + continue TRAVERSAL; + return trav; + } + return null; + } + @SuppressWarnings("unused") // Used for debugging diff --git a/src/dev/w1zzrd/asm/DirectiveTarget.java b/src/dev/w1zzrd/asm/DirectiveTarget.java new file mode 100644 index 0000000..8d88626 --- /dev/null +++ b/src/dev/w1zzrd/asm/DirectiveTarget.java @@ -0,0 +1,17 @@ +package dev.w1zzrd.asm; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +@interface DirectiveTarget { + + TargetType value(); + + enum TargetType { + CALL_ORIGINAL, CALL_SUPER; + } +} diff --git a/src/dev/w1zzrd/asm/Directives.java b/src/dev/w1zzrd/asm/Directives.java new file mode 100644 index 0000000..bfd9754 --- /dev/null +++ b/src/dev/w1zzrd/asm/Directives.java @@ -0,0 +1,29 @@ +package dev.w1zzrd.asm; + +import dev.w1zzrd.asm.exception.DirectiveNotImplementedException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; + +public class Directives { + + @DirectiveTarget(DirectiveTarget.TargetType.CALL_ORIGINAL) + public static void callOriginal() { + throw new DirectiveNotImplementedException("callOriginal"); + } + + @DirectiveTarget(DirectiveTarget.TargetType.CALL_SUPER) + public static void callSuper() { + throw new DirectiveNotImplementedException("callSuper"); + } + + + static String directiveNameByTarget(DirectiveTarget.TargetType type) { + DirectiveTarget target; + for (Method f : Directives.class.getMethods()) + if ((target = f.getDeclaredAnnotation(DirectiveTarget.class)) != null && target.value().equals(type)) + return f.getName(); + + // This won't happen unless I'm dumb, so call it 50/50 odds + throw new RuntimeException("Could not find implementation of directive target: "+type.name()); + } +} diff --git a/src/dev/w1zzrd/asm/exception/DirectiveNotImplementedException.java b/src/dev/w1zzrd/asm/exception/DirectiveNotImplementedException.java new file mode 100644 index 0000000..58a39ff --- /dev/null +++ b/src/dev/w1zzrd/asm/exception/DirectiveNotImplementedException.java @@ -0,0 +1,7 @@ +package dev.w1zzrd.asm.exception; + +public class DirectiveNotImplementedException extends RuntimeException { + public DirectiveNotImplementedException(String directiveName) { + super("Operation not implemented: "+directiveName); + } +}