Skip to content

Commit f904480

Browse files
author
Ben Perez
committedDec 4, 2024
8345512: Remove wrapper functions for intrinsics in PQC algorithms
Reviewed-by: weijun
1 parent 8d19a56 commit f904480

File tree

3 files changed

+48
-427
lines changed

3 files changed

+48
-427
lines changed
 

‎src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java

+33-361
Original file line numberDiff line numberDiff line change
@@ -71,249 +71,6 @@ public final class ML_KEM {
7171
-1599, -709, -789, -1317, -57, 1049, -584
7272
};
7373

74-
private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_ARR = new short[]{
75-
// level 0
76-
-758, -758, -758, -758, -758, -758, -758, -758,
77-
-758, -758, -758, -758, -758, -758, -758, -758,
78-
-758, -758, -758, -758, -758, -758, -758, -758,
79-
-758, -758, -758, -758, -758, -758, -758, -758,
80-
-758, -758, -758, -758, -758, -758, -758, -758,
81-
-758, -758, -758, -758, -758, -758, -758, -758,
82-
-758, -758, -758, -758, -758, -758, -758, -758,
83-
-758, -758, -758, -758, -758, -758, -758, -758,
84-
-758, -758, -758, -758, -758, -758, -758, -758,
85-
-758, -758, -758, -758, -758, -758, -758, -758,
86-
-758, -758, -758, -758, -758, -758, -758, -758,
87-
-758, -758, -758, -758, -758, -758, -758, -758,
88-
-758, -758, -758, -758, -758, -758, -758, -758,
89-
-758, -758, -758, -758, -758, -758, -758, -758,
90-
-758, -758, -758, -758, -758, -758, -758, -758,
91-
-758, -758, -758, -758, -758, -758, -758, -758,
92-
// level 1
93-
-359, -359, -359, -359, -359, -359, -359, -359,
94-
-359, -359, -359, -359, -359, -359, -359, -359,
95-
-359, -359, -359, -359, -359, -359, -359, -359,
96-
-359, -359, -359, -359, -359, -359, -359, -359,
97-
-359, -359, -359, -359, -359, -359, -359, -359,
98-
-359, -359, -359, -359, -359, -359, -359, -359,
99-
-359, -359, -359, -359, -359, -359, -359, -359,
100-
-359, -359, -359, -359, -359, -359, -359, -359,
101-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
102-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
103-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
104-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
105-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
106-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
107-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
108-
-1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517,
109-
// level 2
110-
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
111-
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
112-
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
113-
1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493,
114-
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
115-
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
116-
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
117-
1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422,
118-
287, 287, 287, 287, 287, 287, 287, 287,
119-
287, 287, 287, 287, 287, 287, 287, 287,
120-
287, 287, 287, 287, 287, 287, 287, 287,
121-
287, 287, 287, 287, 287, 287, 287, 287,
122-
202, 202, 202, 202, 202, 202, 202, 202,
123-
202, 202, 202, 202, 202, 202, 202, 202,
124-
202, 202, 202, 202, 202, 202, 202, 202,
125-
202, 202, 202, 202, 202, 202, 202, 202,
126-
// level 3
127-
-171, -171, -171, -171, -171, -171, -171, -171,
128-
-171, -171, -171, -171, -171, -171, -171, -171,
129-
622, 622, 622, 622, 622, 622, 622, 622,
130-
622, 622, 622, 622, 622, 622, 622, 622,
131-
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
132-
1577, 1577, 1577, 1577, 1577, 1577, 1577, 1577,
133-
182, 182, 182, 182, 182, 182, 182, 182,
134-
182, 182, 182, 182, 182, 182, 182, 182,
135-
962, 962, 962, 962, 962, 962, 962, 962,
136-
962, 962, 962, 962, 962, 962, 962, 962,
137-
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
138-
-1202, -1202, -1202, -1202, -1202, -1202, -1202, -1202,
139-
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
140-
-1474, -1474, -1474, -1474, -1474, -1474, -1474, -1474,
141-
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
142-
1468, 1468, 1468, 1468, 1468, 1468, 1468, 1468,
143-
// level 4
144-
573, 573, 573, 573, 573, 573, 573, 573,
145-
-1325, -1325, -1325, -1325, -1325, -1325, -1325, -1325,
146-
264, 264, 264, 264, 264, 264, 264, 264,
147-
383, 383, 383, 383, 383, 383, 383, 383,
148-
-829, -829, -829, -829, -829, -829, -829, -829,
149-
1458, 1458, 1458, 1458, 1458, 1458, 1458, 1458,
150-
-1602, -1602, -1602, -1602, -1602, -1602, -1602, -1602,
151-
-130, -130, -130, -130, -130, -130, -130, -130,
152-
-681, -681, -681, -681, -681, -681, -681, -681,
153-
1017, 1017, 1017, 1017, 1017, 1017, 1017, 1017,
154-
732, 732, 732, 732, 732, 732, 732, 732,
155-
608, 608, 608, 608, 608, 608, 608, 608,
156-
-1542, -1542, -1542, -1542, -1542, -1542, -1542, -1542,
157-
411, 411, 411, 411, 411, 411, 411, 411,
158-
-205, -205, -205, -205, -205, -205, -205, -205,
159-
-1571, -1571, -1571, -1571, -1571, -1571, -1571, -1571,
160-
// level 5
161-
1223, 1223, 1223, 1223, 652, 652, 652, 652,
162-
-552, -552, -552, -552, 1015, 1015, 1015, 1015,
163-
-1293, -1293, -1293, -1293, 1491, 1491, 1491, 1491,
164-
-282, -282, -282, -282, -1544, -1544, -1544, -1544,
165-
516, 516, 516, 516, -8, -8, -8, -8,
166-
-320, -320, -320, -320, -666, -666, -666, -666,
167-
1711, 1711, 1711, 1711, -1162, -1162, -1162, -1162,
168-
126, 126, 126, 126, 1469, 1469, 1469, 1469,
169-
-853, -853, -853, -853, -90, -90, -90, -90,
170-
-271, -271, -271, -271, 830, 830, 830, 830,
171-
107, 107, 107, 107, -1421, -1421, -1421, -1421,
172-
-247, -247, -247, -247, -951, -951, -951, -951,
173-
-398, -398, -398, -398, 961, 961, 961, 961,
174-
-1508, -1508, -1508, -1508, -725, -725, -725, -725,
175-
448, 448, 448, 448, -1065, -1065, -1065, -1065,
176-
677, 677, 677, 677, -1275, -1275, -1275, -1275,
177-
// level 6
178-
-1103, -1103, 430, 430, 555, 555, 843, 843,
179-
-1251, -1251, 871, 871, 1550, 1550, 105, 105,
180-
422, 422, 587, 587, 177, 177, -235, -235,
181-
-291, -291, -460, -460, 1574, 1574, 1653, 1653,
182-
-246, -246, 778, 778, 1159, 1159, -147, -147,
183-
-777, -777, 1483, 1483, -602, -602, 1119, 1119,
184-
-1590, -1590, 644, 644, -872, -872, 349, 349,
185-
418, 418, 329, 329, -156, -156, -75, -75,
186-
817, 817, 1097, 1097, 603, 603, 610, 610,
187-
1322, 1322, -1285, -1285, -1465, -1465, 384, 384,
188-
-1215, -1215, -136, -136, 1218, 1218, -1335, -1335,
189-
-874, -874, 220, 220, -1187, -1187, 1670, 1670,
190-
-1185, -1185, -1530, -1530, -1278, -1278, 794, 794,
191-
-1510, -1510, -854, -854, -870, -870, 478, 478,
192-
-108, -108, -308, -308, 996, 996, 991, 991,
193-
958, 958, -1460, -1460, 1522, 1522, 1628, 1628
194-
};
195-
private static final short[] MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR = new short[]{
196-
// level 0
197-
-1628, -1628, -1522, -1522, 1460, 1460, -958, -958,
198-
-991, -991, -996, -996, 308, 308, 108, 108,
199-
-478, -478, 870, 870, 854, 854, 1510, 1510,
200-
-794, -794, 1278, 1278, 1530, 1530, 1185, 1185,
201-
1659, 1659, 1187, 1187, -220, -220, 874, 874,
202-
1335, 1335, -1218, -1218, 136, 136, 1215, 1215,
203-
-384, -384, 1465, 1465, 1285, 1285, -1322, -1322,
204-
-610, -610, -603, -603, -1097, -1097, -817, -817,
205-
75, 75, 156, 156, -329, -329, -418, -418,
206-
-349, -349, 872, 872, -644, -644, 1590, 1590,
207-
-1119, -1119, 602, 602, -1483, -1483, 777, 777,
208-
147, 147, -1159, -1159, -778, -778, 246, 246,
209-
-1653, -1653, -1574, -1574, 460, 460, 291, 291,
210-
235, 235, -177, -177, -587, -587, -422, -422,
211-
-105, -105, -1550, -1550, -871, -871, 1251, 1251,
212-
-843, -843, -555, -555, -430, -430, 1103, 1103,
213-
// level 1
214-
1275, 1275, 1275, 1275, -677, -677, -677, -677,
215-
1065, 1065, 1065, 1065, -448, -448, -448, -448,
216-
725, 725, 725, 725, 1508, 1508, 1508, 1508,
217-
-961, -961, -961, -961, 398, 398, 398, 398,
218-
951, 951, 951, 951, 247, 247, 247, 247,
219-
1421, 1421, 1421, 1421, -107, -107, -107, -107,
220-
-830, -830, -830, -830, 271, 271, 271, 271,
221-
90, 90, 90, 90, 853, 853, 853, 853,
222-
-1469, -1469, -1469, -1469, -126, -126, -126, -126,
223-
1162, 1162, 1162, 1162, 1618, 1618, 1618, 1618,
224-
666, 666, 666, 666, 320, 320, 320, 320,
225-
8, 8, 8, 8, -516, -516, -516, -516,
226-
1544, 1544, 1544, 1544, 282, 282, 282, 282,
227-
-1491, -1491, -1491, -1491, 1293, 1293, 1293, 1293,
228-
-1015, -1015, -1015, -1015, 552, 552, 552, 552,
229-
-652, -652, -652, -652, -1223, -1223, -1223, -1223,
230-
// level 2
231-
1571, 1571, 1571, 1571, 1571, 1571, 1571, 1571,
232-
205, 205, 205, 205, 205, 205, 205, 205,
233-
-411, -411, -411, -411, -411, -411, -411, -411,
234-
1542, 1542, 1542, 1542, 1542, 1542, 1542, 1542,
235-
-608, -608, -608, -608, -608, -608, -608, -608,
236-
-732, -732, -732, -732, -732, -732, -732, -732,
237-
-1017, -1017, -1017, -1017, -1017, -1017, -1017, -1017,
238-
681, 681, 681, 681, 681, 681, 681, 681,
239-
130, 130, 130, 130, 130, 130, 130, 130,
240-
1602, 1602, 1602, 1602, 1602, 1602, 1602, 1602,
241-
-1458, -1458, -1458, -1458, -1458, -1458, -1458, -1458,
242-
829, 829, 829, 829, 829, 829, 829, 829,
243-
-383, -383, -383, -383, -383, -383, -383, -383,
244-
-264, -264, -264, -264, -264, -264, -264, -264,
245-
1325, 1325, 1325, 1325, 1325, 1325, 1325, 1325,
246-
-573, -573, -573, -573, -573, -573, -573, -573,
247-
// level 3
248-
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
249-
-1468, -1468, -1468, -1468, -1468, -1468, -1468, -1468,
250-
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
251-
1474, 1474, 1474, 1474, 1474, 1474, 1474, 1474,
252-
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
253-
1202, 1202, 1202, 1202, 1202, 1202, 1202, 1202,
254-
-962, -962, -962, -962, -962, -962, -962, -962,
255-
-962, -962, -962, -962, -962, -962, -962, -962,
256-
-182, -182, -182, -182, -182, -182, -182, -182,
257-
-182, -182, -182, -182, -182, -182, -182, -182,
258-
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
259-
-1577, -1577, -1577, -1577, -1577, -1577, -1577, -1577,
260-
-622, -622, -622, -622, -622, -622, -622, -622,
261-
-622, -622, -622, -622, -622, -622, -622, -622,
262-
171, 171, 171, 171, 171, 171, 171, 171,
263-
171, 171, 171, 171, 171, 171, 171, 171,
264-
// level 4
265-
-202, -202, -202, -202, -202, -202, -202, -202,
266-
-202, -202, -202, -202, -202, -202, -202, -202,
267-
-202, -202, -202, -202, -202, -202, -202, -202,
268-
-202, -202, -202, -202, -202, -202, -202, -202,
269-
-287, -287, -287, -287, -287, -287, -287, -287,
270-
-287, -287, -287, -287, -287, -287, -287, -287,
271-
-287, -287, -287, -287, -287, -287, -287, -287,
272-
-287, -287, -287, -287, -287, -287, -287, -287,
273-
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
274-
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
275-
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
276-
-1422, -1422, -1422, -1422, -1422, -1422, -1422, -1422,
277-
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
278-
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
279-
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
280-
-1493, -1493, -1493, -1493, -1493, -1493, -1493, -1493,
281-
// level 5
282-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
283-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
284-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
285-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
286-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
287-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
288-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
289-
1517, 1517, 1517, 1517, 1517, 1517, 1517, 1517,
290-
359, 359, 359, 359, 359, 359, 359, 359,
291-
359, 359, 359, 359, 359, 359, 359, 359,
292-
359, 359, 359, 359, 359, 359, 359, 359,
293-
359, 359, 359, 359, 359, 359, 359, 359,
294-
359, 359, 359, 359, 359, 359, 359, 359,
295-
359, 359, 359, 359, 359, 359, 359, 359,
296-
359, 359, 359, 359, 359, 359, 359, 359,
297-
359, 359, 359, 359, 359, 359, 359, 359,
298-
// level 6
299-
758, 758, 758, 758, 758, 758, 758, 758,
300-
758, 758, 758, 758, 758, 758, 758, 758,
301-
758, 758, 758, 758, 758, 758, 758, 758,
302-
758, 758, 758, 758, 758, 758, 758, 758,
303-
758, 758, 758, 758, 758, 758, 758, 758,
304-
758, 758, 758, 758, 758, 758, 758, 758,
305-
758, 758, 758, 758, 758, 758, 758, 758,
306-
758, 758, 758, 758, 758, 758, 758, 758,
307-
758, 758, 758, 758, 758, 758, 758, 758,
308-
758, 758, 758, 758, 758, 758, 758, 758,
309-
758, 758, 758, 758, 758, 758, 758, 758,
310-
758, 758, 758, 758, 758, 758, 758, 758,
311-
758, 758, 758, 758, 758, 758, 758, 758,
312-
758, 758, 758, 758, 758, 758, 758, 758,
313-
758, 758, 758, 758, 758, 758, 758, 758,
314-
758, 758, 758, 758, 758, 758, 758, 758
315-
};
316-
31774
private static final int[] MONT_ZETAS_FOR_NTT_MULT = new int[]{
31875
-1003, 1003, 222, -222, -1107, 1107, 172, -172,
31976
-42, 42, 620, -620, 1497, -1497, -1649, 1649,
@@ -333,25 +90,6 @@ public final class ML_KEM {
33390
-1317, 1317, -57, 57, 1049, -1049, -584, 584
33491
};
33592

336-
private static final short[] MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR = new short[]{
337-
-1103, 1103, 430, -430, 555, -555, 843, -843,
338-
-1251, 1251, 871, -871, 1550, -1550, 105, -105,
339-
422, -422, 587, -587, 177, -177, -235, 235,
340-
-291, 291, -460, 460, 1574, -1574, 1653, -1653,
341-
-246, 246, 778, -778, 1159, -1159, -147, 147,
342-
-777, 777, 1483, -1483, -602, 602, 1119, -1119,
343-
-1590, 1590, 644, -644, -872, 872, 349, -349,
344-
418, -418, 329, -329, -156, 156, -75, 75,
345-
817, -817, 1097, -1097, 603, -603, 610, -610,
346-
1322, -1322, -1285, 1285, -1465, 1465, 384, -384,
347-
-1215, 1215, -136, 136, 1218, -1218, -1335, 1335,
348-
-874, 874, 220, -220, -1187, 1187, 1670, 1659,
349-
-1185, 1185, -1530, 1530, -1278, 1278, 794, -794,
350-
-1510, 1510, -854, 854, -870, 870, 478, -478,
351-
-108, 108, -308, 308, 996, -996, 991, -991,
352-
958, -958, -1460, 1460, 1522, -1522, 1628, -1628
353-
};
354-
35593
private final int mlKem_k;
35694
private final int mlKem_eta1;
35795
private final int mlKem_eta2;
@@ -499,7 +237,7 @@ protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
499237
System.arraycopy(kPkePrivateKey, 0, decapsKey, 0, kPkePrivateKey.length);
500238
Arrays.fill(kPkePrivateKey, (byte)0);
501239
System.arraycopy(encapsKey, 0, decapsKey,
502-
kPkePrivateKey.length, encapsKey.length);
240+
kPkePrivateKey.length, encapsKey.length);
503241

