未验证 提交 481511a6 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4 No.61】remainder 算子FP16/BF16单测完善 (#52920)

上级 b1333175
......@@ -563,6 +563,20 @@ struct RemainderFunctor<dtype::float16> {
}
};
template <>
struct RemainderFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float b_float = static_cast<float>(b);
float res = fmod(static_cast<float>(a), b_float);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0.0f) && ((res < 0.0f) != (b_float < 0.0f))) res += b_float;
return static_cast<dtype::bfloat16>(res);
}
};
template <typename T, typename Enable = void>
struct InverseRemainderFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
......
......@@ -117,7 +117,8 @@ PD_REGISTER_KERNEL(remainder,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow,
......
......@@ -157,7 +157,8 @@ PD_REGISTER_KERNEL(remainder_raw,
double,
int,
float16,
int64_t) {}
int64_t,
bfloat16) {}
PD_REGISTER_KERNEL(floor_divide_raw,
KPS,
ALL_LAYOUT,
......
......@@ -16,10 +16,15 @@ import random
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import (
OpTest,
convert_float_to_uint16,
convert_uint16_to_float,
)
import paddle
from paddle import fluid
from paddle.fluid import core
class TestElementwiseModOp(OpTest):
......@@ -106,14 +111,17 @@ class TestElementwiseModOpFloat(TestElementwiseModOp):
self.check_output()
class TestElementwiseModOpFp16(TestElementwiseModOp):
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestElementwiseModFP16Op(TestElementwiseModOp):
def init_dtype(self):
self.dtype = np.float16
def init_input_output(self):
self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
def test_check_output(self):
if self.attrs['axis'] == -1:
......@@ -122,6 +130,83 @@ class TestElementwiseModOpFp16(TestElementwiseModOp):
self.check_output()
class TestElementwiseModFP16Op_ZeroDim1(TestElementwiseModFP16Op):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, []).astype(np.float16)
self.y = np.random.uniform(0, 1000, []).astype(np.float16)
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
class TestElementwiseModFP16Op_ZeroDim2(TestElementwiseModFP16Op):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(np.float16)
self.y = np.random.uniform(0, 1000, []).astype(np.float16)
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
class TestElementwiseModFP16Op_ZeroDim3(TestElementwiseModFP16Op):
def init_input_output(self):
self.x = np.random.uniform(0, 10000, []).astype(np.float16)
self.y = np.random.uniform(0, 1000, [10, 10]).astype(np.float16)
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestElementwiseModBF16Op(OpTest):
def init_kernel_type(self):
self.use_mkldnn = False
def init_input_output(self):
self.x = np.random.uniform(0, 10000, [10, 10]).astype(np.float32)
self.x = convert_uint16_to_float(convert_float_to_uint16(self.x))
self.y = np.random.uniform(0, 1000, [10, 10]).astype(np.float32)
self.y = convert_uint16_to_float(convert_float_to_uint16(self.y))
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
def setUp(self):
self.op_type = "elementwise_mod"
self.python_api = paddle.remainder
self.public_python_api = paddle.remainder
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.inputs = {
'X': convert_float_to_uint16(
OpTest.np_dtype_to_fluid_dtype(self.x)
),
'Y': convert_float_to_uint16(
OpTest.np_dtype_to_fluid_dtype(self.y)
),
}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def init_dtype(self):
self.dtype = np.uint16
def init_axis(self):
pass
class TestElementwiseModBF16Op_ZeroDim1(TestElementwiseModBF16Op):
def init_input(self):
self.x = np.random.uniform(0, 10000, []).astype("float32")
self.x = convert_uint16_to_float(convert_float_to_uint16(self.x))
self.y = np.random.uniform(0, 1000, []).astype("float32")
self.y = convert_uint16_to_float(convert_float_to_uint16(self.y))
self.out = np.fmod(self.y + np.fmod(self.x, self.y), self.y)
class TestElementwiseModOpDouble(TestElementwiseModOpFloat):
def init_dtype(self):
self.dtype = np.float64
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册