7
7
import org .junit .jupiter .api .Assertions ;
8
8
import org .junit .jupiter .api .Test ;
9
9
10
+ import static java .util .Optional .empty ;
11
+ import static oracle .code .onnx .OnnxOperators .*;
12
+ import static oracle .code .onnx .OnnxRuntime .execute ;
13
+
10
14
public class SimpleTest {
11
15
12
16
@ CodeReflection
13
17
public Tensor <Float > add (Tensor <Float > a , Tensor <Float > b ) {
14
- return OnnxOperators . Add (a , b );
18
+ return Add (a , b );
15
19
}
16
20
17
21
@ Test
18
22
public void testAdd () throws Exception {
19
23
var a = Tensor .ofFlat (1f , 2 , 3 );
20
24
assertEquals (
21
25
add (a , a ),
22
- OnnxRuntime . execute (() -> add (a , a )));
26
+ execute (() -> add (a , a )));
23
27
}
24
28
25
29
@ CodeReflection
26
30
public Tensor <Float > sub (Tensor <Float > a , Tensor <Float > b ) {
27
- return OnnxOperators . Sub (a , b );
31
+ return Sub (a , b );
28
32
}
29
33
30
34
@ Test
@@ -33,64 +37,64 @@ public void testSub() throws Exception {
33
37
var a = Tensor .ofFlat (1f , 2 , 3 );
34
38
assertEquals (
35
39
sub (a , b ),
36
- OnnxRuntime . execute (() -> sub (a , b )));
40
+ execute (() -> sub (a , b )));
37
41
}
38
42
39
43
@ CodeReflection
40
44
public Tensor <Float > fconstant () {
41
- return OnnxOperators . Constant (-1f );
45
+ return Constant (-1f );
42
46
}
43
47
44
48
@ Test
45
49
public void testFconstant () throws Exception {
46
50
// tests the numbers are encoded correctly
47
51
var expected = Tensor .ofScalar (-1f );
48
52
assertEquals (expected , fconstant ());
49
- assertEquals (expected , OnnxRuntime . execute (() -> fconstant ()));
53
+ assertEquals (expected , execute (() -> fconstant ()));
50
54
}
51
55
52
56
@ CodeReflection
53
57
public Tensor <Float > fconstants () {
54
- return OnnxOperators . Constant (new float []{-1f , 0 , 1 , Float .MIN_VALUE , Float .MAX_VALUE });
58
+ return Constant (new float []{-1f , 0 , 1 , Float .MIN_VALUE , Float .MAX_VALUE });
55
59
}
56
60
57
61
@ Test
58
62
public void testFconstants () throws Exception {
59
63
// tests the numbers are encoded correctly
60
64
var expected = Tensor .ofFlat (-1f , 0 , 1 , Float .MIN_VALUE , Float .MAX_VALUE );
61
65
assertEquals (expected , fconstants ());
62
- assertEquals (expected , OnnxRuntime . execute (() -> fconstants ()));
66
+ assertEquals (expected , execute (() -> fconstants ()));
63
67
}
64
68
65
69
@ CodeReflection
66
70
public Tensor <Long > lconstant () {
67
- return OnnxOperators . Constant (-1l );
71
+ return Constant (-1l );
68
72
}
69
73
70
74
@ Test
71
75
public void testLconstant () throws Exception {
72
76
// tests the numbers are encoded correctly
73
77
var expected = Tensor .ofScalar (-1l );
74
78
assertEquals (expected , lconstant ());
75
- assertEquals (expected , OnnxRuntime . execute (() -> lconstant ()));
79
+ assertEquals (expected , execute (() -> lconstant ()));
76
80
}
77
81
78
82
@ CodeReflection
79
83
public Tensor <Long > lconstants () {
80
- return OnnxOperators . Constant (new long []{-1 , 0 , 1 , Long .MIN_VALUE , Long .MAX_VALUE });
84
+ return Constant (new long []{-1 , 0 , 1 , Long .MIN_VALUE , Long .MAX_VALUE });
81
85
}
82
86
83
87
@ Test
84
88
public void testLconstants () throws Exception {
85
89
// tests the numbers are encoded correctly
86
90
var expected = Tensor .ofFlat (-1l , 0 , 1 , Long .MIN_VALUE , Long .MAX_VALUE );
87
91
assertEquals (expected , lconstants ());
88
- assertEquals (expected , OnnxRuntime . execute (() -> lconstants ()));
92
+ assertEquals (expected , execute (() -> lconstants ()));
89
93
}
90
94
91
95
@ CodeReflection
92
96
public Tensor <Long > reshapeAndShape (Tensor <Float > data , Tensor <Long > shape ) {
93
- return OnnxOperators . Shape (OnnxOperators . Reshape (data , shape , Optional . empty ()), Optional . empty (), Optional . empty ());
97
+ return Shape (Reshape (data , shape , empty ()), empty (), empty ());
94
98
}
95
99
96
100
@ Test
@@ -99,26 +103,26 @@ public void testReshapeAndShape() throws Exception {
99
103
var shape = Tensor .ofFlat (2l , 2 , 2 );
100
104
assertEquals (
101
105
reshapeAndShape (data , shape ),
102
- OnnxRuntime . execute (() -> reshapeAndShape (data , shape )));
106
+ execute (() -> reshapeAndShape (data , shape )));
103
107
}
104
108
105
109
@ CodeReflection
106
110
public Tensor <Long > indicesOfMaxPool (Tensor <Float > x ) {
107
111
// testing secondary output
108
- return OnnxOperators . MaxPool (x , Optional . empty (), Optional . empty (), Optional . empty (), Optional . empty (), Optional . empty (), Optional . empty (), new long []{2 }).Indices ();
112
+ return MaxPool (x , empty (), empty (), empty (), empty (), empty (), empty (), new long []{2 }).Indices ();
109
113
}
110
114
111
115
@ Test
112
116
public void testIndicesOfMaxPool () throws Exception {
113
117
var x = Tensor .ofShape (new long []{2 , 2 , 2 }, 1f , 2 , 3 , 4 , 5 , 6 , 7 , 8 );
114
118
assertEquals (
115
119
indicesOfMaxPool (x ),
116
- OnnxRuntime . execute (() -> indicesOfMaxPool (x )));
120
+ execute (() -> indicesOfMaxPool (x )));
117
121
}
118
122
119
123
@ CodeReflection
120
124
public Tensor <Float > concat (Tensor <Float > input1 , Tensor <Float > input2 , long axis ) {
121
- return OnnxOperators . Concat (List .of (input1 , input2 ), axis );
125
+ return Concat (List .of (input1 , input2 ), axis );
122
126
}
123
127
124
128
@ Test
@@ -127,12 +131,12 @@ public void testConcat() throws Exception {
127
131
var input2 = Tensor .ofFlat (4f , 5 );
128
132
assertEquals (
129
133
concat (input1 , input2 , 0 ),
130
- OnnxRuntime . execute (()-> concat (input1 , input2 , 0 )));
134
+ execute (()-> concat (input1 , input2 , 0 )));
131
135
}
132
136
133
137
@ CodeReflection
134
138
public Tensor <Float > split (Tensor <Float > input , Tensor <Long > split ) {
135
- return OnnxOperators . Split (input , Optional .of (split ), Optional . empty (), Optional . empty ()).get (0 );
139
+ return Split (input , Optional .of (split ), empty (), empty ()).get (0 );
136
140
}
137
141
138
142
@ Test
@@ -141,12 +145,12 @@ public void testSplit() throws Exception {
141
145
var split = Tensor .ofFlat (5l );
142
146
assertEquals (
143
147
split (input , split ),
144
- OnnxRuntime . execute (()-> split (input , split )));
148
+ execute (()-> split (input , split )));
145
149
}
146
150
147
151
@ CodeReflection
148
152
public Tensor <Float > ifConst (Tensor <Boolean > cond ) {
149
- return OnnxOperators . If (cond , () -> List .of (OnnxOperators . Constant (- 1f )), () -> List .of (OnnxOperators . Constant (1f ))).get (0 );
153
+ return If (cond , () -> List .of (Constant (1f )), () -> List .of (Constant (- 1f ))).get (0 );
150
154
}
151
155
152
156
@ Test
@@ -157,16 +161,16 @@ public void testIfConst() throws Exception {
157
161
var expTrue = Tensor .ofScalar (1f );
158
162
159
163
assertEquals (expFalse , ifConst (condFalse ));
160
- assertEquals (expFalse , OnnxRuntime . execute (() -> ifConst (condFalse )));
164
+ assertEquals (expFalse , execute (() -> ifConst (condFalse )));
161
165
162
166
assertEquals (expTrue , ifConst (condTrue ));
163
- assertEquals (expTrue , OnnxRuntime . execute (() -> ifConst (condTrue )));
167
+ assertEquals (expTrue , execute (() -> ifConst (condTrue )));
164
168
}
165
169
166
170
@ CodeReflection
167
171
public Tensor <Float > ifCapture (Tensor <Boolean > cond , Tensor <Float > trueValue ) {
168
- var falseValue = OnnxOperators . Constant (-1f );
169
- return OnnxOperators . If (cond , () -> List . of ( OnnxOperators . Identity (falseValue )) , () -> List . of ( OnnxOperators . Identity (trueValue ))). get ( 0 );
172
+ var falseValue = Constant (-1f );
173
+ return If (cond , () -> Identity (trueValue ) , () -> Identity (falseValue ) );
170
174
}
171
175
172
176
@ Test
@@ -177,24 +181,24 @@ public void testIfCapture() throws Exception {
177
181
var expTrue = Tensor .ofScalar (1f );
178
182
179
183
assertEquals (expFalse , ifCapture (condFalse , expTrue ));
180
- assertEquals (expFalse , OnnxRuntime . execute (() -> ifCapture (condFalse , expTrue )));
184
+ assertEquals (expFalse , execute (() -> ifCapture (condFalse , expTrue )));
181
185
182
186
assertEquals (expTrue , ifCapture (condTrue , expTrue ));
183
- assertEquals (expTrue , OnnxRuntime . execute (() -> ifCapture (condTrue , expTrue )));
187
+ assertEquals (expTrue , execute (() -> ifCapture (condTrue , expTrue )));
184
188
}
185
189
186
190
final Tensor <Float > initialized = Tensor .ofFlat (42f );
187
191
188
192
@ CodeReflection
189
193
public Tensor <Float > initialized () {
190
- return OnnxOperators . Identity (initialized );
194
+ return Identity (initialized );
191
195
}
192
196
193
197
@ Test
194
198
public void testInitialized () throws Exception {
195
199
196
200
assertEquals (initialized (),
197
- OnnxRuntime . execute (() -> initialized ()));
201
+ execute (() -> initialized ()));
198
202
}
199
203
200
204
final Tensor <Float > initialized2 = Tensor .ofFlat (33f );
@@ -203,13 +207,13 @@ public void testInitialized() throws Exception {
203
207
204
208
@ CodeReflection
205
209
public Tensor <Float > ifInitialized (Tensor <Boolean > cond1 , Tensor <Boolean > cond2 ) {
206
- return OnnxOperators . If (cond1 ,
207
- () -> OnnxOperators . If (cond2 ,
208
- () -> List .of (OnnxOperators . Identity (initialized4 )),
209
- () -> List .of (OnnxOperators . Identity (initialized3 ))),
210
- () -> OnnxOperators . If (cond2 ,
211
- () -> List .of (OnnxOperators . Identity (initialized2 )),
212
- () -> List .of (OnnxOperators . Identity (initialized )))).get (0 );
210
+ return If (cond1 ,
211
+ () -> If (cond2 ,
212
+ () -> List .of (Identity (initialized )),
213
+ () -> List .of (Identity (initialized2 ))),
214
+ () -> If (cond2 ,
215
+ () -> List .of (Identity (initialized3 )),
216
+ () -> List .of (Identity (initialized4 )))).get (0 );
213
217
}
214
218
215
219
@ Test
@@ -218,33 +222,30 @@ public void testIfInitialized() throws Exception {
218
222
var condTrue = Tensor .ofScalar (true );
219
223
220
224
assertEquals (initialized , ifInitialized (condTrue , condTrue ));
221
- assertEquals (initialized , OnnxRuntime . execute (() -> ifInitialized (condTrue , condTrue )));
225
+ assertEquals (initialized , execute (() -> ifInitialized (condTrue , condTrue )));
222
226
assertEquals (initialized2 , ifInitialized (condTrue , condFalse ));
223
- assertEquals (initialized2 , OnnxRuntime . execute (() -> ifInitialized (condTrue , condFalse )));
227
+ assertEquals (initialized2 , execute (() -> ifInitialized (condTrue , condFalse )));
224
228
assertEquals (initialized3 , ifInitialized (condFalse , condTrue ));
225
- assertEquals (initialized3 , OnnxRuntime . execute (() -> ifInitialized (condFalse , condTrue )));
229
+ assertEquals (initialized3 , execute (() -> ifInitialized (condFalse , condTrue )));
226
230
assertEquals (initialized4 , ifInitialized (condFalse , condFalse ));
227
- assertEquals (initialized4 , OnnxRuntime . execute (() -> ifInitialized (condFalse , condFalse )));
231
+ assertEquals (initialized4 , execute (() -> ifInitialized (condFalse , condFalse )));
228
232
229
233
}
230
234
235
+ static final Tensor <Boolean > TRUE = Tensor .ofScalar (true );
236
+
231
237
@ CodeReflection
232
- public Tensor <Float > forLoopAdd (Tensor <Float > value , Tensor <Long > max , Tensor <Boolean > condition ) {
233
- return OnnxOperators .Loop (max , condition , List .of (value ),
234
- l -> {
235
- var v = l .userValues ().get (0 );
236
- return new ExplicitOnnxOperators .LoopLocals <>(l .i (), l .cond (), List .of (OnnxOperators .Add (v , v )));
237
- }).get (0 );
238
+ public Tensor <Float > forLoopAdd (Tensor <Long > max , Tensor <Float > initialValue ) {
239
+ return Loop (max , TRUE , initialValue , (i , cond , v ) -> LoopReturn (cond , Add (v , v )));
238
240
}
239
241
240
242
@ Test
241
243
public void testForLoopAdd () throws Exception {
242
244
var expected = Tensor .ofFlat (0f , 8 , 16 , 24 );
243
245
var value = Tensor .ofFlat (0f , 1 , 2 , 3 );
244
246
var max = Tensor .ofScalar (3l );
245
- var cond = Tensor .ofScalar (true );
246
- assertEquals (expected , forLoopAdd (value , max , cond ));
247
- // assertEquals(expected, OnnxRuntime.execute(() -> forLoopAdd(value, max, cond)));
247
+ assertEquals (expected , forLoopAdd (max , value ));
248
+ assertEquals (expected , execute (() -> forLoopAdd (max , value )));
248
249
}
249
250
250
251
static void assertEquals (Tensor expected , Tensor actual ) {
0 commit comments