Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8339473: Add support for FP16 isFinite, isInfinite and isNaN #1239

Closed
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/java.base/share/classes/java/lang/Float16.java
Original file line number Diff line number Diff line change
@@ -444,7 +444,9 @@ public static Float16 valueOf(BigDecimal v) {
* @see Double#isNaN(double)
*/
public static boolean isNaN(Float16 f16) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With changing the implementation if isNaN, isFinite, etc. to bit-based, the regression tests should be updated to cover more cases. In particular, different NaN bit patterns should be tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jddarcy , thank you for the review. I will add more input patterns to be tested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update; looks fine.

return ((float16ToRawShortBits(f16) & 0x7e00) == 0x7e00);
final short bits = float16ToRawShortBits(f16);
// A NaN value has all ones in its exponent and a non-zero significand
return ((bits & 0x7c00) == 0x7c00 && (bits & 0x03ff) != 0);
}

/**
202 changes: 120 additions & 82 deletions test/jdk/java/lang/Float16/FP16ScalarOperations.java
Original file line number Diff line number Diff line change
@@ -24,120 +24,158 @@

/*
* @test
* @bug 8308363 8336406
* @summary Verify binary FP16 scalar operations
* @bug 8308363 8336406 8339473
* @summary Verify FP16 unary, binary and ternary operations
* @compile FP16ScalarOperations.java
* @run main/othervm --enable-preview -XX:-TieredCompilation -Xbatch FP16ScalarOperations
*/

import java.util.Random;

import java.util.stream.IntStream;
import static java.lang.Float16.*;

