Implement method prepending

This commit is contained in:
Gabriel Tofvesson 2021-01-28 05:32:25 +01:00
parent 69b2194b23
commit 60540d21c0
4 changed files with 197 additions and 50 deletions

View File

@ -85,8 +85,6 @@ public class Combine {
final int graftArgCount = xsig.getArgCount() + (isStatic(extension) ? 0 : 1);
final int targetArgCount = msig.getArgCount() + (isStatic(target) ? 0 : 1);
List<AbstractInsnNode> targetInsns;
// If graft method cares about the return value of the original method, i.e. accepts it as an extra "argument"
if (acceptReturn && !msig.getRet().isVoidType()) {
//noinspection OptionalGetWithoutIsPresent
@ -102,21 +100,16 @@ public class Combine {
// Handle retvar specially
extension.localVariables.remove(retVar);
// Convert instructions into a more modifiable format
targetInsns = decomposeToList(target.instructions);
// Make space in the original frames for the return var
// This isn't an optimal solution, but it works for now
adjustFramesForRetVar(targetInsns, targetArgCount);
adjustFramesForRetVar(target.instructions, targetArgCount);
// Replace return instructions with GOTOs to the last instruction in the list
// Return values are stored in retVar
storeAndGotoFromReturn(target, targetInsns, retVar.index, xsig);
storeAndGotoFromReturn(target, target.instructions, retVar.index, xsig);
} else {
targetInsns = decomposeToList(target.instructions);
// If we don't care about the return value from the original, we can replace returns with pops
popAndGotoFromReturn(targetInsns, xsig);
popAndGotoFromReturn(target, target.instructions, xsig);
}
List<LocalVariableNode> extVars = getVarsOver(extension.localVariables, xsig.getArgCount());
@ -125,13 +118,7 @@ public class Combine {
target.localVariables.addAll(extVars);
// Add extension instructions to instruction list
targetInsns.addAll(decomposeToList(extension.instructions));
// Some weird TOP delimiter between "true" locals and args
//insertTopDelimiter(targetArgCount, targetInsns, target.localVariables, true);
// Convert instructions back to a InsnList
target.instructions = coalesceInstructions(targetInsns);
target.instructions.add(extension.instructions);
// Make sure we extend the scope of the original method arguments
for (int i = 0; i < targetArgCount; ++i)
@ -165,6 +152,16 @@ public class Combine {
);
adaptMethod(extension, source);
MethodSignature sig = new MethodSignature(extension.desc);
target.localVariables.addAll(getVarsOver(extension.localVariables, sig.getArgCount()));
extension.instructions.add(target.instructions);
target.instructions = extension.instructions;
// Extend argument scope to cover prepended code
for (int i = 0; i < sig.getArgCount(); ++i)
adjustArgument(target, getVarAt(target.localVariables, i), true, false);
finishGrafting(extension, source);
}
@ -285,7 +282,7 @@ public class Combine {
protected void adaptMethod(MethodNode node, GraftSource source) {
// Adapt instructions
for (AbstractInsnNode insn = node.instructions.getFirst(); insn != null; insn = insn.getNext()) {
if (insn instanceof MethodInsnNode) adaptMethodInsn((MethodInsnNode) insn, source);
if (insn instanceof MethodInsnNode) adaptMethodInsn((MethodInsnNode) insn, source, node);
else if (insn instanceof LdcInsnNode) adaptLdcInsn((LdcInsnNode) insn, source.getTypeName());
else if (insn instanceof FrameNode) adaptFrameNode((FrameNode) insn, source);
else if (insn instanceof FieldInsnNode) adaptFieldInsn((FieldInsnNode) insn, source);
@ -313,83 +310,102 @@ public class Combine {
return label;
}
private void storeAndGotoFromReturn(MethodNode source, List<AbstractInsnNode> nodes, int storeIndex, MethodSignature sig) {
protected static LabelNode findOrMakeEndLabel(InsnList nodes) {
AbstractInsnNode last = nodes.getLast();
while (last instanceof FrameNode) last = last.getPrevious();
if (last instanceof LabelNode)
return (LabelNode) last;
LabelNode label = new LabelNode();
nodes.add(label);
return label;
}
protected static boolean hasEndJumpFrame(InsnList nodes) {
return nodes.getLast() instanceof FrameNode && nodes.getLast().getPrevious() instanceof LabelNode;
}
protected LabelNode makeEndJumpFrame(InsnList nodes, MethodSignature sig, MethodNode source) {
LabelNode endLabel = findOrMakeEndLabel(nodes);
int frameInsert = nodes.indexOf(endLabel);
List<Object> local = makeFrameLocals(sig.getArgs());
if (!isStatic(source))
local.add(0, target.name);
//nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, 0, null, 0, null));
nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_FULL, local.size(), local.toArray(), 0, new Object[0]));
nodes.insert(endLabel, new FrameNode(Opcodes.F_FULL, local.size(), local.toArray(), 0, new Object[0]));
return endLabel;
}
private void storeAndGotoFromReturn(MethodNode source, InsnList nodes, int storeIndex, MethodSignature sig) {
// If we already have a final frame, there's no need to add one
LabelNode endLabel = hasEndJumpFrame(nodes) ? findOrMakeEndLabel(nodes) : makeEndJumpFrame(nodes, sig, source);
INSTRUCTION_LOOP:
for (int i = 0; i < nodes.size(); ++i) {
switch (nodes.get(i).getOpcode()) {
for (AbstractInsnNode current = nodes.getFirst(); current != null; current = current.getNext()) {
switch (current.getOpcode()) {
case Opcodes.IRETURN:
nodes.add(i, new IntInsnNode(Opcodes.ISTORE, storeIndex));
nodes.set(current, current = new IntInsnNode(Opcodes.ISTORE, storeIndex));
break;
case Opcodes.FRETURN:
nodes.add(i, new IntInsnNode(Opcodes.FSTORE, storeIndex));
nodes.set(current, current = new IntInsnNode(Opcodes.FSTORE, storeIndex));
break;
case Opcodes.ARETURN:
nodes.add(i, new IntInsnNode(Opcodes.ASTORE, storeIndex));
nodes.set(current, current = new IntInsnNode(Opcodes.ASTORE, storeIndex));
break;
case Opcodes.LRETURN:
nodes.add(i, new IntInsnNode(Opcodes.LSTORE, storeIndex));
nodes.set(current, current = new IntInsnNode(Opcodes.LSTORE, storeIndex));
break;
case Opcodes.DRETURN:
nodes.add(i, new IntInsnNode(Opcodes.DSTORE, storeIndex));
nodes.set(current, current = new IntInsnNode(Opcodes.DSTORE, storeIndex));
break;
case Opcodes.RETURN:
--i;
break;
nodes.set(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel));
// Fallthrough
default:
continue INSTRUCTION_LOOP;
}
nodes.set(i + 1, new JumpInsnNode(Opcodes.GOTO, endLabel));
nodes.insert(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel));
}
}
private static void popAndGotoFromReturn(List<AbstractInsnNode> nodes, MethodSignature sig) {
LabelNode endLabel = findOrMakeEndLabel(nodes);
int frameInsert = nodes.indexOf(endLabel);
List<Object> local = makeFrameLocals(sig.getArgs());
//nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, 0, null, 0, null));
nodes.add(frameInsert + 1, new FrameNode(Opcodes.F_SAME, local.size(), local.toArray(), 0, new Object[0]));
private void popAndGotoFromReturn(MethodNode source, InsnList nodes, MethodSignature sig) {
// If we already have a final frame, there's no need to add one
LabelNode endLabel = hasEndJumpFrame(nodes) ? findOrMakeEndLabel(nodes) : makeEndJumpFrame(nodes, sig, source);
INSTRUCTION_LOOP:
for (int i = 0; i < nodes.size(); ++i) {
switch (nodes.get(i).getOpcode()) {
for (AbstractInsnNode current = nodes.getFirst(); current != null; current = current.getNext()) {
switch (current.getOpcode()) {
case Opcodes.IRETURN:
case Opcodes.FRETURN:
case Opcodes.ARETURN:
nodes.add(i, new InsnNode(Opcodes.POP));
nodes.set(current, current = new InsnNode(Opcodes.POP));
break;
case Opcodes.LRETURN:
case Opcodes.DRETURN:
nodes.add(i, new InsnNode(Opcodes.POP2));
nodes.set(current, current = new InsnNode(Opcodes.POP2));
break;
case Opcodes.RETURN:
--i;
break;
nodes.set(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel));
// Fallthrough
default:
continue INSTRUCTION_LOOP;
}
nodes.set(i + 1, new JumpInsnNode(Opcodes.GOTO, endLabel));
nodes.insert(current, current = new JumpInsnNode(Opcodes.GOTO, endLabel));
}
}
@ -454,9 +470,9 @@ public class Combine {
}
}
protected static void adjustFramesForRetVar(List<AbstractInsnNode> nodes, int argc) {
protected static void adjustFramesForRetVar(InsnList nodes, int argc) {
boolean isFirst = true;
for (AbstractInsnNode node : nodes)
for (AbstractInsnNode node = nodes.getFirst(); node != null; node = node.getNext())
if (node instanceof FrameNode) {
if (isFirst) {
isFirst = false;
@ -482,7 +498,7 @@ public class Combine {
* @param node Grafted method instruction node
* @param source The {@link GraftSource} from which the instruction node will be adapted
*/
protected void adaptMethodInsn(MethodInsnNode node, GraftSource source) {
protected void adaptMethodInsn(MethodInsnNode node, GraftSource source, MethodNode sourceMethod) {
if (node.owner.equals(source.getTypeName())) {
final MethodNode injected = source.getInjectedMethod(node.name, node.desc);
if (injected != null) {
@ -490,6 +506,65 @@ public class Combine {
node.name = source.getMethodTargetName(injected);
node.desc = adaptMethodSignature(node.desc, source);
}
} else if (node.owner.equals("dev/w1zzrd/asm/Directives")) { // ASM target directives
if (node.name.equals(Directives.directiveNameByTarget(DirectiveTarget.TargetType.CALL_SUPER))) {
// We're attempting to redirect a call to a superclass
for (AbstractInsnNode prev = node.getPrevious(); prev != null; prev = prev.getPrevious()) {
if (prev instanceof MethodInsnNode &&
(((MethodInsnNode) prev).owner.equals(target.name) ||
((MethodInsnNode) prev).owner.equals(source.getTypeName()))) {
// Point method owner to superclass
((MethodInsnNode) prev).owner = target.superName;
// Since we're calling super, we want to make it a special call
if (prev.getOpcode() == Opcodes.INVOKEVIRTUAL)
((MethodInsnNode) prev).setOpcode(Opcodes.INVOKESPECIAL);
} else if (prev instanceof FieldInsnNode &&
(((FieldInsnNode) prev).owner.equals(target.name) ||
((FieldInsnNode) prev).owner.equals(source.getTypeName()))) {
// Just change the field we're accessing to the targets superclass' field
((FieldInsnNode) prev).owner = target.superName;
} else {
continue;
}
return;
}
throw new RuntimeException(String.format("Could not locate a target for directive %s", node.name));
} else if (node.name.equals(Directives.directiveNameByTarget(DirectiveTarget.TargetType.CALL_ORIGINAL))) {
// We want to redirect execution to the original method code
// The callOriginal method returns void, so the stack should be empty at this point
InsnList insnList = sourceMethod.instructions;
// If we already have a final frame, there's no need to add one
LabelNode endLabel = hasEndJumpFrame(insnList) ?
findOrMakeEndLabel(insnList) :
makeEndJumpFrame(insnList, new MethodSignature(sourceMethod.desc), sourceMethod);
AbstractInsnNode jumpInsn = new JumpInsnNode(Opcodes.GOTO, endLabel);
insnList.set(node, jumpInsn);
MethodSignature sig = new MethodSignature(sourceMethod.desc);
final Class<?>[] ignoredNodes = {LineNumberNode.class, LabelNode.class, FrameNode.class};
AbstractInsnNode afterJump = getNextNode(jumpInsn, ignoredNodes);
if (!sig.getRet().isVoidType()) {
// Now we want to remove extraneous (unreachable) return instructions
afterJump = getNextNode(afterJump, ignoredNodes);
// This should remove extraneous return instructions, along with any constants pushed to the stack
if (afterJump.getOpcode() >= Opcodes.IRETURN && afterJump.getOpcode() < Opcodes.RETURN) {
insnList.remove(afterJump);
insnList.remove(getNextNode(jumpInsn, ignoredNodes));
}
} else if (afterJump.getOpcode() == Opcodes.RETURN) {
// This should just remove the extraneous RETURN instruction
insnList.remove(afterJump);
}
}
}
}
@ -712,6 +787,25 @@ public class Combine {
inject.access ^= (inject.access & flags) ^ (target.access & flags);
}
protected static AbstractInsnNode getNextNode(AbstractInsnNode node, Class<?>... skipTypes) {
return traverseNode(node, AbstractInsnNode::getNext, skipTypes);
}
protected static AbstractInsnNode getPreviousNode(AbstractInsnNode node, Class<?>... skipTypes) {
return traverseNode(node, AbstractInsnNode::getPrevious, skipTypes);
}
private static AbstractInsnNode traverseNode(AbstractInsnNode node, INodeTraversal traversal, Class<?>[] skipTypes) {
TRAVERSAL:
for (AbstractInsnNode trav = traversal.traverse(node); trav != null; trav = traversal.traverse(trav)) {
for (Class<?> cls : skipTypes)
if (trav.getClass().equals(cls))
continue TRAVERSAL;
return trav;
}
return null;
}
@SuppressWarnings("unused") // Used for debugging

View File

@ -0,0 +1,17 @@
package dev.w1zzrd.asm;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@interface DirectiveTarget {
TargetType value();
enum TargetType {
CALL_ORIGINAL, CALL_SUPER;
}
}

View File

@ -0,0 +1,29 @@
package dev.w1zzrd.asm;
import dev.w1zzrd.asm.exception.DirectiveNotImplementedException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
public class Directives {
@DirectiveTarget(DirectiveTarget.TargetType.CALL_ORIGINAL)
public static void callOriginal() {
throw new DirectiveNotImplementedException("callOriginal");
}
@DirectiveTarget(DirectiveTarget.TargetType.CALL_SUPER)
public static void callSuper() {
throw new DirectiveNotImplementedException("callSuper");
}
static String directiveNameByTarget(DirectiveTarget.TargetType type) {
DirectiveTarget target;
for (Method f : Directives.class.getMethods())
if ((target = f.getDeclaredAnnotation(DirectiveTarget.class)) != null && target.value().equals(type))
return f.getName();
// This won't happen unless I'm dumb, so call it 50/50 odds
throw new RuntimeException("Could not find implementation of directive target: "+type.name());
}
}

View File

@ -0,0 +1,7 @@
package dev.w1zzrd.asm.exception;
public class DirectiveNotImplementedException extends RuntimeException {
public DirectiveNotImplementedException(String directiveName) {
super("Operation not implemented: "+directiveName);
}
}