diff --git a/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeGenerator.java b/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeGenerator.java index 129c980a0b9..9de55075c0c 100644 --- a/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeGenerator.java +++ b/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeGenerator.java @@ -33,6 +33,7 @@ import java.lang.reflect.code.op.CoreOp.*; import java.lang.classfile.ClassBuilder; +import java.lang.classfile.ClassModel; import java.lang.classfile.Opcode; import java.lang.classfile.TypeKind; import java.lang.classfile.attribute.ConstantValueAttribute; @@ -129,7 +130,9 @@ public static <O extends Op & Op.Invokable> MethodHandle generate(MethodHandles. * @return the class file bytes */ public static byte[] generateClassData(MethodHandles.Lookup lookup, FuncOp fop) { - return generateClassData(lookup, fop.funcName(), fop); + ClassModel generatedModel = ClassFile.of().parse(generateClassData(lookup, fop.funcName(), fop)); + // Compact locals of the generated bytecode + return ClassFile.of().transform(generatedModel, LocalsCompactor.INSTANCE); } /** @@ -204,6 +207,7 @@ private record ExceptionRegionWithBlocks(ExceptionRegionEnter ere, BitSet blocks private final Map<Block.Parameter, Value> singlePredecessorsValues; private final List<LambdaOp> lambdaSink; private final BitSet quotable; + private final Map<Op, Boolean> deferCache; private Value oprOnStack; private BytecodeGenerator(MethodHandles.Lookup lookup, @@ -230,6 +234,7 @@ private BytecodeGenerator(MethodHandles.Lookup lookup, this.singlePredecessorsValues = new HashMap<>(); this.lambdaSink = lambdaSink; this.quotable = quotable; + this.deferCache = new HashMap<>(); } private void setExceptionRegionStack(Block.Reference target, BitSet activeRegionStack) { @@ -347,13 +352,18 @@ private void processOperands(List<Value> operands) { } // Some of the operations can be deferred - private static boolean canDefer(Op op) { - return switch (op) { - case ConstantOp cop -> canDefer(cop); - case VarOp vop -> canDefer(vop); - case VarAccessOp.VarLoadOp vlop -> canDefer(vlop); - default -> false; - }; + private boolean canDefer(Op op) { + Boolean can = deferCache.get(op); + if (can == null) { + can = switch (op) { + case ConstantOp cop -> canDefer(cop); + case VarOp vop -> canDefer(vop); + case VarAccessOp.VarLoadOp vlop -> canDefer(vlop); + default -> false; + }; + deferCache.put(op, can); + } + return can; } // Constant can be deferred, except for loading of a class constant, which may throw an exception @@ -413,7 +423,7 @@ private static boolean isDominatedBy(Op.Result n, Set<Op.Result> doms) { } // Var load can be deferred when not used as immediate operand - private static boolean canDefer(VarAccessOp.VarLoadOp op) { + private boolean canDefer(VarAccessOp.VarLoadOp op) { return !isNextUse(op.result()); } @@ -445,7 +455,7 @@ case ConditionalBranchOp op when getConditionForCondBrOp(op) instanceof BinaryTe } // Determines if the operation result is immediatelly used by the next operation and so can stay on stack - private static boolean isNextUse(Value opr) { + private boolean isNextUse(Value opr) { Op nextOp = switch (opr) { case Block.Parameter p -> p.declaringBlock().firstOp(); case Op.Result r -> r.declaringBlock().nextOp(r.op()); diff --git a/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeLift.java b/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeLift.java index 8928d7cadb0..39e3dd21047 100644 --- a/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeLift.java +++ b/src/java.base/share/classes/java/lang/reflect/code/bytecode/BytecodeLift.java @@ -73,10 +73,13 @@ import java.util.stream.Stream; import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*; +import java.lang.invoke.MethodHandles; import java.util.BitSet; public final class BytecodeLift { + public static boolean DUMP = false; // @@@ only for debugging purpose + private record ExceptionRegion(Label startLabel, Label endLabel, Label handlerLabel) {} private record ExceptionRegionEntry(Op.Result enter, Block.Builder startBlock, ExceptionRegion region) {} @@ -134,7 +137,7 @@ private BytecodeLift(Block.Builder entryBlock, ClassModel classModel, CodeModel initLocalValues.add(null); } }); - this.codeTracker = new LocalsTypeMapper(classModel.thisClass().asSymbol(), initLocalTypes, codeModel.exceptionHandlers(), smta, elements); + this.codeTracker = new LocalsTypeMapper(classModel.thisClass().asSymbol(), initLocalTypes, codeModel.exceptionHandlers(), smta, elements, codeAttribtue); this.blockMap = smta.map(sma -> sma.entries().stream().collect(Collectors.toUnmodifiableMap( StackMapFrameInfo::target, @@ -423,35 +426,13 @@ private void liftBody() { endOfFlow(); } case LoadInstruction inst -> { - LocalsTypeMapper.Variable var = codeTracker.getVarOf(i); - if (var.isSingleValue) { - assert var.value != null; - stack.push(var.value); - } else { - assert var.value instanceof Op.Result r && r.op() instanceof CoreOp.VarOp; - stack.push(op(CoreOp.varLoad(var.value))); - } + stack.push(load(i)); } case StoreInstruction inst -> { - LocalsTypeMapper.Variable var = codeTracker.getVarOf(i); - if (var.isSingleValue) { - assert var.value == null; - var.value = stack.pop(); - } else { - if (var.value == null) { - var.value = op(CoreOp.var("slot#" + inst.slot(), var.type(), stack.pop())); - } else { - assert var.value instanceof Op.Result r && r.op() instanceof CoreOp.VarOp; - op(CoreOp.varStore(var.value, stack.pop())); - } - } + store(i, inst.slot(), stack.pop()); } case IncrementInstruction inst -> { - LocalsTypeMapper.Variable var = codeTracker.getVarOf(i); - assert !var.isSingleValue && var.value instanceof Op.Result r && r.op() instanceof CoreOp.VarOp; - op(CoreOp.varStore(var.value, op(CoreOp.add( - op(CoreOp.varLoad(var.value)), - liftConstant(inst.constant()))))); + store(i, inst.slot(), op(CoreOp.add(load(-i - 1), liftConstant(inst.constant())))); } case ConstantInstruction inst -> { stack.push(liftConstant(inst.constantValue())); @@ -835,6 +816,32 @@ yield op(CoreOp._new( } } + private Value load(int i) { + LocalsTypeMapper.Variable var = codeTracker.getVarOf(i); + if (var.isSingleValue) { + assert var.value != null; + return var.value; + } else { + assert var.value instanceof Op.Result r && r.op() instanceof CoreOp.VarOp; + return op(CoreOp.varLoad(var.value)); + } + } + + private void store(int i, int slot, Value value) { + LocalsTypeMapper.Variable var = codeTracker.getVarOf(i); + if (var.isSingleValue) { + assert var.value == null; + var.value = value; + } else { + if (var.value == null) { + var.value = op(CoreOp.var("slot#" + slot, var.type(), value)); + } else { + assert var.value instanceof Op.Result r && r.op() instanceof CoreOp.VarOp; + op(CoreOp.varStore(var.value, value)); + } + } + } + private Op.Result lookup() { return constantCache.computeIfAbsent(LOOKUP, _ -> op(CoreOp.invoke(LOOKUP))); } diff --git a/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsCompactor.java b/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsCompactor.java new file mode 100644 index 00000000000..160ba3dcb48 --- /dev/null +++ b/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsCompactor.java @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package java.lang.reflect.code.bytecode; + +import java.lang.classfile.Attributes; +import java.lang.classfile.ClassTransform; +import java.lang.classfile.CodeModel; +import java.lang.classfile.Label; +import java.lang.classfile.MethodModel; +import java.lang.classfile.TypeKind; +import java.lang.classfile.attribute.StackMapFrameInfo; +import java.lang.classfile.instruction.BranchInstruction; +import java.lang.classfile.instruction.IncrementInstruction; +import java.lang.classfile.instruction.LabelTarget; +import java.lang.classfile.instruction.LoadInstruction; +import java.lang.classfile.instruction.LookupSwitchInstruction; +import java.lang.classfile.instruction.TableSwitchInstruction; +import java.lang.classfile.instruction.StoreInstruction; +import java.lang.constant.ClassDesc; +import java.lang.reflect.AccessFlag; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*; +import static java.lang.constant.ConstantDescs.CD_double; +import static java.lang.constant.ConstantDescs.CD_long; + +/** + * LocalsCompactor is a CodeTransform reducing maxLocals. + */ +public final class LocalsCompactor { + + public static final ClassTransform INSTANCE = (clb,cle) -> { + if (cle instanceof MethodModel mm) { + clb.transformMethod(mm, (mb, me) -> { + if (me instanceof CodeModel com) { + int[] slotMap = new LocalsCompactor(com, countParamSlots(mm)).slotMap; + mb.transformCode(com, (cob, coe) -> { + switch (coe) { + case LoadInstruction li -> + cob.loadLocal(li.typeKind(), slotMap[li.slot()]); + case StoreInstruction si -> + cob.storeLocal(si.typeKind(), slotMap[si.slot()]); + case IncrementInstruction ii -> + cob.iinc(slotMap[ii.slot()], ii.constant()); + default -> + cob.with(coe); + } + }); + } else { + mb.with(me); + } + }); + } else { + clb.with(cle); + } + }; + + private static int countParamSlots(MethodModel mm) { + int slots = mm.flags().has(AccessFlag.STATIC) ? 0 : 1; + for (ClassDesc p : mm.methodTypeSymbol().parameterList()) { + slots += p == CD_long || p == CD_double ? 2 : 1; + } + return slots; + } + + static final class Slot { + final BitSet map = new BitSet(); // Liveness map of the slot + int flags; // 0 - single slot, 1 - first of double slots, 2 - second of double slots, 3 - mixed + } + + private final List<Slot> maps; // Intermediate slots liveness maps + private final Map<Label, List<StackMapFrameInfo.VerificationTypeInfo>> frames; + private final int[] slotMap; // Output mapping of the slots + + private LocalsCompactor(CodeModel com, int fixedSlots) { + frames = com.findAttribute(Attributes.stackMapTable()).map( + smta -> smta.entries().stream().collect( + Collectors.toMap(StackMapFrameInfo::target, StackMapFrameInfo::locals))) + .orElse(Map.of()); + var exceptionHandlers = com.exceptionHandlers(); + maps = new ArrayList<>(); + int pc = 0; + // Initialization of fixed slots + for (int slot = 0; slot < fixedSlots; slot++) { + getMap(slot).map.set(0); + } + // Filling the slots liveness maps + for (var e : com) { + switch(e) { + case LabelTarget lt -> { + for (var eh : exceptionHandlers) { + if (eh.tryStart() == lt.label()) { + mergeFrom(pc, eh.handler()); + } + } + } + case LoadInstruction li -> + load(pc, li.slot(), li.typeKind()); + case StoreInstruction si -> + store(pc, si.slot(), si.typeKind()); + case IncrementInstruction ii -> + loadSingle(pc, ii.slot()); + case BranchInstruction bi -> + mergeFrom(pc, bi.target()); + case LookupSwitchInstruction si -> { + mergeFrom(pc, si.defaultTarget()); + for (var sc : si.cases()) { + mergeFrom(pc, sc.target()); + } + } + case TableSwitchInstruction si -> { + mergeFrom(pc, si.defaultTarget()); + for (var sc : si.cases()) { + mergeFrom(pc, sc.target()); + } + } + default -> pc--; + } + pc++; + } + // Initialization of slots mapping + slotMap = new int[maps.size()]; + for (int slot = 0; slot < slotMap.length; slot++) { + slotMap[slot] = slot; + } + // Iterative merging of slots + for (int targetSlot = 0; targetSlot < maps.size() - 1; targetSlot++) { + for (int sourceSlot = Math.max(targetSlot + 1, fixedSlots); sourceSlot < maps.size(); sourceSlot++) { + Slot source = maps.get(sourceSlot); + // Re-mapping single slot + if (source.flags == 0) { + Slot target = maps.get(targetSlot); + if (!target.map.intersects(source.map)) { + // Single re-mapping, merge of the liveness maps and shift of the following slots by 1 left + target.map.or(source.map); + maps.remove(sourceSlot); + for (int slot = 0; slot < slotMap.length; slot++) { + if (slotMap[slot] == sourceSlot) { + slotMap[slot] = targetSlot; + } else if (slotMap[slot] > sourceSlot) { + slotMap[slot]--; + } + } + } + } else if (source.flags == 1 && sourceSlot > targetSlot + 1) { + Slot source2 = maps.get(sourceSlot + 1); + // Re-mapping distinct double slot + if (source2.flags == 2) { + Slot target = maps.get(targetSlot); + Slot target2 = maps.get(targetSlot + 1); + if (!target.map.intersects(source.map) && !target2.map.intersects(source2.map)) { + // Double re-mapping, merge of the liveness maps and shift of the following slots by 2 left + target.map.or(source.map); + target2.map.or(source2.map); + maps.remove(sourceSlot + 1); + maps.remove(sourceSlot); + for (int slot = 0; slot < slotMap.length; slot++) { + if (slotMap[slot] == sourceSlot) { + slotMap[slot] = targetSlot; + } else if (slotMap[slot] == sourceSlot + 1) { + slotMap[slot] = targetSlot + 1; + } else if (slotMap[slot] > sourceSlot + 1) { + slotMap[slot] -= 2; + } + } + } + } + } + } + } + } + + private Slot getMap(int slot) { + while (slot >= maps.size()) { + maps.add(new Slot()); + } + return maps.get(slot); + } + + private Slot loadSingle(int pc, int slot) { + Slot s = getMap(slot); + int start = s.map.nextSetBit(0) + 1; + s.map.set(start, pc + 1); + return s; + } + + private void load(int pc, int slot, TypeKind tk) { + load(pc, slot, tk.slotSize() == 2); + } + + private void load(int pc, int slot, boolean dual) { + if (dual) { + loadSingle(pc, slot).flags |= 1; + loadSingle(pc, slot + 1).flags |= 2; + } else { + loadSingle(pc, slot); + } + } + + private void mergeFrom(int pc, Label target) { + int slot = 0; + for (var vti : frames.get(target)) { + if (vti != ITEM_TOP) { + if (vti == ITEM_LONG || vti == ITEM_DOUBLE) { + load(pc, slot++, true); + } else { + loadSingle(pc, slot); + } + } + slot++; + } + } + + private Slot storeSingle(int pc, int slot) { + Slot s = getMap(slot); + s.map.set(pc); + return s; + } + + private void store(int pc, int slot, TypeKind tk) { + if (tk.slotSize() == 2) { + storeSingle(pc, slot).flags |= 1; + storeSingle(pc, slot + 1).flags |= 2; + } else { + storeSingle(pc, slot); + } + } +} diff --git a/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsTypeMapper.java b/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsTypeMapper.java index 0f5f5a25cd0..0dfe5fd192d 100644 --- a/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsTypeMapper.java +++ b/src/java.base/share/classes/java/lang/reflect/code/bytecode/LocalsTypeMapper.java @@ -29,6 +29,7 @@ import java.lang.classfile.Label; import java.lang.classfile.Opcode; import java.lang.classfile.TypeKind; +import java.lang.classfile.attribute.CodeAttribute; import java.lang.classfile.attribute.StackMapFrameInfo; import java.lang.classfile.attribute.StackMapFrameInfo.*; import java.lang.classfile.attribute.StackMapTableAttribute; @@ -47,12 +48,14 @@ import java.util.stream.Collectors; import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.*; +import java.lang.classfile.components.ClassPrinter; import static java.lang.constant.ConstantDescs.*; import java.lang.reflect.code.Value; import java.lang.reflect.code.type.JavaType; import java.util.ArrayDeque; -import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashSet; +import java.util.NoSuchElementException; import java.util.Set; final class LocalsTypeMapper { @@ -80,17 +83,26 @@ Object defaultValue() { default -> throw new IllegalStateException("Invalid type " + type.displayName()); }; } + + @Override + public String toString() { + return Integer.toHexString(hashCode()).substring(0, 2) + " " + isSingleValue; + } } - static class Slot { + static final class Slot { + + enum Kind { + STORE, LOAD, FRAME; + } - record Link(Slot slot, Link other) {} + private record Link(Slot slot, Link other) {} + int bci, sl; // @@@ only for debugging purpose + Kind kind; ClassDesc type; - Link up, down; Variable var; - boolean newValue; - Slot previous; // Previous Slot, not necessary of the same variable + private Link up, down; void link(Slot target) { if (this != target) { @@ -98,6 +110,40 @@ void link(Slot target) { this.down = new Link(target, this.down); } } + + Iterable<Slot> upSlots() { + return () -> new LinkIterator(up); + } + + Iterable<Slot> downSlots() { + return () -> new LinkIterator(down); + } + + @Override + public String toString() { + // @@@ only for debugging purpose + return "%d: #%d %s %s var:%s".formatted(bci, sl, kind, type.displayName(), var == null ? null : var.toString()); + } + + static final class LinkIterator implements Iterator<Slot> { + Link l; + public LinkIterator(Link l) { + this.l = l; + } + + @Override + public boolean hasNext() { + return l != null; + } + + @Override + public Slot next() { + if (l == null) throw new NoSuchElementException(); + Slot s = l.slot(); + l = l.other(); + return s; + } + } } record Frame(List<ClassDesc> stack, List<Slot> locals) {} @@ -107,41 +153,46 @@ record Frame(List<ClassDesc> stack, List<Slot> locals) {} private final LinkedHashSet<Slot> allSlots; private final ClassDesc thisClass; private final List<ExceptionCatch> exceptionHandlers; + private final Set<ExceptionCatch> handlersStack; private final List<ClassDesc> stack; private final List<Slot> locals; private final Map<Label, Frame> stackMap; private final Map<Label, ClassDesc> newMap; + private final CodeAttribute ca; private boolean frameDirty; final List<Slot> slotsToInitialize; LocalsTypeMapper(ClassDesc thisClass, - List<ClassDesc> initFrameLocals, - List<ExceptionCatch> exceptionHandlers, - Optional<StackMapTableAttribute> stackMapTableAttribute, - List<CodeElement> codeElements) { + List<ClassDesc> initFrameLocals, + List<ExceptionCatch> exceptionHandlers, + Optional<StackMapTableAttribute> stackMapTableAttribute, + List<CodeElement> codeElements, + CodeAttribute ca) { this.insMap = new HashMap<>(); this.thisClass = thisClass; this.exceptionHandlers = exceptionHandlers; + this.handlersStack = new LinkedHashSet<>(); this.stack = new ArrayList<>(); this.locals = new ArrayList<>(); this.allSlots = new LinkedHashSet<>(); this.newMap = computeNewMap(codeElements); this.slotsToInitialize = new ArrayList<>(); + this.ca = ca; // @@@ only for debugging purpose this.stackMap = stackMapTableAttribute.map(a -> a.entries().stream().collect(Collectors.toMap( StackMapFrameInfo::target, this::toFrame))).orElse(Map.of()); for (ClassDesc cd : initFrameLocals) { - slotsToInitialize.add(cd == null ? null : newSlot(cd, true)); + slotsToInitialize.add(cd == null ? null : newSlot(cd, Slot.Kind.STORE, -1, slotsToInitialize.size())); } int initSize = allSlots.size(); do { + handlersStack.clear(); // Slot states reset if running additional rounds with adjusted frames if (allSlots.size() > initSize) { while (allSlots.size() > initSize) allSlots.removeLast(); allSlots.forEach(sl -> { sl.up = null; sl.down = null; - sl.previous = null; sl.var = null; }); } @@ -149,55 +200,81 @@ record Frame(List<ClassDesc> stack, List<Slot> locals) {} store(i, slotsToInitialize.get(i), locals); } this.frameDirty = false; + int bci = 0; for (int i = 0; i < codeElements.size(); i++) { - accept(i, codeElements.get(i)); + var ce = codeElements.get(i); + accept(i, ce, bci); + if (ce instanceof Instruction ins) bci += ins.sizeInBytes(); } endOfFlow(); } while (this.frameDirty); - // Assign variable to slots, calculate var type, detect single value variables and dominant slot + // Pull LOADs up the FRAMEs + boolean changed = true; + while (changed) { + changed = false; + for (Slot slot : allSlots) { + if (slot.kind == Slot.Kind.FRAME) { + for (Slot down : slot.downSlots()) { + if (down.kind == Slot.Kind.LOAD) { + changed = true; + slot.kind = Slot.Kind.LOAD; + break; + } + } + } + } + } + + // Assign variable to slots, calculate var type + Set<Slot> stores = new LinkedHashSet<>(); ArrayDeque<Slot> q = new ArrayDeque<>(); - Set<Slot> initialSlots = new HashSet<>(); + Set<Slot> visited = new LinkedHashSet<>(); for (Slot slot : allSlots) { - if (slot.var == null) { + if (slot.var == null && slot.kind != Slot.Kind.FRAME) { Variable var = new Variable(); q.add(slot); - int sources = 0; var.type = slot.type; while (!q.isEmpty()) { Slot sl = q.pop(); if (sl.var == null) { - if (sl.newValue) { - sources++; - if (sl.up == null) { - initialSlots.add(sl); - } - } sl.var = var; - Slot.Link l = sl.up; - while (l != null) { - if (var.type == NULL_TYPE) var.type = l.slot.type; - if (l.slot.var == null) q.add(l.slot); - l = l.other; + for (Slot down : sl.downSlots()) { + if (down.kind == Slot.Kind.LOAD) { + if (var.type == NULL_TYPE) var.type = down.type; + if (down.var == null) q.add(down); + } } - l = sl.down; - while (l != null) { - if (var.type == NULL_TYPE) var.type = l.slot.type; - if (l.slot.var == null) q.add(l.slot); - l = l.other; + if (sl.kind == Slot.Kind.LOAD) { + for (Slot up : sl.upSlots()) { + if (up.kind != Slot.Kind.FRAME) { + if (var.type == NULL_TYPE) var.type = up.type; + if (up.var == null) { + q.add(up); + } + } + } } } + if (sl.var == var && sl.kind == Slot.Kind.STORE) { + stores.add(sl); + } } - var.isSingleValue = sources < 2; - // Filter out slots, which are not initial (store into the same variable) - for (var tsit = initialSlots.iterator(); tsit.hasNext();) { - Slot sl = tsit.next(); - if (sl.previous != null && sl.previous.var == sl.var) { - tsit.remove(); + // Detect single value + var.isSingleValue = stores.size() < 2; + + // Filter initial stores + for (var it = stores.iterator(); it.hasNext();) { + visited.clear(); + Slot s = it.next(); + if (s.up != null && preceedsWithTheVar(s, var, visited)) { + it.remove(); } } - if (initialSlots.size() > 1) { + + // Insert var initialization if necessary + if (stores.size() > 1) { // Add synthetic dominant slot, which needs to be initialized with a default value Slot initialSlot = new Slot(); initialSlot.var = var; @@ -206,9 +283,42 @@ record Frame(List<ClassDesc> stack, List<Slot> locals) {} slotsToInitialize.add(null); } } - initialSlots.clear(); + stores.clear(); } } + + // @@@ only for debugging purpose + if (BytecodeLift.DUMP) { + ClassPrinter.toYaml(ca, ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print); + System.out.println("digraph {"); + for (Slot s : allSlots) { + System.out.println(" S" + Integer.toHexString(s.hashCode()) + " [label=\"" + s.toString() + "\"]"); + } + System.out.println(); + for (Slot s : allSlots) { + var it = s.downSlots().iterator(); + if (it.hasNext()) { + System.out.print(" S" + Integer.toHexString(s.hashCode()) + " -> {S" + Integer.toHexString(it.next().hashCode())); + while (it.hasNext()) { + System.out.print(", S" + Integer.toHexString(it.next().hashCode())); + } + System.out.println("};"); + } + } + System.out.println("}"); + } + } + + // Detects if all of the preceding slots belong to the var + private static boolean preceedsWithTheVar(Slot slot, Variable var, Set<Slot> visited) { + if (visited.add(slot)) { + for (Slot up : slot.upSlots()) { + if (up.var == null ? up.kind != Slot.Kind.FRAME || !preceedsWithTheVar(up, var, visited) : up.var != var) { + return false; + } + } + } + return true; } private Frame toFrame(StackMapFrameInfo smfi) { @@ -218,8 +328,9 @@ private Frame toFrame(StackMapFrameInfo smfi) { fstack.add(vtiToStackType(vti)); } int i = 0; + int bci = ca.labelToBci(smfi.target()); //@@@ only for debugging purpose for (var vti : smfi.locals()) { - store(i, vtiToStackType(vti), flocals, false); + store(i, vtiToStackType(vti), flocals, Slot.Kind.FRAME, bci); i += vti == ITEM_DOUBLE || vti == ITEM_LONG ? 2 : 1; } return new Frame(fstack, flocals); @@ -247,10 +358,12 @@ Variable getVarOf(int li) { return insMap.get(li).var; } - private Slot newSlot(ClassDesc type, boolean newValue) { + private Slot newSlot(ClassDesc type, Slot.Kind kind, int bci, int sl) { Slot s = new Slot(); + s.kind = kind; s.type = type; - s.newValue = newValue; + s.bci = bci; + s.sl = sl; // @@@ only for debugging purpose allSlots.add(s); return s; } @@ -308,26 +421,32 @@ private LocalsTypeMapper pop(int i) { return this; } - private void store(int slot, ClassDesc type) { - store(slot, type, locals, true); + private void store(int slot, ClassDesc type, int bci) { + store(slot, type, locals, Slot.Kind.STORE, bci); } - private void store(int slot, ClassDesc type, List<Slot> where, boolean newValue) { - store(slot, type == null ? null : newSlot(type, newValue), where); + private void store(int slot, ClassDesc type, List<Slot> where, Slot.Kind kind, int bci) { + store(slot, type == null ? null : newSlot(type, kind, bci, slot), where); } private void store(int slot, Slot s, List<Slot> where) { if (s != null) { for (int i = where.size(); i <= slot; i++) where.add(null); - s.previous = where.set(slot, s); + Slot prev = where.set(slot, s); + if (prev != null) { + prev.link(s); + } } } - private ClassDesc load(int slot) { - return locals.get(slot).type; + private ClassDesc load(int slot, int bci) { + Slot sl = locals.get(slot); + Slot nsl = newSlot(sl.type, Slot.Kind.LOAD, bci, slot); + sl.link(nsl); + return sl.type; } - private void accept(int elIndex, CodeElement el) { + private void accept(int elIndex, CodeElement el, int bci) { switch (el) { case ArrayLoadInstruction _ -> pop(1).push(pop().componentType()); @@ -378,10 +497,13 @@ private void accept(int elIndex, CodeElement el) { } } case IncrementInstruction i -> { - Slot v = locals.get(i.slot()); - store(i.slot(), load(i.slot())); - v.link(locals.get(i.slot())); - insMap.put(elIndex, v); + load(i.slot(), bci); + insMap.put(-elIndex - 1, locals.get(i.slot())); + store(i.slot(), CD_int, bci); + insMap.put(elIndex, locals.get(i.slot())); + for (var ec : handlersStack) { + mergeLocalsToTargetFrame(stackMap.get(ec.handler())); + } } case InvokeDynamicInstruction i -> pop(i.typeSymbol().parameterCount()).push(i.typeSymbol().returnType()); @@ -389,12 +511,15 @@ private void accept(int elIndex, CodeElement el) { pop(i.typeSymbol().parameterCount() + (i.opcode() == Opcode.INVOKESTATIC ? 0 : 1)) .push(i.typeSymbol().returnType()); case LoadInstruction i -> { - push(load(i.slot())); + push(load(i.slot(), bci)); insMap.put(elIndex, locals.get(i.slot())); } case StoreInstruction i -> { - store(i.slot(), pop()); + store(i.slot(), pop(), bci); insMap.put(elIndex, locals.get(i.slot())); + for (var ec : handlersStack) { + mergeLocalsToTargetFrame(stackMap.get(ec.handler())); + } } case MonitorInstruction _ -> pop(1); @@ -456,8 +581,12 @@ private void accept(int elIndex, CodeElement el) { } for (ExceptionCatch ec : exceptionHandlers) { if (lt.label() == ec.tryStart()) { + handlersStack.add(ec); mergeLocalsToTargetFrame(stackMap.get(ec.handler())); } + if (lt.label() == ec.tryEnd()) { + handlersStack.remove(ec); + } } } case ReturnInstruction _ , ThrowInstruction _ -> { @@ -519,8 +648,6 @@ private void mergeLocalsToTargetFrame(Frame targetFrame) { if (le.type.isPrimitive() && CD_int.equals(fe.type) ) { fe.type = le.type; // Override int target frame type with more specific int sub-type this.frameDirty = true; - } else { - le.type = fe.type; // Override var type with target frame type } } } diff --git a/test/jdk/java/lang/reflect/code/bytecode/TestSmallCorpus.java b/test/jdk/java/lang/reflect/code/bytecode/TestSmallCorpus.java index 8096eda4b11..3ffb39c5ab4 100644 --- a/test/jdk/java/lang/reflect/code/bytecode/TestSmallCorpus.java +++ b/test/jdk/java/lang/reflect/code/bytecode/TestSmallCorpus.java @@ -21,13 +21,13 @@ * questions. */ -import java.io.PrintWriter; import java.io.StringWriter; import java.lang.classfile.ClassFile; import java.lang.classfile.Instruction; import java.lang.classfile.Label; import java.lang.classfile.MethodModel; import java.lang.classfile.Opcode; +import java.lang.classfile.attribute.CodeAttribute; import java.lang.classfile.components.ClassPrinter; import java.lang.classfile.instruction.*; import java.lang.invoke.MethodHandles; @@ -57,11 +57,17 @@ */ public class TestSmallCorpus { + private static final String ROOT_PATH = "modules/java.base/"; + private static final String CLASS_NAME_SUFFIX = ".class"; + private static final String METHOD_NAME = null; + private static final int ROUNDS = 3; + private static final FileSystem JRT = FileSystems.getFileSystem(URI.create("jrt:/")); private static final ClassFile CF = ClassFile.of(); private static final int COLUMN_WIDTH = 150; private static final MethodHandles.Lookup TRUSTED_LOOKUP; static { + BytecodeLift.DUMP = false; try { var lf = MethodHandles.Lookup.class.getDeclaredField("IMPL_LOOKUP"); lf.setAccessible(true); @@ -71,122 +77,89 @@ public class TestSmallCorpus { } } - private int stable, unstable; - private Map<String, Map<String, Integer>> errorStats; + private MethodModel bytecode; + CoreOp.FuncOp reflection; + private int stable, unstable, originalMaxLocals, maxLocals; @Ignore @Test - public void testTripleRoundtripStability() throws Exception { + public void testRoundTripStability() throws Exception { stable = 0; unstable = 0; - errorStats = new LinkedHashMap<>(); - for (Path p : Files.walk(JRT.getPath("modules/java.base/")) - .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(".class")) + originalMaxLocals = 0; + maxLocals = 0; + for (Path p : Files.walk(JRT.getPath(ROOT_PATH)) + .filter(p -> Files.isRegularFile(p) && p.toString().endsWith(CLASS_NAME_SUFFIX)) .toList()) { - testDoubleRoundtripStability(p); - } - - for (var stats : errorStats.entrySet()) { - System.out.println(String.format(""" - - %s errors: - ----------------------------------------------------- - """, stats.getKey())); - stats.getValue().entrySet().stream().sorted((e1, e2) -> Integer.compare(e2.getValue(), e1.getValue())).forEach(e -> System.out.println(e.getValue() +"x " + e.getKey() + "\n")); + testRoundTripStability(p); } // Roundtrip is >99% stable, no exceptions, no verification errors - Assert.assertTrue(stable > 65240 && unstable < 110 && errorStats.isEmpty(), String.format(""" - - stable: %d - unstable: %d - %s - """, - stable, - unstable, - errorStats.entrySet().stream().map(e -> e.getKey() + - " errors: " - + e.getValue().values().stream().mapToInt(Integer::intValue).sum()).collect(Collectors.joining("\n ")) - )); + Assert.assertTrue(stable > 65290 && unstable < 100, String.format("stable: %d unstable: %d original maxLocals: %d maxLocals: %d", stable, unstable, originalMaxLocals, maxLocals)); } - private void testDoubleRoundtripStability(Path path) throws Exception { + private void testRoundTripStability(Path path) throws Exception { var clm = CF.parse(path); for (var originalModel : clm.methods()) { - if (originalModel.code().isPresent()) try { - CoreOp.FuncOp firstLift = lift(originalModel); - verify("first lift verify", firstLift); - try { - MethodModel firstModel = lower(firstLift); - verify("first gen verify", firstModel); - try { - CoreOp.FuncOp secondLift = lift(firstModel); - verify("second lift verify", firstLift); - try { - MethodModel secondModel = lower(secondLift); - verify("second gen verify", secondModel); - try { - CoreOp.FuncOp thirdLift = lift(secondModel); - verify("third lift verify", firstLift); - try { - MethodModel thirdModel = lower(thirdLift); - verify("third gen verify", thirdModel); - // testing only methods passing through - var secondNormalized = normalize(secondModel); - var thirdNormalized = normalize(thirdModel); - if (!thirdNormalized.equals(secondNormalized)) { - unstable++; - System.out.println(clm.thisClass().asInternalName() + "::" + originalModel.methodName().stringValue() + originalModel.methodTypeSymbol().displayDescriptor()); - printInColumns(secondLift, thirdLift); - printInColumns(secondNormalized, thirdNormalized); - System.out.println(); - } else { - stable++; - } - } catch (Throwable t) { - error("third gen", t); - } - } catch (Throwable t) { - error("third lift", t); - } - } catch (Throwable t) { - error("second gen", t); - } - } catch (Throwable t) { - error("second lift", t); - } + if (originalModel.code().isPresent() && (METHOD_NAME == null || originalModel.methodName().equalsString(METHOD_NAME))) { + bytecode = originalModel; + reflection = null; + MethodModel prevBytecode = null; + CoreOp.FuncOp prevReflection = null; + for (int round = 1; round <= ROUNDS; round++) try { + prevBytecode = bytecode; + prevReflection = reflection; + lift(); + verifyReflection(); + generate(); + verifyBytecode(); } catch (Throwable t) { - error("first gen", t); + System.out.println(" at " + path + " " + originalModel.methodName() + originalModel.methodType() + " round " + round); + throw t; + } + if (ROUNDS > 0) { + var normPrevBytecode = normalize(prevBytecode); + var normBytecode = normalize(bytecode); + if (normPrevBytecode.equals(normBytecode)) { + stable++; + } else { + unstable++; + System.out.println("Unstable code " + path + " " + originalModel.methodName() + originalModel.methodType() + " after " + ROUNDS +" round(s)"); + printInColumns(normPrevBytecode, normBytecode); + printInColumns(prevReflection, reflection); + System.out.println(); + } + originalMaxLocals += ((CodeAttribute)originalModel.code().get()).maxLocals(); + maxLocals += ((CodeAttribute)bytecode.code().get()).maxLocals(); } - } catch (Throwable t) { - error("first lift", t); } } } - private void verify(String category, CoreOp.FuncOp func) { - OpWriter.CodeItemNamerOption naming = func.traverse(null, CodeElement.opVisitor((n, op) -> { + private void verifyReflection() { + reflection.traverse(null, CodeElement.opVisitor((n, op) -> { for (Value v : op.operands()) { // Verify operands dominance if (!op.result().isDominatedBy(v)) { - if (n == null) { - n = OpWriter.CodeItemNamerOption.of(OpWriter.computeGlobalNames(func)); - } - error(category, "block_%d %s is not dominated by its operand declaration in block_%d".formatted( - op.parentBlock().index(), OpWriter.toText(op, n), v.declaringBlock().index())); + printBytecode(); + var naming = OpWriter.CodeItemNamerOption.of(OpWriter.computeGlobalNames(reflection)); + System.out.println(OpWriter.toText(reflection, naming)); + System.out.println("Reflection verification failed"); + throw new AssertionError("block_%d %s is not dominated by its operand declaration in block_%d".formatted( + op.parentBlock().index(), OpWriter.toText(op, naming), v.declaringBlock().index())); } } - return n; + return null; })); - if (naming != null) { - System.out.println(OpWriter.toText(func, naming)); - } } - private void verify(String category, MethodModel model) { - for (var e : ClassFile.of().verify(model.parent().get())) { + private void verifyBytecode() { + for (var e : ClassFile.of().verify(bytecode.parent().get())) { if (!e.getMessage().contains("Illegal call to internal method")) { - error(category, e.getMessage()); + printReflection(); + printBytecode(); + System.out.println("Bytecode verification failed"); + throw new AssertionError(e.getMessage()); } } } @@ -208,16 +181,37 @@ private static void printInColumns(List<String> first, List<String> second) { } } - private static CoreOp.FuncOp lift(MethodModel mm) { - return BytecodeLift.lift(mm); + private void lift() { + try { + reflection = BytecodeLift.lift(bytecode); + } catch (Throwable t) { + printReflection(); + printBytecode(); + System.out.println("Lift failed"); + throw t; + } } - private static MethodModel lower(CoreOp.FuncOp func) { - return CF.parse(BytecodeGenerator.generateClassData( + private void generate() { + try { + bytecode = CF.parse(BytecodeGenerator.generateClassData( TRUSTED_LOOKUP, - func)).methods().get(0); + reflection)).methods().getFirst(); + } catch (Throwable t) { + printBytecode(); + printReflection(); + System.out.println("Generation failed"); + throw t; + } } + private void printBytecode() { + ClassPrinter.toYaml(bytecode, ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print); + } + + private void printReflection() { + if (reflection != null) System.out.println(reflection.toText()); + } public static List<String> normalize(MethodModel mm) { record El(int index, String format, Label... targets) { @@ -293,15 +287,4 @@ private static String trim(Opcode opcode) { int i = name.indexOf('_'); return i > 2 ? name.substring(0, i) : name; } - - private void error(String category, Throwable t) { - StringWriter sw = new StringWriter(); - t.printStackTrace(new PrintWriter(sw)); - error(category, sw.toString()); - } - - private void error(String category, String msg) { - errorStats.computeIfAbsent(category, _ -> new HashMap<>()) - .compute(msg, (_, i) -> i == null ? 1 : i + 1); - } }