Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception regions fix #12

Closed
wants to merge 13 commits into from
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package java.lang.reflect.code.bytecode;

import java.lang.classfile.CodeBuilder;
import java.lang.classfile.CodeElement;
import java.lang.classfile.CodeTransform;
import java.lang.classfile.PseudoInstruction;
import java.lang.classfile.instruction.BranchInstruction;
import java.lang.classfile.instruction.LabelTarget;
import java.util.ArrayList;
import java.util.List;

/**
* BranchCompactor is a CodeTransform skipping redundant branches to immediate targets.
*/
public final class BranchCompactor implements CodeTransform {

private BranchInstruction branch;
private final List<PseudoInstruction> buffer = new ArrayList<>();

public BranchCompactor() {
}

@Override
public void accept(CodeBuilder cob, CodeElement coe) {
if (branch == null) {
if (coe instanceof BranchInstruction bi && bi.opcode().isUnconditionalBranch()) {
//unconditional branch is stored
branch = bi;
} else {
//all other elements are passed
cob.with(coe);
}
} else {
switch (coe) {
case LabelTarget lt when branch.target() == lt.label() -> {
//skip branch to immediate target
branch = null;
//flush the buffer
atEnd(cob);
//pass the target
cob.with(lt);
}
case PseudoInstruction pi -> {
//buffer pseudo instructions
buffer.add(pi);
}
default -> {
//any other instruction flushes the branch and buffer
atEnd(cob);
//replay the code element
accept(cob, coe);
}
}
}
}

@Override
public void atEnd(CodeBuilder cob) {
if (branch != null) {
//flush the branch
cob.with(branch);
branch = null;
}
//flush the buffer
buffer.forEach(cob::with);
buffer.clear();
}
}
Original file line number Diff line number Diff line change
@@ -35,16 +35,11 @@
import java.io.File;
import java.io.FileOutputStream;
import java.lang.classfile.CodeBuilder.BlockCodeBuilder;
import java.lang.classfile.CodeElement;
import java.lang.classfile.CodeTransform;
import java.lang.classfile.Instruction;
import java.lang.classfile.Opcode;
import java.lang.classfile.TypeKind;
import static java.lang.classfile.TypeKind.DoubleType;
import static java.lang.classfile.TypeKind.FloatType;
import static java.lang.classfile.TypeKind.LongType;
import java.lang.classfile.instruction.BranchInstruction;
import java.lang.classfile.instruction.LabelTarget;
import java.lang.constant.ClassDesc;
import java.lang.constant.ConstantDesc;
import java.lang.constant.ConstantDescs;
@@ -57,7 +52,6 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.code.Block.Reference;
import java.lang.reflect.code.Value;
import java.lang.reflect.code.analysis.Liveness;
import java.lang.reflect.code.descriptor.FieldDesc;
@@ -66,7 +60,6 @@
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
@@ -116,22 +109,6 @@ private static void print(byte[] classBytes) {
ClassPrinter.toYaml(cm, ClassPrinter.Verbosity.CRITICAL_ATTRIBUTES, System.out::print);
}

private static class DebugTransform implements CodeTransform {
private int bci = 0;

@Override
public void accept(CodeBuilder cob, CodeElement coe) {
cob.accept(coe);
switch (coe) {
case LabelTarget lt -> System.out.println(bci + ": " + lt + " (" + lt.label().hashCode() + ")");
case BranchInstruction bi -> System.out.println(bci + ": " + bi + " (" + bi.target().hashCode() + ")");
default -> System.out.println(bci + ": " + coe);
}
if (coe instanceof Instruction ins) bci += ins.sizeInBytes();
}

}

