Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX FFM Runtime initial work #311

Closed
wants to merge 20 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
import java.util.stream.Stream;
import oracle.code.onnx.ir.OnnxOp;
import oracle.code.onnx.Tensor.ElementType;
import oracle.code.onnx.ir.OnnxType;

// Generated from onnx.proto3
sealed class OnnxProtoBuilder<T extends OnnxProtoBuilder> {
@@ -263,8 +264,7 @@ <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) {
// @@@ tensors only
// order of building defines order inside protobufs
static ByteBuffer buildOpModel(OnnxOp.OnnxSchema schema, List<ElementType> inputElementTypes) {
// @@@ output element types inferred from the first input
int outputElementType = inputElementTypes.getFirst().id;
var fallback = inputElementTypes.getFirst();
var bytes = new ModelProto()
.ir_version(IR_VERSION)
.graph(new GraphProto()
@@ -276,13 +276,65 @@ static ByteBuffer buildOpModel(OnnxOp.OnnxSchema schema, List<ElementType> input
.name(i.name())
.type(new TypeProto()
// inputValues matching schema inputs by OnnxParameter::ordinal
.tensor_type(new Tensor().elem_type(inputElementTypes.get(i.ordinal()).id)))))
.tensor_type(new Tensor().elem_type(resolveElementType(i.type(), inputElementTypes.get(i.ordinal())))))))
.forEach(schema.outputs(), (g, o) -> g.output(new ValueInfoProto()
.name(o.name())
.type(new TypeProto()
.tensor_type(new Tensor().elem_type(outputElementType))))))
.tensor_type(new Tensor().elem_type(resolveElementType(o.type(), fallback)))))))
.opset_import(new OperatorSetIdProto().version(OPSET_VERSION))
.buf.toByteArray();
return ByteBuffer.allocateDirect(bytes.length).put(bytes).asReadOnlyBuffer();
}

private static int resolveElementType(OnnxType schemeType, ElementType fallback) {
if (schemeType == OnnxType.TENSOR_FLOAT32) {
return 1;
} else if(schemeType == OnnxType.TENSOR_UINT8) {
return 2;
} else if(schemeType == OnnxType.TENSOR_INT8) {
return 3;
} else if(schemeType == OnnxType.TENSOR_UINT16) {
return 4;
} else if(schemeType == OnnxType.TENSOR_INT16) {
return 5;
} else if(schemeType == OnnxType.TENSOR_INT32) {
return 6;
} else if(schemeType == OnnxType.TENSOR_INT64) {
return 7;
} else if(schemeType == OnnxType.TENSOR_STRING) {
return 8;
} else if(schemeType == OnnxType.TENSOR_BOOL) {
return 9;
} else if(schemeType == OnnxType.TENSOR_FLOAT16) {
return 10;
} else if(schemeType == OnnxType.TENSOR_FLOAT64) {
return 11;
} else if(schemeType == OnnxType.TENSOR_UINT32) {
return 12;
} else if(schemeType == OnnxType.TENSOR_UINT64) {
return 13;
} else if(schemeType == OnnxType.TENSOR_COMPLEX64) {
return 14;
} else if(schemeType == OnnxType.TENSOR_COMPLEX128) {
return 15;
} else if(schemeType == OnnxType.TENSOR_BFLOAT16) {
return 16;
} else if(schemeType == OnnxType.TENSOR_FLOAT8E4M3FN) {
return 17;
} else if(schemeType == OnnxType.TENSOR_FLOAT8E4M3FNUZ) {
return 18;
} else if(schemeType == OnnxType.TENSOR_FLOAT8E5M2) {
return 19;
} else if(schemeType == OnnxType.TENSOR_FLOAT8E5M2FNUZ) {
return 20;
} else if(schemeType == OnnxType.TENSOR_UINT4) {
return 21;
} else if(schemeType == OnnxType.TENSOR_INT4) {
return 22;
} else if(schemeType == OnnxType.TENSOR_FLOAT4E2M1) {
return 23;
} else {
return fallback.id;
}
}
}
Original file line number Diff line number Diff line change
@@ -25408,7 +25408,7 @@ public Quantifier quantifier() {
}

public enum OutputParameter implements OnnxParameter {
shape(TypeConstraint.T1.typeVariable(), Quantifier.REQUIRED),
shape(OnnxType.tensor(OnnxType.int64()), Quantifier.REQUIRED),
;

final OnnxType type;
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package oracle.code.onnx;

import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.util.List;
import java.util.Map;
import oracle.code.onnx.ir.OnnxOps;
import org.junit.jupiter.api.Test;

import static oracle.code.onnx.Tensor.ElementType.*;

import static org.junit.jupiter.api.Assertions.*;
@@ -71,4 +72,11 @@ static void assertEqualData(FloatBuffer expectedData, FloatBuffer actualData) {
assertEquals(expectedData.get(i), actualData.get(i), 1e-6f);
}
}

static void assertEqualData(LongBuffer expectedData, LongBuffer actualData) {
assertEquals(expectedData.capacity(), actualData.capacity());
for (int i = 0; i < expectedData.capacity(); i++) {
assertEquals(expectedData.get(i), actualData.get(i));
}
}
}
23 changes: 17 additions & 6 deletions cr-examples/onnx/src/test/java/oracle/code/onnx/SimpleTest.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oracle.code.onnx;

import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.util.Optional;
import jdk.incubator.code.CodeReflection;
import org.junit.jupiter.api.Test;
@@ -16,22 +17,32 @@ public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {

@Test
public void testAdd() {
assertEquals(add(new Tensor(1f, 2, 3), new Tensor(6f, 5, 4)), 7, 7, 7);
assertEquals(add(new Tensor(1f, 2, 3), new Tensor(6f, 5, 4)), 7f, 7, 7);
}

@CodeReflection
public Tensor<Float> reshape(Tensor<Float> a, Tensor<Long> b) {
return OnnxOperators.Reshape(a, b, Optional.empty());
public Tensor<Float> reshape(Tensor<Float> data, Tensor<Long> shape) {
return OnnxOperators.Reshape(data, shape, Optional.empty());
}

@CodeReflection
public Tensor<Long> shape(Tensor<Float> data) {
return OnnxOperators.Shape(data, Optional.empty(), Optional.empty());
}

@Test
public void testReshape() {
var reshaped = reshape(new Tensor(1f, 2, 3, 4, 5, 6, 7, 8), new Tensor(2, 2, 2));
public void testReshapeAndShape() {
var reshaped = reshape(new Tensor(1f, 2, 3, 4, 5, 6, 7, 8), new Tensor(2l, 2, 2));
assertEquals(reshaped, 1f, 2, 3, 4, 5, 6, 7, 8);

var shape = shape(reshaped);
assertEquals(shape, 2l, 2, 2);
}

static void assertEquals(Tensor actual, float... expected) {
RuntimeTest.assertEqualData(FloatBuffer.wrap(expected), actual.rtTensor.asByteBuffer().asFloatBuffer());
}

static void assertEquals(Tensor actual, long... expected) {
RuntimeTest.assertEqualData(LongBuffer.wrap(expected), actual.rtTensor.asByteBuffer().asLongBuffer());
}
}