Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExplicitOnnxOps.If implementation of captured values and initializers #359

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
@@ -66,10 +66,25 @@ public static Object interpret(Class<? extends OnnxOp> opClass,
.toList(),
schema.outputs().size(),
attributeMap);
if (outTensors.size() == 1) {
return outTensors.getFirst();
var outputs = schema.outputs();
if (outputs.size() == 1) {
if (outputs.getLast().quantifier() == OnnxOp.OnnxParameter.Quantifier.VARIADIC) {
return outTensors; // single variadic
} else {
return outTensors.getFirst(); // single tensor
}
} else {
return outTensors.toArray();
// @@@ assuming only tail can be variadic
if (outputs.getLast().quantifier() == OnnxOp.OnnxParameter.Quantifier.VARIADIC) {
var outArray = new Object[schema.outputs().size()];
for (int i = 0; i < outArray.length - 1; i++) {
outArray[i] = outputs.get(i);
}
outArray[outArray.length - 1] = outputs.subList(outArray.length - 1, outputs.size());
return outArray; // multiple tensors with variadic tail
} else {
return outTensors.toArray(); // multiple tensors
}
}
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
Original file line number Diff line number Diff line change
@@ -8,8 +8,10 @@
import java.util.function.BiConsumer;
import java.util.stream.IntStream;
import jdk.incubator.code.Block;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;
import jdk.incubator.code.op.CoreOp;
import jdk.incubator.code.type.JavaType;
import oracle.code.onnx.ir.OnnxOp;
import oracle.code.onnx.ir.OnnxOps;
import oracle.code.onnx.ir.OnnxType;
@@ -372,12 +374,20 @@ static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tens
onnxOp.operands().stream().map(v -> indexer.getName(v)).toList(),
IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.getName(onnxOp.result(), o)).toList(),
onnxOp.onnxAttributes()));
case CoreOp.ReturnOp _ -> { // skip
case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
}
case CoreOp.TupleLoadOp tlo ->
indexer.put(tlo.result(), indexer.getName(tlo.operands().getFirst(), tlo.index()));
default ->
case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class))
&& io.invokeDescriptor().name().equals("get")
&& io.operands().getLast() instanceof Op.Result or
&& or.op() instanceof CoreOp.ConstantOp co
&& co.value() instanceof Integer i ->
indexer.put(io.result(), indexer.getName(io.operands().getFirst(), i));
default -> {
System.out.println(op.parent().parent().parent().toText());
throw new UnsupportedOperationException(op.toText());
}
}
}).toList(),
List.of(indexer.getName(block.terminatingOp().operands().getFirst())));
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@
final class OnnxPartialEvaluator {

static final JavaType ONNX_OPERATORS_CLASS = JavaType.type(OnnxOperators.class);
static final JavaType LIST_CLASS = JavaType.type(List.class);

// Map from ONNX operator invocation to evaluated attributes
final Map<CoreOp.InvokeOp, List<Object>> evaluatedAttributes;
@@ -326,6 +327,10 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {
evaluatedAttributes.put(io, attrs);
}

unevaluatedOperations.add(o);
return null;
} else if (o instanceof CoreOp.InvokeOp io && io.invokeDescriptor().refType().equals(LIST_CLASS) && io.invokeDescriptor().name().equals("get")) {
evaluatedAttributes.put(io, List.of(oc.getValue(io.operands().getLast())));
unevaluatedOperations.add(o);
return null;
} else if (!o.operands().stream().allMatch(oc::isValueDefined)) {
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ public class OnnxTransformer {


static final JavaType TENSOR_CLASS = JavaType.type(Tensor.class);
static final JavaType LIST_CLASS = JavaType.type(List.class);

private final MethodHandles.Lookup l;
private final CoreOp.FuncOp inputFunc;
@@ -224,6 +225,15 @@ OpTransformer bodyTransformer(OnnxPartialEvaluator pe) {
Op.Result result = bb.op(CoreOp.tupleLoad(bb.context().getValue(io.operands().getFirst()), index));
bb.context().mapValue(io.result(), result);
}
// Transform access to the result of an operator that is a list access
// @@@ raw use of List::get with constant argument
case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(LIST_CLASS) && io.invokeDescriptor().name().equals("get") -> {
Op.Result result = bb.op(CoreOp.invoke(
io.invokeDescriptor(),
bb.context().getValue(io.operands().getFirst()),
bb.op(CoreOp.constant(JavaType.INT, pe.evaluatedAttributes.get(io).getLast()))));
bb.context().mapValue(io.result(), result);
}
// Skip nested lambdas
case CoreOp.LambdaOp _ -> {
}
14 changes: 14 additions & 0 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/SimpleTest.java
Original file line number Diff line number Diff line change
@@ -130,6 +130,20 @@ public void testConcat() throws Exception {
OnnxRuntime.execute(()-> concat(input1, input2, 0)));
}

@CodeReflection
public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
return OnnxOperators.Split(input, Optional.of(split), Optional.empty(), Optional.empty()).get(0);
}

@Test
public void testSplit() throws Exception {
var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
var split = Tensor.ofFlat(5l);
assertEquals(
split(input, split),
OnnxRuntime.execute(()-> split(input, split)));
}

@CodeReflection
public Tensor<Float> ifConst(Tensor<Boolean> cond) {
return OnnxOperators.If(cond, () -> OnnxOperators.Constant(-1f), () -> OnnxOperators.Constant(1f));