Skip to content

Commit 602d7b2

Browse files
committedMar 28, 2025
Onnx loop op work continuation + WalkTheMazeTest
1 parent d73e8ca commit 602d7b2

File tree

8 files changed

+484
-123
lines changed

8 files changed

+484
-123
lines changed
 

‎cr-examples/onnx/src/main/java/oracle/code/onnx/ExplicitOnnxOperators.java

+37-13
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@
2626
package oracle.code.onnx;
2727

2828
import java.lang.foreign.ValueLayout;
29-
import java.util.ArrayList;
30-
import java.util.List;
3129
import java.util.Optional;
32-
import java.util.function.Function;
33-
import java.util.function.Supplier;
30+
import jdk.incubator.code.Quotable;
3431

3532
class ExplicitOnnxOperators {
3633

@@ -80,18 +77,45 @@ public static Tensor<Integer> Constant(
8077

8178
// @@@ Constants for value - TENSOR and sparse_value - SPARSE_TENSOR
8279

83-
public static <T> List<Tensor<T>> If(Tensor<Boolean> cond, Supplier<List<Tensor<T>>> elseBody, Supplier<List<Tensor<T>>> thenBody) {
84-
return cond.data().get(ValueLayout.JAVA_BOOLEAN, 0) ? thenBody.get() : elseBody.get();
80+
81+
public interface IfBody<T> extends Quotable {
82+
T invoke();
83+
}
84+
85+
public static <T> T If(Tensor<Boolean> cond, IfBody<T> thenBody, IfBody<T> elseBody) {
86+
return booleanValue(cond) ? thenBody.invoke() : elseBody.invoke();
87+
}
88+
89+
public record LoopReturn<T>(Tensor<Boolean> cond, T output) {}
90+
public static <T> LoopReturn<T> LoopReturn(Tensor<Boolean> cond, T output) {
91+
return new LoopReturn<>(cond, output);
92+
}
93+
94+
public interface LoopBody<T> extends Quotable {
95+
LoopReturn<T> invoke(Tensor<Long> i, Tensor<Boolean> cond, T input);
8596
}
8697

87-
public record LoopLocals<T>(Tensor<Long> i, Tensor<Boolean> cond, List<Tensor<T>> userValues) {}
88-
public static <T> List<Tensor<T>> Loop(Tensor<Long> max, Tensor<Boolean> cond, List<Tensor<T>> v_initial, Function<LoopLocals<T>, LoopLocals<T>> body) {
98+
public static <T> T Loop(Tensor<Long> max, Tensor<Boolean> cond, T values, LoopBody<T> loopBody) {
8999
long m = max.data().get(ValueLayout.JAVA_LONG, 0);
90-
LoopLocals<T> ll = new LoopLocals<>(Tensor.ofScalar(0), cond, v_initial);
91-
while (ll.i.data().get(ValueLayout.JAVA_LONG, 0) < m && ll.cond.data().get(ValueLayout.JAVA_BOOLEAN, 0)) {
92-
ll = body.apply(ll);
93-
ll.i.data().set(ValueLayout.JAVA_LONG, 0, ll.i.data().get(ValueLayout.JAVA_LONG, 0) + 1); // i++
100+
for (var i = Tensor.ofScalar(0l); longValue(i) < m && booleanValue(cond); set(i, longValue(i) + 1)) {
101+
LoopReturn<T> ret = loopBody.invoke(i, cond, values);
102+
cond = ret.cond();
103+
values = ret.output();
94104
}
95-
return ll.userValues();
105+
return values;
106+
}
107+
108+
// @@@ move to Tensor API
109+
110+
private static boolean booleanValue(Tensor<Boolean> t) {
111+
return t.data().get(ValueLayout.JAVA_BOOLEAN, 0);
112+
}
113+
114+
private static long longValue(Tensor<Long> t) {
115+
return t.data().get(ValueLayout.JAVA_LONG, 0);
116+
}
117+
118+
private static void set(Tensor<Long> t, long value) {
119+
t.data().set(ValueLayout.JAVA_LONG, 0, value);
96120
}
97121
}

‎cr-examples/onnx/src/main/java/oracle/code/onnx/OnnxProtoBuilder.java

+22-18
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ String getName(Value v, int subIndex) {
332332
// entry block only
333333
static byte[] build(Block block, List<oracle.code.onnx.Tensor> initializers) {
334334
var indexer = new Indexer();
335-
var model = build(graph(indexer, block, initializers));
335+
var model = build(graph(indexer, block, initializers, 0));
336336
// OnnxProtoPrinter.printModel(model);
337337
return model;
338338
}
@@ -349,25 +349,34 @@ static byte[] build(GraphProto graph) {
349349
.buf.toByteArray();
350350
}
351351

352-
static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers) {
352+
static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tensor> initializers, int scalarArgs) {
353353
var params = block.parameters();
354354
params.forEach(indexer::getName);
355355
int firstInitializer = params.size() - initializers.size();
356356
var args = params.subList(params.isEmpty() || params.getFirst().type() instanceof OnnxType.TensorType ? 0 : 1, firstInitializer);
357357
return graph(
358358
IntStream.range(0, initializers.size()).mapToObj(i -> tensorProto(indexer.getName(params.get(i + firstInitializer)), initializers.get(i))).toList(),
359-
args.stream().map(v ->
360-
tensorInfo(indexer.getName(v), ((OnnxType.TensorType)v.type()).eType().id())).toList(),
359+
IntStream.range(0, args.size()).mapToObj(i ->
360+
tensorInfo(indexer.getName(args.get(i)), ((OnnxType.TensorType)args.get(i).type()).eType().id(), i < scalarArgs)).toList(),
361361
block.ops().stream().<NodeProto>mapMulti((op, opNodes) -> {
362362
switch (op) {
363363
case OnnxOps.If ifOp ->
364364
opNodes.accept(node(
365365
ifOp.opName(),
366366
List.of(indexer.getName(ifOp.operands().getFirst())),
367367
List.of(indexer.getName(ifOp.result())),
368-
java.util.Map.of( // @@@ wrong mapping of captured inputs
369-
"else_branch", graph(indexer, ifOp.elseBranch().entryBlock(), List.of()),
370-
"then_branch", graph(indexer, ifOp.thenBranch().entryBlock(), List.of()))));
368+
java.util.Map.of(
369+
"then_branch", graph(indexer, ifOp.thenBranch().entryBlock(), List.of(), 0),
370+
"else_branch", graph(indexer, ifOp.elseBranch().entryBlock(), List.of(), 0))));
371+
case OnnxOps.LoopReturn _ -> {} // skip
372+
case OnnxOps.Loop loopOp -> {
373+
opNodes.accept(node(
374+
loopOp.opName(),
375+
loopOp.operands().stream().map(indexer::getName).toList(),
376+
List.of(indexer.getName(loopOp.result())),
377+
java.util.Map.of(
378+
"body", graph(indexer, loopOp.loopBody().entryBlock(), List.of(), 2))));
379+
}
371380
case OnnxOp onnxOp ->
372381
opNodes.accept(node(
373382
onnxOp.opName(),
@@ -394,7 +403,7 @@ static GraphProto graph(Indexer indexer, Block block, List<oracle.code.onnx.Tens
394403
}
395404
}
396405
}).toList(),
397-
List.of(indexer.getName(block.terminatingOp().operands().getFirst())));
406+
block.terminatingOp().operands().stream().map(indexer::getName).toList());
398407
}
399408

400409
static GraphProto graph(List<TensorProto> initializers, List<ValueInfoProto> inputs, List<NodeProto> ops, List<String> outputNames) {
@@ -414,20 +423,15 @@ static NodeProto node(String opName, List<String> inputNames, List<String> outpu
414423
}
415424

416425
static ValueInfoProto tensorInfo(String name, int tensorElementType) {
417-
return new ValueInfoProto()
418-
.name(name)
419-
.type(new TypeProto()
420-
.tensor_type(new Tensor()
421-
.elem_type(tensorElementType)));
426+
return tensorInfo(name, tensorElementType, false);
422427
}
423428

424-
static ValueInfoProto scalarInfo(String name, int tensorElementType) {
429+
static ValueInfoProto tensorInfo(String name, int tensorElementType, boolean addScalarShape) {
430+
var t = new Tensor().elem_type(tensorElementType);
431+
if (addScalarShape) t.shape(new TensorShapeProto());
425432
return new ValueInfoProto()
426433
.name(name)
427-
.type(new TypeProto()
428-
.tensor_type(new Tensor()
429-
.elem_type(tensorElementType)
430-
.shape(new TensorShapeProto())));
434+
.type(new TypeProto().tensor_type(t));
431435
}
432436

433437
static TensorProto tensorProto(String name, oracle.code.onnx.Tensor tensor) {

‎cr-examples/onnx/src/main/java/oracle/code/onnx/compiler/OnnxPartialEvaluator.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ Object interpretOp(MethodHandles.Lookup l, OpContext oc, Op o) {
303303
OnnxOp.OnnxSchema schema = schemaFromOnnxOpClass(opClass);
304304

305305
List<OnnxOp.OnnxParameter> inputs = schema.inputs();
306-
assert o.operands().subList(0, inputs.size()).stream().noneMatch(oc::isValueDefined);
306+
// assert o.operands().subList(0, inputs.size()).stream().noneMatch(oc::isValueDefined);
307307
List<OnnxOp.OnnxAttribute> attributes = schema.attributes();
308308

309309
if (opClass == OnnxOps.Constant.class && o.operands().size() == 1) {

‎cr-examples/onnx/src/main/java/oracle/code/onnx/compiler/OnnxTransformer.java

+60-23
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,38 @@ public static OnnxTransformer ofLambda(MethodHandles.Lookup lookup, CoreOp.Lambd
4545
return new OnnxTransformer(lookup, flatLambdaFunc);
4646
}
4747

48-
public OnnxTransformer(MethodHandles.Lookup lookup, CoreOp.FuncOp func) {
49-
l = lookup;
50-
51-
var inlinedFunc = func.transform((bb, op) -> {
48+
final CoreOp.FuncOp inline(CoreOp.FuncOp func) {
49+
return func.transform((bb, op) -> {
5250
var cc = bb.context();
5351
switch (op) {
5452
case CoreOp.InvokeOp io when resolve(io) instanceof CoreOp.FuncOp inline ->
55-
bb.inline(inline, cc.getValues(io.operands()), (_, v) -> cc.mapValue(io.result(), v));
53+
bb.inline(inline(inline), cc.getValues(io.operands()), (_, v) -> cc.mapValue(io.result(), v));
5654
default ->
5755
bb.apply(op);
5856
}
5957
return bb;
6058
});
59+
}
60+
61+
public OnnxTransformer(MethodHandles.Lookup lookup, CoreOp.FuncOp func) {
62+
l = lookup;
63+
64+
var inlinedFunc = inline(func);
6165

6266
inits = new ArrayList<>();
67+
var initMap = new HashMap<FieldRef, Block.Parameter>();
6368
var top = new Block.Builder[1];
6469
// turning field loads into additiona arguments
6570
inputFunc = inlinedFunc.transform((bb, op) -> {
6671
if (top[0] == null) top[0] = bb;
6772
var cc = bb.context();
6873
switch (op) {
6974
case CoreOp.FieldAccessOp.FieldLoadOp flo when op.resultType() instanceof ClassType ct && ct.rawType().equals(TENSOR_CLASS) -> {
70-
inits.add(flo.fieldDescriptor());
7175
// initializers turn into top block parameters
72-
cc.mapValue(op.result(), top[0].parameter(op.resultType()));
76+
cc.mapValue(op.result(), initMap.computeIfAbsent(flo.fieldDescriptor(), fd -> {
77+
inits.add(fd);
78+
return top[0].parameter(op.resultType());
79+
}));
7380
}
7481
default -> bb.apply(op);
7582
}
@@ -90,7 +97,7 @@ CoreOp.FuncOp resolve(CoreOp.InvokeOp io) {
9097
public List<Tensor> initializers(Object receiver) {
9198
return inits.stream().map(i -> {
9299
try {
93-
return (Tensor)i.resolveToHandle(l).get(receiver);
100+
return (Tensor)(i.resolveToMember(l).accessFlags().contains(AccessFlag.STATIC) ? i.resolveToHandle(l).get() : i.resolveToHandle(l).get(receiver));
94101
} catch (ReflectiveOperationException ex) {
95102
throw new RuntimeException(ex);
96103
}
@@ -197,7 +204,8 @@ OpTransformer bodyTransformer(OnnxPartialEvaluator pe) {
197204
default -> throw new UnsupportedOperationException();
198205
}
199206
} else {
200-
throw new UnsupportedOperationException();
207+
// otherwise pass through a single value
208+
opArgs.add(bb.context().getValue(v));
201209
}
202210
}
203211
}
@@ -207,12 +215,12 @@ OpTransformer bodyTransformer(OnnxPartialEvaluator pe) {
207215
// Explicit transformation of nested bodies
208216
for (int i = 1; i < 3; i++) {
209217
var lambda = (CoreOp.LambdaOp)(((Op.Result)op.operands().get(i)).op());
210-
opArgs.add(lambda.body().transform(bb.context(), bodyTransformer(pe)));
218+
opArgs.add(transformBodyTranslateTypes(lambda.body(), bb.context(), bodyTransformer(pe)));
211219
}
212220
} else if (opClass == ExplicitOnnxOps.Loop.class) {
213221
// Explicit transformation of nested body
214222
var lambda = (CoreOp.LambdaOp)(((Op.Result)op.operands().get(3)).op());
215-
opArgs.add(lambda.body().transform(bb.context(), bodyTransformer(pe)));
223+
opArgs.add(transformBodyTranslateTypes(lambda.body(), bb.context(), bodyTransformer(pe)));
216224
}
217225
OnnxOp onnxOp;
218226
try {
@@ -241,13 +249,37 @@ OpTransformer bodyTransformer(OnnxPartialEvaluator pe) {
241249
// Skip nested lambdas
242250
case CoreOp.LambdaOp _ -> {
243251
}
252+
case Op.Terminating _ -> {
253+
try {
254+
bb.op(op); // @@@ how to test the terminating op has been already inserted?
255+
} catch (IllegalStateException _) {}
256+
}
244257
// Copy remaining operations, which may be removed later transformations
245258
default -> bb.op(op);
246259
}
247260
return bb;
248261
};
249262
}
250263

264+
// @@@ Ugly copy of Body::transform content to translate types
265+
static Body.Builder transformBodyTranslateTypes(Body body, CopyContext cc, OpTransformer ot) {
266+
// return body.transform(cc, ot);
267+
268+
Body ancestorBody = body.parentOp().parentBlock() instanceof Block parentBlock ? parentBlock.parentBody() : null;
269+
270+
Block.Builder ancestorBlockBuilder = ancestorBody != null
271+
? cc.getBlock(ancestorBody.entryBlock()) : null;
272+
Body.Builder ancestorBodyBuilder = ancestorBlockBuilder != null
273+
? ancestorBlockBuilder.parentBody() : null;
274+
275+
Body.Builder bb = Body.Builder.of(ancestorBodyBuilder, FunctionType.functionType(type(body.yieldType())), cc, ot); // translate types
276+
for (Block.Parameter p : body.entryBlock().parameters()) {
277+
bb.entryBlock().parameter(type(p.type())); // translate types
278+
}
279+
bb.entryBlock().transformBody(body, bb.entryBlock().parameters(), cc, ot);
280+
return bb;
281+
}
282+
251283
@SuppressWarnings({"rawtypes", "unchecked"})
252284
static Class<? extends OnnxOp> onnxOpClassFromName(String operatorName) {
253285
try {
@@ -330,22 +362,27 @@ static Integer recordComponentAccessToTupleIndex(MethodHandles.Lookup l, MethodR
330362
}
331363

332364
static final TypeElement TENSOR_RAW_CLASS = JavaType.type(Tensor.class);
365+
static final TypeElement LOOP_RETURN_RAW_CLASS = JavaType.type(ExplicitOnnxOps.LoopReturn.class);
333366

334367
// @@@ Map of Java tensor types to ONNX tensor types
335368
// @@@ Shape??
336369
static TypeElement type(TypeElement type) {
337-
if (type instanceof ClassType ct && ct.rawType().equals(TENSOR_RAW_CLASS)) {
338-
JavaType elementType = ct.typeArguments().getFirst();
339-
if (elementType.equals(JavaType.J_L_INTEGER)) {
340-
return OnnxType.TENSOR_INT32;
341-
} else if (elementType.equals(JavaType.J_L_FLOAT)) {
342-
return OnnxType.TENSOR_FLOAT32;
343-
} else if (elementType.equals(JavaType.J_L_LONG)) {
344-
return OnnxType.TENSOR_INT64;
345-
} else if (elementType.equals(JavaType.J_L_BYTE)) {
346-
return OnnxType.TENSOR_UINT8;
347-
} else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
348-
return OnnxType.TENSOR_BOOL;
370+
if (type instanceof ClassType ct) {
371+
if (ct.rawType().equals(TENSOR_RAW_CLASS)) {
372+
JavaType elementType = ct.typeArguments().getFirst();
373+
if (elementType.equals(JavaType.J_L_INTEGER)) {
374+
return OnnxType.TENSOR_INT32;
375+
} else if (elementType.equals(JavaType.J_L_FLOAT)) {
376+
return OnnxType.TENSOR_FLOAT32;
377+
} else if (elementType.equals(JavaType.J_L_LONG)) {
378+
return OnnxType.TENSOR_INT64;
379+
} else if (elementType.equals(JavaType.J_L_BYTE)) {
380+
return OnnxType.TENSOR_UINT8;
381+
} else if (elementType.equals(JavaType.J_L_BOOLEAN)) {
382+
return OnnxType.TENSOR_BOOL;
383+
}
384+
} else if (ct.rawType().equals(LOOP_RETURN_RAW_CLASS)) {
385+
return JavaType.VOID;
349386
}
350387
}
351388
return type;

‎cr-examples/onnx/src/main/java/oracle/code/onnx/ir/ExplicitOnnxOps.java

+138-18
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public sealed class ExplicitOnnxOps permits OnnxOps {
3737
public static final class If extends OnnxOp implements Nested {
3838
public static final String NAME = "If";
3939

40-
final Body elseBody, thenBody;
40+
final Body thenBody, elseBody;
4141

4242
// @@@ make or fake elseBody as "else_branch" attribute and thenBody as "then_branch" attribute
4343
public enum Attribute implements OnnxOp.OnnxAttribute.None { }
@@ -117,32 +117,32 @@ public OnnxOp.OnnxParameter.Quantifier quantifier() {
117117
public If(ExternalizableOp.ExternalizedOp def) {
118118
super(SCHEMA, def);
119119

120-
this.elseBody = def.bodyDefinitions().get(0).build(this);
121-
this.thenBody = def.bodyDefinitions().get(1).build(this);
120+
this.thenBody = def.bodyDefinitions().get(0).build(this);
121+
this.elseBody = def.bodyDefinitions().get(1).build(this);
122122
}
123123

124124
If(If that, CopyContext cc, OpTransformer ot) {
125125
super(that, cc);
126126

127-
this.elseBody = that.elseBody.transform(cc, ot).build(this);
128127
this.thenBody = that.thenBody.transform(cc, ot).build(this);
128+
this.elseBody = that.elseBody.transform(cc, ot).build(this);
129129
}
130130

131131
@Override
132132
public If transform(CopyContext cc, OpTransformer ot) {
133133
return new If(this, cc, ot);
134134
}
135135

136-
If(TypeElement resultType, Value cond, Body.Builder elseBranch, Body.Builder thenBranch) {
136+
If(TypeElement resultType, Value cond, Body.Builder thenBranch, Body.Builder elseBranch) {
137137
super(SCHEMA, resultType, Set.of(), List.of(cond), List.of());
138138

139-
this.elseBody = elseBranch.build(this);
140139
this.thenBody = thenBranch.build(this);
140+
this.elseBody = elseBranch.build(this);
141141
}
142142

143143
@Override
144144
public List<Body> bodies() {
145-
return List.of(elseBody, thenBody);
145+
return List.of(thenBody, elseBody);
146146
}
147147

148148
@Override
@@ -168,8 +168,130 @@ public Body thenBranch() {
168168
}
169169
}
170170

171-
public static If If(TypeElement resultType, Value cond, Body.Builder elseBody, Body.Builder thenBody) {
172-
return new If(resultType, cond, elseBody, thenBody);
171+
public static If If(TypeElement resultType, Value cond, Body.Builder thenBody, Body.Builder elseBody) {
172+
return new If(resultType, cond, thenBody, elseBody);
173+
}
174+
175+
@OpFactory.OpDeclaration(LoopReturn.NAME)
176+
public static final class LoopReturn extends OnnxOp implements Op.Terminating {
177+
public static final String NAME = "LoopReturn";
178+
179+
// @@@ make or fake body
180+
public enum Attribute implements OnnxOp.OnnxAttribute.None { }
181+
182+
public enum TypeConstraint implements OnnxOp.OnnxTypeConstraint {
183+
V(new OnnxType.TypeVariable("V", List.of(OnnxType.tensor(OnnxType.uint8()), OnnxType.tensor(OnnxType.uint16()), OnnxType.tensor(OnnxType.uint32()), OnnxType.tensor(OnnxType.uint64()), OnnxType.tensor(OnnxType.int8()), OnnxType.tensor(OnnxType.int16()), OnnxType.tensor(OnnxType.int32()), OnnxType.tensor(OnnxType.int64()), OnnxType.tensor(OnnxType.bfloat16()), OnnxType.tensor(OnnxType.float16()), OnnxType.tensor(OnnxType.float32()), OnnxType.tensor(OnnxType.float64()), OnnxType.tensor(OnnxType.bool())))),
184+
B(new OnnxType.TypeVariable("B", List.of(OnnxType.tensor(OnnxType.bool())))),
185+
;
186+
187+
final OnnxType.TypeVariable typeVariable;
188+
189+
TypeConstraint(OnnxType.TypeVariable typeVariable) {
190+
assert typeVariable.name().equals(name());
191+
this.typeVariable = typeVariable;
192+
}
193+
194+
@Override
195+
public OnnxType.TypeVariable typeVariable() {
196+
return typeVariable;
197+
}
198+
}
199+
200+
public enum InputParameter implements OnnxOp.OnnxParameter {
201+
// @@@ Onnx spec declares the input parameters as optional, however it is causing problems
202+
cond(TypeConstraint.B.typeVariable(), OnnxOp.OnnxParameter.Quantifier.REQUIRED),
203+
values(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
204+
;
205+
206+
final OnnxType type;
207+
final OnnxOp.OnnxParameter.Quantifier quantifier;
208+
209+
InputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
210+
this.type = type;
211+
this.quantifier = quantifier;
212+
}
213+
214+
@Override
215+
public OnnxType type() {
216+
return type;
217+
}
218+
219+
@Override
220+
public OnnxOp.OnnxParameter.Quantifier quantifier() {
221+
return quantifier;
222+
}
223+
}
224+
225+
public enum OutputParameter implements OnnxOp.OnnxParameter {
226+
outputs(TypeConstraint.V.typeVariable(), OnnxOp.OnnxParameter.Quantifier.VARIADIC),
227+
;
228+
229+
final OnnxType type;
230+
final OnnxOp.OnnxParameter.Quantifier quantifier;
231+
232+
OutputParameter(OnnxType type, OnnxOp.OnnxParameter.Quantifier quantifier) {
233+
this.type = type;
234+
this.quantifier = quantifier;
235+
}
236+
237+
@Override
238+
public OnnxType type() {
239+
return type;
240+
}
241+
242+
@Override
243+
public OnnxOp.OnnxParameter.Quantifier quantifier() {
244+
return quantifier;
245+
}
246+
}
247+
248+
public static final OnnxOp.OnnxSchema SCHEMA = new OnnxSchemaRecord(
249+
NAME,
250+
List.of(Attribute.values()),
251+
List.of(TypeConstraint.values()),
252+
List.of(InputParameter.values()),
253+
List.of(OutputParameter.values())
254+
);
255+
256+
public LoopReturn(ExternalizableOp.ExternalizedOp def) {
257+
super(SCHEMA, def);
258+
}
259+
260+
LoopReturn(ExplicitOnnxOps.LoopReturn that, CopyContext cc, OpTransformer ot) {
261+
super(that, cc);
262+
}
263+
264+
@Override
265+
public ExplicitOnnxOps.LoopReturn transform(CopyContext cc, OpTransformer ot) {
266+
return new ExplicitOnnxOps.LoopReturn(this, cc, ot);
267+
}
268+
269+
LoopReturn(TypeElement resultType, Value cond, Value v_initial) {
270+
super(SCHEMA, resultType, Set.of(), List.of(cond, v_initial), List.of());
271+
}
272+
273+
@Override
274+
public SequencedSet<OnnxOp.OnnxParameter> onnxOutputs() {
275+
return onnxOutputs(SCHEMA);
276+
}
277+
278+
@Override
279+
public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
280+
return onnxInputs(SCHEMA, List.of(cond()));
281+
}
282+
283+
284+
public Value cond() {
285+
return operands().get(0);
286+
}
287+
288+
public List<Value> values() {
289+
return operands().subList(1, operands().size());
290+
}
291+
}
292+
293+
public static LoopReturn LoopReturn(TypeElement resultType, Value cond, Value values) {
294+
return new LoopReturn(resultType, cond, values);
173295
}
174296

175297
@OpFactory.OpDeclaration(Loop.NAME)
@@ -274,7 +396,7 @@ public ExplicitOnnxOps.Loop transform(CopyContext cc, OpTransformer ot) {
274396
return new ExplicitOnnxOps.Loop(this, cc, ot);
275397
}
276398

277-
Loop(TypeElement resultType, Value m, Value cond, List<Value> v_initial, Body.Builder body) {
399+
Loop(TypeElement resultType, Value m, Value cond, Value v_initial, Body.Builder body) {
278400
super(SCHEMA, resultType, Set.of(), List.of(m, cond, v_initial), List.of());
279401

280402
this.body = body.build(this);
@@ -295,18 +417,16 @@ public SequencedMap<OnnxOp.OnnxParameter, Object> onnxInputs() {
295417
return onnxInputs(SCHEMA, List.of(cond()));
296418
}
297419

298-
public Optional<Value> m() {
299-
int i = optionalInputArguments.indexOf(InputParameter.M);
300-
return i != -1 ? Optional.of(operands().get(1 + i)) : Optional.empty();
420+
public Value max() {
421+
return operands().get(0);
301422
}
302423

303-
public Optional<Value> cond() {
304-
int i = optionalInputArguments.indexOf(InputParameter.cond);
305-
return i != -1 ? Optional.of(operands().get(1 + i)) : Optional.empty();
424+
public Value cond() {
425+
return operands().get(1);
306426
}
307427

308428
public List<Value> v_initial() {
309-
return operands().subList(1, operands().size());
429+
return operands().subList(2, operands().size());
310430
}
311431

312432
@Override
@@ -315,7 +435,7 @@ public Body loopBody() {
315435
}
316436
}
317437

318-
public static Loop Loop(TypeElement resultType, Value m, Value cond, List<Value> v_initial, Body.Builder body) {
438+
public static Loop Loop(TypeElement resultType, Value m, Value cond, Value v_initial, Body.Builder body) {
319439
return new Loop(resultType, m, cond, v_initial, body);
320440
}
321441
}

‎cr-examples/onnx/src/test/java/oracle/code/onnx/RuntimeTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public void testLoop() throws Exception {
9595
List.of(node("Loop", List.of("max", "cond", "a"), List.of("a_out"), Map.of(
9696
"body", graph(
9797
List.of(),
98-
List.of(scalarInfo("i", INT64.id), scalarInfo("cond_in", BOOL.id), tensorInfo("a_in", INT64.id)),
98+
List.of(tensorInfo("i", INT64.id, true), tensorInfo("cond_in", BOOL.id, true), tensorInfo("a_in", INT64.id)),
9999
List.of(node("Identity", List.of("cond_in"), List.of("cond_out"), Map.of()),
100100
node("Add", List.of("a_in", "a_in"), List.of("a_out"), Map.of())),
101101
List.of("cond_out", "a_out"))))),

‎cr-examples/onnx/src/test/java/oracle/code/onnx/SimpleTest.java

+50-49
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,28 @@
77
import org.junit.jupiter.api.Assertions;
88
import org.junit.jupiter.api.Test;
99

10+
import static java.util.Optional.empty;
11+
import static oracle.code.onnx.OnnxOperators.*;
12+
import static oracle.code.onnx.OnnxRuntime.execute;
13+
1014
public class SimpleTest {
1115

1216
@CodeReflection
1317
public Tensor<Float> add(Tensor<Float> a, Tensor<Float> b) {
14-
return OnnxOperators.Add(a, b);
18+
return Add(a, b);
1519
}
1620

1721
@Test
1822
public void testAdd() throws Exception {
1923
var a = Tensor.ofFlat(1f, 2, 3);
2024
assertEquals(
2125
add(a, a),
22-
OnnxRuntime.execute(() -> add(a, a)));
26+
execute(() -> add(a, a)));
2327
}
2428

2529
@CodeReflection
2630
public Tensor<Float> sub(Tensor<Float> a, Tensor<Float> b) {
27-
return OnnxOperators.Sub(a, b);
31+
return Sub(a, b);
2832
}
2933

3034
@Test
@@ -33,64 +37,64 @@ public void testSub() throws Exception {
3337
var a = Tensor.ofFlat(1f, 2, 3);
3438
assertEquals(
3539
sub(a, b),
36-
OnnxRuntime.execute(() -> sub(a, b)));
40+
execute(() -> sub(a, b)));
3741
}
3842

3943
@CodeReflection
4044
public Tensor<Float> fconstant() {
41-
return OnnxOperators.Constant(-1f);
45+
return Constant(-1f);
4246
}
4347

4448
@Test
4549
public void testFconstant() throws Exception {
4650
// tests the numbers are encoded correctly
4751
var expected = Tensor.ofScalar(-1f);
4852
assertEquals(expected, fconstant());
49-
assertEquals(expected, OnnxRuntime.execute(() -> fconstant()));
53+
assertEquals(expected, execute(() -> fconstant()));
5054
}
5155

5256
@CodeReflection
5357
public Tensor<Float> fconstants() {
54-
return OnnxOperators.Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
58+
return Constant(new float[]{-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE});
5559
}
5660

5761
@Test
5862
public void testFconstants() throws Exception {
5963
// tests the numbers are encoded correctly
6064
var expected = Tensor.ofFlat(-1f, 0, 1, Float.MIN_VALUE, Float.MAX_VALUE);
6165
assertEquals(expected, fconstants());
62-
assertEquals(expected, OnnxRuntime.execute(() -> fconstants()));
66+
assertEquals(expected, execute(() -> fconstants()));
6367
}
6468

6569
@CodeReflection
6670
public Tensor<Long> lconstant() {
67-
return OnnxOperators.Constant(-1l);
71+
return Constant(-1l);
6872
}
6973

7074
@Test
7175
public void testLconstant() throws Exception {
7276
// tests the numbers are encoded correctly
7377
var expected = Tensor.ofScalar(-1l);
7478
assertEquals(expected, lconstant());
75-
assertEquals(expected, OnnxRuntime.execute(() -> lconstant()));
79+
assertEquals(expected, execute(() -> lconstant()));
7680
}
7781

7882
@CodeReflection
7983
public Tensor<Long> lconstants() {
80-
return OnnxOperators.Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
84+
return Constant(new long[]{-1, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE});
8185
}
8286

8387
@Test
8488
public void testLconstants() throws Exception {
8589
// tests the numbers are encoded correctly
8690
var expected = Tensor.ofFlat(-1l, 0, 1, Long.MIN_VALUE, Long.MAX_VALUE);
8791
assertEquals(expected, lconstants());
88-
assertEquals(expected, OnnxRuntime.execute(() -> lconstants()));
92+
assertEquals(expected, execute(() -> lconstants()));
8993
}
9094

9195
@CodeReflection
9296
public Tensor<Long> reshapeAndShape(Tensor<Float> data, Tensor<Long> shape) {
93-
return OnnxOperators.Shape(OnnxOperators.Reshape(data, shape, Optional.empty()), Optional.empty(), Optional.empty());
97+
return Shape(Reshape(data, shape, empty()), empty(), empty());
9498
}
9599

96100
@Test
@@ -99,26 +103,26 @@ public void testReshapeAndShape() throws Exception {
99103
var shape = Tensor.ofFlat(2l, 2, 2);
100104
assertEquals(
101105
reshapeAndShape(data, shape),
102-
OnnxRuntime.execute(() -> reshapeAndShape(data, shape)));
106+
execute(() -> reshapeAndShape(data, shape)));
103107
}
104108

105109
@CodeReflection
106110
public Tensor<Long> indicesOfMaxPool(Tensor<Float> x) {
107111
// testing secondary output
108-
return OnnxOperators.MaxPool(x, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), new long[]{2}).Indices();
112+
return MaxPool(x, empty(), empty(), empty(), empty(), empty(), empty(), new long[]{2}).Indices();
109113
}
110114

111115
@Test
112116
public void testIndicesOfMaxPool() throws Exception {
113117
var x = Tensor.ofShape(new long[]{2, 2, 2}, 1f, 2, 3, 4, 5, 6, 7, 8);
114118
assertEquals(
115119
indicesOfMaxPool(x),
116-
OnnxRuntime.execute(() -> indicesOfMaxPool(x)));
120+
execute(() -> indicesOfMaxPool(x)));
117121
}
118122

119123
@CodeReflection
120124
public Tensor<Float> concat(Tensor<Float> input1, Tensor<Float> input2, long axis) {
121-
return OnnxOperators.Concat(List.of(input1, input2), axis);
125+
return Concat(List.of(input1, input2), axis);
122126
}
123127

124128
@Test
@@ -127,12 +131,12 @@ public void testConcat() throws Exception {
127131
var input2 = Tensor.ofFlat(4f, 5);
128132
assertEquals(
129133
concat(input1, input2, 0),
130-
OnnxRuntime.execute(()-> concat(input1, input2, 0)));
134+
execute(()-> concat(input1, input2, 0)));
131135
}
132136

133137
@CodeReflection
134138
public Tensor<Float> split(Tensor<Float> input, Tensor<Long> split) {
135-
return OnnxOperators.Split(input, Optional.of(split), Optional.empty(), Optional.empty()).get(0);
139+
return Split(input, Optional.of(split), empty(), empty()).get(0);
136140
}
137141

138142
@Test
@@ -141,12 +145,12 @@ public void testSplit() throws Exception {
141145
var split = Tensor.ofFlat(5l);
142146
assertEquals(
143147
split(input, split),
144-
OnnxRuntime.execute(()-> split(input, split)));
148+
execute(()-> split(input, split)));
145149
}
146150

147151
@CodeReflection
148152
public Tensor<Float> ifConst(Tensor<Boolean> cond) {
149-
return OnnxOperators.If(cond, () -> List.of(OnnxOperators.Constant(-1f)), () -> List.of(OnnxOperators.Constant(1f))).get(0);
153+
return If(cond, () -> List.of(Constant(1f)), () -> List.of(Constant(-1f))).get(0);
150154
}
151155

152156
@Test
@@ -157,16 +161,16 @@ public void testIfConst() throws Exception {
157161
var expTrue = Tensor.ofScalar(1f);
158162

159163
assertEquals(expFalse, ifConst(condFalse));
160-
assertEquals(expFalse, OnnxRuntime.execute(() -> ifConst(condFalse)));
164+
assertEquals(expFalse, execute(() -> ifConst(condFalse)));
161165

162166
assertEquals(expTrue, ifConst(condTrue));
163-
assertEquals(expTrue, OnnxRuntime.execute(() -> ifConst(condTrue)));
167+
assertEquals(expTrue, execute(() -> ifConst(condTrue)));
164168
}
165169

166170
@CodeReflection
167171
public Tensor<Float> ifCapture(Tensor<Boolean> cond, Tensor<Float> trueValue) {
168-
var falseValue = OnnxOperators.Constant(-1f);
169-
return OnnxOperators.If(cond, () -> List.of(OnnxOperators.Identity(falseValue)), () -> List.of(OnnxOperators.Identity(trueValue))).get(0);
172+
var falseValue = Constant(-1f);
173+
return If(cond, () -> Identity(trueValue), () -> Identity(falseValue));
170174
}
171175

172176
@Test
@@ -177,24 +181,24 @@ public void testIfCapture() throws Exception {
177181
var expTrue = Tensor.ofScalar(1f);
178182

179183
assertEquals(expFalse, ifCapture(condFalse, expTrue));
180-
assertEquals(expFalse, OnnxRuntime.execute(() -> ifCapture(condFalse, expTrue)));
184+
assertEquals(expFalse, execute(() -> ifCapture(condFalse, expTrue)));
181185

182186
assertEquals(expTrue, ifCapture(condTrue, expTrue));
183-
assertEquals(expTrue, OnnxRuntime.execute(() -> ifCapture(condTrue, expTrue)));
187+
assertEquals(expTrue, execute(() -> ifCapture(condTrue, expTrue)));
184188
}
185189

186190
final Tensor<Float> initialized = Tensor.ofFlat(42f);
187191

188192
@CodeReflection
189193
public Tensor<Float> initialized() {
190-
return OnnxOperators.Identity(initialized);
194+
return Identity(initialized);
191195
}
192196

193197
@Test
194198
public void testInitialized() throws Exception {
195199

196200
assertEquals(initialized(),
197-
OnnxRuntime.execute(() -> initialized()));
201+
execute(() -> initialized()));
198202
}
199203

200204
final Tensor<Float> initialized2 = Tensor.ofFlat(33f);
@@ -203,13 +207,13 @@ public void testInitialized() throws Exception {
203207

204208
@CodeReflection
205209
public Tensor<Float> ifInitialized(Tensor<Boolean> cond1, Tensor<Boolean> cond2) {
206-
return OnnxOperators.If(cond1,
207-
() -> OnnxOperators.If(cond2,
208-
() -> List.of(OnnxOperators.Identity(initialized4)),
209-
() -> List.of(OnnxOperators.Identity(initialized3))),
210-
() -> OnnxOperators.If(cond2,
211-
() -> List.of(OnnxOperators.Identity(initialized2)),
212-
() -> List.of(OnnxOperators.Identity(initialized)))).get(0);
210+
return If(cond1,
211+
() -> If(cond2,
212+
() -> List.of(Identity(initialized)),
213+
() -> List.of(Identity(initialized2))),
214+
() -> If(cond2,
215+
() -> List.of(Identity(initialized3)),
216+
() -> List.of(Identity(initialized4)))).get(0);
213217
}
214218

215219
@Test
@@ -218,33 +222,30 @@ public void testIfInitialized() throws Exception {
218222
var condTrue = Tensor.ofScalar(true);
219223

220224
assertEquals(initialized, ifInitialized(condTrue, condTrue));
221-
assertEquals(initialized, OnnxRuntime.execute(() -> ifInitialized(condTrue, condTrue)));
225+
assertEquals(initialized, execute(() -> ifInitialized(condTrue, condTrue)));
222226
assertEquals(initialized2, ifInitialized(condTrue, condFalse));
223-
assertEquals(initialized2, OnnxRuntime.execute(() -> ifInitialized(condTrue, condFalse)));
227+
assertEquals(initialized2, execute(() -> ifInitialized(condTrue, condFalse)));
224228
assertEquals(initialized3, ifInitialized(condFalse, condTrue));
225-
assertEquals(initialized3, OnnxRuntime.execute(() -> ifInitialized(condFalse, condTrue)));
229+
assertEquals(initialized3, execute(() -> ifInitialized(condFalse, condTrue)));
226230
assertEquals(initialized4, ifInitialized(condFalse, condFalse));
227-
assertEquals(initialized4, OnnxRuntime.execute(() -> ifInitialized(condFalse, condFalse)));
231+
assertEquals(initialized4, execute(() -> ifInitialized(condFalse, condFalse)));
228232

229233
}
230234

235+
static final Tensor<Boolean> TRUE = Tensor.ofScalar(true);
236+
231237
@CodeReflection
232-
public Tensor<Float> forLoopAdd(Tensor<Float> value, Tensor<Long> max, Tensor<Boolean> condition) {
233-
return OnnxOperators.Loop(max, condition, List.of(value),
234-
l -> {
235-
var v = l.userValues().get(0);
236-
return new ExplicitOnnxOperators.LoopLocals<>(l.i(), l.cond(), List.of(OnnxOperators.Add(v, v)));
237-
}).get(0);
238+
public Tensor<Float> forLoopAdd(Tensor<Long> max, Tensor<Float> initialValue) {
239+
return Loop(max, TRUE, initialValue, (i, cond, v) -> LoopReturn(cond, Add(v, v)));
238240
}
239241

240242
@Test
241243
public void testForLoopAdd() throws Exception {
242244
var expected = Tensor.ofFlat(0f, 8, 16, 24);
243245
var value = Tensor.ofFlat(0f, 1, 2, 3);
244246
var max = Tensor.ofScalar(3l);
245-
var cond = Tensor.ofScalar(true);
246-
assertEquals(expected, forLoopAdd(value, max, cond));
247-
// assertEquals(expected, OnnxRuntime.execute(() -> forLoopAdd(value, max, cond)));
247+
assertEquals(expected, forLoopAdd(max, value));
248+
assertEquals(expected, execute(() -> forLoopAdd(max, value)));
248249
}
249250

250251
static void assertEquals(Tensor expected, Tensor actual) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package oracle.code.onnx;
2+
3+
import java.lang.foreign.Arena;
4+
import java.lang.foreign.ValueLayout;
5+
import java.lang.invoke.MethodHandles;
6+
import java.util.List;
7+
import java.util.Optional;
8+
import jdk.incubator.code.CodeReflection;
9+
import org.junit.jupiter.api.Assertions;
10+
import org.junit.jupiter.api.Test;
11+
12+
import static java.util.Optional.empty;
13+
import static oracle.code.onnx.OnnxOperators.*;
14+
import static oracle.code.onnx.OnnxRuntime.execute;
15+
16+
public class WalkTheMazeTest {
17+
18+
final String expectedPath;
19+
20+
// initializers
21+
final Tensor<Byte> maze;
22+
final Tensor<Boolean> _true;
23+
final Tensor<Long> homePos, directionNorth, directionSouth, directionEast, directionWest,
24+
oneOne, zero, two, three, mOne, mThree, max, limit,
25+
stepSouth, stepNorth, stepEast, stepWest, scalarShape, wall;
26+
27+
public WalkTheMazeTest() {
28+
expectedPath = ">>^^>>vv>>>>>><<^^^^^<<<<<<^^>>>>>>>>>>vv<<vvv>>vv>>>>^^>>^<<^^>><<^^>><<<<^^>>>>^^>>vvvvvvvvv>>vv<<<<>>>>^^^<<^^>><<^^"
29+
+ ">>>>vv>>>>>><<^^>>>><<<<vv<<^^^^>>>>>>>>vvvv<<vv<<<<<<<<vvv>>>>>><<<<<<^^>>>>>>>>vv>>>>>>>>^^^<<<<^^>>>>>>>>^^^^^^<vv<<"
30+
+ "<<<<<<<^^>>>>>>><<<^^<<^^^^>>>>vv<<>>vv>>>>^^<<<<^^>>>><<^^>><<<<vv<<<<vv<<<<<<^^>>>>^^<<<<<<vvvv<<^^<<<<^^>>>><<<<vvvv"
31+
+ "<<^^<<<<^^>>>><<<<vvvvvv<<^^vvvv>>>>vv<<<<vvvv^^>>vvvv<<v>>vv<<<<^^^<<^^>>^^<<<<^^^^^^<<<<>>vv<<<<^^<<^^>>^^^^>>vv>>>>>"
32+
+ ">^^<<<<>>>>vv<<<<<<^^<<<<vvvvvv>>vv<<vv>>>>>>>>vv<<<<<<vv>>>>vvvvv<<^^^<<<<^^^^vvvvv>>vv<<>";
33+
34+
var arena = Arena.ofAuto();
35+
36+
maze = Tensor.ofShape(arena, new long[]{22, 48},
37+
"""
38+
###############################################
39+
# # # # # # # # #
40+
# # # ##### # ### ##### ##### ##### ##### # ###
41+
# # # # # # # # # #
42+
# # ######### ### ### # ### # # ##### ### #####
43+
# # # # # # # # #
44+
# ########### # # # # ####### ### ### ### ## #
45+
# # # # # # # # #
46+
### ### # # ### ############### # ##### #######
47+
# # # # # # # # #
48+
# ####### ### ##### # # ##### ##### ######## #
49+
# # # # # # # #
50+
######### ##### ##### ### # ####### ######## #
51+
# # # # # # # # #
52+
# # ######### # # ### ### # # ##### # ###### #
53+
# # # # # # # # # #
54+
# ##### # # ##### ### ########### ### #########
55+
# # # # # # # # # #
56+
# # # # # # # # # #
57+
### # # # ### ### ##### # ####### ### ### # #
58+
# # # # # # # #
59+
###############################################
60+
""".getBytes());
61+
62+
_true = Tensor.ofScalar(arena, true);
63+
homePos = Tensor.ofFlat(arena, 20l, 1); // bottom left corner
64+
directionNorth = Tensor.ofFlat(arena, '^');
65+
directionSouth = Tensor.ofFlat(arena, 'v');
66+
directionEast = Tensor.ofFlat(arena, '>');
67+
directionWest = Tensor.ofFlat(arena, '<');
68+
oneOne = Tensor.ofFlat(arena, 1l, 1);
69+
zero = Tensor.ofFlat(arena, 0l);
70+
two = Tensor.ofFlat(arena, 2l);
71+
three = Tensor.ofFlat(arena, 3l);
72+
mOne = Tensor.ofFlat(arena, -1l);
73+
mThree = Tensor.ofFlat(arena, -3l);
74+
max = Tensor.ofFlat(arena, Long.MAX_VALUE);
75+
limit = Tensor.ofFlat(arena, 1000l);
76+
stepSouth = Tensor.ofFlat(arena, 1l, 0);
77+
stepNorth = Tensor.ofFlat(arena, -1l, 0);
78+
stepEast = Tensor.ofFlat(arena, 0l, 1);
79+
stepWest = Tensor.ofFlat(arena, 0l, -1);
80+
scalarShape = Tensor.ofFlat(arena, new long[0]);
81+
wall = Tensor.ofScalar(arena, '#');
82+
}
83+
84+
@CodeReflection
85+
public Tensor<Long> turnLeft(Tensor<Long> direction) {
86+
return If(Equal(direction, directionEast),
87+
() -> Identity(directionNorth),
88+
() -> If(Equal(direction, directionNorth),
89+
() -> Identity(directionWest),
90+
() -> If(Equal(direction, directionWest),
91+
() -> Identity(directionSouth),
92+
() -> Identity(directionEast))));
93+
}
94+
95+
@CodeReflection
96+
public Tensor<Long> turnRight(Tensor<Long> direction) {
97+
return Loop(three, _true, direction, (i, cond, d)
98+
-> LoopReturn(cond, turnLeft(d)));
99+
}
100+
101+
@CodeReflection
102+
public Tensor<Boolean> isWallAt(Tensor<Long> pos) {
103+
return Equal(CastLike(Slice(maze, pos, Add(pos, oneOne), empty(), empty()), wall, empty()), wall);
104+
}
105+
106+
@CodeReflection
107+
public Tensor<Long> posInFrontOfMe(Tensor<Long> myPos, Tensor<Long> myDirection) {
108+
return If(Equal(myDirection, directionEast),
109+
() -> Add(myPos, stepEast),
110+
() -> If(Equal(myDirection, directionNorth),
111+
() -> Add(myPos, stepNorth),
112+
() -> If(Equal(myDirection, directionWest),
113+
() -> Add(myPos, stepWest),
114+
() -> Add(myPos, stepSouth))));
115+
}
116+
117+
@CodeReflection
118+
public Tensor<Boolean> atHome(Tensor<Long> pos) {
119+
return ReduceMin(Equal(pos, homePos), empty(), empty(), empty());
120+
}
121+
122+
@CodeReflection
123+
public Tensor<Long> lastPos(Tensor<Long> pathLog) {
124+
return Slice(pathLog, mThree, mOne, empty(), empty());
125+
}
126+
127+
@CodeReflection
128+
public Tensor<Long> lastDirection(Tensor<Long> pathLog) {
129+
return Slice(pathLog, mOne, max, empty(), empty());
130+
}
131+
132+
@CodeReflection
133+
public Tensor<Long> addToLog(Tensor<Long> pathLog, Tensor<Long> pos, Tensor<Long> direction) {
134+
return Concat(List.of(pathLog, pos, direction), 0);
135+
}
136+
137+
@CodeReflection
138+
public Tensor<Byte> extractDirections(Tensor<Long> pathLog) {
139+
return Cast(Slice(pathLog, two, max, Optional.of(zero), Optional.of(three)), empty(), 3);
140+
}
141+
142+
@CodeReflection
143+
public Tensor<Long> turnLeftWhileWall(Tensor<Long> pos, Tensor<Long> direction) {
144+
var initialCond = Reshape(isWallAt(posInFrontOfMe(pos, direction)), scalarShape, empty());
145+
return Loop(limit, initialCond, direction, (_, _, dir) -> {
146+
dir = turnLeft(dir);
147+
return LoopReturn(isWallAt(posInFrontOfMe(pos, dir)), dir);
148+
});
149+
}
150+
151+
@CodeReflection
152+
public Tensor<Long> walkAroundTheMaze() {
153+
var start = Concat(List.of(homePos, directionEast), 0);
154+
var pathLog = Loop(limit, _true, start, (_, _, log) -> {
155+
var pos = lastPos(log);
156+
var direction = lastDirection(log);
157+
158+
// walk along the right wall
159+
pos = posInFrontOfMe(pos, direction);
160+
direction = turnRight(direction);
161+
direction = turnLeftWhileWall(pos, direction);
162+
163+
return LoopReturn(Not(atHome(pos)), addToLog(log, pos, direction));
164+
});
165+
return pathLog;
166+
}
167+
168+
@Test
169+
public void testWalkAroundTheMaze() throws Exception {
170+
try (var arena = Arena.ofConfined()) {
171+
var directions = execute(arena, MethodHandles.lookup(), () -> extractDirections(walkAroundTheMaze()));
172+
Assertions.assertEquals(expectedPath, new String(directions.data().toArray(ValueLayout.JAVA_BYTE)));
173+
}
174+
}
175+
}

0 commit comments

Comments
 (0)
Please sign in to comment.