2
2
3
3
import java .io .ByteArrayOutputStream ;
4
4
import java .nio .ByteBuffer ;
5
- import java .nio .ByteOrder ;
6
5
import java .nio .charset .StandardCharsets ;
7
6
import java .util .IdentityHashMap ;
8
7
import java .util .List ;
9
8
import java .util .function .BiConsumer ;
9
+ import java .util .stream .IntStream ;
10
10
import jdk .incubator .code .Value ;
11
11
import jdk .incubator .code .op .CoreOp ;
12
12
import jdk .incubator .code .op .CoreOp .FuncOp ;
13
13
import oracle .code .onnx .ir .OnnxOp ;
14
- import oracle .code .onnx .Tensor .ElementType ;
15
14
import oracle .code .onnx .ir .OnnxType ;
16
15
17
16
// Generated from onnx.proto3
@@ -262,42 +261,6 @@ <P> T forEach(Iterable<P> sup, BiConsumer<T, ? super P> cons) {
262
261
static final int IR_VERSION = 10 ;
263
262
static final int OPSET_VERSION = 21 ;
264
263
265
- // @@@ tensors only
266
- static ByteBuffer buildOpModel (OnnxOp .OnnxSchema schema , List <java .util .Optional <ElementType >> inputElementTypes , List <Object > attributes ) {
267
- var bytes = new ModelProto ()
268
- .ir_version (IR_VERSION )
269
- .graph (new GraphProto ()
270
- .forEach (schema .inputs (), (g , i ) -> {
271
- if (inputElementTypes .get (i .ordinal ()).isPresent ()) {
272
- g .input (new ValueInfoProto ()
273
- .name (i .name ())
274
- .type (new TypeProto ()
275
- // inputValues match schema inputs by OnnxParameter::ordinal
276
- .tensor_type (new Tensor ().elem_type (inputElementTypes .get (i .ordinal ()).get ().id ))));
277
- }
278
- })
279
- .node (new NodeProto ()
280
- .forEach (schema .inputs (), (n , i ) -> n .input (i .name ()))
281
- .forEach (schema .outputs (), (n , o ) -> n .output (o .name ()))
282
- .op_type (schema .name ())
283
- .forEach (schema .attributes (), (n , a ) -> {
284
- // attributes match schema by OnnxAttribute::ordinal
285
- var attrValue = attributes .get (a .ordinal ());
286
- if (a .isOptional ()) {
287
- if (attrValue instanceof java .util .Optional o && o .isPresent ()) {
288
- n .attribute (buildAttribute (a .name (), o .get ()));
289
- }
290
- } else {
291
- n .attribute (buildAttribute (a .name (), attrValue ));
292
- }
293
- }))
294
- .forEach (schema .outputs (), (g , o ) -> g .output (new ValueInfoProto ()
295
- .name (o .name ()))))
296
- .opset_import (new OperatorSetIdProto ().version (OPSET_VERSION ))
297
- .buf .toByteArray ();
298
- return ByteBuffer .allocateDirect (bytes .length ).put (bytes ).asReadOnlyBuffer ();
299
- }
300
-
301
264
// @@@ unchecked constraints:
302
265
// tensor FuncOp parameters and single tensor return type
303
266
// OnnxOps (with tensor operands and single tensor return value) and ReturnOp (returning single tensor)
@@ -313,36 +276,43 @@ String getName(Value v, int subIndex) {
313
276
return name ;
314
277
}
315
278
};
316
- var entryBlock = model .body ().entryBlock ();
279
+ return buildModel (
280
+ model .body ().entryBlock ().parameters ().stream ().map (v -> new Input (indexer .getName (v ), ((OnnxType .TensorType )v .type ()).eType ().id ())).toList (),
281
+ model .body ().entryBlock ().ops ().stream ().<OpNode >mapMulti ((op , opNodes ) -> {
282
+ switch (op ) {
283
+ case OnnxOp onnxOp ->
284
+ opNodes .accept (new OpNode (
285
+ onnxOp .opName (),
286
+ onnxOp .operands ().stream ().map (v -> indexer .getName (v )).toList (),
287
+ IntStream .range (0 , onnxOp .onnxOutputs ().size ()).mapToObj (o -> indexer .getName (onnxOp .result (), o )).toList (),
288
+ onnxOp .onnxAttributes ()));
289
+ case CoreOp .ReturnOp _ -> { // skip
290
+ }
291
+ case CoreOp .TupleLoadOp tlo ->
292
+ indexer .put (tlo .result (), indexer .getName (tlo .operands ().getFirst (), tlo .index ()));
293
+ default ->
294
+ throw new UnsupportedOperationException (op .toText ());
295
+ }
296
+ }).toList (),
297
+ List .of (indexer .getName (model .body ().entryBlock ().terminatingOp ().operands ().getFirst ())));
298
+ }
299
+
300
+ record Input (String name , int tensorElementType ) {}
301
+ record OpNode (String opName , List <String > inputNames , List <String > outputNames , java .util .Map <String , Object > attributes ) {}
302
+
303
+ static ByteBuffer buildModel (List <Input > inputs , List <OpNode > ops , List <String > outputNames ) {
317
304
var bytes = new ModelProto ()
318
305
.ir_version (IR_VERSION )
319
306
.graph (new GraphProto ()
320
- .forEach (entryBlock .parameters (), (g , p ) -> g .input (new ValueInfoProto ()
321
- .name (indexer .getName (p ))
322
- .type (new TypeProto ()
323
- .tensor_type (new Tensor ().elem_type (((OnnxType .TensorType )p .type ()).eType ().id ())))))
324
- .forEach (entryBlock .ops (), (g , op ) -> {
325
- switch (op ) {
326
- case OnnxOp onnxOp ->
327
- g .node (new NodeProto ()
328
- .forEach (op .operands (), (n , i ) -> n .input (indexer .getName (i )))
329
- .forEach (onnxOp .onnxOutputs (), (n , o ) -> n .output (indexer .getName (op .result (), o .ordinal ())))
330
- .op_type (op .opName ())
331
- .forEach (onnxOp .onnxAttributes ().entrySet (), (n , ae ) -> n .attribute (buildAttribute (ae .getKey (), ae .getValue ()))));
332
- case CoreOp .ReturnOp _ -> {
333
- // skip
334
- }
335
- case CoreOp .TupleLoadOp tlo -> {
336
- indexer .put (op .result (), indexer .getName (op .operands ().getFirst (), tlo .index ()));
337
- }
338
- default ->
339
- throw new UnsupportedOperationException (op .toText ());
340
- }
341
- })
342
- .output (new ValueInfoProto ()
343
- .name (indexer .getName (entryBlock .terminatingOp ().operands ().getFirst ()))
344
- .type (new TypeProto ()
345
- .tensor_type (new Tensor ().elem_type (((OnnxType .TensorType )model .body ().yieldType ()).eType ().id ())))))
307
+ .forEach (inputs , (g , input ) -> g
308
+ .input (new ValueInfoProto ().name (input .name ())
309
+ .type (new TypeProto ().tensor_type (new Tensor ().elem_type (input .tensorElementType ())))))
310
+ .forEach (ops , (g , op ) -> g .node (new NodeProto ()
311
+ .forEach (op .inputNames (), (n , iName ) -> n .input (iName ))
312
+ .forEach (op .outputNames (), (n , oName ) -> n .output (oName ))
313
+ .op_type (op .opName ())
314
+ .forEach (op .attributes ().entrySet (), (n , ae ) -> n .attribute (buildAttribute (ae .getKey (), ae .getValue ())))))
315
+ .forEach (outputNames , (g , oName ) -> g .output (new ValueInfoProto ().name (oName ))))
346
316
.opset_import (new OperatorSetIdProto ().version (OPSET_VERSION ))
347
317
.buf .toByteArray ();
348
318
// OnnxProtoPrinter.printModel(ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN));
0 commit comments