504242
mlKemH.update(encapsKey);
505243
try {
@@ -534,7 +272,7 @@ protected ML_KEM_EncapsulateResult encapsulate(
534272
var kHatAndRandomCoins = mlKemG.digest();
535273
var randomCoins = Arrays.copyOfRange(kHatAndRandomCoins, 32, 64);
536274
var cipherText = kPkeEncrypt(new K_PKE_EncryptionKey(encapsulationKey.keyBytes),
537-
randomMessage, randomCoins);
275+
randomMessage, randomCoins);
538276
Arrays.fill(randomCoins, (byte) 0);
539277
byte[] sharedSecret = Arrays.copyOfRange(kHatAndRandomCoins, 0, 32);
540278
Arrays.fill(kHatAndRandomCoins, (byte) 0);
@@ -564,7 +302,7 @@ protected byte[] decapsulate(ML_KEM_DecapsulationKey decapsulationKey,
564302

565303
byte[] kPkePrivateKeyBytes = new byte[mlKem_k * encode12PolyLen];
566304
System.arraycopy(decapsKeyBytes, 0, kPkePrivateKeyBytes, 0,
567-
kPkePrivateKeyBytes.length);
305+
kPkePrivateKeyBytes.length);
568306

569307
byte[] encapsKeyBytes = new byte[mlKem_k * encode12PolyLen + 32];
570308
System.arraycopy(decapsKeyBytes, mlKem_k * encode12PolyLen,
@@ -678,8 +416,8 @@ private K_PKE_KeyPair generateK_PkeKeyPair(byte[] seed) {
678416
pkEncoded, (mlKem_k * ML_KEM_N * 12) / 8, rho.length);
679417

680418
return new K_PKE_KeyPair(
681-
new K_PKE_EncryptionKey(pkEncoded),
682-
new K_PKE_DecryptionKey(skEncoded));
419+
new K_PKE_EncryptionKey(pkEncoded),
420+
new K_PKE_DecryptionKey(skEncoded));
683421
}
684422

