Skip to content

Commit 416c84b

Browse files
committedFeb 18, 2025
OnnxRuntime and OnnxProtoBuilder continuation
1 parent 9933983 commit 416c84b

26 files changed

+406
-139
lines changed
 

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxProtoBuilder.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import java.io.ByteArrayOutputStream;
44
import java.nio.ByteBuffer;
5+
import java.nio.ByteOrder;
56
import java.nio.charset.StandardCharsets;
67
import java.util.IdentityHashMap;
78
import java.util.List;
@@ -306,6 +307,11 @@ static ByteBuffer buildFuncModel(FuncOp model) {
306307
String getName(Value v) {
307308
return computeIfAbsent(v, _ -> "#" + size());
308309
}
310+
String getName(Value v, int subIndex) {
311+
var name = getName(v);
312+
if (subIndex != 0) name += "." + subIndex;
313+
return name;
314+
}
309315
};
310316
var entryBlock = model.body().entryBlock();
311317
var bytes = new ModelProto()
@@ -319,16 +325,15 @@ String getName(Value v) {
319325
switch (op) {
320326
case OnnxOp onnxOp ->
321327
g.node(new NodeProto()
322-
.forEach(op.operands(), (n, p) -> n.input(indexer.getName(p)))
323-
.output(indexer.getName(op.result()))
328+
.forEach(op.operands(), (n, i) -> n.input(indexer.getName(i)))
329+
.forEach(onnxOp.onnxOutputs(), (n, o) -> n.output(indexer.getName(op.result(), o.ordinal())))
324330
.op_type(op.opName())
325331
.forEach(onnxOp.onnxAttributes().entrySet(), (n, ae) -> n.attribute(buildAttribute(ae.getKey(), ae.getValue()))));
326332
case CoreOp.ReturnOp _ -> {
327333
// skip
328334
}
329-
case CoreOp.TupleLoadOp _ -> {
330-
// @@@ hack to forward to the first from the tuple
331-
indexer.put(op.result(), indexer.getName(op.operands().getFirst()));
335+
case CoreOp.TupleLoadOp tlo -> {
336+
indexer.put(op.result(), indexer.getName(op.operands().getFirst(), tlo.index()));
332337
}
333338
default ->
334339
throw new UnsupportedOperationException(op.toText());
@@ -340,6 +345,7 @@ String getName(Value v) {
340345
.tensor_type(new Tensor().elem_type(((OnnxType.TensorType)model.body().yieldType()).eType().id())))))
341346
.opset_import(new OperatorSetIdProto().version(OPSET_VERSION))
342347
.buf.toByteArray();
348+
// OnnxProtoPrinter.printModel(ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN));
343349
return ByteBuffer.allocateDirect(bytes.length).put(bytes).asReadOnlyBuffer();
344350
}
345351

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
package oracle.code.onnx;
2+
3+
import java.io.RandomAccessFile;
4+
import java.nio.ByteBuffer;
5+
import java.nio.ByteOrder;
6+
import java.nio.channels.FileChannel;
7+
import java.util.Arrays;
8+
9+
public enum OnnxProtoPrinter {
10+
BYTE, BYTES, INT, LONG, FLOAT, DOUBLE, STRING,
11+
Attribute, ValueInfoProto, NodeProto, TrainingInfoProto, ModelProto, StringStringEntryProto, TensorAnnotation,
12+
GraphProto, TensorProto, Segment, SparseTensorProto, TensorShapeProto, Dimension, TypeProto, Tensor, Sequence,
13+
Map, Optional, SparseTensor, OperatorSetIdProto, FunctionProto;
14+
15+
static {
16+
init(Attribute,
17+
1, "name", STRING,
18+
2, "f", FLOAT,
19+
3, "i", LONG,
20+
4, "s", BYTES,
21+
5, "t", TensorProto,
22+
6, "g", GraphProto,
23+
7, "floats", FLOAT,
24+
8, "ints", LONG,
25+
9, "strings", BYTES,
26+
10, "tensors", TensorProto,
27+
11, "graphs", GraphProto,
28+
13, "doc_string", STRING,
29+
14, "tp", TypeProto,
30+
15, "type_protos", TypeProto,
31+
20, "type", INT,
32+
21, "ref_attr_name", STRING,
33+
22, "sparse_tensor", SparseTensorProto,
34+
23, "sparse_tensors", SparseTensorProto);
35+
init(ValueInfoProto,
36+
1, "name", STRING,
37+
2, "type", TypeProto,
38+
3, "doc_string", STRING,
39+
4, "metadata_props", StringStringEntryProto);
40+
init(NodeProto,
41+
1, "input", STRING,
42+
2, "output", STRING,
43+
3, "name", STRING,
44+
4, "op_type", STRING,
45+
5, "attribute", Attribute,
46+
6, "doc_string", STRING,
47+
7, "domain", STRING,
48+
8, "overload", STRING,
49+
9, "metadata_props", StringStringEntryProto);
50+
init(TrainingInfoProto,
51+
1, "initialization", GraphProto,
52+
2, "algorithm", GraphProto,
53+
3, "initialization_binding", StringStringEntryProto,
54+
4, "update_binding", StringStringEntryProto);
55+
init(ModelProto,
56+
1, "ir_version", LONG,
57+
2, "producer_name", STRING,
58+
3, "producer_version", STRING,
59+
4, "domain", STRING,
60+
5, "model_version", LONG,
61+
6, "doc_string", STRING,
62+
7, "graph", GraphProto,
63+
8, "opset_import", OperatorSetIdProto,
64+
14, "metadata_props", StringStringEntryProto,
65+
20, "training_info", TrainingInfoProto,
66+
25, "functions", FunctionProto);
67+
init(StringStringEntryProto,
68+
1, "key", STRING,
69+
2, "value", STRING);
70+
init(TensorAnnotation,
71+
1, "tensor_name", STRING,
72+
2, "quant_parameter_tensor_names", StringStringEntryProto);
73+
init(GraphProto,
74+
1, "node", NodeProto,
75+
2, "name", STRING,
76+
5, "initializer", TensorProto,
77+
10, "doc_string", STRING,
78+
11, "input", ValueInfoProto,
79+
12, "output", ValueInfoProto,
80+
13, "value_info", ValueInfoProto,
81+
14, "quantization_annotation", TensorAnnotation,
82+
15, "sparse_initializer", SparseTensorProto,
83+
16, "metadata_props", StringStringEntryProto);
84+
init(TensorProto,
85+
1, "dims", LONG,
86+
2, "data_type", INT,
87+
3, "segment", Segment,
88+
4, "float_data", FLOAT,
89+
5, "int32_data", INT,
90+
6, "string_data", BYTES,
91+
7, "int64_data", LONG,
92+
8, "name", STRING,
93+
9, "raw_data", BYTES,
94+
10, "double_data", DOUBLE,
95+
11, "uint64_data", LONG,
96+
12, "doc_string", STRING,
97+
13, "external_data", StringStringEntryProto,
98+
14, "data_location", INT,
99+
16, "metadata_props", StringStringEntryProto);
100+
init(Segment,
101+
1, "begin", LONG,
102+
2, "end", LONG);
103+
init(SparseTensorProto,
104+
1, "values", TensorProto,
105+
2, "indices", TensorProto,
106+
3, "dims", LONG);
107+
init(TensorShapeProto,
108+
1, "dim", Dimension);
109+
init(Dimension,
110+
1, "dim_value", LONG,
111+
2, "dim_param", STRING,
112+
3, "denotation", STRING);
113+
init(TypeProto,
114+
1, "tensor_type", Tensor,
115+
4, "sequence_type", Sequence,
116+
5, "map_type", Map,
117+
6, "denotation", STRING,
118+
8, "sparse_tensor_type", SparseTensor,
119+
9, "optional_type", Optional);
120+
init(Tensor,
121+
1, "elem_type", INT,
122+
2, "shape", TensorShapeProto);
123+
init(Sequence,
124+
1, "elem_type", TypeProto);
125+
init(Map,
126+
1, "key_type", INT,
127+
2, "value_type", TypeProto);
128+
init(Optional,
129+
1, "elem_type", TypeProto);
130+
init(SparseTensor,
131+
1, "elem_type", INT,
132+
2, "shape", TensorShapeProto);
133+
init(OperatorSetIdProto,
134+
1, "domain", STRING,
135+
2, "version", LONG);
136+
init(FunctionProto,
137+
1, "name", STRING,
138+
4, "input", STRING,
139+
5, "output", STRING,
140+
6, "attribute", STRING,
141+
7, "node", NodeProto,
142+
8, "doc_string", STRING,
143+
9, "opset_import", OperatorSetIdProto,
144+
10, "domain", STRING,
145+
11, "attribute_proto", Attribute,
146+
12, "value_info", ValueInfoProto,
147+
13, "overload", STRING,
148+
14, "metadata_props", StringStringEntryProto);
149+
}
150+
151+
private record Field(String name, OnnxProtoPrinter type) {}
152+
153+
private static void init(OnnxProtoPrinter proto, Object... fields) {
154+
proto.fields = new Field[(int)fields[fields.length - 3]];
155+
for (int i = 0; i < fields.length; i += 3) {
156+
proto.fields[(int)fields[i] - 1] = new Field((String)fields[i + 1], (OnnxProtoPrinter)fields[i + 2]);
157+
}
158+
}
159+
160+
private static long decodeVarint(ByteBuffer data) {
161+
int i, shift = 0;
162+
long value = 0;
163+
do {
164+
value |= ((i = data.get()) & 0x7f) << shift;
165+
shift += 7;
166+
} while ((i & 0x80) != 0);
167+
return value;
168+
}
169+
170+
private Field[] fields;
171+
172+
public void print(int indent, ByteBuffer data) {
173+
while (data.remaining() > 0) {
174+
long tag = decodeVarint(data);
175+
var f = fields[((int)tag >> 3) - 1];
176+
System.out.print(" ".repeat(indent) + f.type() + " " + f.name() + " ");
177+
switch (f.type) {
178+
case BYTE, INT, LONG ->
179+
System.out.println(decodeVarint(data));
180+
case FLOAT ->
181+
System.out.println(data.getFloat());
182+
case DOUBLE ->
183+
System.out.println(data.getDouble());
184+
case BYTES -> {
185+
var bytes = new byte[(int)decodeVarint(data)];
186+
data.get(bytes);
187+
System.out.println(Arrays.toString(bytes));
188+
}
189+
case STRING -> {
190+
var bytes = new byte[(int)decodeVarint(data)];
191+
data.get(bytes);
192+
System.out.println('"' + new String(bytes) + '"');
193+
}
194+
default -> {
195+
var size = decodeVarint(data);
196+
int limit = data.limit();
197+
System.out.println();
198+
f.type().print(indent + 1, data.limit(data.position() + (int)size));
199+
data.limit(limit);
200+
}
201+
}
202+
}
203+
}
204+
205+
public static void printModel(ByteBuffer model) {
206+
ModelProto.print(0, model);
207+
}
208+
209+
public static void main(String... args) throws Exception {
210+
for (var fName : args) {
211+
System.out.println(fName);
212+
try (var in = new RandomAccessFile(fName, "r")) {
213+
ModelProto.print(1, in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()).order(ByteOrder.LITTLE_ENDIAN));
214+
}
215+
}
216+
}
217+
}

