Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ExplicitOnnxOps.If implementation of captured values and initializers #359

Closed
wants to merge 8 commits into from
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
package oracle.code.onnx;

import java.lang.foreign.ValueLayout;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;

@@ -77,7 +78,7 @@ public static Tensor<Integer> Constant(

// @@@ Constants for value - TENSOR and sparse_value - SPARSE_TENSOR

public static <T> Tensor<T> If(Tensor<Boolean> cond, Supplier<Tensor<T>> elseBody, Supplier<Tensor<T>> thenBody) {
public static <T> List<Tensor<T>> If(Tensor<Boolean> cond, Supplier<List<Tensor<T>>> elseBody, Supplier<List<Tensor<T>>> thenBody) {
return cond.data().get(ValueLayout.JAVA_BOOLEAN, 0) ? thenBody.get() : elseBody.get();
}
}
182 changes: 0 additions & 182 deletions cr-examples/onnx/src/main/java/oracle/code/onnx/LambdaToFunc.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -54,14 +54,37 @@ public static Object interpret(Class<? extends OnnxOp> opClass,
}
var outTensors = OnnxRuntime.getInstance().runOp(Arena.ofAuto(), schema.name(),
inputs.stream().takeWhile(i -> !(i instanceof Optional o && o.isEmpty())) // @@@ assuming gaps in the optional inputs are not allowed
.map(i -> (Tensor)(i instanceof Optional o ? o.get() : i))
.toList(),
.map(i -> i instanceof Optional o ? o.get() : i)
.mapMulti((i, ic) -> {
if (i instanceof List li) {
li.forEach(ic);
} else {
ic.accept(i);
}
})
.map(Tensor.class::cast)
.toList(),
schema.outputs().size(),
attributeMap);
if (outTensors.size() == 1) {
return outTensors.getFirst();
var outputs = schema.outputs();
if (outputs.size() == 1) {
if (outputs.getLast().quantifier() == OnnxOp.OnnxParameter.Quantifier.VARIADIC) {
return outTensors; // single variadic
} else {
return outTensors.getFirst(); // single tensor
}
} else {
return outTensors.toArray();
// @@@ assuming only tail can be variadic
if (outputs.getLast().quantifier() == OnnxOp.OnnxParameter.Quantifier.VARIADIC) {
var outArray = new Object[schema.outputs().size()];
for (int i = 0; i < outArray.length - 1; i++) {
outArray[i] = outputs.get(i);
}
outArray[outArray.length - 1] = outputs.subList(outArray.length - 1, outputs.size());
return outArray; // multiple tensors with variadic tail
} else {
return outTensors.toArray(); // multiple tensors
}
}
} catch (NoSuchFieldException | IllegalAccessException e) {
throw new RuntimeException(e);
Original file line number Diff line number Diff line change
@@ -8,8 +8,10 @@
import java.util.function.BiConsumer;
import java.util.stream.IntStream;
import jdk.incubator.code.Block;
import jdk.incubator.code.Op;
import jdk.incubator.code.Value;
import jdk.incubator.code.op.CoreOp;
import jdk.incubator.code.type.JavaType;
import oracle.code.onnx.ir.OnnxOp;
import oracle.code.onnx.ir.OnnxOps;
import oracle.code.onnx.ir.OnnxType;
@@ -350,10 +352,10 @@ static byte[] build(GraphProto graph) {
static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers) {
var params = block.parameters();
params.forEach(indexer::getName);
int first = params.size() - initializers.size();
var args = params.isEmpty() || params.getFirst().type() instanceof OnnxType.TensorType ? params : params.subList(1, params.size());
int firstInitializer = params.size() - initializers.size();
var args = params.subList(params.isEmpty() || params.getFirst().type() instanceof OnnxType.TensorType ? 0 : 1, firstInitializer);
return graph(
IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.getName(params.get(i + first)), initializers.get(i))).toList(),
IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.getName(params.get(i + firstInitializer)), initializers.get(i))).toList(),
args.stream().map(v ->
tensorInfo(indexer.getName(v), ((OnnxType.TensorType)v.type()).eType().id())).toList(),
block.ops().stream().<NodeProto>mapMulti((op, opNodes) -> {
@@ -372,12 +374,24 @@ static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tens
onnxOp.operands().stream().map(v -> indexer.getName(v)).toList(),
IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.getName(onnxOp.result(), o)).toList(),
onnxOp.onnxAttributes()));
case CoreOp.ReturnOp _ -> { // skip
case CoreOp.ReturnOp _, CoreOp.ConstantOp _ -> { // skip
}
case CoreOp.TupleLoadOp tlo ->
indexer.put(tlo.result(), indexer.getName(tlo.operands().getFirst(), tlo.index()));
default ->
case CoreOp.InvokeOp io when io.invokeDescriptor().refType().equals(JavaType.type(List.class)) -> {
if (io.invokeDescriptor().name().equals("get") && io.operands().getLast() instanceof Op.Result or && or.op() instanceof CoreOp.ConstantOp co && co.value() instanceof Integer i) {
indexer.put(io.result(), indexer.getName(io.operands().getFirst(), i));
} else if (io.invokeDescriptor().name().equals("of")) {
for (int i = 0; i < io.operands().size(); i++) {
indexer.put(io.result(), indexer.getName(io.operands().get(i), i));
}
} else {
throw new UnsupportedOperationException(op.toText());
}
}
default -> {
throw new UnsupportedOperationException(op.toText());
}
}
}).toList(),
List.of(indexer.getName(block.terminatingOp().operands().getFirst())));
42 changes: 25 additions & 17 deletions cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
import jdk.incubator.code.*;