public class FP16ScalarOperations {

public static Random r = new Random(1024);
public static final int SIZE = 65504;
public static Random r = new Random(SIZE);
public static final Float16 ONE = Float16.valueOf(1.0);
public static final Float16 ZERO = Float16.valueOf(0.0);
public static final int EXP = 0x7c00; // Mask for Float16 Exponent in a NaN (which is all ones)
public static final int SIGN_BIT = 0x8000; // Mask for sign bit for Float16

public static short actual_value(String oper, short... val) {
public static Float16 actual_value(String oper, Float16... val) {
switch (oper) {
case "abs" : return float16ToRawShortBits(Float16.abs(shortBitsToFloat16(val[0])));
case "neg" : return float16ToRawShortBits(Float16.negate(shortBitsToFloat16(val[0])));
case "sqrt" : return float16ToRawShortBits(Float16.sqrt(shortBitsToFloat16(val[0])));
case "isInfinite" : return (short)(Float16.isInfinite(shortBitsToFloat16(val[0])) ? 1 : 0);
case "isFinite" : return (short)(Float16.isFinite(shortBitsToFloat16(val[0])) ? 1 : 0);
case "isNaN" : return (short)(Float16.isNaN(shortBitsToFloat16(val[0])) ? 1 : 0);
case "+" : return float16ToRawShortBits(Float16.add(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "-" : return float16ToRawShortBits(Float16.subtract(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "*" : return float16ToRawShortBits(Float16.multiply(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "/" : return float16ToRawShortBits(Float16.divide(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "min" : return float16ToRawShortBits(Float16.min(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "max" : return float16ToRawShortBits(Float16.max(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1])));
case "fma" : return float16ToRawShortBits(Float16.fma(shortBitsToFloat16(val[0]), shortBitsToFloat16(val[1]), shortBitsToFloat16(val[2])));
case "abs" : return Float16.abs(val[0]);
case "neg" : return Float16.negate(val[0]);
case "sqrt" : return Float16.sqrt(val[0]);
case "isInfinite" : return Float16.isInfinite(val[0]) ? ONE : ZERO;
case "isFinite" : return Float16.isFinite(val[0]) ? ONE : ZERO;
case "isNaN" : return Float16.isNaN(val[0]) ? ONE : ZERO;
case "+" : return Float16.add(val[0], val[1]);
case "-" : return Float16.subtract(val[0], val[1]);
case "*" : return Float16.multiply(val[0], val[1]);
case "/" : return Float16.divide(val[0], val[1]);
case "min" : return Float16.min(val[0], val[1]);
case "max" : return Float16.max(val[0], val[1]);
case "fma" : return Float16.fma(val[0], val[1], val[2]);
default : throw new AssertionError("Unsupported Operation!");
}
}

public static void test_operations(short [] arr1, short arr2[], short arr3[]) {
for (int i = 0; i < arr1.length; i++) {
validate("abs", arr1[i]);
validate("neg", arr1[i]);
validate("sqrt", arr1[i]);
validate("isInfinite", arr1[i]);
validate("isFinite", arr1[i]);
validate("isNaN", arr1[i]);
validate("+", arr1[i], arr2[i]);
validate("-", arr1[i], arr2[i]);
validate("*", arr1[i], arr2[i]);
validate("/", arr1[i], arr2[i]);
validate("min", arr1[i], arr2[i]);
validate("max", arr1[i], arr2[i]);
validate("fma", arr1[i], arr2[i], arr3[i]);
public static Float16 expected_value(String oper, Float16... val) {
switch (oper) {
case "abs" : return Float16.valueOf(Math.abs(val[0].floatValue()));
case "neg" : return Float16.shortBitsToFloat16((short)(Float16.float16ToRawShortBits(val[0]) ^ (short)0x0000_8000));
case "sqrt" : return Float16.valueOf(Math.sqrt(val[0].floatValue()));
case "isInfinite" : return Float.isInfinite(val[0].floatValue()) ? ONE : ZERO;
case "isFinite" : return Float.isFinite(val[0].floatValue()) ? ONE : ZERO;
case "isNaN" : return Float.isNaN(val[0].floatValue()) ? ONE : ZERO;
case "+" : return Float16.valueOf(val[0].floatValue() + val[1].floatValue());
case "-" : return Float16.valueOf(val[0].floatValue() - val[1].floatValue());
case "*" : return Float16.valueOf(val[0].floatValue() * val[1].floatValue());
case "/" : return Float16.valueOf(val[0].floatValue() / val[1].floatValue());
case "min" : return Float16.valueOf(Float.min(val[0].floatValue(), val[1].floatValue()));
case "max" : return Float16.valueOf(Float.max(val[0].floatValue(), val[1].floatValue()));
case "fma" : return Float16.valueOf(val[0].floatValue() * val[1].floatValue() + val[2].floatValue());
default : throw new AssertionError("Unsupported Operation!");
}
}

public static short expected_value(String oper, short... input) {
switch(oper) {
case "abs" : return Float.floatToFloat16(Math.abs(Float.float16ToFloat(input[0])));
case "neg" : return (short)(input[0] ^ (short)0x0000_8000);
case "sqrt" : return Float.floatToFloat16((float)Math.sqrt((double)Float.float16ToFloat(input[0])));
case "isInfinite" : return (short)(Float.isInfinite(Float.float16ToFloat(input[0])) ? 1 : 0);
case "isFinite" : return (short)(Float.isFinite(Float.float16ToFloat(input[0])) ? 1 : 0);
case "isNaN" : return (short)(Float.isNaN(Float.float16ToFloat(input[0])) ? 1 : 0);
case "+" : return Float.floatToFloat16(Float.float16ToFloat(input[0]) + Float.float16ToFloat(input[1]));
case "-" : return Float.floatToFloat16(Float.float16ToFloat(input[0]) - Float.float16ToFloat(input[1]));
case "*" : return Float.floatToFloat16(Float.float16ToFloat(input[0]) * Float.float16ToFloat(input[1]));
case "/" : return Float.floatToFloat16(Float.float16ToFloat(input[0]) / Float.float16ToFloat(input[1]));
case "min" : return Float.floatToFloat16(Float.min(Float.float16ToFloat(input[0]), Float.float16ToFloat(input[1])));
case "max" : return Float.floatToFloat16(Float.max(Float.float16ToFloat(input[0]), Float.float16ToFloat(input[1])));
case "fma" : return Float.floatToFloat16(Float.float16ToFloat(input[0]) * Float.float16ToFloat(input[1]) + Float.float16ToFloat(input[2]));
default : throw new AssertionError("Unsupported Operation!");
public static void validate(String oper, Float16... input) {
int arity = input.length;

short actual = Float16.float16ToRawShortBits(actual_value(oper, input));
short expected = Float16.float16ToRawShortBits(expected_value(oper, input));

if (actual != expected) {
switch (arity) {
case 1:
throw new AssertionError("Test Failed: " + oper + "(" + Float16.float16ToRawShortBits(input[0]) + ") : " + actual + " != " + expected);
case 2:
throw new AssertionError("Test Failed: " + oper + "(" + Float16.float16ToRawShortBits(input[0]) + ", " + Float16.float16ToRawShortBits(input[1]) + ") : " + actual + " != " + expected);
case 3:
throw new AssertionError("Test failed: " + oper + "(" + Float16.float16ToRawShortBits(input[0]) + ", " + Float16.float16ToRawShortBits(input[1]) + ", " + Float16.float16ToRawShortBits(input[2]) + ") : " + actual + " != " + expected);
default:
throw new AssertionError("Incorrect operation (" + oper + ") arity = " + arity);
}
}
}

public static boolean compare(short actual, short expected) {
return !((0xFFFF & actual) == (0xFFFF & expected));
public static void test_unary_operations(Float16 [] inp) {
for (int i = 0; i < inp.length; i++) {
validate("abs", inp[i]);
validate("neg", inp[i]);
validate("sqrt", inp[i]);
validate("isInfinite", inp[i]);
validate("isFinite", inp[i]);
validate("isNaN", inp[i]);
}
}

public static void validate(String oper, short... input) {
short actual = actual_value(oper, input);
short expected = expected_value(oper, input);
if (compare(actual, expected)) {
if (input.length == 1) {
throw new AssertionError("Test Failed: " + oper + "(" + input[0] + ") : " + actual + " != " + expected);
}
if (input.length == 2) {
throw new AssertionError("Test Failed: " + oper + "(" + input[0] + ", " + input[1] + ") : " + actual + " != " + expected);
}
if (input.length == 3) {
throw new AssertionError("Test failed: " + oper + "(" + input[0] + ", " + input[1] + ", " + input[2] + ") : " + actual + " != " + expected);
}
public static void test_binary_operations(Float16 [] inp1, Float16 inp2[]) {
for (int i = 0; i < inp1.length; i++) {
validate("+", inp1[i], inp2[i]);
validate("-", inp1[i], inp2[i]);
validate("*", inp1[i], inp2[i]);
validate("/", inp1[i], inp2[i]);
}
}

public static short [] get_fp16_array(int size) {
short [] arr = new short[size];
for (int i = 0; i < arr.length; i++) {
arr[i] = Float.floatToFloat16(r.nextFloat());
public static void test_ternary_operations(Float16 [] inp1, Float16 inp2[], Float16 inp3[]) {
for (int i = 0; i < inp1.length; i++) {
validate("fma", inp1[i], inp2[i], inp3[i]);
}
}

public static void test_fin_inf_nan() {
Float16 pos_nan, neg_nan;
// Starting from 1 as the significand in a NaN value is always non-zero
for (int i = 1; i < 0x03ff; i++) {
pos_nan = Float16.shortBitsToFloat16((short)(EXP | i));
neg_nan = Float16.shortBitsToFloat16((short)(Float16.float16ToRawShortBits(pos_nan) | SIGN_BIT));

// Test isFinite, isInfinite, isNaN for all positive NaN values
validate("isInfinite", pos_nan);
validate("isFinite", pos_nan);
validate("isNaN", pos_nan);

// Test isFinite, isinfinite, isNaN for all negative NaN values
validate("isInfinite", neg_nan);
validate("isFinite", neg_nan);
validate("isNaN", neg_nan);
}
return arr;
}

public static void main(String [] args) {
int res = 0;
short [] input1 = get_fp16_array(1024);
short [] input2 = get_fp16_array(1024);
short [] input3 = get_fp16_array(1024);

short [] special_values = {
32256, // NAN
31744, // +Inf
(short)-1024, // -Inf
0, // +0.0
(short)-32768, // -0.0
Float16 [] input1 = new Float16[SIZE];
Float16 [] input2 = new Float16[SIZE];
Float16 [] input3 = new Float16[SIZE];

// input1, input2, input3 contain the entire value range for FP16
IntStream.range(0, input1.length).forEach(i -> {input1[i] = Float16.valueOf((float)i);});
IntStream.range(0, input2.length).forEach(i -> {input2[i] = Float16.valueOf((float)i);});
IntStream.range(0, input2.length).forEach(i -> {input3[i] = Float16.valueOf((float)i);});

Float16 [] special_values = {
Float16.NaN, // NAN
Float16.POSITIVE_INFINITY, // +Inf
Float16.NEGATIVE_INFINITY, // -Inf
Float16.valueOf(0.0), // +0.0
Float16.valueOf(-0.0), // -0.0
};

for (int i = 0; i < 1000; i++) {
test_operations(input1, input2, input3);
test_operations(special_values, special_values, special_values);
test_unary_operations(input1);
test_binary_operations(input1, input2);
test_ternary_operations(input1, input2, input3);

test_unary_operations(special_values);
test_binary_operations(special_values, input1);
test_ternary_operations(special_values, input1, input2);

// The above functions test isFinite, isInfinite and isNaN for all possible finite FP16 values
// and infinite values. The below method tests these functions for all possible NaN values as well.
test_fin_inf_nan();
}
System.out.println("PASS");
}