From f6d9ec2701ebda3917cd5a61ecdac61cd4332fb0 Mon Sep 17 00:00:00 2001 From: xiaohemaikoo <35183089+xiaohemaikoo@users.noreply.github.com> Date: Tue, 6 Sep 2022 06:00:52 +0200 Subject: [PATCH] elementwise op support fp16 (#45496) --- paddle/phi/kernels/elementwise_kernel.cc | 9 +++-- .../phi/kernels/funcs/elementwise_functor.h | 26 +++++++++++++- .../kernels/gpu/elementwise_grad_kernel.cu | 5 +++ .../impl/elementwise_grad_kernel_impl.h | 30 ++++++++++++++++ paddle/phi/kernels/kps/elementwise_kernel.cu | 26 +++++++++++--- .../test_elementwise_heaviside_op.py | 31 +++++++++++++++++ .../unittests/test_elementwise_mod_op.py | 17 ++++++++++ .../unittests/test_elementwise_pow_op.py | 34 +++++++++++++++++-- .../fluid/tests/unittests/test_fmax_op.py | 30 ++++++++++++++++ .../fluid/tests/unittests/test_fmin_op.py | 26 ++++++++++++++ python/paddle/tensor/math.py | 18 +++++----- 11 files changed, 233 insertions(+), 19 deletions(-) diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index 0c208d2db0..ba58bae003 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -237,7 +237,8 @@ PD_REGISTER_KERNEL(remainder, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} PD_REGISTER_KERNEL( floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {} PD_REGISTER_KERNEL(elementwise_heaviside, @@ -247,7 +248,8 @@ PD_REGISTER_KERNEL(elementwise_heaviside, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} PD_REGISTER_KERNEL(elementwise_pow, KPS, ALL_LAYOUT, @@ -255,7 +257,8 @@ PD_REGISTER_KERNEL(elementwise_pow, float, double, int, - int64_t) {} + int64_t, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index bfbfd28abb..a4636565cf 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -524,6 +524,19 @@ struct RemainderFunctor< } }; +template <> +struct RemainderFunctor { + inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, + const dtype::float16 b) const { + float b_float = static_cast(b); + float res = fmod(static_cast(a), b_float); + // Accoding to #PR26732: in dividen % divsor + // remainder shall have the same sign as divsor. + if ((res != 0.0f) && ((res < 0.0f) != (b_float < 0.0f))) res += b_float; + return static_cast(res); + } +}; + template struct InverseRemainderFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { @@ -547,7 +560,7 @@ struct InverseRemainderFunctor< template struct ElementwiseHeavisideFunctor { inline HOSTDEVICE T operator()(const T a, const T b) const { - return a == static_cast(0) ? b : static_cast(a > 0); + return a == static_cast(0) ? b : static_cast(a > static_cast(0)); } }; @@ -592,5 +605,16 @@ struct ElementwisePowFunctor { return std::pow(a, b); } }; + +template <> +struct ElementwisePowFunctor { + inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a, + const dtype::float16 b) const { + float f_a = static_cast(a); + float f_b = static_cast(b); + return static_cast(std::pow(f_a, f_b)); + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 4921cf884c..a802fe12c6 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -35,6 +35,7 @@ void MaximumGradKernel(const Context& dev_ctx, DenseTensor* dx, DenseTensor* dy) { const auto place = dev_ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { std::vector ins = {&x, &y, &dout}; GetGradXAndYOut( @@ -96,6 +97,7 @@ PD_REGISTER_KERNEL(fmax_grad, float, double, int, + phi::dtype::float16, int64_t) {} PD_REGISTER_KERNEL(fmin_grad, @@ -105,6 +107,7 @@ PD_REGISTER_KERNEL(fmin_grad, float, double, int, + phi::dtype::float16, int64_t) {} PD_REGISTER_KERNEL(maximum_grad, @@ -136,6 +139,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_grad, float, double, int, + phi::dtype::float16, int64_t) {} PD_REGISTER_KERNEL(elementwise_pow_grad, @@ -145,4 +149,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad, float, double, int, + phi::dtype::float16, int64_t) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index da74280b26..7759de509a 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/common/complex.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" @@ -753,6 +754,20 @@ struct PowGradDX { } }; +template <> +struct PowGradDX { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + float tmp_y = static_cast(y); + float tmp_dout = static_cast(dout); + float tmp_x = static_cast(x); + float result = tmp_dout * tmp_y * std::pow(tmp_x, tmp_y - 1.0f); + return static_cast(result); + } +}; + template struct PowGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -766,6 +781,21 @@ struct PowGradDY { } }; +template <> +struct PowGradDY { + HOSTDEVICE dtype::float16 operator()(dtype::float16 x, + dtype::float16 y, + dtype::float16 out, + dtype::float16 dout) const { + float tmp_y = static_cast(y); + float tmp_dout = static_cast(dout); + float tmp_x = static_cast(x); + float tmp_pow = std::pow(tmp_x, tmp_y); + float result = tmp_pow * tmp_dout * std::log(tmp_x); + return static_cast(result); + } +}; + template void ElementwisePowGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index 9da65c590b..346c836814 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -32,6 +32,7 @@ void MaximumKernel(const Context& dev_ctx, int axis = -1; MaximumRawKernel(dev_ctx, x, y, axis, out); } + // Create the definition of Minimum DEFINE_CUDA_ELEMENTWISE_OP(Minimum) template @@ -92,11 +93,25 @@ using bfloat16 = phi::dtype::bfloat16; using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; -PD_REGISTER_KERNEL( - fmax, KPS, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(fmax, + KPS, + ALL_LAYOUT, + phi::FMaxKernel, + float, + double, + int, + float16, + int64_t) {} -PD_REGISTER_KERNEL( - fmin, KPS, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(fmin, + KPS, + ALL_LAYOUT, + phi::FMinKernel, + float, + double, + int, + float16, + int64_t) {} PD_REGISTER_KERNEL(maximum_raw, KPS, @@ -125,6 +140,7 @@ PD_REGISTER_KERNEL(remainder_raw, float, double, int, + float16, int64_t) {} PD_REGISTER_KERNEL(floor_divide_raw, KPS, @@ -139,6 +155,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_raw, float, double, int, + float16, int64_t) {} PD_REGISTER_KERNEL(elementwise_pow_raw, KPS, @@ -147,5 +164,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw, float, double, int, + float16, int64_t) {} #endif diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py index 7d5dd02b67..7789f872d4 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_heaviside_op.py @@ -18,6 +18,13 @@ from op_test import OpTest import paddle +def Heaviside_grad(x, y, dout): + tmp = np.zeros(x.shape).astype("float16") + dx = np.multiply(tmp, dout) + dy = np.multiply(np.equal(x, 0), dout).astype("float16") + return dx, dy + + class TestElementwiseOp(OpTest): def setUp(self): @@ -152,6 +159,30 @@ class TestHeavisideAPI_int32(TestHeavisideAPI_float64): self.dtype = "int32" +class TestHeavisideAPI_float16(OpTest): + + def setUp(self): + self.dtype = np.float16 + self.op_type = "elementwise_heaviside" + self.python_api = paddle.heaviside + self.inputs = { + 'X': np.random.uniform(1, 2, [20, 5]).astype("float16"), + 'Y': np.random.uniform(1, 2, [20, 5]).astype("float16") + } + self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Y'], + 'Out', + user_defined_grads=Heaviside_grad( + self.inputs['X'], self.inputs['Y'], + 1 / self.inputs['X'].size), + check_eager=True) + + class TestHeavisideError(unittest.TestCase): def test_input(self): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py index eeccd8b976..491da7ad99 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py @@ -89,6 +89,23 @@ class TestElementwiseModOpFloat(TestElementwiseModOp): self.check_output(check_eager=False) +class TestElementwiseModOpFp16(TestElementwiseModOp): + + def init_dtype(self): + self.dtype = np.float16 + + def init_input_output(self): + self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype) + self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype) + self.out = np.mod(self.x, self.y) + + def test_check_output(self): + if self.attrs['axis'] == -1: + self.check_output(check_eager=True) + else: + self.check_output(check_eager=False) + + class TestElementwiseModOpDouble(TestElementwiseModOpFloat): def init_dtype(self): diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py index 904b9fe06d..921bbd93ec 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py @@ -20,6 +20,12 @@ import paddle.fluid as fluid import paddle +def pow_grad(x, y, dout): + dx = dout * y * np.power(x, (y - 1)) + dy = dout * np.log(x) * np.power(x, y) + return dx, dy + + class TestElementwisePowOp(OpTest): def setUp(self): @@ -194,7 +200,6 @@ class TestElementwisePowGradOpInt(unittest.TestCase): # dy = dout * log(x) * pow(x, y) self.grad_y = (self.grad_res * np.log(self.x) * (self.x**self.y)).astype("int") - print(self.grad_res, self.grad_x, self.grad_y) def test_grad(self): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True}) @@ -205,7 +210,6 @@ class TestElementwisePowGradOpInt(unittest.TestCase): with fluid.dygraph.guard(place): x = fluid.dygraph.to_variable(self.x, zero_copy=False) y = fluid.dygraph.to_variable(self.y, zero_copy=False) - print(x, y) x.stop_gradient = False y.stop_gradient = False res = x**y @@ -216,5 +220,31 @@ class TestElementwisePowGradOpInt(unittest.TestCase): fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False}) +class TestElementwisePowOpFP16(OpTest): + + def setUp(self): + self.op_type = "elementwise_pow" + self.python_api = paddle.pow + self.inputs = { + 'X': np.random.uniform(1, 2, [20, 5]).astype("float16"), + 'Y': np.random.uniform(1, 2, [20, 5]).astype("float16") + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + if hasattr(self, 'attrs'): + self.check_output(check_eager=False) + else: + self.check_output(check_eager=True) + + def test_check_grad(self): + self.check_grad(['X', 'Y'], + 'Out', + user_defined_grads=pow_grad(self.inputs['X'], + self.inputs['Y'], + 1 / self.inputs['X'].size), + check_eager=True) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py index 593f3dd257..7bc0a96321 100644 --- a/python/paddle/fluid/tests/unittests/test_fmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -209,3 +209,33 @@ class TestElementwiseFmax2Op(OpTest): max_relative_error=0.005, no_grad_set=set('Y'), check_eager=True) + + +class TestElementwiseFmax3Op(OpTest): + """TestElementwiseFmax3Op""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmax" + self.python_api = paddle.fmax + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float16") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float16") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float16") + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output(check_eager=True) + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index 888ff2c8af..dc99838f23 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -213,6 +213,32 @@ class TestElementwiseFmin2Op(OpTest): check_eager=True) +class TestElementwiseFmin3Op(OpTest): + """TestElementwiseFmin2Op""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmin" + self.python_api = paddle.fmin + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(1, 1, [13, 17]).astype("float16") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float16") + y = x + sgn * np.random.uniform(1, 1, [13, 17]).astype("float16") + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output(check_eager=True) + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out', check_eager=True) + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index e91c5ac3ae..c5b995454a 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -352,7 +352,7 @@ def pow(x, y, name=None): Args: - x (Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. + x (Tensor): An N-D Tensor, the data type is float16, float32, float64, int32 or int64. y (float|int|Tensor): If it is an N-D Tensor, its data type should be the same as `x`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -762,8 +762,8 @@ def remainder(x, y, name=None): ``paddle.remainder`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . Args: - x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. - y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -1003,8 +1003,8 @@ def fmax(x, y, name=None): ``paddle.fmax`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . Args: - x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. - y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -1066,8 +1066,8 @@ def fmin(x, y, name=None): ``paddle.fmin`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . Args: - x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. - y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + x (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float16, float32, float64, int32, int64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4696,8 +4696,8 @@ def heaviside(x, y, name=None): ``paddle.heaviside`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. Args: - x (Tensor): The input tensor of Heaviside step function, it's data type should be float32, float64, int32 or int64. - y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float32, float64, int32 or int64. + x (Tensor): The input tensor of Heaviside step function, it's data type should be float16, float32, float64, int32 or int64. + y (Tensor): The tensor that determines a Heaviside step function, it's data type should be float16, float32, float64, int32 or int64. name (str, optional): Name for the operation (optional, default is None). Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. Returns: -- GitLab