public static byte[] generateClassData(MethodHandles.Lookup lookup, CoreOps.FuncOp fop) {
String packageName = lookup.lookupClass().getPackageName();
String className = packageName.isEmpty()
@@ -143,7 +120,7 @@ public static byte[] generateClassData(MethodHandles.Lookup lookup, CoreOps.Func
fop.funcName(),
fop.funcDescriptor().toNominalDescriptor(),
ClassFile.ACC_PUBLIC | ClassFile.ACC_STATIC,
cb -> cb.transforming(new DebugTransform(), cob -> {
cb -> cb.transforming(new BranchCompactor(), cob -> {
ConversionContext c = new ConversionContext(lookup, liveness, cob);
generateBody(fop.body(), cob, c);
})));
@@ -167,7 +144,6 @@ static final class ConversionContext implements BytecodeInstructionOps.MethodVis
final Map<Block, LiveSlotSet> liveSet;
Block current;
final Set<Block> catchingBlocks;
final Map<Block, ExceptionRegionNode> coveredBlocks;

public ConversionContext(MethodHandles.Lookup lookup, Liveness liveness, CodeBuilder cb) {
this.lookup = lookup;
@@ -177,7 +153,6 @@ public ConversionContext(MethodHandles.Lookup lookup, Liveness liveness, CodeBui
this.labels = new HashMap<>();
this.liveSet = new HashMap<>();
this.catchingBlocks = new HashSet<>();
this.coveredBlocks = new HashMap<>();
}

@Override
@@ -207,15 +182,10 @@ int getSlot(Value v) {
return liveSlotSet().getSlot(v);
}


int getOrAssignSlot(Value v, boolean assignIfUnused) {
return liveSlotSet().getOrAssignSlot(v, assignIfUnused);
}

int getOrAssignSlot(Value v) {
return getOrAssignSlot(v, false);
}

int assignSlot(Value v) {
return liveSlotSet().assignSlot(v);
}
@@ -345,88 +315,77 @@ private static TypeKind toTypeKind(TypeDesc t) {
}
}

record ExceptionRegionNode(CoreOps.ExceptionRegionEnter ere, int size, ExceptionRegionNode next) {
}

static final ExceptionRegionNode NIL = new ExceptionRegionNode(null, 0, null);

private static void computeExceptionRegionMembership(Body r, ConversionContext c) {
Set<Block> visited = new HashSet<>();
Deque<Block> stack = new ArrayDeque<>();
stack.push(r.entryBlock());

// Set of catching blocks
Set<Block> catchingBlocks = c.catchingBlocks;
// Map of block to stack of covered exception regions
Map<Block, ExceptionRegionNode> coveredBlocks = c.coveredBlocks;
private static void computeExceptionRegionMembership(Body body, CodeBuilder cob, ConversionContext c) {
record ExceptionRegionWithBlocks(CoreOps.ExceptionRegionEnter ere, BitSet blocks) {
}
// List of all regions
final List<ExceptionRegionWithBlocks> allRegions = new ArrayList<>();
class BlockWithActiveExceptionRegions {
final Block block;
final BitSet activeRegionStack;
BlockWithActiveExceptionRegions(Block block, BitSet activeRegionStack) {
this.block = block;
this.activeRegionStack = activeRegionStack;
activeRegionStack.stream().forEach(r -> allRegions.get(r).blocks.set(block.index()));
}
}
final Set<Block> visited = new HashSet<>();
final Deque<BlockWithActiveExceptionRegions> stack = new ArrayDeque<>();
stack.push(new BlockWithActiveExceptionRegions(body.entryBlock(), new BitSet()));
// Compute exception region membership
while (!stack.isEmpty()) {
Block b = stack.pop();
BlockWithActiveExceptionRegions bm = stack.pop();
Block b = bm.block;
if (!visited.add(b)) {
continue;
}

Op top = b.terminatingOp();
ExceptionRegionNode bRegions = coveredBlocks.get(b);
if (top instanceof CoreOps.BranchOp bop) {
if (bRegions != null) {
coveredBlocks.put(bop.branch().targetBlock(), bRegions);
}

stack.push(bop.branch().targetBlock());
stack.push(new BlockWithActiveExceptionRegions(bop.branch().targetBlock(), bm.activeRegionStack));
} else if (top instanceof CoreOps.ConditionalBranchOp cop) {
if (bRegions != null) {
coveredBlocks.put(cop.falseBranch().targetBlock(), bRegions);
coveredBlocks.put(cop.trueBranch().targetBlock(), bRegions);
}

stack.push(cop.falseBranch().targetBlock());
stack.push(cop.trueBranch().targetBlock());
stack.push(new BlockWithActiveExceptionRegions(cop.falseBranch().targetBlock(), bm.activeRegionStack));
stack.push(new BlockWithActiveExceptionRegions(cop.trueBranch().targetBlock(), bm.activeRegionStack));
} else if (top instanceof CoreOps.ExceptionRegionEnter er) {
ArrayList<Block.Reference> catchBlocks = new ArrayList<>(er.catchBlocks());
Collections.reverse(catchBlocks);
for (Block.Reference catchBlock : catchBlocks) {
catchingBlocks.add(catchBlock.targetBlock());
if (bRegions != null) {
coveredBlocks.put(catchBlock.targetBlock(), bRegions);
}

stack.push(catchBlock.targetBlock());
for (Block.Reference catchBlock : er.catchBlocks().reversed()) {
c.catchingBlocks.add(catchBlock.targetBlock());
stack.push(new BlockWithActiveExceptionRegions(catchBlock.targetBlock(), bm.activeRegionStack));
}

ExceptionRegionNode n;
if (bRegions != null) {
n = new ExceptionRegionNode(er, bRegions.size + 1, bRegions);
} else {
n = new ExceptionRegionNode(er, 1, NIL);
}
coveredBlocks.put(er.start().targetBlock(), n);

stack.push(er.start().targetBlock());
BitSet activeRegionStack = (BitSet)bm.activeRegionStack.clone();
activeRegionStack.set(allRegions.size());
ExceptionRegionWithBlocks newNode = new ExceptionRegionWithBlocks(er, new BitSet());
allRegions.add(newNode);
stack.push(new BlockWithActiveExceptionRegions(er.start().targetBlock(), activeRegionStack));
} else if (top instanceof CoreOps.ExceptionRegionExit er) {
assert bRegions != null;

if (bRegions.size() > 1) {
coveredBlocks.put(er.end().targetBlock(), bRegions.next());
}

stack.push(er.end().targetBlock());
BitSet activeRegionStack = (BitSet)bm.activeRegionStack.clone();
activeRegionStack.clear(activeRegionStack.length() - 1);
stack.push(new BlockWithActiveExceptionRegions(er.end().targetBlock(), activeRegionStack));
}
}
}

private static void branch(CodeBuilder cob, ConversionContext c, List<Block> blocks, Block source, Block target) {
int bi = blocks.indexOf(source);
int si = blocks.indexOf(target);
// If successor occurs immediately after this block,
// then no need for goto instruction
if (bi != si - 1) {
cob.goto_(c.getLabel(target));
// Declare the exception regions
final List<Block> blocks = body.blocks();
for (ExceptionRegionWithBlocks erNode : allRegions.reversed()) {
int start = erNode.blocks.nextSetBit(0);
while (start >= 0) {
int end = erNode.blocks.nextClearBit(start);
Label startLabel = c.getLabel(blocks.get(start));
Label endLabel = c.getLabel(blocks.get(end));
for (Block.Reference cbr : erNode.ere.catchBlocks()) {
Block cb = cbr.targetBlock();
if (!cb.parameters().isEmpty()) {
ClassDesc type = cb.parameters().get(0).type().toNominalDescriptor();
cob.exceptionCatch(startLabel, endLabel, c.getLabel(cb), type);
} else {
cob.exceptionCatchAll(startLabel, endLabel, c.getLabel(cb));
}
}
start = erNode.blocks.nextSetBit(end);
}
}
}

private static void generateBody(Body body, CodeBuilder cob, ConversionContext c) {
computeExceptionRegionMembership(body, c);
computeExceptionRegionMembership(body, cob, c);

// Process blocks in topological order
// A jump instruction assumes the false successor block is
@@ -452,8 +411,6 @@ private static void generateBody(Body body, CodeBuilder cob, ConversionContext c
if (c.catchingBlocks.contains(b)) {
// Retain block argument for exception table generation
Block.Parameter ex = b.parameters().get(0);
// clb.parameter(ex.type());

// Store in slot if used, otherwise pop
if (!ex.uses().isEmpty()) {
int slot = c.getSlot(ex);
@@ -738,7 +695,7 @@ private static void generateBody(Body body, CodeBuilder cob, ConversionContext c
}
case BranchOp op -> {
assignBlockArguments(op, op.branch(), cob, c);
branch(cob, c, blocks, b, op.branch().targetBlock());
cob.goto_(c.getLabel(op.branch().targetBlock()));
}
case ConditionalBranchOp op -> {
if (getConditionForCondBrOp(op) instanceof CoreOps.BinaryTestOp btop) {
@@ -747,22 +704,22 @@ private static void generateBody(Body body, CodeBuilder cob, ConversionContext c
conditionalBranch(cob, btop,
trueBuilder -> {
assignBlockArguments(btop, op.trueBranch(), trueBuilder, c);
branch(trueBuilder, c, blocks, b, op.trueBranch().targetBlock());
trueBuilder.goto_(c.getLabel(op.trueBranch().targetBlock()));
},
falseBuilder -> {
assignBlockArguments(btop, op.falseBranch(), falseBuilder, c);
branch(falseBuilder, c, blocks, b, op.falseBranch().targetBlock());
falseBuilder.goto_(c.getLabel(op.falseBranch().targetBlock()));
});
} else {
processOperands(cob, c, op, isLastOpResultOnStack);
cob.ifThenElse(
trueBuilder -> {
assignBlockArguments(op, op.trueBranch(), trueBuilder, c);
branch(trueBuilder, c, blocks, b, op.trueBranch().targetBlock());
trueBuilder.goto_(c.getLabel(op.trueBranch().targetBlock()));
},
falseBuilder -> {
assignBlockArguments(op, op.falseBranch(), falseBuilder, c);
branch(falseBuilder, c, blocks, b, op.falseBranch().targetBlock());
falseBuilder.goto_(c.getLabel(op.falseBranch().targetBlock()));
});
}
}
@@ -774,19 +731,7 @@ private static void generateBody(Body body, CodeBuilder cob, ConversionContext c
}
case ExceptionRegionExit op -> {
assignBlockArguments(op, op.end(), cob, c);
ExceptionRegionEnter enterOp = op.regionStart();
Label start = c.getLabel(enterOp.start().targetBlock());
Label end = cob.newBoundLabel();
for (Reference cbr : enterOp.catchBlocks()) {
Block cb = cbr.targetBlock();
if (!cb.parameters().isEmpty()) {
ClassDesc type = cb.parameters().get(0).type().toNominalDescriptor();
cob.exceptionCatch(start, end, c.getLabel(cb), type);
} else {
cob.exceptionCatchAll(start, end, c.getLabel(cb));
}
}
branch(cob, c, blocks, b, op.end().targetBlock());
cob.goto_(c.getLabel(op.end().targetBlock()));
}
default ->
throw new UnsupportedOperationException("Terminating operation not supported: " + top);
Loading