未验证 提交 8163faaa 编写于 作者: S superwinner1 提交者: GitHub

【Hackathon No.55】add fmin BF16 test (#53100)

* 'fmin'

* 'fix'

* 'fix'
上级 4001f7ae
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/complex.h" #include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -190,6 +191,17 @@ struct FMinFunctor<dtype::float16> { ...@@ -190,6 +191,17 @@ struct FMinFunctor<dtype::float16> {
} }
}; };
template <>
struct FMinFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return static_cast<dtype::bfloat16>(result);
}
};
template <> template <>
struct FMinFunctor<int> { struct FMinFunctor<int> {
inline HOSTDEVICE int operator()(const int a, const int b) const { inline HOSTDEVICE int operator()(const int a, const int b) const {
......
...@@ -108,6 +108,7 @@ PD_REGISTER_KERNEL(fmin_grad, ...@@ -108,6 +108,7 @@ PD_REGISTER_KERNEL(fmin_grad,
double, double,
int, int,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(maximum_grad, PD_REGISTER_KERNEL(maximum_grad,
......
...@@ -166,6 +166,7 @@ PD_REGISTER_KERNEL(fmin, ...@@ -166,6 +166,7 @@ PD_REGISTER_KERNEL(fmin,
double, double,
int, int,
float16, float16,
bfloat16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(heaviside, PD_REGISTER_KERNEL(heaviside,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -243,6 +243,35 @@ class TestElementwiseFmin3Op(OpTest): ...@@ -243,6 +243,35 @@ class TestElementwiseFmin3Op(OpTest):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and not support the bfloat16",
)
class TestFminBF16OP(OpTest):
def setUp(self):
self.op_type = "elementwise_fmin"
self.python_api = paddle.fmin
self.dtype = np.uint16
x = np.random.uniform(1, 1, [13, 17]).astype("float32")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float32")
y = x + sgn * np.random.uniform(1, 1, [13, 17]).astype("float32")
out = np.fmin(x, y)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y),
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X', 'Y'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册