Skip to content

Commit b09a451

Browse files
committedJun 13, 2024
8333840: C2 SuperWord: wrong result for MulAddS2I when inputs permuted
Reviewed-by: kvn, chagedorn
1 parent cff048c commit b09a451

File tree

3 files changed

+141
-32
lines changed

3 files changed

+141
-32
lines changed
 

‎src/hotspot/share/opto/superword.cpp

+54-5
Original file line numberDiff line numberDiff line change
@@ -2794,11 +2794,7 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const {
27942794
}
27952795

27962796
if (VectorNode::is_muladds2i(use)) {
2797-
// MulAddS2I takes shorts and produces ints.
2798-
if (u_pk->size() * 2 != d_pk->size()) {
2799-
return false;
2800-
}
2801-
return true;
2797+
return _packset.is_muladds2i_pack_with_pack_inputs(u_pk);
28022798
}
28032799

28042800
if (u_pk->size() != d_pk->size()) {
@@ -2815,6 +2811,59 @@ bool SuperWord::is_vector_use(Node* use, int u_idx) const {
28152811
return true;
28162812
}
28172813

2814+
// MulAddS2I takes 4 shorts and produces an int. We can reinterpret
2815+
// the 4 shorts as two ints: a = (a0, a1) and b = (b0, b1).
2816+
//
2817+
// Inputs: 1 2 3 4
2818+
// Offsets: 0 0 1 1
2819+
// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1
2820+
//
2821+
// But permutations are possible, because add and mul are commutative. For
2822+
// simplicity, the first input is always either a0 or a1. These are all
2823+
// the possible permutations:
2824+
//
2825+
// v = MulAddS2I(a, b) = a0 * b0 + a1 * b1 (case 1)
2826+
// v = MulAddS2I(a, b) = a0 * b0 + b1 * a1 (case 2)
2827+
// v = MulAddS2I(a, b) = a1 * b1 + a0 * b0 (case 3)
2828+
// v = MulAddS2I(a, b) = a1 * b1 + b0 * a0 (case 4)
2829+
//
2830+
// To vectorize, we expect (a0, a1) to be consecutive in one input pack,
2831+
// and (b0, b1) in the other input pack. Thus, both a and b are strided,
2832+
// with stride = 2. Further, a0 and b0 have offset 0, whereas a1 and b1
2833+
// have offset 1.
2834+
bool PackSet::is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const {
2835+
assert(VectorNode::is_muladds2i(pack->at(0)), "must be MulAddS2I");
2836+
2837+
bool pack1_has_offset_0 = (strided_pack_input_at_index_or_null(pack, 1, 2, 0) != nullptr);
2838+
Node_List* pack1 = strided_pack_input_at_index_or_null(pack, 1, 2, pack1_has_offset_0 ? 0 : 1);
2839+
Node_List* pack2 = strided_pack_input_at_index_or_null(pack, 2, 2, pack1_has_offset_0 ? 0 : 1);
2840+
Node_List* pack3 = strided_pack_input_at_index_or_null(pack, 3, 2, pack1_has_offset_0 ? 1 : 0);
2841+
Node_List* pack4 = strided_pack_input_at_index_or_null(pack, 4, 2, pack1_has_offset_0 ? 1 : 0);
2842+
2843+
return pack1 != nullptr &&
2844+
pack2 != nullptr &&
2845+
pack3 != nullptr &&
2846+
pack4 != nullptr &&
2847+
((pack1 == pack3 && pack2 == pack4) || // case 1 or 3
2848+
(pack1 == pack4 && pack2 == pack3)); // case 2 or 4
2849+
}
2850+
2851+
Node_List* PackSet::strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const {
2852+
Node* def0 = pack->at(0)->in(index);
2853+
2854+
Node_List* pack_in = get_pack(def0);
2855+
if (pack_in == nullptr || pack->size() * stride != pack_in->size()) {
2856+
return nullptr; // size mismatch
2857+
}
2858+
2859+
for (uint i = 1; i < pack->size(); i++) {
2860+
if (pack->at(i)->in(index) != pack_in->at(i * stride + offset)) {
2861+
return nullptr; // use-def mismatch
2862+
}
2863+
}
2864+
return pack_in;
2865+
}
2866+
28182867
// Check if the output type of def is compatible with the input type of use, i.e. if the
28192868
// types have the same size.
28202869
bool SuperWord::is_velt_basic_type_compatible_use_def(Node* use, Node* def) const {

‎src/hotspot/share/opto/superword.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,9 @@ class PackSet : public StackObj {
362362
}
363363
}
364364

365+
Node_List* strided_pack_input_at_index_or_null(const Node_List* pack, const int index, const int stride, const int offset) const;
366+
bool is_muladds2i_pack_with_pack_inputs(const Node_List* pack) const;
365367
Node* same_inputs_at_index_or_null(const Node_List* pack, const int index) const;
366-
367368
VTransformBoolTest get_bool_test(const Node_List* bool_pack) const;
368369

369370
private:

‎test/hotspot/jtreg/compiler/loopopts/superword/TestMulAddS2I.java

+85-26
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ public class TestMulAddS2I {
4141

4242
static short[] sArr1 = new short[RANGE];
4343
static short[] sArr2 = new short[RANGE];
44-
static int[] ioutArr = new int[RANGE];
4544
static final int[] GOLDEN_A;
4645
static final int[] GOLDEN_B;
4746
static final int[] GOLDEN_C;
@@ -50,6 +49,10 @@ public class TestMulAddS2I {
5049
static final int[] GOLDEN_F;
5150
static final int[] GOLDEN_G;
5251
static final int[] GOLDEN_H;
52+
static final int[] GOLDEN_I;
53+
static final int[] GOLDEN_J;
54+
static final int[] GOLDEN_K;
55+
static final int[] GOLDEN_L;
5356

5457
static {
5558
for (int i = 0; i < RANGE; i++) {
@@ -58,12 +61,16 @@ public class TestMulAddS2I {
5861
}
5962
GOLDEN_A = testa();
6063
GOLDEN_B = testb();
61-
GOLDEN_C = testc();
62-
GOLDEN_D = testd();
63-
GOLDEN_E = teste();
64-
GOLDEN_F = testf();
65-
GOLDEN_G = testg();
66-
GOLDEN_H = testh();
64+
GOLDEN_C = testc(new int[ITER]);
65+
GOLDEN_D = testd(new int[ITER]);
66+
GOLDEN_E = teste(new int[ITER]);
67+
GOLDEN_F = testf(new int[ITER]);
68+
GOLDEN_G = testg(new int[ITER]);
69+
GOLDEN_H = testh(new int[ITER]);
70+
GOLDEN_I = testi(new int[ITER]);
71+
GOLDEN_J = testj(new int[ITER]);
72+
GOLDEN_K = testk(new int[ITER]);
73+
GOLDEN_L = testl(new int[ITER]);
6774
}
6875

6976

@@ -72,17 +79,22 @@ public static void main(String[] args) {
7279
TestFramework.runWithFlags("-XX:-AlignVector");
7380
}
7481

75-
@Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh"})
82+
@Run(test = {"testa", "testb", "testc", "testd", "teste", "testf", "testg", "testh",
83+
"testi", "testj", "testk", "testl"})
7684
@Warmup(0)
7785
public static void run() {
7886
compare(testa(), GOLDEN_A, "testa");
7987
compare(testb(), GOLDEN_B, "testb");
80-
compare(testc(), GOLDEN_C, "testc");
81-
compare(testd(), GOLDEN_D, "testd");
82-
compare(teste(), GOLDEN_E, "teste");
83-
compare(testf(), GOLDEN_F, "testf");
84-
compare(testg(), GOLDEN_G, "testg");
85-
compare(testh(), GOLDEN_H, "testh");
88+
compare(testc(new int[ITER]), GOLDEN_C, "testc");
89+
compare(testd(new int[ITER]), GOLDEN_D, "testd");
90+
compare(teste(new int[ITER]), GOLDEN_E, "teste");
91+
compare(testf(new int[ITER]), GOLDEN_F, "testf");
92+
compare(testg(new int[ITER]), GOLDEN_G, "testg");
93+
compare(testh(new int[ITER]), GOLDEN_H, "testh");
94+
compare(testi(new int[ITER]), GOLDEN_I, "testi");
95+
compare(testj(new int[ITER]), GOLDEN_J, "testj");
96+
compare(testk(new int[ITER]), GOLDEN_K, "testk");
97+
compare(testl(new int[ITER]), GOLDEN_L, "testl");
8698
}
8799

88100
public static void compare(int[] out, int[] golden, String name) {
@@ -138,8 +150,7 @@ public static int[] testb() {
138150
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
139151
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
140152
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
141-
public static int[] testc() {
142-
int[] out = new int[ITER];
153+
public static int[] testc(int[] out) {
143154
for (int i = 0; i < ITER; i++) {
144155
out[i] += ((sArr1[2*i] * sArr2[2*i]) + (sArr1[2*i+1] * sArr2[2*i+1]));
145156
}
@@ -155,8 +166,7 @@ public static int[] testc() {
155166
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
156167
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
157168
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
158-
public static int[] testd() {
159-
int[] out = ioutArr;
169+
public static int[] testd(int[] out) {
160170
for (int i = 0; i < ITER-2; i+=2) {
161171
// Unrolled, with the same structure.
162172
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@@ -174,8 +184,7 @@ public static int[] testd() {
174184
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
175185
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
176186
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
177-
public static int[] teste() {
178-
int[] out = ioutArr;
187+
public static int[] teste(int[] out) {
179188
for (int i = 0; i < ITER-2; i+=2) {
180189
// Unrolled, with some swaps.
181190
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@@ -193,8 +202,7 @@ public static int[] teste() {
193202
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
194203
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
195204
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
196-
public static int[] testf() {
197-
int[] out = ioutArr;
205+
public static int[] testf(int[] out) {
198206
for (int i = 0; i < ITER-2; i+=2) {
199207
// Unrolled, with some swaps.
200208
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@@ -212,8 +220,7 @@ public static int[] testf() {
212220
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
213221
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
214222
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
215-
public static int[] testg() {
216-
int[] out = ioutArr;
223+
public static int[] testg(int[] out) {
217224
for (int i = 0; i < ITER-2; i+=2) {
218225
// Unrolled, with some swaps.
219226
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
@@ -231,13 +238,65 @@ public static int[] testg() {
231238
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI, "> 0"})
232239
@IR(applyIfCPUFeature = {"avx512_vnni", "true"},
233240
counts = {IRNode.MUL_ADD_S2I, "> 0", IRNode.MUL_ADD_VS2VI_VNNI, "> 0"})
234-
public static int[] testh() {
235-
int[] out = ioutArr;
241+
public static int[] testh(int[] out) {
236242
for (int i = 0; i < ITER-2; i+=2) {
237243
// Unrolled, with some swaps.
238244
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1]));
239245
out[i+1] += ((sArr2[2*i+3] * sArr1[2*i+3]) + (sArr2[2*i+2] * sArr1[2*i+2])); // swap(1 4), swap(2 3)
240246
}
241247
return out;
242248
}
249+
250+
@Test
251+
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
252+
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
253+
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
254+
public static int[] testi(int[] out) {
255+
for (int i = 0; i < ITER-2; i+=2) {
256+
// Unrolled, with some swaps that prevent vectorization.
257+
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+0]) + (sArr1[2*i+1] * sArr2[2*i+1])); // ok
258+
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
259+
}
260+
return out;
261+
}
262+
263+
@Test
264+
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
265+
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
266+
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
267+
public static int[] testj(int[] out) {
268+
for (int i = 0; i < ITER-2; i+=2) {
269+
// Unrolled, with some swaps that prevent vectorization.
270+
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad
271+
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
272+
}
273+
return out;
274+
}
275+
276+
@Test
277+
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
278+
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
279+
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
280+
public static int[] testk(int[] out) {
281+
for (int i = 0; i < ITER-2; i+=2) {
282+
// Unrolled, with some swaps that prevent vectorization.
283+
out[i+0] += ((sArr1[2*i+0] * sArr2[2*i+1]) + (sArr1[2*i+1] * sArr2[2*i+0])); // bad
284+
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+2]) + (sArr1[2*i+3] * sArr2[2*i+3])); // ok
285+
}
286+
return out;
287+
}
288+
289+
@Test
290+
@IR(counts = {IRNode.MUL_ADD_S2I, "> 0"},
291+
applyIfCPUFeatureOr = {"sse2", "true", "asimd", "true"})
292+
@IR(counts = {IRNode.MUL_ADD_VS2VI, "= 0"})
293+
public static int[] testl(int[] out) {
294+
for (int i = 0; i < ITER-2; i+=2) {
295+
// Unrolled, with some swaps that prevent vectorization.
296+
out[i+0] += ((sArr1[2*i+1] * sArr2[2*i+1]) + (sArr1[2*i+0] * sArr2[2*i+0])); // ok
297+
out[i+1] += ((sArr1[2*i+2] * sArr2[2*i+3]) + (sArr1[2*i+3] * sArr2[2*i+2])); // bad
298+
}
299+
return out;
300+
}
301+
243302
}

0 commit comments

Comments
 (0)
Please sign in to comment.