Skip to content

Commit 22a2404

Browse files
committedFeb 21, 2025
OnnxProtoBuilder simplifications
1 parent 576b09a commit 22a2404

File tree

8 files changed

+87
-104
lines changed

8 files changed

+87
-104
lines changed
 

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxInterpreter.java

+21-7
Original file line numberDiff line numberDiff line change
@@ -25,25 +25,39 @@
2525

2626
package oracle.code.onnx;
2727

28+
import java.util.LinkedHashMap;
2829
import oracle.code.onnx.ir.OnnxOp;
2930

3031
import java.util.List;
3132
import java.util.Optional;
3233

3334
public class OnnxInterpreter {
35+
3436
public static Object interpret(Class<? extends OnnxOp> opClass,
3537
List<Object> inputs,
3638
List<Object> attributes) {
3739
try {
3840
// @@@ assuming tensor inputs and outputs
41+
var schema = (OnnxOp.OnnxSchema)opClass.getDeclaredField("SCHEMA").get(null);
42+
var attrSchema = schema.attributes();
43+
var attributeMap = new LinkedHashMap<String, Object>(attributes.size());
44+
for (int i = 0; i < attributes.size(); i++) {
45+
var a = attributes.get(i);
46+
if (a instanceof Optional o) {
47+
if (o.isPresent()) {
48+
attributeMap.put(attrSchema.get(i).name(), o.get());
49+
}
50+
} else {
51+
attributeMap.put(attrSchema.get(i).name(), a);
52+
}
53+
}
3954
var outTensors = OnnxRuntime.getInstance().runOp(
40-
(OnnxOp.OnnxSchema)opClass.getDeclaredField("SCHEMA").get(null),
41-
inputs.stream().map(o -> Optional.ofNullable(switch (o) {
42-
case Tensor t -> t.tensorAddr;
43-
case Optional ot when ot.isPresent() && ot.get() instanceof Tensor t -> t.tensorAddr;
44-
default -> null;
45-
})).toList(),
46-
attributes);
55+
schema.name(),
56+
inputs.stream().takeWhile(i -> !(i instanceof Optional o && o.isEmpty())) // @@@ assuming gaps in the optional inputs are not allowed
57+
.map(i -> ((Tensor)(i instanceof Optional o ? o.get() : i)).tensorAddr)
58+
.toList(),
59+
schema.outputs().size(),
60+
attributeMap);
4761
if (outTensors.size() == 1) {
4862
return new Tensor(outTensors.getFirst());
4963
} else {

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxProtoBuilder.java

+35-65
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
import java.io.ByteArrayOutputStream;
44
import java.nio.ByteBuffer;
5-
import java.nio.ByteOrder;
65
import java.nio.charset.StandardCharsets;
76
import java.util.IdentityHashMap;
87
import java.util.List;
98
import java.util.function.BiConsumer;
9+
import java.util.stream.IntStream;
1010
import jdk.incubator.code.Value;
1111
import jdk.incubator.code.op.CoreOp;
1212
import jdk.incubator.code.op.CoreOp.FuncOp;
1313
import oracle.code.onnx.ir.OnnxOp;
14-
import oracle.code.onnx.Tensor.ElementType;
1514
import oracle.code.onnx.ir.OnnxType;
1615

1716
// Generated from onnx.proto3
@@ -262,42 +261,6 @@ <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) {
262261
static final int IR_VERSION = 10;
263262
static final int OPSET_VERSION = 21;
264263

265-
// @@@ tensors only
266-
static ByteBuffer buildOpModel(OnnxOp.OnnxSchema schema, List<java.util.Optional<ElementType>> inputElementTypes, List<Object> attributes) {
267-
var bytes = new ModelProto()
268-
.ir_version(IR_VERSION)
269-
.graph(new GraphProto()
270-
.forEach(schema.inputs(), (g, i) -> {
271-
if (inputElementTypes.get(i.ordinal()).isPresent()) {
272-
g.input(new ValueInfoProto()
273-
.name(i.name())
274-
.type(new TypeProto()
275-
// inputValues match schema inputs by OnnxParameter::ordinal
276-
.tensor_type(new Tensor().elem_type(inputElementTypes.get(i.ordinal()).get().id))));
277-
}
278-
})
279-
.node(new NodeProto()
280-
.forEach(schema.inputs(), (n, i) -> n.input(i.name()))
281-
.forEach(schema.outputs(), (n, o) -> n.output(o.name()))
282-
.op_type(schema.name())
283-
.forEach(schema.attributes(), (n, a) -> {
284-
// attributes match schema by OnnxAttribute::ordinal
285-
var attrValue = attributes.get(a.ordinal());
286-
if (a.isOptional()) {
287-
if (attrValue instanceof java.util.Optional o && o.isPresent()) {
288-
n.attribute(buildAttribute(a.name(), o.get()));
289-
}
290-
} else {
291-
n.attribute(buildAttribute(a.name(), attrValue));
292-
}
293-
}))
294-
.forEach(schema.outputs(), (g, o) -> g.output(new ValueInfoProto()
295-
.name(o.name()))))
296-
.opset_import(new OperatorSetIdProto().version(OPSET_VERSION))
297-
.buf.toByteArray();
298-
return ByteBuffer.allocateDirect(bytes.length).put(bytes).asReadOnlyBuffer();
299-
}
300-
301264
// @@@ unchecked constraints:
302265
// tensor FuncOp parameters and single tensor return type
303266
// OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
@@ -313,36 +276,43 @@ String getName(Value v, int subIndex) {
313276
return name;
314277
}
315278
};
316-
var entryBlock = model.body().entryBlock();
279+
return buildModel(
280+
model.body().entryBlock().parameters().stream().map(v -> new Input(indexer.getName(v), ((OnnxType.TensorType)v.type()).eType().id())).toList(),
281+
model.body().entryBlock().ops().stream().<OpNode>mapMulti((op, opNodes) -> {
282+
switch (op) {
283+
case OnnxOp onnxOp ->
284+
opNodes.accept(new OpNode(
285+
onnxOp.opName(),
286+
onnxOp.operands().stream().map(v -> indexer.getName(v)).toList(),
287+
IntStream.range(0, onnxOp.onnxOutputs().size()).mapToObj(o -> indexer.getName(onnxOp.result(), o)).toList(),
288+
onnxOp.onnxAttributes()));
289+
case CoreOp.ReturnOp _ -> { // skip
290+
}
291+
case CoreOp.TupleLoadOp tlo ->
292+
indexer.put(tlo.result(), indexer.getName(tlo.operands().getFirst(), tlo.index()));
293+
default ->
294+
throw new UnsupportedOperationException(op.toText());
295+
}
296+
}).toList(),
297+
List.of(indexer.getName(model.body().entryBlock().terminatingOp().operands().getFirst())));
298+
}
299+
300+
record Input(String name, int tensorElementType) {}
301+
record OpNode(String opName, List<String> inputNames, List<String> outputNames, java.util.Map<String, Object> attributes) {}
302+
303+
static ByteBuffer buildModel(List<Input> inputs, List<OpNode> ops, List<String> outputNames) {
317304
var bytes = new ModelProto()
318305
.ir_version(IR_VERSION)
319306
.graph(new GraphProto()
320-
.forEach(entryBlock.parameters(), (g, p) -> g.input(new ValueInfoProto()
321-
.name(indexer.getName(p))
322-
.type(new TypeProto()
323-
.tensor_type(new Tensor().elem_type(((OnnxType.TensorType)p.type()).eType().id())))))
324-
.forEach(entryBlock.ops(), (g, op) -> {
325-
switch (op) {
326-
case OnnxOp onnxOp ->
327-
g.node(new NodeProto()
328-
.forEach(op.operands(), (n, i) -> n.input(indexer.getName(i)))
329-
.forEach(onnxOp.onnxOutputs(), (n, o) -> n.output(indexer.getName(op.result(), o.ordinal())))
330-
.op_type(op.opName())
331-
.forEach(onnxOp.onnxAttributes().entrySet(), (n, ae) -> n.attribute(buildAttribute(ae.getKey(), ae.getValue()))));
332-
case CoreOp.ReturnOp _ -> {
333-
// skip
334-
}
335-
case CoreOp.TupleLoadOp tlo -> {
336-
indexer.put(op.result(), indexer.getName(op.operands().getFirst(), tlo.index()));
337-
}
338-
default ->
339-
throw new UnsupportedOperationException(op.toText());
340-
}
341-
})
342-
.output(new ValueInfoProto()
343-
.name(indexer.getName(entryBlock.terminatingOp().operands().getFirst()))
344-
.type(new TypeProto()
345-
.tensor_type(new Tensor().elem_type(((OnnxType.TensorType)model.body().yieldType()).eType().id())))))
307+
.forEach(inputs, (g, input) -> g
308+
.input(new ValueInfoProto().name(input.name())
309+
.type(new TypeProto().tensor_type(new Tensor().elem_type(input.tensorElementType())))))
310+
.forEach(ops, (g, op) -> g.node(new NodeProto()
311+
.forEach(op.inputNames(), (n, iName) -> n.input(iName))
312+
.forEach(op.outputNames(), (n, oName) -> n.output(oName))
313+
.op_type(op.opName())
314+
.forEach(op.attributes().entrySet(), (n, ae) -> n.attribute(buildAttribute(ae.getKey(), ae.getValue())))))
315+
.forEach(outputNames, (g, oName) -> g.output(new ValueInfoProto().name(oName))))
346316
.opset_import(new OperatorSetIdProto().version(OPSET_VERSION))
347317
.buf.toByteArray();
348318
// OnnxProtoPrinter.printModel(ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN));

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java

+17-15
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import java.nio.file.StandardCopyOption;
1515
import java.util.List;
1616
import java.util.Locale;
17-
import java.util.Optional;
17+
import java.util.Map;
18+
import java.util.stream.IntStream;
1819
import jdk.incubator.code.op.CoreOp;
19-
import oracle.code.onnx.ir.OnnxOp;
2020

2121
import static java.lang.foreign.ValueLayout.*;
2222

@@ -161,18 +161,22 @@ private static MethodHandle handle(int methodIndex, MemoryLayout... args) {
161161
return mh.asType(mh.type().changeReturnType(Object.class));
162162
}
163163

164-
private List<Optional<Tensor.ElementType>> toElementTypes(List<Optional<MemorySegment>> values) {
165-
return values.stream().map(ot -> ot.map(this::tensorElementType)).toList();
166-
}
167-
168-
public List<MemorySegment> runOp(OnnxOp.OnnxSchema schema, List<Optional<MemorySegment>> inputValues, List<Object> attributes) {
169-
var protoModel = OnnxProtoBuilder.buildOpModel(schema, toElementTypes(inputValues), attributes);
164+
public List<MemorySegment> runOp(String opName, List<MemorySegment> inputValues, int numOutputs, Map<String, Object> attributes) {
165+
var outputNames = IntStream.range(0, numOutputs).mapToObj(o -> "o" + o).toList();
166+
var protoModel = OnnxProtoBuilder.buildModel(
167+
IntStream.range(0, inputValues.size()).mapToObj(i -> new OnnxProtoBuilder.Input("i" + i, tensorElementType(inputValues.get(i)).id)).toList(),
168+
List.of(new OnnxProtoBuilder.OpNode(
169+
opName,
170+
IntStream.range(0, inputValues.size()).mapToObj(i -> "i" + i).toList(),
171+
outputNames,
172+
attributes)),
173+
outputNames);
170174
try (var session = createSession(protoModel)) {
171175
return session.run(inputValues);
172176
}
173177
}
174178

175-
public List<MemorySegment> runFunc(CoreOp.FuncOp model, List<Optional<MemorySegment>> inputValues) {
179+
public List<MemorySegment> runFunc(CoreOp.FuncOp model, List<MemorySegment> inputValues) {
176180
var protoModel = OnnxProtoBuilder.buildFuncModel(model);
177181
try (var session = createSession(protoModel)) {
178182
return session.run(inputValues);
@@ -244,18 +248,16 @@ public String getOutputName(int inputIndex) {
244248
}
245249

246250
// @@@ only tensors are supported yet
247-
public List<MemorySegment> run(List<Optional<MemorySegment>> inputValues) {
251+
public List<MemorySegment> run(List<MemorySegment> inputValues) {
248252
var runOptions = MemorySegment.NULL;
249253
int inputLen = getNumberOfInputs();
250254
int outputLen = getNumberOfOutputs();
251255
var inputNames = arena.allocate(ADDRESS, inputLen);
252256
var inputs = arena.allocate(ADDRESS, inputLen);
253257
long index = 0;
254258
for (int i = 0; i < inputLen; i++) {
255-
if (inputValues.get(i).isPresent()) {
256-
inputNames.setAtIndex(ADDRESS, index, arena.allocateFrom(getInputName(i)));
257-
inputs.setAtIndex(ADDRESS, index++, inputValues.get(i).get());
258-
}
259+
inputNames.setAtIndex(ADDRESS, index, arena.allocateFrom(getInputName(i)));
260+
inputs.setAtIndex(ADDRESS, index++, inputValues.get(i));
259261
}
260262
var outputNames = arena.allocate(ADDRESS, outputLen);
261263
var outputs = arena.allocate(ADDRESS, outputLen);
@@ -278,7 +280,7 @@ public List<MemorySegment> run(List<Optional<MemorySegment>> inputValues) {
278280
@Override
279281
public void close() {
280282
try {
281-
checkStatus(releaseSession.invokeExact(runtimeAddress, sessionAddress));
283+
Object o = releaseSession.invokeExact(runtimeAddress, sessionAddress);
282284
} catch (Throwable t) {
283285
throw wrap(t);
284286
}

‎cr-examples/onnx/src/main/java/oracle/code/onnx/ir/OnnxOp.java

-4
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ public abstract class OnnxOp extends ExternalizableOp {
3737
public interface OnnxAttribute {
3838
String name();
3939

40-
int ordinal();
41-
4240
Class<?> type();
4341

4442
Object defaultValue();
@@ -146,8 +144,6 @@ public boolean isVariadoc() {
146144

147145
String name();
148146

149-
int ordinal();
150-
151147
OnnxType type();
152148

153149
Quantifier quantifier();

‎cr-examples/onnx/src/test/java/oracle/code/onnx/CNNTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ public void testProtobufModel() throws Exception {
390390
test(inputImage -> new Tensor(OnnxRuntime.getInstance().runFunc(
391391
OnnxTransformer.transform(MethodHandles.lookup(), getFuncOp("cnn")),
392392
Stream.concat(weights.stream(), Stream.of(inputImage))
393-
.map(t -> Optional.of(t.tensorAddr)).toList()).getFirst()));
393+
.map(t -> t.tensorAddr).toList()).getFirst()));
394394
}
395395

396396
private void test(Function<Tensor<Byte>, Tensor<Float>> executor) throws Exception {

‎cr-examples/onnx/src/test/java/oracle/code/onnx/MNISTDemo.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.awt.image.BufferedImage;
2929
import java.io.*;
3030
import jdk.incubator.code.CodeReflection;
31-
import java.util.Optional;
3231
import java.lang.foreign.MemorySegment;
3332
import java.lang.invoke.MethodHandles;
3433
import java.nio.ByteBuffer;
@@ -126,7 +125,7 @@ public static void main(String[] args) throws Exception {
126125
var scaledImage = new BufferedImage(IMAGE_SIZE, IMAGE_SIZE, BufferedImage.TYPE_BYTE_GRAY);
127126
var scaledGraphics = scaledImage.createGraphics();
128127
var scaledImageDataBuffer = ByteBuffer.allocateDirect(IMAGE_SIZE * IMAGE_SIZE * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
129-
var inputArguments = List.of(Optional.of(new Tensor(MemorySegment.ofBuffer(scaledImageDataBuffer), FLOAT, 1, 1, IMAGE_SIZE, IMAGE_SIZE).tensorAddr));
128+
var inputArguments = List.of(new Tensor(MemorySegment.ofBuffer(scaledImageDataBuffer), FLOAT, 1, 1, IMAGE_SIZE, IMAGE_SIZE).tensorAddr);
130129
var sampleArray = new float[IMAGE_SIZE * IMAGE_SIZE];
131130

132131
drawPane.setPreferredSize(new Dimension(DRAW_AREA_SIZE, DRAW_AREA_SIZE));

‎cr-examples/onnx/src/test/java/oracle/code/onnx/RuntimeTest.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package oracle.code.onnx;
22

33
import java.util.List;
4-
import java.util.Optional;
5-
import oracle.code.onnx.Tensor.ElementType;
6-
import oracle.code.onnx.ir.OnnxOps;
4+
import java.util.Map;
75
import org.junit.jupiter.api.Test;
86

97
import static oracle.code.onnx.Tensor.ElementType.*;
@@ -12,13 +10,17 @@
1210

1311
public class RuntimeTest {
1412

15-
static final Optional<ElementType> OF_FLOAT = Optional.of(FLOAT);
16-
1713
@Test
1814
public void test() throws Exception {
1915
var ort = OnnxRuntime.getInstance();
20-
try (var absOp = ort.createSession(OnnxProtoBuilder.buildOpModel(OnnxOps.Abs.SCHEMA, List.of(OF_FLOAT, OF_FLOAT), List.of()));
21-
var addOp = ort.createSession(OnnxProtoBuilder.buildOpModel(OnnxOps.Add.SCHEMA, List.of(OF_FLOAT, OF_FLOAT), List.of()))) {
16+
try (var absOp = ort.createSession(OnnxProtoBuilder.buildModel(
17+
List.of(new OnnxProtoBuilder.Input("x", FLOAT.id)),
18+
List.of(new OnnxProtoBuilder.OpNode("Abs", List.of("x"), List.of("y"), Map.of())),
19+
List.of("y")));
20+
var addOp = ort.createSession(OnnxProtoBuilder.buildModel(
21+
List.of(new OnnxProtoBuilder.Input("a", FLOAT.id), new OnnxProtoBuilder.Input("b", FLOAT.id)),
22+
List.of(new OnnxProtoBuilder.OpNode("Add", List.of("a", "b"), List.of("y"), Map.of())),
23+
List.of("y")))) {
2224

2325
assertEquals(1, absOp.getNumberOfInputs());
2426
assertEquals(1, absOp.getNumberOfOutputs());
@@ -30,15 +32,15 @@ public void test() throws Exception {
3032

3133
var absExpectedTensor = Tensor.ofFlat(1f, 2, 3, 4, 5, 6);
3234

33-
var absResult = absOp.run(List.of(Optional.of(inputTensor.tensorAddr)));
35+
var absResult = absOp.run(List.of(inputTensor.tensorAddr));
3436

3537
assertEquals(1, absResult.size());
3638

3739
var absOutputTensor = new Tensor(absResult.getFirst());
3840

3941
SimpleTest.assertEquals(absExpectedTensor, absOutputTensor);
4042

41-
var addResult = addOp.run(List.of(Optional.of(inputTensor.tensorAddr), Optional.of(absOutputTensor.tensorAddr)));
43+
var addResult = addOp.run(List.of(inputTensor.tensorAddr, absOutputTensor.tensorAddr));
4244

4345
assertEquals(1, addResult.size());
4446

‎cr-examples/onnx/src/test/java/oracle/code/onnx/SimpleTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public void testIndicesOfMaxPool() throws Exception {
113113
private static Tensor runModel(String name, Tensor... params) throws NoSuchMethodException {
114114
return new Tensor(OnnxRuntime.getInstance().runFunc(
115115
getOnnxModel(name),
116-
Stream.of(params).map(t -> Optional.ofNullable(t.tensorAddr)).toList()).getFirst());
116+
Stream.of(params).map(t -> t.tensorAddr).toList()).getFirst());
117117
}
118118

119119
private static CoreOp.FuncOp getOnnxModel(String name) throws NoSuchMethodException {

0 commit comments

Comments
 (0)
Please sign in to comment.