‎cr-examples/onnx/src/main/java/oracle/code/onnx/Tensor.java

+26-11
Original file line numberDiff line numberDiff line change
@@ -66,40 +66,55 @@ public class Tensor<T> extends OnnxNumber {
6666

6767
public static final long[] SCALAR_SHAPE = new long[0];
6868

69+
public static Tensor<Byte> ofScalar(byte b) {
70+
return ofShape(SCALAR_SHAPE, b);
71+
}
72+
6973
public static Tensor<Long> ofScalar(long l) {
70-
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_LONG, l);
71-
return new Tensor<>(data, ElementType.INT64, SCALAR_SHAPE);
74+
return ofShape(SCALAR_SHAPE, l);
7275
}
7376

7477
public static Tensor<Float> ofScalar(float f) {
75-
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_FLOAT, f);
76-
return new Tensor(data, ElementType.FLOAT, SCALAR_SHAPE);
78+
return ofShape(SCALAR_SHAPE, f);
7779
}
7880

81+
7982
public static Tensor<Byte> ofFlat(byte... values) {
80-
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_BYTE, values);
81-
return new Tensor(data, ElementType.UINT8, new long[]{values.length});
83+
return ofShape(new long[]{values.length}, values);
8284
}
8385

8486
public static Tensor<Long> ofFlat(long... values) {
85-
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_LONG, values);
86-
return new Tensor(data, ElementType.INT64, new long[]{values.length});
87+
return ofShape(new long[]{values.length}, values);
8788
}
8889

