Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8308363: Initial compiler support for FP16 scalar operations. #848

Closed
wants to merge 9 commits into from
9 changes: 8 additions & 1 deletion src/hotspot/share/classfile/classFileParser.cpp
Original file line number Diff line number Diff line change
@@ -4704,6 +4704,13 @@ bool ClassFileParser::is_jdk_internal_class(const Symbol* class_name) const {
return false;
}

bool ClassFileParser::is_jdk_internal_class_sig(const char* sig) const {
if (strstr(sig, vmSymbols::java_lang_Float16_signature()->as_C_string())) {
return true;
}
return false;
}

// utility methods for format checking

void ClassFileParser::verify_legal_class_modifiers(jint flags, const char* name, bool is_Object, TRAPS) const {
@@ -5166,7 +5173,7 @@ const char* ClassFileParser::skip_over_field_signature(const char* signature,
case JVM_SIGNATURE_PRIMITIVE_OBJECT:
// Can't enable this check fully until JDK upgrades the bytecode generators (TODO: JDK-8270852).
// For now, compare to class file version 51 so old verifier doesn't see Q signatures.
if ( (_major_version < 51 /* CONSTANT_CLASS_DESCRIPTORS */ ) || (!EnablePrimitiveClasses)) {
if ( (_major_version < 51 /* CONSTANT_CLASS_DESCRIPTORS */ ) || (!EnablePrimitiveClasses && !is_jdk_internal_class_sig(signature))) {
Comment on lines 5175 to +5176

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last condition (!EnablePrimitiveClasses && !is_jdk_internal_class_sig(signature)) seems not right to me. Assume there is a jdk internal primitive class defined , and EnablePrimitiveClasses is disabled, then result of (!EnablePrimitiveClasses && !is_jdk_internal_class_sig(signature)) is false (i.e. true && false), then the followed error is not printed if the jdk version matches >= 51. But it should report the error since EnablePrimitiveClasses is closed. So it should be:

if ( (_major_version < 51 /* CONSTANT_CLASS_DESCRIPTORS */ ) || !EnablePrimitiveClasses || !is_jdk_internal_class_sig(signature))

right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea here is to relax the need to use an explicit JVM flag -XX:+EnablePrimtiiveClasses for primitive classes known to VM.

classfile_parse_error("Class name contains illegal Q-signature "
"in descriptor in class file %s, requires option -XX:+EnablePrimitiveClasses",
CHECK_0);
2 changes: 2 additions & 0 deletions src/hotspot/share/classfile/classFileParser.hpp
Original file line number Diff line number Diff line change
@@ -219,6 +219,8 @@ class ClassFileParser {

bool is_jdk_internal_class(const Symbol* class_name) const;

bool is_jdk_internal_class_sig(const char* sig) const;

void parse_stream(const ClassFileStream* const stream, TRAPS);

void mangle_hidden_class_name(InstanceKlass* const ik);
1 change: 1 addition & 0 deletions src/hotspot/share/classfile/vmSymbols.hpp
Original file line number Diff line number Diff line change
@@ -82,6 +82,7 @@ class SerializeClosure;
template(java_lang_CharacterDataLatin1, "java/lang/CharacterDataLatin1") \
template(java_lang_Float, "java/lang/Float") \
template(java_lang_Float16, "java/lang/Float16") \
template(java_lang_Float16_signature, "Qjava/lang/Float16;") \
template(java_lang_Double, "java/lang/Double") \
template(java_lang_Byte, "java/lang/Byte") \
template(java_lang_Byte_ByteCache, "java/lang/Byte$ByteCache") \
8 changes: 8 additions & 0 deletions src/hotspot/share/opto/convertnode.cpp
Original file line number Diff line number Diff line change
@@ -867,6 +867,14 @@ const Type* ReinterpretS2HFNode::Value(PhaseGVN* phase) const {
return Type::FLOAT;
}

Node* ReinterpretS2HFNode::Identity(PhaseGVN* phase) {
if (in(1)->Opcode() == Op_ReinterpretHF2S) {
assert(in(1)->in(1)->bottom_type()->isa_float(), "");
return in(1)->in(1);
}
return this;
}

const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const {
const Type* type = phase->type( in(1) );
// Convert Float constant value to FP16 constant value.
1 change: 1 addition & 0 deletions src/hotspot/share/opto/convertnode.hpp
Original file line number Diff line number Diff line change
@@ -179,6 +179,7 @@ class ReinterpretS2HFNode : public Node {
virtual int Opcode() const;
virtual const Type *bottom_type() const { return Type::FLOAT; }
virtual const Type* Value(PhaseGVN* phase) const;
virtual Node* Identity(PhaseGVN* phase);
virtual uint ideal_reg() const { return Op_RegF; }
};

2 changes: 2 additions & 0 deletions src/java.base/share/classes/java/lang/Float16.java
Original file line number Diff line number Diff line change
@@ -52,6 +52,8 @@
* @since 20.00
*/

// Currently Float16 is a primitive class but in future will be aligned with
// Enhanced Primitive Boxes described by JEP-402 (https://openjdk.org/jeps/402)
public primitive class Float16 {
private final short value;

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/**
* @test
* @bug 8308363
* @summary Validate compiler IR for FP16 scalar operations.
* @requires vm.compiler2.enabled
* @library /test/lib /
* @compile -XDenablePrimitiveClasses TestFP16ScalarAdd.java
* @run driver compiler.vectorization.TestFP16ScalarAdd
*/

package compiler.vectorization;
import compiler.lib.ir_framework.*;
import java.util.Random;

public class TestFP16ScalarAdd {
private static final int count = 1024;

private short[] src;
private short[] dst;
private short res;

public static void main(String args[]) {
TestFramework.run(TestFP16ScalarAdd.class);
}

public TestFP16ScalarAdd() {
src = new short[count];
dst = new short[count];
for (int i = 0; i < count; i++) {
src[i] = Float.floatToFloat16(i);
}
}

@Test
@IR(applyIfCPUFeature = {"avx512_fp16", "true"}, counts = {IRNode.ADD_HF, "> 0", IRNode.REINTERPRET_S2HF, "> 0", IRNode.REINTERPRET_HF2S, "> 0"})
public void test1() {
Float16 res = new Float16((short)0);
for (int i = 0; i < count; i++) {
res = res.add(Float16.valueOf(src[i]));
dst[i] = res.float16ToRawShortBits();
}
}

@Test
@IR(applyIfCPUFeature = {"avx512_fp16", "true"}, failOn = {IRNode.ADD_HF, IRNode.REINTERPRET_S2HF, IRNode.REINTERPRET_HF2S})
public void test2() {
Float16 hf0 = Float16.valueOf((short)0);
Float16 hf1 = Float16.valueOf((short)15360);
Float16 hf2 = Float16.valueOf((short)16384);
Float16 hf3 = Float16.valueOf((short)16896);
Float16 hf4 = Float16.valueOf((short)17408);
res = hf0.add(hf1).add(hf2).add(hf3).add(hf4).float16ToRawShortBits();
}
}
15 changes: 15 additions & 0 deletions test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java
Original file line number Diff line number Diff line change
@@ -151,6 +151,11 @@ public class IRNode {
beforeMatchingNameRegex(ADD_L, "AddL");
}

public static final String ADD_HF = PREFIX + "ADD_HF" + POSTFIX;
static {
beforeMatchingNameRegex(ADD_HF, "AddHF");
}

public static final String ADD_V = PREFIX + "ADD_V" + POSTFIX;
static {
beforeMatchingNameRegex(ADD_V, "AddV(B|S|I|L|F|D)");
@@ -893,6 +898,16 @@ public class IRNode {
trapNodes(RANGE_CHECK_TRAP,"range_check");
}

public static final String REINTERPRET_S2HF = PREFIX + "REINTERPRET_S2HF" + POSTFIX;
static {
beforeMatchingNameRegex(REINTERPRET_S2HF, "ReinterpretS2HF");
}

public static final String REINTERPRET_HF2S = PREFIX + "REINTERPRET_HF2S" + POSTFIX;
static {
beforeMatchingNameRegex(REINTERPRET_HF2S, "ReinterpretHF2S");
}

public static final String REPLICATE_B = PREFIX + "REPLICATE_B" + POSTFIX;
static {
String regex = START + "ReplicateB" + MID + END;
Original file line number Diff line number Diff line change
@@ -75,6 +75,7 @@ public class IREncodingPrinter {
"avx512dq",
"avx512vl",
"avx512f",
"avx512_fp16",
// AArch64
"sha3",
"asimd",
131 changes: 131 additions & 0 deletions test/jdk/java/lang/Float16/FP16ReductionOperations.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/

/*
* @test
* @bug 8308363
* @summary Test FP16 reduction operations.
* @compile -XDenablePrimitiveClasses FP16ReductionOperations.java
* @run main/othervm -XX:+EnablePrimitiveClasses -XX:-TieredCompilation -Xbatch FP16ReductionOperations
*/

import java.util.Random;

public class FP16ReductionOperations {

public static Random r = new Random(1024);

public static short test_reduction_add_constants() {
Float16 hf0 = Float16.valueOf((short)0);
Float16 hf1 = Float16.valueOf((short)15360);
Float16 hf2 = Float16.valueOf((short)16384);
Float16 hf3 = Float16.valueOf((short)16896);
Float16 hf4 = Float16.valueOf((short)17408);
return hf0.add(hf1).add(hf2).add(hf3).add(hf4).float16ToRawShortBits();
}

public static short expected_reduction_add_constants() {
Float16 hf0 = Float16.valueOf((short)0);
Float16 hf1 = Float16.valueOf((short)15360);
Float16 hf2 = Float16.valueOf((short)16384);
Float16 hf3 = Float16.valueOf((short)16896);
Float16 hf4 = Float16.valueOf((short)17408);
return Float.floatToFloat16(Float.float16ToFloat(hf0.float16ToRawShortBits()) +
Float.float16ToFloat(hf1.float16ToRawShortBits()) +
Float.float16ToFloat(hf2.float16ToRawShortBits()) +
Float.float16ToFloat(hf3.float16ToRawShortBits()) +
Float.float16ToFloat(hf4.float16ToRawShortBits()));
}

public static void test_reduction_constants(char oper) {
short actual = 0;
short expected = 0;
switch(oper) {
case '+' -> {
actual = test_reduction_add_constants();
expected = expected_reduction_add_constants();
}
default -> throw new AssertionError("Unsupported Operation.");
}
if (actual != expected) {
throw new AssertionError("Result mismatch!, expected = " + expected + " actual = " + actual);
}
}

public static short test_reduction_add(short [] arr) {
Float16 res = Float16.valueOf((short)0);
for (int i = 0; i < arr.length; i++) {
res = res.add(Float16.valueOf(arr[i]));
}
return res.float16ToRawShortBits();
}

public static short expected_reduction_add(short [] arr) {
short res = 0;
for (int i = 0; i < arr.length; i++) {
res = Float.floatToFloat16(Float.float16ToFloat(res) + Float.float16ToFloat(arr[i]));
}
return res;
}

public static void test_reduction(char oper, short [] arr) {
short actual = 0;
short expected = 0;
switch(oper) {
case '+' -> {
actual = test_reduction_add(arr);
expected = expected_reduction_add(arr);
}
default -> throw new AssertionError("Unsupported Operation.");
}
if (actual != expected) {
throw new AssertionError("Result mismatch!, expected = " + expected + " actual = " + actual);
}
}

public static short [] get_fp16_array(int size) {
short [] arr = new short[size];
for (int i = 0; i < arr.length; i++) {
arr[i] = Float.floatToFloat16(r.nextFloat());
}
return arr;
}

public static void main(String [] args) {
int res = 0;
short [] input = get_fp16_array(1024);
short [] special_values = {
32256, // NAN
31744, // +Inf
(short)-1024, // -Inf
0, // +0.0
(short)-32768, // -0.0
};
for (int i = 0; i < 1000; i++) {
test_reduction('+', input);
test_reduction('+', special_values);
test_reduction_constants('+');
}
System.out.println("PASS");
}
}
2 changes: 1 addition & 1 deletion test/jdk/java/lang/Float16/FP16ScalarOperations.java
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@

public class FP16ScalarOperations {

public static Random r = new Random(1024);
public static Random r = new Random(1024);

public static short actual_value(char oper, short val1, short val2) {
Float16 obj1 = new Float16((short)val1);