From cea7fa06bddd7a92eebc423dc591639c8cccf82e Mon Sep 17 00:00:00 2001 From: Gabriel Tofvesson Date: Thu, 21 Jan 2021 06:40:26 +0100 Subject: [PATCH] Implement method appending --- src/dev/w1zzrd/asm/Combine.java | 380 +++++++++++++++++- src/dev/w1zzrd/asm/GraftSource.java | 10 + .../w1zzrd/asm/signature/MethodSignature.java | 10 +- .../w1zzrd/asm/signature/TypeSignature.java | 11 + 4 files changed, 405 insertions(+), 6 deletions(-) diff --git a/src/dev/w1zzrd/asm/Combine.java b/src/dev/w1zzrd/asm/Combine.java index eebe0a3..3d5cb2d 100644 --- a/src/dev/w1zzrd/asm/Combine.java +++ b/src/dev/w1zzrd/asm/Combine.java @@ -4,20 +4,30 @@ import dev.w1zzrd.asm.analysis.AsmAnnotation; import dev.w1zzrd.asm.exception.MethodNodeResolutionException; import dev.w1zzrd.asm.exception.SignatureInstanceMismatchException; import dev.w1zzrd.asm.signature.MethodSignature; +import dev.w1zzrd.asm.signature.TypeSignature; import jdk.internal.org.objectweb.asm.ClassWriter; +import jdk.internal.org.objectweb.asm.Handle; import jdk.internal.org.objectweb.asm.Opcodes; import jdk.internal.org.objectweb.asm.Type; import jdk.internal.org.objectweb.asm.tree.*; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Arrays; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS; public class Combine { + private static final String VAR_NAME_CHARS = "$_qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM"; + private static final String VAR_NAME_CHARS1 = "$_qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM1234567890"; + + private final ArrayList graftSources = new ArrayList<>(); + private final ClassNode target; @@ -58,18 +68,88 @@ public class Combine { * @param acceptReturn Whether or not the grafted method should "receive" the original method's return value as an "argument" */ public void append(MethodNode extension, GraftSource source, boolean acceptReturn) { + if (initiateGrafting(extension, source)) + return; + final MethodNode target = checkMethodExists(source.getMethodTargetName(extension), source.getMethodTargetSignature(extension)); adaptMethod(extension, source); + MethodSignature msig = new MethodSignature(target.desc); + MethodSignature xsig = new MethodSignature(extension.desc); + + + List targetInsns; + + // If graft method cares about the return value of the original method + if (acceptReturn) { + LocalVariableNode retVar = null; + + // If return of original is not void, we need to capture and store it to pass to the extension code + if (!msig.getRet().isVoidType()) { + // Generate a random return var name + String name; + GEN_NAME: + do { + name = getRandomString(1, 16); + for (LocalVariableNode vNode : target.localVariables) + if (name.equals(vNode.name)) + continue GEN_NAME; + + break; + } while (true); + + // Create return variable + retVar = insertRetvarNode(target, name, msig.getRet()); + } + + // Convert instructions into a more modifiable format + targetInsns = decomposeToList(target.instructions); + + // Replace return instructions with GOTOs to the last instruction in the list + // Return values are stored in retVar + storeAndGotoFromReturn(targetInsns, retVar == null ? -1 : retVar.index); + + // We need to extend the scope of the retVar into the grafted code + if (retVar != null) + //noinspection OptionalGetWithoutIsPresent + retVar.end = extension.localVariables + .stream() + .filter(it -> it.index == xsig.getArgCount() - 1) + .findFirst() + .get() // This should never fail + .end; + } else { + targetInsns = decomposeToList(target.instructions); + + // If we don't care about the return value from the original, we can replace returns with pops + popAndGotoFromReturn(targetInsns); + } + + List extVars = getVarsOver(extension.localVariables, xsig.getArgCount()); + + // Add extension vars to target + target.localVariables.addAll(extVars); + + // Add extension instructions to instruction list + targetInsns.addAll(decomposeToList(extension.instructions)); + + // Convert instructions back to a InsnList + target.instructions = coalesceInstructions(targetInsns); } public void prepend(MethodNode extension, GraftSource source) { + if (initiateGrafting(extension, source)) + return; + final MethodNode target = checkMethodExists(source.getMethodTargetName(extension), source.getMethodTargetSignature(extension)); adaptMethod(extension, source); } public void replace(MethodNode inject, GraftSource source, boolean preserveOriginalAccess) { + if (initiateGrafting(inject, source)) + return; + final MethodNode remove = checkMethodExists(source.getMethodTargetName(inject), source.getMethodTargetSignature(inject)); ensureMatchingSignatures(remove, inject, Opcodes.ACC_STATIC); if (preserveOriginalAccess) @@ -79,11 +159,17 @@ public class Combine { } public void insert(MethodNode inject, GraftSource source) { + if (initiateGrafting(inject, source)) + return; + checkMethodNotExists(source.getMethodTargetName(inject), source.getMethodTargetSignature(inject)); insertOrReplace(inject, source); } protected void insertOrReplace(MethodNode inject, GraftSource source) { + if (initiateGrafting(inject, source)) + return; + MethodNode replace = findMethodNode(source.getMethodTargetName(inject), source.getMethodTargetSignature(inject)); if (replace != null) @@ -95,6 +181,16 @@ public class Combine { } + private boolean initiateGrafting(MethodNode node, GraftSource source) { + DynamicSourceUnit unit = new DynamicSourceUnit(source, node); + boolean alreadyGrafting = graftSources.contains(unit); + + if (!alreadyGrafting) + graftSources.add(unit); + + return alreadyGrafting; + } + /** * Compile target class data to a byte array @@ -158,16 +254,191 @@ public class Combine { * @param source The {@link GraftSource} from which the node will be adapted */ protected void adaptMethod(MethodNode node, GraftSource source) { - final AbstractInsnNode last = node.instructions.getLast(); - for (AbstractInsnNode insn = node.instructions.getFirst(); insn != last; insn = insn.getNext()) { + // Adapt instructions + ADAPT: + for (AbstractInsnNode insn = node.instructions.getFirst(); insn != null; insn = insn.getNext()) { if (insn instanceof MethodInsnNode) adaptMethodInsn((MethodInsnNode) insn, source); else if (insn instanceof LdcInsnNode) adaptLdcInsn((LdcInsnNode) insn, source.getTypeName()); else if (insn instanceof FrameNode) adaptFrameNode((FrameNode) insn, node, source); + else if (insn instanceof InvokeDynamicInsnNode && ((Handle)((InvokeDynamicInsnNode) insn).bsmArgs[1]).getOwner().equals(source.getTypeName())) { + // We have an INVOKEDYNAMIC to a method in the graft source class. The target has to be injected into the target + Handle handle = (Handle)((InvokeDynamicInsnNode) insn).bsmArgs[1]; + + for (MethodNode mNode : target.methods) + if (mNode.name.equals(handle.getName()) && mNode.desc.equals(handle.getDesc())) + continue ADAPT; // The target has already been injected + + MethodNode inject = source.getMethodNode(handle.getName(), handle.getDesc()); + if (inject == null) + throw new MethodNodeResolutionException(String.format( + "Could not locate lambda target %s%s in graft source %s", + handle.getName(), + handle.getDesc(), + source.getTypeName() + )); + + // Attempt to inject lambda target site into target class + insert(inject, source); + + // The INVOKEDYNAMIC now points to a call site in the target class + ((InvokeDynamicInsnNode) insn).bsmArgs[1] = new Handle( + handle.getTag(), + target.name, + handle.getName(), + handle.getDesc() + ); + } } node.name = source.getMethodTargetName(node); } + private static LabelNode findOrMakeEndLabel(List nodes) { + AbstractInsnNode last = nodes.get(nodes.size() - 1); + + if (last instanceof LabelNode) + return (LabelNode) last; + + LabelNode label = new LabelNode(); + + nodes.add(label); + return label; + } + + private static void storeAndGotoFromReturn(List nodes, int storeIndex) { + LabelNode endLabel = findOrMakeEndLabel(nodes); + + INSTRUCTION_LOOP: + for (int i = 0; i < nodes.size(); ++i) { + switch (nodes.get(i).getOpcode()) { + case Opcodes.IRETURN: + nodes.add(i, new IntInsnNode(Opcodes.ISTORE, storeIndex)); + break; + + case Opcodes.FRETURN: + nodes.add(i, new IntInsnNode(Opcodes.FSTORE, storeIndex)); + break; + + case Opcodes.ARETURN: + nodes.add(i, new IntInsnNode(Opcodes.ASTORE, storeIndex)); + break; + + case Opcodes.LRETURN: + nodes.add(i, new IntInsnNode(Opcodes.LSTORE, storeIndex)); + break; + + case Opcodes.DRETURN: + nodes.add(i, new IntInsnNode(Opcodes.DSTORE, storeIndex)); + break; + + case Opcodes.RETURN: + --i; + break; + + default: + continue INSTRUCTION_LOOP; + } + + nodes.set(i, new JumpInsnNode(Opcodes.GOTO, endLabel)); + } + } + + private static void popAndGotoFromReturn(List nodes) { + LabelNode endLabel = findOrMakeEndLabel(nodes); + + INSTRUCTION_LOOP: + for (int i = 0; i < nodes.size(); ++i) { + switch (nodes.get(i).getOpcode()) { + case Opcodes.IRETURN: + case Opcodes.FRETURN: + case Opcodes.ARETURN: + nodes.add(i, new InsnNode(Opcodes.POP)); + break; + + case Opcodes.LRETURN: + case Opcodes.DRETURN: + nodes.add(i, new InsnNode(Opcodes.POP2)); + break; + + case Opcodes.RETURN: + --i; + break; + + default: + continue INSTRUCTION_LOOP; + } + + nodes.set(i, new JumpInsnNode(Opcodes.GOTO, endLabel)); + } + } + + private static List decomposeToList(InsnList insns) { + try { + // This should save us some overhead + Field elementData = ArrayList.class.getDeclaredField("elementData"); + elementData.setAccessible(true); + Field size = ArrayList.class.getDeclaredField("size"); + size.setAccessible(true); + + // Make arraylist and get array of instructions + ArrayList decomposed = new ArrayList<>(); + AbstractInsnNode[] nodes = insns.toArray(); + + // Copy instructions to arraylist + elementData.set(decomposed, nodes); + size.set(decomposed, nodes.length); + + return decomposed; + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); // Probably Java 9+ + } + } + + + private static LocalVariableNode insertRetvarNode(MethodNode node, String name, TypeSignature type) { + // Finds first label or creates it + LabelNode firstLabel = findLabelBeforeReturn(node.instructions.getFirst(), AbstractInsnNode::getNext); + + // No label found before a return: create one + if (firstLabel == null) + node.instructions.insert(firstLabel = new LabelNode()); + + // Finds last label or creates it + LabelNode lastLabel = findLabelBeforeReturn(node.instructions.getLast(), AbstractInsnNode::getPrevious); + + // No label found after a return: create one + if (lastLabel == null) + node.instructions.add(lastLabel = new LabelNode()); + + + // Put new variable immediately after the method arguments + MethodSignature msig = new MethodSignature(node.desc); + LocalVariableNode varNode = new LocalVariableNode( + name, + type.getSig(), + null, + firstLabel, + lastLabel, + msig.getArgCount() + ); + + // Increment existing variable indices by 1 + for (LocalVariableNode vNode : node.localVariables) + if (vNode.index >= varNode.index) + ++vNode.index; + + // Update instructions referencing local variables + for (AbstractInsnNode insn = node.instructions.getFirst(); insn != null; insn = insn.getNext()) { + if (insn instanceof VarInsnNode && ((VarInsnNode) insn).var >= varNode.index) + ++((VarInsnNode) insn).var; + } + + // Add variable to locals + node.localVariables.add(varNode); + + return varNode; + } + /** * Adapts a grafted method instruction node to fit its surrogate @@ -195,10 +466,42 @@ public class Combine { } protected void adaptFrameNode(FrameNode node, MethodNode method, GraftSource source) { - for (int i = 0; i < node.stack.size(); ++i) - if (node.stack.get(i) instanceof Type) { + adaptFrameTypes(node.stack, source); + adaptFrameTypes(node.local, source); + } + + protected void adaptFrameTypes(List types, GraftSource source) { + if (types == null) + return; + + for (int i = 0; i < types.size(); ++i) { + if (types.get(i) instanceof Type) { + Type t = (Type) types.get(i); + + if (t.getSort() == Type.OBJECT && source.getTypeName().equals(t.getInternalName())) + types.set(i, Type.getType(source.getTypeName())); + else if (t.getSort() == Type.METHOD) { + TypeSignature sourceSig = new TypeSignature("L"+source.getTypeName()+";"); + TypeSignature targetSig = new TypeSignature("L"+target.name+";"); + MethodSignature mDesc = new MethodSignature(t.getDescriptor()); + for (int j = 0; j < mDesc.getArgCount(); ++j) + if (mDesc.getArg(j).getArrayAtomType().equals(sourceSig)) + mDesc.setArg(j, new TypeSignature( + targetSig.getSig(), + mDesc.getArg(j).getArrayDepth(), + false + )); + + if (mDesc.getRet().getArrayAtomType().equals(sourceSig)) + mDesc.setRet(new TypeSignature( + targetSig.getSig(), + mDesc.getRet().getArrayDepth(), + false + )); + } } + } } /** @@ -283,4 +586,71 @@ public class Combine { public static java.util.List dumpInsns(MethodNode node) { return Arrays.asList(node.instructions.toArray()); } + + + private static String getRandomString(int minLen, int maxLen) { + Random r = ThreadLocalRandom.current(); + + // Select a random length + char[] str = new char[r.nextInt(maxLen - minLen) + minLen]; + + // Generate random string + str[0] = VAR_NAME_CHARS.charAt(r.nextInt(VAR_NAME_CHARS.length())); + for(int i = 1; i < str.length; ++i) + str[i] = VAR_NAME_CHARS1.charAt(r.nextInt(VAR_NAME_CHARS1.length())); + + return new String(str); + } + + protected static InsnList coalesceInstructions(List nodes) { + InsnList insns = new InsnList(); + + for(AbstractInsnNode node : nodes) + insns.add(node); + + return insns; + } + + protected static List getVarsOver(List varNodes, int minIndex) { + return varNodes.stream().filter(it -> it.index >= minIndex).collect(Collectors.toList()); + } + + + protected static @Nullable LabelNode findLabelBeforeReturn(AbstractInsnNode start, INodeTraversal traverse) { + for (AbstractInsnNode cur = start; cur != null; cur = traverse.traverse(cur)) + if (cur instanceof LabelNode) // Traversal hit label + return (LabelNode) cur; + else if (cur.getOpcode() >= Opcodes.IRETURN && cur.getOpcode() <= Opcodes.RETURN) // Traversal hit return + return null; + + return null; // Nothing was found + } + + + protected interface INodeTraversal { + AbstractInsnNode traverse(AbstractInsnNode cur); + } + + private static class DynamicSourceUnit { + public final GraftSource source; + public final MethodNode node; + + private DynamicSourceUnit(GraftSource source, MethodNode node) { + this.source = source; + this.node = node; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DynamicSourceUnit that = (DynamicSourceUnit) o; + return source.equals(that.source) && node.equals(that.node); + } + + @Override + public int hashCode() { + return Objects.hash(source, node); + } + } } diff --git a/src/dev/w1zzrd/asm/GraftSource.java b/src/dev/w1zzrd/asm/GraftSource.java index 9290cd0..6d9860d 100644 --- a/src/dev/w1zzrd/asm/GraftSource.java +++ b/src/dev/w1zzrd/asm/GraftSource.java @@ -18,8 +18,10 @@ public final class GraftSource { private final String typeName; private final HashMap>> methodAnnotations; private final HashMap>> fieldAnnotations; + private final ClassNode source; public GraftSource(ClassNode source) { + this.source = source; this.typeName = source.name; methodAnnotations = new HashMap<>(); @@ -43,6 +45,14 @@ public final class GraftSource { } } + public MethodNode getMethodNode(String name, String desc) { + for (MethodNode node : source.methods) + if (node.name.equals(name) && node.desc.equals(desc)) + return node; + + return null; + } + public String getTypeName() { return typeName; } diff --git a/src/dev/w1zzrd/asm/signature/MethodSignature.java b/src/dev/w1zzrd/asm/signature/MethodSignature.java index d1ea553..b824297 100644 --- a/src/dev/w1zzrd/asm/signature/MethodSignature.java +++ b/src/dev/w1zzrd/asm/signature/MethodSignature.java @@ -6,7 +6,7 @@ import java.util.Objects; public class MethodSignature { private final TypeSignature[] args; - private final TypeSignature ret; + private TypeSignature ret; public MethodSignature(String sig) { // Minimal signature size is 3. For example: "()V". With name, minimal length is 4: "a()V" @@ -94,6 +94,14 @@ public class MethodSignature { return args[index]; } + public void setArg(int idx, TypeSignature sig) { + args[idx] = sig; + } + + public void setRet(TypeSignature sig) { + ret = sig; + } + public TypeSignature getRet() { return ret; } diff --git a/src/dev/w1zzrd/asm/signature/TypeSignature.java b/src/dev/w1zzrd/asm/signature/TypeSignature.java index 9d6ef68..003e665 100644 --- a/src/dev/w1zzrd/asm/signature/TypeSignature.java +++ b/src/dev/w1zzrd/asm/signature/TypeSignature.java @@ -173,6 +173,17 @@ public class TypeSignature { return new TypeSignature(sig.substring(1)); } + /** + * Get the type signature of the innermost element type. For example, "[[[I" would produce "I" + * @return Innermost element type for array types, this for non-array types + */ + public TypeSignature getArrayAtomType() { + if (!isArray()) + return this; + + return new TypeSignature(sig.substring(arrayDepth)); + } + /** * Check whether or not this type represents a Top type. * @return True if it is a Top, else false