Skip to content

Commit 5a30fac

Browse files
committedFeb 25, 2025
MNISTDemo UI tweak
Reviewed-by: psandoz
1 parent 4752b37 commit 5a30fac

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed
 

‎cr-examples/onnx/src/test/java/oracle/code/onnx/MNISTDemo.java

+13-12
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.io.*;
3030
import jdk.incubator.code.CodeReflection;
3131
import java.lang.foreign.MemorySegment;
32+
import java.lang.foreign.ValueLayout;
3233
import java.lang.invoke.MethodHandles;
3334
import java.nio.ByteBuffer;
3435
import java.nio.ByteOrder;
@@ -47,10 +48,9 @@
4748
import static oracle.code.onnx.Tensor.ElementType.*;
4849

4950
public class MNISTDemo {
50-
5151
private static float[] loadConstant(String resource) throws IOException {
52-
var bb = ByteBuffer.wrap(MNISTDemo.class.getResourceAsStream(resource).readAllBytes()).order(ByteOrder.LITTLE_ENDIAN);
53-
return FloatBuffer.allocate(bb.capacity() / 4).put(bb.asFloatBuffer()).array();
52+
return MemorySegment.ofArray(MNISTDemo.class.getResourceAsStream(resource).readAllBytes())
53+
.toArray(ValueLayout.JAVA_FLOAT_UNALIGNED);
5454
}
5555

5656
@CodeReflection
@@ -110,11 +110,13 @@ public static Tensor<Float> cnn(Tensor<Float> inputImage) throws IOException {
110110
static final int IMAGE_SIZE = 28;
111111
static final int DRAW_AREA_SIZE = 600;
112112
static final int PEN_SIZE = 20;
113+
static final String[] COLORS = {"1034a6", "412f88", "722b6a", "a2264b", "d3212d", "f62d2d"};
113114

114115
public static void main(String[] args) throws Exception {
115116
var frame = new JFrame("CNN MNIST Demo - Handwritten Digit Classification");
116117
var drawPane = new JPanel(false);
117118
var statusBar = new JLabel(" Hold SHIFT key to draw with trackpad or mouse, click ENTER to run digit classification.");
119+
var results = new JLabel();
118120
var cleanFlag = new AtomicBoolean(true);
119121
var modelRuntimeSession = OnnxRuntime.getInstance().createSession(
120122
OnnxProtoBuilder.buildFuncModel(
@@ -128,6 +130,8 @@ public static void main(String[] args) throws Exception {
128130
var inputArguments = List.of(new Tensor(MemorySegment.ofBuffer(scaledImageDataBuffer), FLOAT, 1, 1, IMAGE_SIZE, IMAGE_SIZE).tensorAddr);
129131
var sampleArray = new float[IMAGE_SIZE * IMAGE_SIZE];
130132

133+
results.setPreferredSize(new Dimension(100, 0));
134+
131135
drawPane.setPreferredSize(new Dimension(DRAW_AREA_SIZE, DRAW_AREA_SIZE));
132136
drawPane.addMouseMotionListener(new MouseAdapter() {
133137
@Override
@@ -145,6 +149,7 @@ public void mouseMoved(MouseEvent e) {
145149

146150
frame.setLayout(new BorderLayout());
147151
frame.add(drawPane, BorderLayout.CENTER);
152+
frame.add(results, BorderLayout.EAST);
148153
frame.add(statusBar, BorderLayout.SOUTH);
149154
frame.pack();
150155
frame.setResizable(false);
@@ -155,17 +160,13 @@ public void keyPressed(KeyEvent e) {
155160
scaledGraphics.drawImage(drawAreaImage.getScaledInstance(IMAGE_SIZE, IMAGE_SIZE, Image.SCALE_SMOOTH), 0, 0, null);
156161
scaledImageDataBuffer.put(0, scaledImage.getData().getSamples(0, 0, IMAGE_SIZE, IMAGE_SIZE, 0, sampleArray));
157162
FloatBuffer result = OnnxRuntime.getInstance().tensorBuffer(modelRuntimeSession.run(inputArguments).getFirst()).asFloatBuffer();
158-
int max = 0;
159-
for (int i = 1; i < 10; i++) {
160-
if (result.get(i) > result.get(max)) max = i;
161-
}
162-
var msg = new StringBuilder("<html>&nbsp;");
163+
var msg = new StringBuilder("<html>");
163164
for (int i = 0; i < 10; i++) {
164-
msg.append((max == i ? "&nbsp;&nbsp;<b><u>%d:&nbsp;%.1f%%</u></b>"
165-
: "&nbsp;&nbsp;%d:&nbsp;%.1f%%")
166-
.formatted(i, 100 * result.get(i)));
165+
var w = result.get(i);
166+
msg.append("&nbsp;<font size=\"%d\" color=\"#%s\">%d</font>&nbsp;(%.1f%%)&nbsp;<br><br><br>"
167+
.formatted((int)(20 * w) + 3, COLORS[(int)(5.99 * w)], i, 100 * w));
167168
}
168-
statusBar.setText(msg.toString());
169+
results.setText(msg.toString());
169170
cleanFlag.set(true);
170171
}
171172
}

0 commit comments

Comments
 (0)