diff --git a/src/dev/w1zzrd/asm/Combine.java b/src/dev/w1zzrd/asm/Combine.java index 0f5ead9..bfb7eb6 100644 --- a/src/dev/w1zzrd/asm/Combine.java +++ b/src/dev/w1zzrd/asm/Combine.java @@ -1,5 +1,6 @@ package dev.w1zzrd.asm; +import com.sun.org.apache.bcel.internal.generic.GotoInstruction; import dev.w1zzrd.asm.analysis.AsmAnnotation; import dev.w1zzrd.asm.exception.MethodNodeResolutionException; import dev.w1zzrd.asm.exception.SignatureCheckException; @@ -19,6 +20,10 @@ import java.util.stream.Collectors; import static jdk.internal.org.objectweb.asm.ClassWriter.COMPUTE_MAXS; public class Combine { + public static final String VAR_ASSERT_NAME = "$assertionsDisabled"; + public static final int VAR_ASSERT_FLAGS = Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC | Opcodes.ACC_FINAL; + + private final ArrayList graftSources = new ArrayList<>(); private final ClassNode target; @@ -323,6 +328,88 @@ public class Combine { return target; } + protected void ensureLoadClassAssertionState() { + if (!hasDeclaredAssertionState()) + target.fields.add(new FieldNode( + VAR_ASSERT_FLAGS, + VAR_ASSERT_NAME, + "Z", + null, + null + )); + + // Check if state is loaded + if (target.methods.stream().noneMatch(it -> it.name.equals(""))) { + MethodNode mnode = new MethodNode(Opcodes.ACC_STATIC, "", "()V", null, null); + injectAssertionLoad(mnode, true); + target.methods.add(mnode); + } else { + MethodNode clinit = target.methods.stream().filter(it -> it.name.equals("")).findAny().get(); + + for (AbstractInsnNode node = clinit.instructions.getFirst(); node != null; node = node.getNext()) + if (node instanceof FieldInsnNode && + node.getOpcode() == Opcodes.PUTSTATIC && + VAR_ASSERT_NAME.equals(((FieldInsnNode) node).name)) + return; + + // Assertion state not loaded in the current clinit. Add it to the start of the clinit + injectAssertionLoad(clinit, false); + } + } + + /** + * For internal use. Adds instructions to the given method to load assertion flag to a static final field + * @param node Method to inject instructions into + */ + private void injectAssertionLoad(MethodNode node, boolean insertReturn) { + if (node.instructions == null) { + node.instructions = new InsnList(); + insertReturn = true; + } + + AbstractInsnNode current; + if (node.instructions.getFirst() == null || !(node.instructions.getFirst() instanceof LabelNode)) + node.instructions.insert(current = new LabelNode()); + else + current = node.instructions.getFirst(); + + node.instructions.insert(current, current = new LdcInsnNode(Type.getType("L"+target.name+";"))); + node.instructions.insert(current, current = new MethodInsnNode( + Opcodes.INVOKEVIRTUAL, + "java/lang/Class", + "desiredAssertionStatus", + "()Z", + false + )); + + final LabelNode jumpNE = new LabelNode(); + final LabelNode jumpGOTO = new LabelNode(); + + node.instructions.insert(current, current = new JumpInsnNode(Opcodes.IFNE, jumpNE)); + node.instructions.insert(current, current = new InsnNode(Opcodes.ICONST_1)); + node.instructions.insert(current, current = new JumpInsnNode(Opcodes.GOTO, jumpGOTO)); + node.instructions.insert(current, current = jumpNE); + node.instructions.insert(current, current = new FrameNode(Opcodes.F_SAME, 0, new Object[0], 0, new Object[0])); + node.instructions.insert(current, current = new InsnNode(Opcodes.ICONST_0)); + node.instructions.insert(current, current = jumpGOTO); + node.instructions.insert(current, current = new FrameNode( + Opcodes.F_SAME1, + 0, + new Object[0], + 1, + new Object[]{ Opcodes.INTEGER } + )); + + node.instructions.insert(current, current = new FieldInsnNode(Opcodes.PUTSTATIC, target.name, VAR_ASSERT_NAME, "Z")); + + if (insertReturn) + node.instructions.insert(current, new InsnNode(Opcodes.RETURN)); + } + + protected boolean hasDeclaredAssertionState() { + return target.fields.stream().anyMatch(it -> VAR_ASSERT_NAME.equals(it.name)); + } + /** * Prepares a {@link MethodNode} for grafting on to a given method and into the targeted {@link ClassNode} * @param node Node to adapt @@ -334,8 +421,14 @@ public class Combine { if (insn instanceof MethodInsnNode) insn = adaptMethodInsn((MethodInsnNode) insn, source, node); else if (insn instanceof LdcInsnNode) adaptLdcInsn((LdcInsnNode) insn, source.getTypeName()); else if (insn instanceof FrameNode) adaptFrameNode((FrameNode) insn, source); - else if (insn instanceof FieldInsnNode) adaptFieldInsn((FieldInsnNode) insn, source); else if (insn instanceof InvokeDynamicInsnNode) adaptInvokeDynamicInsn((InvokeDynamicInsnNode) insn, source); + else if (insn instanceof FieldInsnNode) { + adaptFieldInsn((FieldInsnNode) insn, source); + + // If a method is trying to access the assertion state of the class, ensure the target class declares the state + if (VAR_ASSERT_NAME.equals(((FieldInsnNode) insn).name) && insn.getOpcode() == Opcodes.GETSTATIC) + ensureLoadClassAssertionState(); + } } // Adapt variable types