8990
public static Tensor<Float> ofFlat(float... values) {
91+
return ofShape(new long[]{values.length}, values);
92+
}
93+
94+
public static Tensor<Byte> ofShape(long[] shape, byte... values) {
95+
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_BYTE, values);
96+
return new Tensor(data, ElementType.UINT8, shape);
97+
}
98+
99+
public static Tensor<Long> ofShape(long[] shape, long... values) {
100+
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_LONG, values);
101+
return new Tensor(data, ElementType.INT64, shape);
102+
}
103+
104+
public static Tensor<Float> ofShape(long[] shape, float... values) {
90105
var data = Arena.ofAuto().allocateFrom(ValueLayout.JAVA_FLOAT, values);
91-
return new Tensor(data, ElementType.FLOAT, new long[]{values.length});
106+
return new Tensor(data, ElementType.FLOAT, shape);
92107
}
93108

94109
// Mandatory reference to dataAddr to avoid its garbage colletion
95110
private final MemorySegment dataAddr;
96111
final MemorySegment tensorAddr;
97112

98-
Tensor(MemorySegment dataAddr, ElementType type, long... shape) {
113+
public Tensor(MemorySegment dataAddr, ElementType type, long... shape) {
99114
this(dataAddr, OnnxRuntime.getInstance().createTensor(dataAddr, type, shape));
100115
}
101116

102-
Tensor(MemorySegment tensorAddr) {
117+
public Tensor(MemorySegment tensorAddr) {
103118
this(null, tensorAddr);
104119
}
105120

0 commit comments

Comments
 (0)
Please sign in to comment.