import jdk.incubator.code.op.CoreOp;
import oracle.code.onnx.compiler.OnnxTransformer;
import oracle.code.onnx.foreign.OrtApi;
import oracle.code.onnx.foreign.OrtApiBase;

@@ -50,14 +51,12 @@ public final class OnnxRuntime {
public interface OnnxFunction<T> extends Supplier<T>, Quotable {
}

record CachedSession(Session session, int[] operandsMapping) {}

static class CachedSessionClassValue extends ClassValue<CachedSession> {
static class CachedSessionClassValue extends ClassValue<Session> {

private MethodHandles.Lookup l;
private Quoted q;

CachedSession computeIfAbsent(Class<?> lambdaClass, MethodHandles.Lookup l, Quoted q) {
Session computeIfAbsent(Class<?> lambdaClass, MethodHandles.Lookup l, Quoted q) {
try {
this.l = l;
this.q = q;
@@ -69,30 +68,42 @@ CachedSession computeIfAbsent(Class<?> lambdaClass, MethodHandles.Lookup l, Quo
}
}

@Override
protected CachedSession computeValue(Class<?> type) {
var mf = LambdaToFunc.fromLambda(l, (CoreOp.LambdaOp)q.op(), q.capturedValues());
// @@@ heuristic assumption the first non-tensor and non-varbox captured value is receiver
private static Object getReceiver(SequencedCollection<Object> values) {
for (var v : values) {
if (!(v instanceof Tensor || v instanceof CoreOp.Var)) return v;
}
return null;
}

List<Tensor> initializers = mf.func().initializers().stream().map(val -> (Tensor) val).toList();
byte[] protobufModel = OnnxProtoBuilder.build(mf.func().func().body().entryBlock(), initializers);
@Override
protected Session computeValue(Class<?> type) {
var trans = OnnxTransformer.ofLambda(l, (CoreOp.LambdaOp)q.op());
var func = trans.transform();
byte[] protobufModel = OnnxProtoBuilder.build(func.body().entryBlock(), trans.initializers(getReceiver(q.capturedValues().sequencedValues())));

if (DEBUG) {
System.out.println(mf.func().func().toText());
System.out.println(func.toText());
try {
var export = Path.of(type.getSimpleName().split("\\$")[0] + ".onnx");
Files.write(export, protobufModel);
System.out.println("Onnx model exported to: " + export.toAbsolutePath());
} catch (IOException _) {}
}

return new CachedSession(getInstance().createSession(
return getInstance().createSession(
Arena.ofAuto(), // cached session must be created under its own auto arena
protobufModel), mf.operandsMapping());
protobufModel);

}
}

private static final CachedSessionClassValue SESSION_CACHE = new CachedSessionClassValue();

public static <T> Tensor<T> execute(OnnxFunction<Tensor<T>> codeLambda) {
return execute(MethodHandles.lookup(), codeLambda);
}

public static <T> Tensor<T> execute(MethodHandles.Lookup l, OnnxFunction<Tensor<T>> codeLambda) {
return execute(Arena.ofAuto(), l, codeLambda);
}
@@ -103,18 +114,15 @@ public static <T> Tensor<T> execute(Arena arena, MethodHandles.Lookup l, OnnxFun

var model = SESSION_CACHE.computeIfAbsent(codeLambda.getClass(), l, q);

var captured = q.capturedValues().sequencedValues().toArray();
List<Tensor> arguments = IntStream.of(model.operandsMapping())
.mapToObj(i -> captured[i])
List<Tensor> arguments = q.capturedValues().sequencedValues().stream()
.map(val -> val instanceof CoreOp.Var<?> v ? v.value() : val)
.<Tensor>mapMulti((val, args) -> {
if (val instanceof Tensor t) {
args.accept(t);
}
})
.toList();

return model.session.run(arena, arguments).getFirst();
return model.run(arena, arguments).getFirst();
}

public static OnnxRuntime getInstance() {
Original file line number Diff line number Diff line change
@@ -37,32 +37,28 @@
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import oracle.code.onnx.Tensor;
import oracle.code.onnx.ir.ExplicitOnnxOps;

final class OnnxPartialEvaluator {

static final JavaType ONNX_OPERATORS_CLASS = JavaType.type(OnnxOperators.class);
static final TypeElement TENSOR_RAW_CLASS = JavaType.type(Tensor.class);
static final JavaType LIST_CLASS = JavaType.type(List.class);

// Map from ONNX operator invocation to evaluated attributes
Map<CoreOp.InvokeOp, List<Object>> evaluatedAttributes;
final Map<CoreOp.InvokeOp, List<Object>> evaluatedAttributes;

// Operations that depend directly or indirectly on input parameters
// The operations' results are not evaluated
Set<Op> unevaluatedOperations;

List<Object> initializers;
final Set<Op> unevaluatedOperations;

public OnnxPartialEvaluator() {
this.evaluatedAttributes = new HashMap<>();
this.unevaluatedOperations = new HashSet<>();
this.initializers = new ArrayList<>();
}

public <T extends Op & Op.Invokable>
void evaluate(MethodHandles.Lookup l, T op, Map<Value, Object> evaluatedValues) {
var ev = new HashMap(evaluatedValues);
void evaluate(MethodHandles.Lookup l, T op) {
var ev = new HashMap();

interpretEntryBlock(l, op.body().entryBlock(), new OpContext(), ev);

@@ -206,8 +202,8 @@ void interpretEntryBlock(MethodHandles.Lookup l, Block entry,
// an entry block with a parent body whose nearest ancestor body
// is the current context block's parent body
BlockContext yieldContext = oc.stack.peek();
assert yieldContext == null ||
yieldContext.b().parentBody() == entry.parentBody().parentOp().ancestorBody();
// assert yieldContext == null ||
// yieldContext.b().parentBody() == entry.parentBody().parentOp().ancestorBody();

// Note that first block cannot have any successors so the queue will have at least one entry
oc.stack.push(new BlockContext(entry, evaluatedValues));
@@ -307,7 +303,7 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {
OnnxOp.OnnxSchema schema = schemaFromOnnxOpClass(opClass);

List<OnnxOp.OnnxParameter> inputs = schema.inputs();
// assert o.operands().subList(0, inputs.size()).stream().noneMatch(oc::isValueDefined);
assert o.operands().subList(0, inputs.size()).stream().noneMatch(oc::isValueDefined);
List<OnnxOp.OnnxAttribute> attributes = schema.attributes();

if (opClass == OnnxOps.Constant.class && o.operands().size() == 1) {
@@ -321,12 +317,6 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {
}
}
evaluatedAttributes.put(io, attrs);
} else if (opClass == ExplicitOnnxOps.If.class) {
// @@@ hard-coded 2 extra undeclared attributes
List<Object> attrs = o.operands().subList(inputs.size(), inputs.size() + 2).stream()
.map(oc::getValue)
.toList();
evaluatedAttributes.put(io, attrs);
} else {
for (int i = 0; i < attributes.size(); i++) {
assert oc.isValueDefined(o.operands().get(inputs.size() + i)) : operatorName;
@@ -339,16 +329,8 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {

unevaluatedOperations.add(o);
return null;
} else if (o instanceof CoreOp.FieldAccessOp.FieldLoadOp fo && fo.fieldDescriptor().type() instanceof ClassType ct && ct.rawType().equals(TENSOR_RAW_CLASS)) {
try {
if (fo.operands().isEmpty()) {
initializers.add(fo.fieldDescriptor().resolveToHandle(l).get());
} else {
initializers.add(fo.fieldDescriptor().resolveToHandle(l).get(oc.getValue(fo.operands().getFirst())));
}
} catch (ReflectiveOperationException ex) {
throw interpreterException(ex);
}
} else if (o instanceof CoreOp.InvokeOp io && io.invokeDescriptor().refType().equals(LIST_CLASS) && io.invokeDescriptor().name().equals("get")) {
evaluatedAttributes.put(io, List.of(oc.getValue(io.operands().getLast())));
unevaluatedOperations.add(o);
return null;
} else if (!o.operands().stream().allMatch(oc::isValueDefined)) {
@@ -496,7 +478,9 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {
.collect(Collectors.joining());
}
case CoreOp.LambdaOp lambdaOp -> {
return lambdaOp;
interpretBody(l, lambdaOp.body(), oc, List.of());
unevaluatedOperations.add(o);
return null;
}
case null, default -> throw interpreterException(
new UnsupportedOperationException("Unsupported operation: " + o.opName()));

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/CNNTest.java
Original file line number Diff line number Diff line change
@@ -32,7 +32,6 @@
import jdk.incubator.code.type.FunctionType;
import jdk.incubator.code.type.TupleType;
import jdk.incubator.code.writer.OpWriter;
import oracle.code.onnx.compiler.OnnxTransformer;
import oracle.code.onnx.ir.OnnxOps;
import oracle.code.onnx.ir.OnnxType;
import org.junit.jupiter.api.Test;
@@ -46,8 +45,8 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.channels.FileChannel;
import java.util.HashMap;
import java.util.function.Function;
import oracle.code.onnx.compiler.OnnxTransformer;

import static java.util.Optional.empty;
import static java.util.Optional.of;
@@ -323,13 +322,13 @@ private Tensor<Float> floatTensor(Arena arena, String resource, long... shape) t
public void testModels() {
try (var arena = Arena.ofConfined()) {
CoreOp.FuncOp f = getFuncOp("cnn");
var onnxModel = OnnxTransformer.transform(MethodHandles.lookup(), new HashMap<>(), f);
System.out.println(onnxModel.func().toText());
var onnxModel = new OnnxTransformer(MethodHandles.lookup(), f).transform();
System.out.println(onnxModel.toText());

CoreOp.FuncOp expectedOnnxModel = cnnModel();
System.out.println(expectedOnnxModel.toText());

Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel.func()));
Assertions.assertEquals(serialize(expectedOnnxModel), serialize(onnxModel));
}
}

91 changes: 75 additions & 16 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/SimpleTest.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package oracle.code.onnx;

import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandles;
import java.util.List;
import java.util.Optional;
import jdk.incubator.code.CodeReflection;
import org.junit.jupiter.api.Assertions;
@@ -19,7 +19,7 @@ public void testAdd() throws Exception {
var a = Tensor.ofFlat(1f, 2, 3);
assertEquals(
add(a, a),
OnnxRuntime.execute(MethodHandles.lookup(), () -> add(a, a)));
OnnxRuntime.execute(() -> add(a, a)));
}

@CodeReflection
@@ -33,7 +33,7 @@ public void testSub() throws Exception {
var a = Tensor.ofFlat(1f, 2, 3);
assertEquals(
sub(a, b),
OnnxRuntime.execute(MethodHandles.lookup(), () -> sub(a, b)));
OnnxRuntime.execute(() -> sub(a, b)));
}

@CodeReflection
@@ -46,7 +46,7 @@ public void testFconstant() throws Exception {
// tests the numbers are encoded correctly
var expected = Tensor.ofScalar(-1f);
assertEquals(expected, fconstant());
assertEquals(expected, OnnxRuntime.execute(MethodHandles.lookup(), () -> fconstant()));
assertEquals(expected, OnnxRuntime.execute(() -> fconstant()));
}

@CodeReflection
@@ -59,7 +59,7 @@ public void testFconstants() throws Exception {
// tests the numbers are encoded correctly
var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
assertEquals(expected, fconstants());
assertEquals(expected, OnnxRuntime.execute(MethodHandles.lookup(), () -> fconstants()));
assertEquals(expected, OnnxRuntime.execute(() -> fconstants()));
}

@CodeReflection
@@ -72,7 +72,7 @@ public void testLconstant() throws Exception {
// tests the numbers are encoded correctly
var expected = Tensor.ofScalar(-1l);
assertEquals(expected, lconstant());
assertEquals(expected, OnnxRuntime.execute(MethodHandles.lookup(), () -> lconstant()));
assertEquals(expected, OnnxRuntime.execute(() -> lconstant()));
}

@CodeReflection
@@ -85,7 +85,7 @@ public void testLconstants() throws Exception {
// tests the numbers are encoded correctly
var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
assertEquals(expected, lconstants());
assertEquals(expected, OnnxRuntime.execute(MethodHandles.lookup(), () -> lconstants()));
assertEquals(expected, OnnxRuntime.execute(() -> lconstants()));
}

@CodeReflection
@@ -99,7 +99,7 @@ public void testReshapeAndShape() throws Exception {
var shape = Tensor.ofFlat(2l, 2, 2);
assertEquals(
reshapeAndShape(data, shape),
OnnxRuntime.execute(MethodHandles.lookup(), () -> reshapeAndShape(data, shape)));
OnnxRuntime.execute(() -> reshapeAndShape(data, shape)));
}

@CodeReflection
@@ -113,12 +113,40 @@ public void testIndicesOfMaxPool() throws Exception {
var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
assertEquals(
indicesOfMaxPool(x),
OnnxRuntime.execute(MethodHandles.lookup(), () -> indicesOfMaxPool(x)));
OnnxRuntime.execute(() -> indicesOfMaxPool(x)));
}

@CodeReflection
public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2, long axis) {
return OnnxOperators.Concat(List.of(input1, input2), axis);
}

@Test
public void testConcat() throws Exception {
var input1 = Tensor.ofFlat(1f, 2, 3);
var input2 = Tensor.ofFlat(4f, 5);
assertEquals(
concat(input1, input2, 0),
OnnxRuntime.execute(()-> concat(input1, input2, 0)));
}

@CodeReflection
public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
return OnnxOperators.Split(input, Optional.of(split), Optional.empty(), Optional.empty()).get(0);
}

@Test
public void testSplit() throws Exception {
var input = Tensor.ofFlat(1f, 2, 3, 4, 5);
var split = Tensor.ofFlat(5l);
assertEquals(
split(input, split),
OnnxRuntime.execute(()-> split(input, split)));
}

@CodeReflection
public Tensor<Float> ifConst(Tensor<Boolean> cond) {
return OnnxOperators.If(cond, () -> OnnxOperators.Constant(-1f), () -> OnnxOperators.Constant(1f));
return OnnxOperators.If(cond, () -> List.of(OnnxOperators.Constant(-1f)), () -> List.of(OnnxOperators.Constant(1f))).get(0);
}

@Test
@@ -129,16 +157,16 @@ public void testIfConst() throws Exception {
var expTrue = Tensor.ofScalar(1f);

assertEquals(expFalse, ifConst(condFalse));
assertEquals(expFalse, OnnxRuntime.execute(MethodHandles.lookup(), () -> ifConst(condFalse)));
assertEquals(expFalse, OnnxRuntime.execute(() -> ifConst(condFalse)));

assertEquals(expTrue, ifConst(condTrue));
assertEquals(expTrue, OnnxRuntime.execute(MethodHandles.lookup(), () -> ifConst(condTrue)));
assertEquals(expTrue, OnnxRuntime.execute(() -> ifConst(condTrue)));
}

@CodeReflection
public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
var falseValue = OnnxOperators.Constant(-1f);
return OnnxOperators.If(cond, () -> OnnxOperators.Identity(falseValue), () -> trueValue);
return OnnxOperators.If(cond, () -> List.of(OnnxOperators.Identity(falseValue)), () -> List.of(OnnxOperators.Identity(trueValue))).get(0);
}

@Test
@@ -149,10 +177,10 @@ public void testIfCapture() throws Exception {
var expTrue = Tensor.ofScalar(1f);

assertEquals(expFalse, ifCapture(condFalse, expTrue));
assertEquals(expFalse, OnnxRuntime.execute(MethodHandles.lookup(), () -> ifCapture(condFalse, expTrue)));
assertEquals(expFalse, OnnxRuntime.execute(() -> ifCapture(condFalse, expTrue)));

assertEquals(expTrue, ifCapture(condTrue, expTrue));
assertEquals(expTrue, OnnxRuntime.execute(MethodHandles.lookup(), () -> ifCapture(condTrue, expTrue)));
assertEquals(expTrue, OnnxRuntime.execute(() -> ifCapture(condTrue, expTrue)));
}

final Tensor<Float> initialized = Tensor.ofFlat(42f);
@@ -166,7 +194,38 @@ public Tensor<Float> initialized() {
public void testInitialized() throws Exception {

assertEquals(initialized(),
OnnxRuntime.execute(MethodHandles.lookup(), () -> initialized()));
OnnxRuntime.execute(() -> initialized()));
}

final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
final Tensor<Float> initialized3 = Tensor.ofFlat(-1f);
final Tensor<Float> initialized4 = Tensor.ofFlat(-99f);

@CodeReflection
public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
return OnnxOperators.If(cond1,
() -> OnnxOperators.If(cond2,
() -> List.of(OnnxOperators.Identity(initialized4)),
() -> List.of(OnnxOperators.Identity(initialized3))),
() -> OnnxOperators.If(cond2,
() -> List.of(OnnxOperators.Identity(initialized2)),
() -> List.of(OnnxOperators.Identity(initialized)))).get(0);
}

@Test
public void testIfInitialized() throws Exception {
var condFalse = Tensor.ofScalar(false);
var condTrue = Tensor.ofScalar(true);

assertEquals(initialized, ifInitialized(condTrue, condTrue));
assertEquals(initialized, OnnxRuntime.execute(() -> ifInitialized(condTrue, condTrue)));
assertEquals(initialized2, ifInitialized(condTrue, condFalse));
assertEquals(initialized2, OnnxRuntime.execute(() -> ifInitialized(condTrue, condFalse)));
assertEquals(initialized3, ifInitialized(condFalse, condTrue));
assertEquals(initialized3, OnnxRuntime.execute(() -> ifInitialized(condFalse, condTrue)));
assertEquals(initialized4, ifInitialized(condFalse, condFalse));
assertEquals(initialized4, OnnxRuntime.execute(() -> ifInitialized(condFalse, condFalse)));

}

static void assertEquals(Tensor expected, Tensor actual) {