685423
private K_PKE_CipherText kPkeEncrypt(
@@ -969,11 +707,9 @@ private short[][] mlKemVectorInverseNTT(short[][] vector) {
969707
return vector;
970708
}
971709

972-
static void implMlKemNtt(short[] poly, short[] ntt_zetas) {
973-
implMlKemNttJava(poly);
974-
}
975-
976-
private static void implMlKemNttJava(short[] poly) {
710+
// The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
711+
// The elements of poly at return will be in the range of [0, ML_KEM_Q]
712+
private void mlKemNTT(short[] poly) {
977713
int[] coeffs = new int[ML_KEM_N];
978714
for (int m = 0; m < ML_KEM_N; m++) {
979715
coeffs[m] = poly[m];
@@ -982,20 +718,12 @@ private static void implMlKemNttJava(short[] poly) {
982718
for (int m = 0; m < ML_KEM_N; m++) {
983719
poly[m] = (short) coeffs[m];
984720
}
985-
}
986-
987-
// The elements of poly should be in the range [-ML_KEM_Q, ML_KEM_Q]
988-
// The elements of poly at return will be in the range of [0, ML_KEM_Q]
989-
private void mlKemNTT(short[] poly) {
990-
implMlKemNtt(poly, MONT_ZETAS_FOR_VECTOR_NTT_ARR);
991721
mlKemBarrettReduce(poly);
992722
}
993723

994-
static void implMlKemInverseNtt(short[] poly, short[] zetas) {
995-
implMlKemInverseNttJava(poly);
996-
}
997-
998-
private static void implMlKemInverseNttJava(short[] poly) {
724+
// Works in place, but also returns its (modified) input so that it can
725+
// be used in expressions
726+
private short[] mlKemInverseNTT(short[] poly) {
999727
int[] coeffs = new int[ML_KEM_N];
1000728
for (int m = 0; m < ML_KEM_N; m++) {
1001729
coeffs[m] = poly[m];
@@ -1004,12 +732,6 @@ private static void implMlKemInverseNttJava(short[] poly) {
1004732
for (int m = 0; m < ML_KEM_N; m++) {
1005733
poly[m] = (short) coeffs[m];
1006734
}
1007-
}
1008-
1009-
// Works in place, but also returns its (modified) input so that it can
1010-
// be used in expressions
1011-
private short[] mlKemInverseNTT(short[] poly) {
1012-
implMlKemInverseNtt(poly, MONT_ZETAS_FOR_VECTOR_INVERSE_NTT_ARR);
1013735
return poly;
1014736
}
1015737

@@ -1100,14 +822,10 @@ private short[] mlKemVectorScalarMult(short[][] a, short[][] b) {
1100822
return result;
1101823
}
1102824

1103-
static void implMlKemNttMult(short[] result, short[] ntta, short[] nttb,
1104-
short[] zetas) {
1105-
implMlKemNttMultJava(result, ntta, nttb);
1106-
}
1107-
1108-
private static void implMlKemNttMultJava(short[] result,
1109-
short[] ntta, short[] nttb) {
1110-
825+
// Multiplies two polynomials represented in the NTT domain.
826+
// The result is a representation of the product still in the NTT domain.
827+
// The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
828+
private void nttMult(short[] result, short[] ntta, short[] nttb) {
1111829
for (int m = 0; m < ML_KEM_N / 2; m++) {
1112830
int a0 = ntta[2 * m];
1113831
int a1 = ntta[2 * m + 1];
@@ -1121,13 +839,6 @@ private static void implMlKemNttMultJava(short[] result,
1121839
}
1122840
}
1123841

1124-
// Multiplies two polynomials represented in the NTT domain.
1125-
// The result is a representation of the product still in the NTT domain.
1126-
// The coefficients in the result are in the range (-ML_KEM_Q, ML_KEM_Q).
1127-
private void nttMult(short[] result, short[] ntta, short[] nttb) {
1128-
implMlKemNttMult(result, ntta, nttb, MONT_ZETAS_FOR_VECTOR_NTT_MULT_ARR);
1129-
}
1130-
1131842
// Adds the vector of polynomials b to a in place, i.e. a will hold
1132843
// the result. It also returns (the modified) a so that it can be used
1133844
// in an expression.
@@ -1142,36 +853,15 @@ private short[][] mlKemAddVec(short[][] a, short[][] b) {
1142853
return a;
1143854
}
1144855

1145-
static void implMlKemAddPoly(short[] result, short[] a, short[] b) {
1146-
implMlKemAddPolyJava(result, a, b);
1147-
}
1148-
1149-
private static void implMlKemAddPolyJava(short[] result, short[] a, short[] b) {
1150-
for (int m = 0; m < ML_KEM_N; m++) {
1151-
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
1152-
result[m] = (short) r;
1153-
}
1154-
}
1155-
1156856
// Adds the polynomial b to a in place, i.e. (the modified) a will hold
1157857
// the result.
1158858
// The coefficients are supposed be greater than -ML_KEM_Q in a and
1159859
// greater than -ML_KEM_Q and less than ML_KEM_Q in b.
1160860
// The coefficients in the result are greater than -ML_KEM_Q.
1161861
private void mlKemAddPoly(short[] a, short[] b) {
1162-
implMlKemAddPoly(a, a, b);
1163-
}
1164-
1165-
static void implMlKemAddPoly(short[] result, short[] a, short[] b, short[] c) {
1166-
implMlKemAddPolyJava(result, a, b, c);
1167-
}
1168-
1169-
private static void implMlKemAddPolyJava(short[] result, short[] a,
1170-
short[] b, short[] c) {
1171-
1172862
for (int m = 0; m < ML_KEM_N; m++) {
1173-
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
1174-
result[m] = (short) r;
863+
int r = a[m] + b[m] + ML_KEM_Q; // This makes r > -ML_KEM_Q
864+
a[m] = (short) r;
1175865
}
1176866
}
1177867

@@ -1181,7 +871,10 @@ private static void implMlKemAddPolyJava(short[] result, short[] a,
1181871
// greater than -ML_KEM_Q and less than ML_KEM_Q.
1182872
// The coefficients in the result are nonnegative and less than ML_KEM_Q.
1183873
private short[] mlKemAddPoly(short[] a, short[] b, short[] c) {
1184-
implMlKemAddPoly(a, a, b, c);
874+
for (int m = 0; m < ML_KEM_N; m++) {
875+
int r = a[m] + b[m] + c[m] + 2 * ML_KEM_Q; // This makes r > - ML_KEM_Q
876+
a[m] = (short) r;
877+
}
1185878
mlKemBarrettReduce(a);
1186879
return a;
1187880
}
@@ -1304,23 +997,6 @@ private short[][] decodeVector(int l, byte[] encodedVector) {
1304997
return result;
1305998
}
1306999

1307-
private static void implMlKem12To16(byte[] condensed, int index,
1308-
short[] parsed, int parsedLength) {
1309-
1310-
implMlKem12To16Java(condensed, index, parsed, parsedLength);
1311-
}
1312-
1313-
private static void implMlKem12To16Java(byte[] condensed, int index,
1314-
short[] parsed, int parsedLength) {
1315-
1316-
for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
1317-
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
1318-
256 * (condensed[i + index + 1] & 0xf));
1319-
parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) +
1320-
16 * (condensed[i + index + 2] & 0xff));
1321-
}
1322-
}
1323-
13241000
// The intrinsic implementations assume that the input and output buffers
13251001
// are such that condensed can be read in 192-byte chunks and
13261002
// parsed can be written in 128 shorts chunks. In other words,
@@ -1330,7 +1006,12 @@ private static void implMlKem12To16Java(byte[] condensed, int index,
13301006
private void twelve2Sixteen(byte[] condensed, int index,
13311007
short[] parsed, int parsedLength) {
13321008

1333-
implMlKem12To16(condensed, index, parsed, parsedLength);
1009+
for (int i = 0; i < parsedLength * 3 / 2; i += 3) {
1010+
parsed[(i / 3) * 2] = (short) ((condensed[i + index] & 0xff) +
1011+
256 * (condensed[i + index + 1] & 0xf));
1012+
parsed[(i / 3) * 2 + 1] = (short) (((condensed[i + index + 1] >>> 4) & 0xf) +
1013+
16 * (condensed[i + index + 2] & 0xff));
1014+
}
13341015
}
13351016

13361017
private static void decodePoly5(byte[] condensed, int index, short[] parsed) {
@@ -1471,18 +1152,6 @@ private static short[] decompressDecode(byte[] input) {
14711152
return result;
14721153
}
14731154

1474-
static void implMlKemBarrettReduce(short[] coeffs) {
1475-
implMlKemBarrettReduceJava(coeffs);
1476-
}
1477-
1478-
private static void implMlKemBarrettReduceJava(short[] coeffs) {
1479-
for (int m = 0; m < ML_KEM_N; m++) {
1480-
int tmp = ((int) coeffs[m] * BARRETT_MULTIPLIER) >>
1481-
BARRETT_SHIFT;
1482-
coeffs[m] = (short) (coeffs[m] - tmp * ML_KEM_Q);
1483-
}
1484-
}
1485-
14861155
// The input elements can have any short value.
14871156
// Modifies poly such that upon return poly[i] will be
14881157
// in the range [0, ML_KEM_Q] and will be congruent with the original
@@ -1493,7 +1162,10 @@ private static void implMlKemBarrettReduceJava(short[] coeffs) {
14931162
// will be in the range [0, ML_KEM_Q), i.e. it will be the canonical
14941163
// representative of its residue class.
14951164
private void mlKemBarrettReduce(short[] poly) {
1496-
implMlKemBarrettReduce(poly);
1165+
for (int m = 0; m < ML_KEM_N; m++) {
1166+
int tmp = ((int) poly[m] * BARRETT_MULTIPLIER) >> BARRETT_SHIFT;
1167+
poly[m] = (short) (poly[m] - tmp * ML_KEM_Q);
1168+
}
14971169
}
14981170

14991171
// Precondition: -(2^MONT_R_BITS -1) * MONT_Q <= b * c < (2^MONT_R_BITS - 1) * MONT_Q
@@ -1503,8 +1175,8 @@ private static int montMul(int b, int c) {
15031175
int a = b * c;
15041176
int aHigh = a >> MONT_R_BITS;
15051177
int aLow = a & ((1 << MONT_R_BITS) - 1);
1506-
int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >>
1507-
(32 - MONT_R_BITS); // signed low product
1178+
// signed low product
1179+
int m = ((MONT_Q_INV_MOD_R * aLow) << (32 - MONT_R_BITS)) >> (32 - MONT_R_BITS);
15081180

15091181
return (aHigh - ((m * MONT_Q) >> MONT_R_BITS)); // subtract signed high product
15101182
}

‎src/java.base/share/classes/com/sun/crypto/provider/ML_KEM_Impls.java

+2-24
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,6 @@
3737

3838
public final class ML_KEM_Impls {
3939

40-
static int name2int(String name) {
41-
if (name.endsWith("512")) {
42-
return 512;
43-
} else if (name.endsWith("768")) {
44-
return 768;
45-
} else if (name.endsWith("1024")) {
46-
return 1024;
47-
} else {
48-
// should not happen
49-
throw new ProviderException("Unknown name " + name);
50-
}
51-
}
52-
5340
public sealed static class KPG
5441
extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 {
5542

@@ -164,17 +151,8 @@ protected byte[] implDecapsulate(String name, byte[] decapsulationKey,
164151

165152
ML_KEM mlKem = new ML_KEM(name);
166153
var kpkeCipherText = new ML_KEM.K_PKE_CipherText(cipherText);
167-
168-
byte[] decapsulateResult;
169-
try {
170-
decapsulateResult = mlKem.decapsulate(
171-
new ML_KEM.ML_KEM_DecapsulationKey(
172-
decapsulationKey), kpkeCipherText);
173-
} catch (DecapsulateException e) {
174-
throw new DecapsulateException("Decapsulate error", e) ;
175-
}
176-
177-
return decapsulateResult;
154+
return mlKem.decapsulate(new ML_KEM.ML_KEM_DecapsulationKey(
155+
decapsulationKey), kpkeCipherText);
178156
}
179157

180158
@Override

‎src/java.base/share/classes/sun/security/provider/ML_DSA.java

+13-42
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
package sun.security.provider;
2727

28-
import jdk.internal.vm.annotation.IntrinsicCandidate;
2928
import sun.security.provider.SHA3.SHAKE128;
3029
import sun.security.provider.SHA3.SHAKE256;
3130

@@ -527,13 +526,13 @@ public int[][] t1Unpack(byte[] v) {
527526
int tOffset = j*4;
528527
int vOffset = (i*320) + (j*5);
529528
t1[i][tOffset] = (v[vOffset] & 0xFF) +
530-
((v[vOffset+1] << 8) & 0x3FF);
529+
((v[vOffset+1] << 8) & 0x3FF);
531530
t1[i][tOffset+1] = ((v[vOffset+1] >> 2) & 0x3F) +
532-
((v[vOffset+2] << 6) & 0x3FF);
531+
((v[vOffset+2] << 6) & 0x3FF);
533532
t1[i][tOffset+2] = ((v[vOffset+2] >> 4) & 0xF) +
534-
((v[vOffset+3] << 4) & 0x3FF);
533+
((v[vOffset+3] << 4) & 0x3FF);
535534
t1[i][tOffset+3] = ((v[vOffset+3] >> 6) & 0x3) +
536-
((v[vOffset+4] << 2) & 0x3FF);
535+
((v[vOffset+4] << 2) & 0x3FF);
537536
}
538537
}
539538
return t1;
@@ -875,8 +874,8 @@ int[][][] generateA(byte[] seed) {
875874
rawOfs = 0;
876875
}
877876
tmp = (rawAij[rawOfs] & 0xFF) +
878-
((rawAij[rawOfs + 1] & 0xFF) << 8) +
879-
((rawAij[rawOfs + 2] & 0x7F) << 16);
877+
((rawAij[rawOfs + 1] & 0xFF) << 8) +
878+
((rawAij[rawOfs + 2] & 0x7F) << 16);
880879
rawOfs += 3;
881880
if (tmp < ML_DSA_Q) {
882881
aij[ofs] = tmp;
@@ -981,7 +980,7 @@ private void decompose(int[][] input, int[][] lowPart, int[][] highPart) {
981980
int multiplier = (gamma2 == 95232 ? 22 : 8);
982981
for (int i = 0; i < mlDsa_k; i++) {
983982
ML_DSA.mlDsaDecomposePoly(input[i], lowPart[i],
984-
highPart[i], gamma2 * 2, multiplier);
983+
highPart[i], gamma2 * 2, multiplier);
985984
}
986985
}
987986

@@ -1032,12 +1031,6 @@ private int[][] useHint(boolean[][] h, int[][] r) {
10321031
*/
10331032

10341033
public static int[] mlDsaNtt(int[] coeffs) {
1035-
implMlDsaAlmostNttJava(coeffs);
1036-
implMlDsaMontMulByConstantJava(coeffs, MONT_R_MOD_Q);
1037-
return coeffs;
1038-
}
1039-
1040-
static void implMlDsaAlmostNttJava(int[] coeffs) {
10411034
int dimension = ML_DSA_N;
10421035
int m = 0;
10431036
for (int l = dimension / 2; l > 0; l /= 2) {
@@ -1050,15 +1043,11 @@ static void implMlDsaAlmostNttJava(int[] coeffs) {
10501043
m++;
10511044
}
10521045
}
1053-
}
1054-
1055-
public static int[] mlDsaInverseNtt(int[] coeffs) {
1056-
implMlDsaAlmostInverseNttJava(coeffs);
1057-
implMlDsaMontMulByConstantJava(coeffs, MONT_DIM_INVERSE);
1046+
montMulByConstant(coeffs, MONT_R_MOD_Q);
10581047
return coeffs;
10591048
}
10601049

1061-
static void implMlDsaAlmostInverseNttJava(int[] coeffs) {
1050+
public static int[] mlDsaInverseNtt(int[] coeffs) {
10621051
int dimension = ML_DSA_N;
10631052
int m = 0;
10641053
for (int l = 1; l < dimension; l *= 2) {
@@ -1067,11 +1056,13 @@ static void implMlDsaAlmostInverseNttJava(int[] coeffs) {
10671056
int tmp = coeffs[j];
10681057
coeffs[j] = (tmp + coeffs[j + l]);
10691058
coeffs[j + l] = montMul(tmp - coeffs[j + l],
1070-
MONT_ZETAS_FOR_INVERSE_NTT[m]);
1059+
MONT_ZETAS_FOR_INVERSE_NTT[m]);
10711060
}
10721061
m++;
10731062
}
10741063
}
1064+
montMulByConstant(coeffs, MONT_DIM_INVERSE);
1065+
return coeffs;
10751066
}
10761067

10771068
void mlDsaVectorNtt(int[][] vector) {
@@ -1086,40 +1077,20 @@ void mlDsaVectorInverseNtt(int[][] vector) {
10861077
}
10871078
}
10881079

1089-
//Todo
1090-
public static void mlDsaNttMultiply(int[] res, int[] coeffs1, int[] coeffs2) {
1091-
implMlDsaNttMultJava(res, coeffs1, coeffs2);
1092-
}
1093-
1094-
static void implMlDsaNttMultJava(int[] product, int[] coeffs1, int[] coeffs2) {
1080+
public static void mlDsaNttMultiply(int[] product, int[] coeffs1, int[] coeffs2) {
10951081
for (int i = 0; i < ML_DSA_N; i++) {
10961082
product[i] = montMul(coeffs1[i], toMont(coeffs2[i]));
10971083
}
10981084
}
10991085

11001086
public static void montMulByConstant(int[] coeffs, int constant) {
1101-
implMlDsaMontMulByConstantJava(coeffs, constant);
1102-
}
1103-
1104-
static void implMlDsaMontMulByConstantJava(int[] coeffs, int constant) {
11051087
for (int i = 0; i < ML_DSA_N; i++) {
11061088
coeffs[i] = montMul((coeffs[i]), constant);
11071089
}
11081090
}
11091091

11101092
public static void mlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
11111093
int twoGamma2, int multiplier) {
1112-
implMlDsaDecomposePoly(input, lowPart, highPart, twoGamma2, multiplier);
1113-
}
1114-
1115-
@IntrinsicCandidate
1116-
static void implMlDsaDecomposePoly(int[] input, int[] lowPart, int[] highPart,
1117-
int twoGamma2, int multiplier) {
1118-
decomposePolyJava(input, lowPart, highPart, twoGamma2, multiplier);
1119-
}
1120-
1121-
static void decomposePolyJava(int[] input, int[] lowPart, int[] highPart,
1122-
int twoGamma2, int multiplier) {
11231094
for (int m = 0; m < ML_DSA_N; m++) {
11241095
int rplus = input[m];
11251096
rplus = rplus - ((rplus + 5373807) >> 23) * ML_DSA_Q;

0 commit comments

Comments
 (0)
Please sign in to comment.