diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index af0ae3b6447be5bfa8b261bfbc2b6baafd7de65d..a291c1542d3e2b986a757d359e6f3854956b2dcb 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/enforce.h" @@ -190,6 +191,17 @@ struct FMinFunctor { } }; +template <> +struct FMinFunctor { + inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a, + const dtype::bfloat16 b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return static_cast(result); + } +}; + template <> struct FMinFunctor { inline HOSTDEVICE int operator()(const int a, const int b) const { diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 30e222663da1000e2be0573dca16ff8526dc8458..b69434c82da8c362c9fdb74edd19478cb35d5385 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -108,6 +108,7 @@ PD_REGISTER_KERNEL(fmin_grad, double, int, phi::dtype::float16, + phi::dtype::bfloat16, int64_t) {} PD_REGISTER_KERNEL(maximum_grad, diff --git a/paddle/phi/kernels/kps/elementwise_kernel.cu b/paddle/phi/kernels/kps/elementwise_kernel.cu index cbf811a9830b0c21946096b5a2d37b737657e3ef..550e579ac1159639b5ae128356310414a4e93a85 100644 --- a/paddle/phi/kernels/kps/elementwise_kernel.cu +++ b/paddle/phi/kernels/kps/elementwise_kernel.cu @@ -166,6 +166,7 @@ PD_REGISTER_KERNEL(fmin, double, int, float16, + bfloat16, int64_t) {} PD_REGISTER_KERNEL(heaviside, diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py index 19e43fa3d1d6c5acf508e6a72cfec8810db2f41b..1956d5b4fc433ce2da7e1f690f4003710581677e 100644 --- a/python/paddle/fluid/tests/unittests/test_fmin_op.py +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -243,6 +243,35 @@ class TestElementwiseFmin3Op(OpTest): self.check_grad(['X', 'Y'], 'Out') +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestFminBF16OP(OpTest): + def setUp(self): + self.op_type = "elementwise_fmin" + self.python_api = paddle.fmin + self.dtype = np.uint16 + x = np.random.uniform(1, 1, [13, 17]).astype("float32") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float32") + y = x + sgn * np.random.uniform(1, 1, [13, 17]).astype("float32") + out = np.fmin(x, y) + self.inputs = { + 'X': convert_float_to_uint16(x), + 'Y': convert_float_to_uint16(y), + } + self.outputs = {'Out': convert_float_to_uint16(out)} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place(place, ['X', 'Y'], 'Out') + + if __name__ == "__main__": paddle.enable_static() unittest.main()