Skip to content

Commit 1f9416b

Browse files
committedFeb 12, 2025
ONNX FFM Runtime initial work
1 parent 7ce041b commit 1f9416b

File tree

8 files changed

+1102
-6
lines changed

8 files changed

+1102
-6
lines changed
 

‎cr-examples/onnx/pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ questions.
3939
</properties>
4040

4141
<dependencies>
42+
<dependency>
43+
<groupId>com.microsoft.onnxruntime</groupId>
44+
<artifactId>onnxruntime</artifactId>
45+
<version>1.20.0</version>
46+
<scope>runtime</scope>
47+
</dependency>
4248
<dependency>
4349
<groupId>com.google.protobuf</groupId>
4450
<artifactId>protobuf-java</artifactId>

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxInterpreter.java

+13-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ public class OnnxInterpreter {
3333
public static Object interpret(Class<? extends OnnxOp> opClass,
3434
List<Object> inputs,
3535
List<Object> attributes) {
36-
throw new UnsupportedOperationException();
36+
try {
37+
// @@@ assuming tensor inputs and outputs
38+
var outTensors = OnnxRuntime.getInstance().runOp(
39+
(OnnxOp.OnnxSchema)opClass.getDeclaredField("SCHEMA").get(null),
40+
inputs.stream().map(o -> ((Tensor)o).rtTensor).toList());
41+
if (outTensors.size() == 1) {
42+
return new Tensor<>(outTensors.getFirst());
43+
} else {
44+
return outTensors.stream().map(Tensor::new).toArray();
45+
}
46+
} catch (NoSuchFieldException | IllegalAccessException e) {
47+
throw new RuntimeException(e);
48+
}
3749
}
3850
}

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxProtoBuilder.java

+370
Large diffs are not rendered by default.

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxRuntime.java

+535
Large diffs are not rendered by default.

‎cr-examples/onnx/src/main/java/oracle/code/onnx/Tensor.java

+26-5
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,24 @@ class DataType(enum.IntEnum):
5757
FLOAT4E2M1 = 23
5858
*/
5959

