|
| 1 | +package oracle.code.onnx; |
| 2 | + |
| 3 | +import java.io.RandomAccessFile; |
| 4 | +import java.nio.ByteBuffer; |
| 5 | +import java.nio.ByteOrder; |
| 6 | +import java.nio.channels.FileChannel; |
| 7 | +import java.util.Arrays; |
| 8 | + |
| 9 | +public enum OnnxProtoPrinter { |
| 10 | + BYTE, BYTES, INT, LONG, FLOAT, DOUBLE, STRING, |
| 11 | + Attribute, ValueInfoProto, NodeProto, TrainingInfoProto, ModelProto, StringStringEntryProto, TensorAnnotation, |
| 12 | + GraphProto, TensorProto, Segment, SparseTensorProto, TensorShapeProto, Dimension, TypeProto, Tensor, Sequence, |
| 13 | + Map, Optional, SparseTensor, OperatorSetIdProto, FunctionProto; |
| 14 | + |
| 15 | + static { |
| 16 | + init(Attribute, |
| 17 | + 1, "name", STRING, |
| 18 | + 2, "f", FLOAT, |
| 19 | + 3, "i", LONG, |
| 20 | + 4, "s", BYTES, |
| 21 | + 5, "t", TensorProto, |
| 22 | + 6, "g", GraphProto, |
| 23 | + 7, "floats", FLOAT, |
| 24 | + 8, "ints", LONG, |
| 25 | + 9, "strings", BYTES, |
| 26 | + 10, "tensors", TensorProto, |
| 27 | + 11, "graphs", GraphProto, |
| 28 | + 13, "doc_string", STRING, |
| 29 | + 14, "tp", TypeProto, |
| 30 | + 15, "type_protos", TypeProto, |
| 31 | + 20, "type", INT, |
| 32 | + 21, "ref_attr_name", STRING, |
| 33 | + 22, "sparse_tensor", SparseTensorProto, |
| 34 | + 23, "sparse_tensors", SparseTensorProto); |
| 35 | + init(ValueInfoProto, |
| 36 | + 1, "name", STRING, |
| 37 | + 2, "type", TypeProto, |
| 38 | + 3, "doc_string", STRING, |
| 39 | + 4, "metadata_props", StringStringEntryProto); |
| 40 | + init(NodeProto, |
| 41 | + 1, "input", STRING, |
| 42 | + 2, "output", STRING, |
| 43 | + 3, "name", STRING, |
| 44 | + 4, "op_type", STRING, |
| 45 | + 5, "attribute", Attribute, |
| 46 | + 6, "doc_string", STRING, |
| 47 | + 7, "domain", STRING, |
| 48 | + 8, "overload", STRING, |
| 49 | + 9, "metadata_props", StringStringEntryProto); |
| 50 | + init(TrainingInfoProto, |
| 51 | + 1, "initialization", GraphProto, |
| 52 | + 2, "algorithm", GraphProto, |
| 53 | + 3, "initialization_binding", StringStringEntryProto, |
| 54 | + 4, "update_binding", StringStringEntryProto); |
| 55 | + init(ModelProto, |
| 56 | + 1, "ir_version", LONG, |
| 57 | + 2, "producer_name", STRING, |
| 58 | + 3, "producer_version", STRING, |
| 59 | + 4, "domain", STRING, |
| 60 | + 5, "model_version", LONG, |
| 61 | + 6, "doc_string", STRING, |
| 62 | + 7, "graph", GraphProto, |
| 63 | + 8, "opset_import", OperatorSetIdProto, |
| 64 | + 14, "metadata_props", StringStringEntryProto, |
| 65 | + 20, "training_info", TrainingInfoProto, |
| 66 | + 25, "functions", FunctionProto); |
| 67 | + init(StringStringEntryProto, |
| 68 | + 1, "key", STRING, |
| 69 | + 2, "value", STRING); |
| 70 | + init(TensorAnnotation, |
| 71 | + 1, "tensor_name", STRING, |
| 72 | + 2, "quant_parameter_tensor_names", StringStringEntryProto); |
| 73 | + init(GraphProto, |
| 74 | + 1, "node", NodeProto, |
| 75 | + 2, "name", STRING, |
| 76 | + 5, "initializer", TensorProto, |
| 77 | + 10, "doc_string", STRING, |
| 78 | + 11, "input", ValueInfoProto, |
| 79 | + 12, "output", ValueInfoProto, |
| 80 | + 13, "value_info", ValueInfoProto, |
| 81 | + 14, "quantization_annotation", TensorAnnotation, |
| 82 | + 15, "sparse_initializer", SparseTensorProto, |
| 83 | + 16, "metadata_props", StringStringEntryProto); |
| 84 | + init(TensorProto, |
| 85 | + 1, "dims", LONG, |
| 86 | + 2, "data_type", INT, |
| 87 | + 3, "segment", Segment, |
| 88 | + 4, "float_data", FLOAT, |
| 89 | + 5, "int32_data", INT, |
| 90 | + 6, "string_data", BYTES, |
| 91 | + 7, "int64_data", LONG, |
| 92 | + 8, "name", STRING, |
| 93 | + 9, "raw_data", BYTES, |
| 94 | + 10, "double_data", DOUBLE, |
| 95 | + 11, "uint64_data", LONG, |
| 96 | + 12, "doc_string", STRING, |
| 97 | + 13, "external_data", StringStringEntryProto, |
| 98 | + 14, "data_location", INT, |
| 99 | + 16, "metadata_props", StringStringEntryProto); |
| 100 | + init(Segment, |
| 101 | + 1, "begin", LONG, |
| 102 | + 2, "end", LONG); |
| 103 | + init(SparseTensorProto, |
| 104 | + 1, "values", TensorProto, |
| 105 | + 2, "indices", TensorProto, |
| 106 | + 3, "dims", LONG); |
| 107 | + init(TensorShapeProto, |
| 108 | + 1, "dim", Dimension); |
| 109 | + init(Dimension, |
| 110 | + 1, "dim_value", LONG, |
| 111 | + 2, "dim_param", STRING, |
| 112 | + 3, "denotation", STRING); |
| 113 | + init(TypeProto, |
| 114 | + 1, "tensor_type", Tensor, |
| 115 | + 4, "sequence_type", Sequence, |
| 116 | + 5, "map_type", Map, |
| 117 | + 6, "denotation", STRING, |
| 118 | + 8, "sparse_tensor_type", SparseTensor, |
| 119 | + 9, "optional_type", Optional); |
| 120 | + init(Tensor, |
| 121 | + 1, "elem_type", INT, |
| 122 | + 2, "shape", TensorShapeProto); |
| 123 | + init(Sequence, |
| 124 | + 1, "elem_type", TypeProto); |
| 125 | + init(Map, |
| 126 | + 1, "key_type", INT, |
| 127 | + 2, "value_type", TypeProto); |
| 128 | + init(Optional, |
| 129 | + 1, "elem_type", TypeProto); |
| 130 | + init(SparseTensor, |
| 131 | + 1, "elem_type", INT, |
| 132 | + 2, "shape", TensorShapeProto); |
| 133 | + init(OperatorSetIdProto, |
| 134 | + 1, "domain", STRING, |
| 135 | + 2, "version", LONG); |
| 136 | + init(FunctionProto, |
| 137 | + 1, "name", STRING, |
| 138 | + 4, "input", STRING, |
| 139 | + 5, "output", STRING, |
| 140 | + 6, "attribute", STRING, |
| 141 | + 7, "node", NodeProto, |
| 142 | + 8, "doc_string", STRING, |
| 143 | + 9, "opset_import", OperatorSetIdProto, |
| 144 | + 10, "domain", STRING, |
| 145 | + 11, "attribute_proto", Attribute, |
| 146 | + 12, "value_info", ValueInfoProto, |
| 147 | + 13, "overload", STRING, |
| 148 | + 14, "metadata_props", StringStringEntryProto); |
| 149 | + } |
| 150 | + |
| 151 | + private record Field(String name, OnnxProtoPrinter type) {} |
| 152 | + |
| 153 | + private static void init(OnnxProtoPrinter proto, Object... fields) { |
| 154 | + proto.fields = new Field[(int)fields[fields.length - 3]]; |
| 155 | + for (int i = 0; i < fields.length; i += 3) { |
| 156 | + proto.fields[(int)fields[i] - 1] = new Field((String)fields[i + 1], (OnnxProtoPrinter)fields[i + 2]); |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + private static long decodeVarint(ByteBuffer data) { |
| 161 | + int i, shift = 0; |
| 162 | + long value = 0; |
| 163 | + do { |
| 164 | + value |= ((i = data.get()) & 0x7f) << shift; |
| 165 | + shift += 7; |
| 166 | + } while ((i & 0x80) != 0); |
| 167 | + return value; |
| 168 | + } |
| 169 | + |
| 170 | + private Field[] fields; |
| 171 | + |
| 172 | + public void print(int indent, ByteBuffer data) { |
| 173 | + while (data.remaining() > 0) { |
| 174 | + long tag = decodeVarint(data); |
| 175 | + var f = fields[((int)tag >> 3) - 1]; |
| 176 | + System.out.print(" ".repeat(indent) + f.type() + " " + f.name() + " "); |
| 177 | + switch (f.type) { |
| 178 | + case BYTE, INT, LONG -> |
| 179 | + System.out.println(decodeVarint(data)); |
| 180 | + case FLOAT -> |
| 181 | + System.out.println(data.getFloat()); |
| 182 | + case DOUBLE -> |
| 183 | + System.out.println(data.getDouble()); |
| 184 | + case BYTES -> { |
| 185 | + var bytes = new byte[(int)decodeVarint(data)]; |
| 186 | + data.get(bytes); |
| 187 | + System.out.println(Arrays.toString(bytes)); |
| 188 | + } |
| 189 | + case STRING -> { |
| 190 | + var bytes = new byte[(int)decodeVarint(data)]; |
| 191 | + data.get(bytes); |
| 192 | + System.out.println('"' + new String(bytes) + '"'); |
| 193 | + } |
| 194 | + default -> { |
| 195 | + var size = decodeVarint(data); |
| 196 | + int limit = data.limit(); |
| 197 | + System.out.println(); |
| 198 | + f.type().print(indent + 1, data.limit(data.position() + (int)size)); |
| 199 | + data.limit(limit); |
| 200 | + } |
| 201 | + } |
| 202 | + } |
| 203 | + } |
| 204 | + |
| 205 | + public static void printModel(ByteBuffer model) { |
| 206 | + ModelProto.print(0, model); |
| 207 | + } |
| 208 | + |
| 209 | + public static void main(String... args) throws Exception { |
| 210 | + for (var fName : args) { |
| 211 | + System.out.println(fName); |
| 212 | + try (var in = new RandomAccessFile(fName, "r")) { |
| 213 | + ModelProto.print(1, in.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, in.length()).order(ByteOrder.LITTLE_ENDIAN)); |
| 214 | + } |
| 215 | + } |
| 216 | + } |
| 217 | +} |
0 commit comments