Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8303238: Create generalizations for existing LShift ideal transforms #12734

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
158 changes: 132 additions & 26 deletions src/hotspot/share/opto/mulnode.cpp
@@ -1,5 +1,5 @@
/*
* Copyright (c) 1997, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 1997, 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -847,21 +847,74 @@ Node *LShiftINode::Ideal(PhaseGVN *phase, bool can_reshape) {
}
}

// Check for "(x>>c0)<<c0" which just masks off low bits
if( (add1_op == Op_RShiftI || add1_op == Op_URShiftI ) &&
add1->in(2) == in(2) )
// Convert to "(x & -(1<<c0))"
return new AndINode(add1->in(1),phase->intcon( -(1<<con)));
// Check for "(x >> C1) << C2"
if (add1_op == Op_RShiftI || add1_op == Op_URShiftI) {
// Special case C1 == C2, which just masks off low bits
if (add1->in(2) == in(2)) {
// Convert to "(x & -(1 << C2))"
return new AndINode(add1->in(1), phase->intcon(-(1 << con)));
} else {
int add1Con = 0;
const_shift_count(phase, add1, &add1Con);

// Wait until the right shift has been sharpened to the correct count
if (add1Con > 0 && add1Con < BitsPerJavaInteger) {
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
if (phase->is_IterGVN()) {
if (con > add1Con) {
// Creates "(x << (C2 - C1)) & -(1 << C2)"
Node* lshift = phase->transform(new LShiftINode(add1->in(1), phase->intcon(con - add1Con)));
return new AndINode(lshift, phase->intcon(-(1 << con)));
} else {
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
// Creates "(x >> (C1 - C2)) & -(1 << C2)"

// Handle logical and arithmetic shifts
Node* rshift;
if (add1_op == Op_RShiftI) {
rshift = phase->transform(new RShiftINode(add1->in(1), phase->intcon(add1Con - con)));
} else {
rshift = phase->transform(new URShiftINode(add1->in(1), phase->intcon(add1Con - con)));
}

return new AndINode(rshift, phase->intcon(-(1 << con)));
}
} else {
phase->record_for_igvn(this);
}
}
}
}

// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
if( add1_op == Op_AndI ) {
// Check for "((x >> C1) & Y) << C2"
if (add1_op == Op_AndI) {
Node *add2 = add1->in(1);
int add2_op = add2->Opcode();
if( (add2_op == Op_RShiftI || add2_op == Op_URShiftI ) &&
add2->in(2) == in(2) ) {
// Convert to "(x & (Y<<c0))"
Node *y_sh = phase->transform( new LShiftINode( add1->in(2), in(2) ) );
return new AndINode( add2->in(1), y_sh );
if (add2_op == Op_RShiftI || add2_op == Op_URShiftI) {
// Special case C1 == C2, which just masks off low bits
if (add2->in(2) == in(2)) {
// Convert to "(x & (Y << C2))"
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));
return new AndINode(add2->in(1), y_sh);
}

int add2Con = 0;
const_shift_count(phase, add2, &add2Con);
if (add2Con > 0 && add2Con < BitsPerJavaInteger) {
if (phase->is_IterGVN()) {
// Convert to "((x >> C1) << C2) & (Y << C2)"

// Make "(x >> C1) << C2", which will get folded away by the rule above
Node* x_sh = phase->transform(new LShiftINode(add2, phase->intcon(con)));
// Make "Y << C2", which will simplify when Y is a constant
Node* y_sh = phase->transform(new LShiftINode(add1->in(2), phase->intcon(con)));

return new AndINode(x_sh, y_sh);
} else {
phase->record_for_igvn(this);
}
}
}
}

Expand Down Expand Up @@ -970,21 +1023,74 @@ Node *LShiftLNode::Ideal(PhaseGVN *phase, bool can_reshape) {
}
}

// Check for "(x>>c0)<<c0" which just masks off low bits
if( (add1_op == Op_RShiftL || add1_op == Op_URShiftL ) &&
add1->in(2) == in(2) )
// Convert to "(x & -(1<<c0))"
return new AndLNode(add1->in(1),phase->longcon( -(CONST64(1)<<con)));
// Check for "(x >> C1) << C2"
if (add1_op == Op_RShiftL || add1_op == Op_URShiftL) {
// Special case C1 == C2, which just masks off low bits
if (add1->in(2) == in(2)) {
// Convert to "(x & -(1 << C2))"
return new AndLNode(add1->in(1), phase->longcon(-(CONST64(1) << con)));
} else {
int add1Con = 0;
const_shift_count(phase, add1, &add1Con);

// Wait until the right shift has been sharpened to the correct count
if (add1Con > 0 && add1Con < BitsPerJavaLong) {
// As loop parsing can produce LShiftI nodes, we should wait until the graph is fully formed
// to apply optimizations, otherwise we can inadvertently stop vectorization opportunities.
if (phase->is_IterGVN()) {
if (con > add1Con) {
// Creates "(x << (C2 - C1)) & -(1 << C2)"
Node* lshift = phase->transform(new LShiftLNode(add1->in(1), phase->intcon(con - add1Con)));
return new AndLNode(lshift, phase->longcon(-(CONST64(1) << con)));
} else {
assert(con < add1Con, "must be (%d < %d)", con, add1Con);
// Creates "(x >> (C1 - C2)) & -(1 << C2)"

// Handle logical and arithmetic shifts
Node* rshift;
if (add1_op == Op_RShiftL) {
rshift = phase->transform(new RShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
} else {
rshift = phase->transform(new URShiftLNode(add1->in(1), phase->intcon(add1Con - con)));
}

return new AndLNode(rshift, phase->longcon(-(CONST64(1) << con)));
}
} else {
phase->record_for_igvn(this);
}
}
}
}

// Check for "((x>>c0) & Y)<<c0" which just masks off more low bits
if( add1_op == Op_AndL ) {
Node *add2 = add1->in(1);
// Check for "((x >> C1) & Y) << C2"
if (add1_op == Op_AndL) {
Node* add2 = add1->in(1);
int add2_op = add2->Opcode();
if( (add2_op == Op_RShiftL || add2_op == Op_URShiftL ) &&
add2->in(2) == in(2) ) {
// Convert to "(x & (Y<<c0))"
Node *y_sh = phase->transform( new LShiftLNode( add1->in(2), in(2) ) );
return new AndLNode( add2->in(1), y_sh );
if (add2_op == Op_RShiftL || add2_op == Op_URShiftL) {
// Special case C1 == C2, which just masks off low bits
if (add2->in(2) == in(2)) {
// Convert to "(x & (Y << C2))"
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));
return new AndLNode(add2->in(1), y_sh);
}

int add2Con = 0;
const_shift_count(phase, add2, &add2Con);
if (add2Con > 0 && add2Con < BitsPerJavaLong) {
if (phase->is_IterGVN()) {
// Convert to "((x >> C1) << C2) & (Y << C2)"

// Make "(x >> C1) << C2", which will get folded away by the rule above
Node* x_sh = phase->transform(new LShiftLNode(add2, phase->intcon(con)));
// Make "Y << C2", which will simplify when Y is a constant
Node* y_sh = phase->transform(new LShiftLNode(add1->in(2), phase->intcon(con)));

return new AndLNode(x_sh, y_sh);
} else {
phase->record_for_igvn(this);
}
}
}
}

Expand Down
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -27,7 +27,7 @@

/*
* @test
* @bug 8297384
* @bug 8297384 8303238
* @summary Test that Ideal transformations of LShiftINode* are being performed as expected.
* @library /test/lib /
* @run driver compiler.c2.irTests.LShiftINodeIdealizationTests
Expand All @@ -37,15 +37,21 @@ public static void main(String[] args) {
TestFramework.run();
}

@Run(test = { "test1", "test2" })
@Run(test = { "test1", "test2", "test3", "test4", "test5", "test6", "test7", "test8" })
public void runMethod() {
int a = RunInfo.getRandom().nextInt();
int b = RunInfo.getRandom().nextInt();
int c = RunInfo.getRandom().nextInt();
int d = RunInfo.getRandom().nextInt();

int min = Integer.MIN_VALUE;
int max = Integer.MAX_VALUE;

assertResult(0);
assertResult(a);
assertResult(b);
assertResult(c);
assertResult(d);
assertResult(min);
assertResult(max);
}
Expand All @@ -54,6 +60,12 @@ public void runMethod() {
public void assertResult(int a) {
Asserts.assertEQ((a >> 2022) << 2022, test1(a));
Asserts.assertEQ((a >>> 2022) << 2022, test2(a));
Asserts.assertEQ((a >> 4) << 8, test3(a));
Asserts.assertEQ((a >>> 4) << 8, test4(a));
Asserts.assertEQ((a >> 8) << 4, test5(a));
Asserts.assertEQ((a >>> 8) << 4, test6(a));
Asserts.assertEQ(((a >> 4) & 0xFF) << 8, test7(a));
Asserts.assertEQ(((a >>> 4) & 0xFF) << 8, test8(a));
}

@Test
Expand All @@ -71,4 +83,52 @@ public int test1(int x) {
public int test2(int x) {
return (x >>> 2022) << 2022;
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >> 4) << 8 => (x << 4) & -16
public int test3(int x) {
return (x >> 4) << 8;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks (x >>> 4) << 8 => (x << 4) & -16
public int test4(int x) {
return (x >>> 4) << 8;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.RSHIFT, "1" })
// Checks (x >> 8) << 4 => (x >> 4) & -16
public int test5(int x) {
return (x >> 8) << 4;
}

@Test
@IR(failOn = { IRNode.LSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.URSHIFT, "1" })
// Checks (x >>> 8) << 4 => (x >>> 4) & -16
public int test6(int x) {
return (x >>> 8) << 4;
}

@Test
@IR(failOn = { IRNode.RSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public int test7(int x) {
return ((x >> 4) & 0xFF) << 8;
}

@Test
@IR(failOn = { IRNode.URSHIFT })
@IR(counts = { IRNode.AND, "1", IRNode.LSHIFT, "1" })
// Checks ((x >>> 4) & 0xFF) << 8 => (x << 4) & 0xFF00
public int test8(int x) {
return ((x >>> 4) & 0xFF) << 8;
}
}