60-
import java.util.NoSuchElementException;
61-
import java.util.Optional;
62-
import java.util.stream.Stream;
63-
6460
public class Tensor<T> extends OnnxNumber {
6561
// element type
6662
// dim
6763
// runtime representation
6864
// defer to ONNX runtime?
6965

70-
Tensor() {
66+
final OnnxRuntime.OrtTensor rtTensor;
67+
68+
public Tensor(long... data) {
69+
this(OnnxRuntime.getInstance().createFlatTensor(data));
70+
}
71+
72+
public Tensor(float... data) {
73+
this(OnnxRuntime.getInstance().createFlatTensor(data));
74+
}
75+
76+
Tensor(OnnxRuntime.OrtTensor rtTensor) {
77+
this.rtTensor = rtTensor;
7178
}
7279

7380
enum ElementType {
@@ -111,8 +118,22 @@ public String onnxName() {
111118
return name().toLowerCase();
112119
}
113120

121+
int size() {
122+
return switch (this) {
123+
case UINT8, INT8, BOOL, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ -> 1;
124+
case UINT16, INT16, FLOAT16, BFLOAT16 -> 2;
125+
case UINT32, INT32, FLOAT -> 4;
126+
case UINT64, INT64, DOUBLE -> 8;
127+
default -> 0;
128+
};
129+
}
130+
114131
public static ElementType fromOnnxName(String name) {
115132
return ElementType.valueOf(name.toUpperCase());
116133
}
134+
135+
public static ElementType fromOnnxId(int id) {
136+
return values()[id - 1];
137+
}
117138
}
118139
}

‎cr-examples/onnx/src/main/java/oracle/code/onnx/ir/OnnxOp.java

+2
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ public boolean isVariadoc() {
144144

145145
String name();
146146

147+
int ordinal();
148+
147149
OnnxType type();
148150

149151
Quantifier quantifier();
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package oracle.code.onnx;
2+
3+
import java.nio.FloatBuffer;
4+
import java.nio.LongBuffer;
5+
import java.util.List;
6+
import oracle.code.onnx.ir.OnnxOps;
7+
import org.junit.jupiter.api.Test;
8+
9+
import static oracle.code.onnx.Tensor.ElementType.*;
10+
11+
import static org.junit.jupiter.api.Assertions.*;
12+
13+
public class RuntimeTest {
14+
15+
@Test
16+
public void test() throws Exception {
17+
var ort = OnnxRuntime.getInstance();
18+
try (var absOp = ort.createSession(OnnxProtoBuilder.buildOpModel(OnnxOps.Abs.SCHEMA, List.of(FLOAT, FLOAT)));
19+
var addOp = ort.createSession(OnnxProtoBuilder.buildOpModel(OnnxOps.Add.SCHEMA, List.of(FLOAT, FLOAT)))) {
20+
21+
assertEquals(1, absOp.getNumberOfInputs());
22+
assertEquals(1, absOp.getNumberOfOutputs());
23+
24+
assertEquals(2, addOp.getNumberOfInputs());
25+
assertEquals(1, addOp.getNumberOfOutputs());
26+
27+
var inputTensor = ort.createFlatTensor(-1f, 2, -3, 4, -5, 6);
28+
29+
var absExpectedTensor = ort.createFlatTensor(1f, 2, 3, 4, 5, 6);
30+
31+
var absResult = absOp.run(List.of(inputTensor));
32+
33+
assertEquals(1, absResult.size());
34+
35+
var absOutputTensor = (OnnxRuntime.OrtTensor)absResult.getFirst();
36+
37+
assertTensorEquals(absExpectedTensor, absOutputTensor);
38+
39+
var addResult = addOp.run(List.of(inputTensor, absOutputTensor));
40+
41+
assertEquals(1, addResult.size());
42+
43+
var addOutputTensor = (OnnxRuntime.OrtTensor)addResult.getFirst();
44+
45+
var addExpectedTensor = ort.createFlatTensor(0f, 4, 0, 8, 0, 12);
46+
47+
assertTensorEquals(addExpectedTensor, addOutputTensor);
48+
}
49+
}
50+
51+
static void assertTensorEquals(OnnxRuntime.OrtTensor expectedTensor, OnnxRuntime.OrtTensor actualTensor) {
52+
var expectedType = expectedTensor.getTensorTypeAndShape();
53+
var expectedShape = expectedType.getShape();
54+
55+
var actualType = actualTensor.getTensorTypeAndShape();
56+
var actualShape = actualType.getShape();
57+
58+
assertEquals(expectedShape.getDimensionsCount(), actualShape.getDimensionsCount());
59+
for (int i = 0; i < expectedShape.getDimensionsCount(); i++) {
60+
assertEquals(expectedShape.getDimension(i), actualShape.getDimension(i));
61+
}
62+
63+
assertEquals(expectedType.getTensorElementType(), actualType.getTensorElementType());
64+
assertEquals(expectedType.getTensorShapeElementCount(), actualType.getTensorShapeElementCount());
65+
66+
assertEqualData(expectedTensor.asByteBuffer().asFloatBuffer(), actualTensor.asByteBuffer().asFloatBuffer());
67+
}
68+
69+
static void assertEqualData(FloatBuffer expectedData, FloatBuffer actualData) {
70+
assertEquals(expectedData.capacity(), actualData.capacity());
71+
for (int i = 0; i < expectedData.capacity(); i++) {
72+
assertEquals(expectedData.get(i), actualData.get(i), 1e-6f);
73+
}
74+
}
75+
76+
static void assertEqualData(LongBuffer expectedData, LongBuffer actualData) {
77+
assertEquals(expectedData.capacity(), actualData.capacity());
78+
for (int i = 0; i < expectedData.capacity(); i++) {
79+
assertEquals(expectedData.get(i), actualData.get(i));
80+
}
81+
}
82+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package oracle.code.onnx;
2+
3+
import java.lang.invoke.MethodHandles;
4+
import java.util.List;
5+
import java.util.Optional;
6+
import jdk.incubator.code.CodeReflection;
7+
import jdk.incubator.code.Op;
8+
import jdk.incubator.code.op.CoreOp;
9+
import oracle.code.onnx.compiler.OnnxTransformer;
10+
import org.junit.jupiter.api.Test;
11+
12+
import static org.junit.jupiter.api.Assertions.*;
13+
14+
public class SimpleTest {
15+
16+
// Java code model -> ONNX code model -> ONNX runtime instance -> execute via ORT
17+
// Run directly, each operation reflectively executes via ORT
18+
@CodeReflection
19+
public static Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
20+
return OnnxOperators.Add(a, b);
21+
}
22+
23+
@Test
24+
public void testAdd() throws Exception {
25+
var a = new Tensor(1f, 2, 3);
26+
var b = new Tensor(6f, 5, 4);
27+
assertEquals(
28+
add(a, b),
29+
new Tensor(OnnxRuntime.getInstance().runFunc(getOnnxModel("add", Tensor.class, Tensor.class),
30+
List.of(a.rtTensor, b.rtTensor)).getFirst()));
31+
}
32+
33+
@CodeReflection
34+
public static Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
35+
return OnnxOperators.Shape(OnnxOperators.Reshape(data, shape, Optional.empty()), Optional.empty(), Optional.empty());
36+
}
37+
38+
@Test
39+
public void testReshapeAndShape() throws Exception {
40+
var data = new Tensor(1f, 2, 3, 4, 5, 6, 7, 8);
41+
var shape = new Tensor(2l, 2, 2);
42+
assertEquals(shape, reshapeAndShape(data, shape));
43+
assertEquals(shape, new Tensor(OnnxRuntime.getInstance().runFunc(getOnnxModel("reshapeAndShape", Tensor.class, Tensor.class),
44+
List.of(data.rtTensor, shape.rtTensor)).getFirst()));
45+
}
46+
47+
private static CoreOp.FuncOp getOnnxModel(String name, Class... params) throws NoSuchMethodException {
48+
return OnnxTransformer.transform(MethodHandles.publicLookup(),
49+
Op.ofMethod(SimpleTest.class.getDeclaredMethod(name, params)).get());
50+
}
51+
52+
static void assertEquals(Tensor actual, Tensor expected) {
53+
var expectedTS = expected.rtTensor.getTensorTypeAndShape();
54+
var actualTS = actual.rtTensor.getTensorTypeAndShape();
55+
assertSame(expectedTS.getTensorElementType(), actualTS.getTensorElementType());
56+
57+
// @@@ assert equal shapes
58+
59+
switch (actualTS.getTensorElementType()) {
60+
case FLOAT ->
61+
RuntimeTest.assertEqualData(expected.rtTensor.asByteBuffer().asFloatBuffer(), actual.rtTensor.asByteBuffer().asFloatBuffer());
62+
case INT64 ->
63+
RuntimeTest.assertEqualData(expected.rtTensor.asByteBuffer().asLongBuffer(), actual.rtTensor.asByteBuffer().asLongBuffer());
64+
default ->
65+
throw new UnsupportedOperationException(); // @@@ ToDo
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)
Please sign in to comment.