29
29
import java .io .*;
30
30
import jdk .incubator .code .CodeReflection ;
31
31
import java .lang .foreign .MemorySegment ;
32
+ import java .lang .foreign .ValueLayout ;
32
33
import java .lang .invoke .MethodHandles ;
33
34
import java .nio .ByteBuffer ;
34
35
import java .nio .ByteOrder ;
47
48
import static oracle .code .onnx .Tensor .ElementType .*;
48
49
49
50
public class MNISTDemo {
50
-
51
51
private static float [] loadConstant (String resource ) throws IOException {
52
- var bb = ByteBuffer . wrap (MNISTDemo .class .getResourceAsStream (resource ).readAllBytes ()). order ( ByteOrder . LITTLE_ENDIAN );
53
- return FloatBuffer . allocate ( bb . capacity () / 4 ). put ( bb . asFloatBuffer ()). array ( );
52
+ return MemorySegment . ofArray (MNISTDemo .class .getResourceAsStream (resource ).readAllBytes ())
53
+ . toArray ( ValueLayout . JAVA_FLOAT_UNALIGNED );
54
54
}
55
55
56
56
@ CodeReflection
@@ -110,11 +110,13 @@ public static Tensor<Float> cnn(Tensor<Float> inputImage) throws IOException {
110
110
static final int IMAGE_SIZE = 28 ;
111
111
static final int DRAW_AREA_SIZE = 600 ;
112
112
static final int PEN_SIZE = 20 ;
113
+ static final String [] COLORS = {"1034a6" , "412f88" , "722b6a" , "a2264b" , "d3212d" , "f62d2d" };
113
114
114
115
public static void main (String [] args ) throws Exception {
115
116
var frame = new JFrame ("CNN MNIST Demo - Handwritten Digit Classification" );
116
117
var drawPane = new JPanel (false );
117
118
var statusBar = new JLabel (" Hold SHIFT key to draw with trackpad or mouse, click ENTER to run digit classification." );
119
+ var results = new JLabel ();
118
120
var cleanFlag = new AtomicBoolean (true );
119
121
var modelRuntimeSession = OnnxRuntime .getInstance ().createSession (
120
122
OnnxProtoBuilder .buildFuncModel (
@@ -128,6 +130,8 @@ public static void main(String[] args) throws Exception {
128
130
var inputArguments = List .of (new Tensor (MemorySegment .ofBuffer (scaledImageDataBuffer ), FLOAT , 1 , 1 , IMAGE_SIZE , IMAGE_SIZE ).tensorAddr );
129
131
var sampleArray = new float [IMAGE_SIZE * IMAGE_SIZE ];
130
132
133
+ results .setPreferredSize (new Dimension (100 , 0 ));
134
+
131
135
drawPane .setPreferredSize (new Dimension (DRAW_AREA_SIZE , DRAW_AREA_SIZE ));
132
136
drawPane .addMouseMotionListener (new MouseAdapter () {
133
137
@ Override
@@ -145,6 +149,7 @@ public void mouseMoved(MouseEvent e) {
145
149
146
150
frame .setLayout (new BorderLayout ());
147
151
frame .add (drawPane , BorderLayout .CENTER );
152
+ frame .add (results , BorderLayout .EAST );
148
153
frame .add (statusBar , BorderLayout .SOUTH );
149
154
frame .pack ();
150
155
frame .setResizable (false );
@@ -155,17 +160,13 @@ public void keyPressed(KeyEvent e) {
155
160
scaledGraphics .drawImage (drawAreaImage .getScaledInstance (IMAGE_SIZE , IMAGE_SIZE , Image .SCALE_SMOOTH ), 0 , 0 , null );
156
161
scaledImageDataBuffer .put (0 , scaledImage .getData ().getSamples (0 , 0 , IMAGE_SIZE , IMAGE_SIZE , 0 , sampleArray ));
157
162
FloatBuffer result = OnnxRuntime .getInstance ().tensorBuffer (modelRuntimeSession .run (inputArguments ).getFirst ()).asFloatBuffer ();
158
- int max = 0 ;
159
- for (int i = 1 ; i < 10 ; i ++) {
160
- if (result .get (i ) > result .get (max )) max = i ;
161
- }
162
- var msg = new StringBuilder ("<html> " );
163
+ var msg = new StringBuilder ("<html>" );
163
164
for (int i = 0 ; i < 10 ; i ++) {
164
- msg . append (( max == i ? " <b><u>%d: %.1f%%</u></b>"
165
- : " %d: %.1f%%" )
166
- .formatted (i , 100 * result . get ( i ) ));
165
+ var w = result . get ( i );
166
+ msg . append ( " <font size= \" %d \" color= \" #%s \" >%d</font> ( %.1f%%) <br><br><br>"
167
+ .formatted (( int )( 20 * w ) + 3 , COLORS [( int )( 5.99 * w )], i , 100 * w ));
167
168
}
168
- statusBar .setText (msg .toString ());
169
+ results .setText (msg .toString ());
169
170
cleanFlag .set (true );
170
171
}
171
172
}
0 commit comments