未验证 提交 98c427e2 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: sigmoid & sqrt & softplus & square (#40004)

* add activ

* refine unittest

* refine unittest

* refine unittest

* refine unittest

* refine code
上级 b4eb413e
......@@ -1509,7 +1509,9 @@ namespace plat = paddle::platform;
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<double>>, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
ops::functor<plat::float16>>, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::bfloat16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
......@@ -1517,7 +1519,9 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<double>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
ops::grad_functor<plat::float16>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>);
#define REGISTER_ACTIVATION_CUDA_KERNEL_INT(act_type, op_name, functor, \
grad_functor) \
......@@ -1531,7 +1535,9 @@ namespace plat = paddle::platform;
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext, \
ops::functor<int64_t>>, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::float16>>); \
ops::functor<plat::float16>>, \
ops::ActivationCudaKernel<plat::CUDADeviceContext, \
ops::functor<plat::bfloat16>>); \
REGISTER_OP_CUDA_KERNEL( \
act_type##_grad, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
......@@ -1543,7 +1549,9 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<int64_t>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
ops::grad_functor<plat::float16>>, \
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>);
/* ======================== leaky relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, CudaLeakyReluFunctor,
......@@ -1650,7 +1658,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SigmoidDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidGradGradFunctor<double>>,
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
ops::SigmoidGradGradFunctor<plat::float16>>);
ops::SigmoidGradGradFunctor<plat::float16>>,
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
ops::SigmoidGradGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
sigmoid_triple_grad,
......@@ -1659,7 +1669,10 @@ REGISTER_OP_CUDA_KERNEL(
ops::SigmoidTripleGradKernel<paddle::platform::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<double>>,
ops::SigmoidTripleGradKernel<plat::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<plat::float16>>);
ops::SigmoidTripleGradFunctor<plat::float16>>,
ops::SigmoidTripleGradKernel<
plat::CUDADeviceContext,
ops::SigmoidTripleGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== tanh register ============================ */
......@@ -1696,7 +1709,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
ops::SqrtGradGradFunctor<plat::float16>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
......@@ -1726,6 +1741,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::bfloat16>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -32,6 +33,12 @@ class MPTypeTrait<platform::float16> {
using Type = float;
};
template <>
class MPTypeTrait<platform::bfloat16> {
public:
using Type = float;
};
} // namespace details
} // namespace operators
} // namespace paddle
......@@ -266,7 +266,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
cudaMemcpyDeviceToDevice, stream));
#endif
} else {
T factor = static_cast<T>(1.0f - dropout_prob);
using MT = typename details::MPTypeTrait<T>::Type;
MT factor = static_cast<MT>(1.0f - dropout_prob);
std::vector<const framework::Tensor*> ins = {&x};
std::vector<framework::Tensor*> outs = {y};
auto functor = phi::funcs::ScaleFunctor<T>(factor);
......
......@@ -310,6 +310,10 @@ HOSTDEVICE inline bool(isfinite)(const bfloat16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
HOSTDEVICE inline bfloat16(abs)(const bfloat16& a) {
return bfloat16(std::abs(static_cast<float>(a)));
}
inline std::ostream& operator<<(std::ostream& os, const bfloat16& a) {
os << static_cast<float>(a);
return os;
......
......@@ -183,6 +183,34 @@ class TestSigmoid(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSigmoidBF16(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32)
out = 1 / (1 + np.exp(-x))
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestSilu(TestActivation):
def setUp(self):
self.op_type = "silu"
......@@ -945,6 +973,34 @@ class TestSqrt(TestActivation, TestParameter):
self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSqrtBF16(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.init_dtype()
np.random.seed(1023)
x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32)
out = np.sqrt(x)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
......@@ -2195,6 +2251,34 @@ class TestSquare(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.007)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSquareBF16(OpTest):
def setUp(self):
self.op_type = "square"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(0.1, 1, [11, 17]).astype(np.float32)
out = np.square(x)
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(convert_float_to_uint16(x))
}
self.outputs = {'Out': convert_float_to_uint16(out)}
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.5)
class TestPow(TestActivation):
def setUp(self):
self.op_type = "pow"
......@@ -2433,6 +2517,35 @@ class TestSoftplus(TestActivation):
self.check_grad(['X'], 'Out')
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestSoftplusBF16(OpTest):
def setUp(self):
self.op_type = "softplus"
self.init_dtype()
beta = 2
threshold = 15
np.random.seed(1024)
x = np.random.uniform(-1, 1, [10, 12]).astype(np.float32)
out = ref_softplus(x, beta, threshold)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'beta': beta, "threshold": threshold}
self.outputs = {'Out': convert_float_to_uint16(out)}
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=0.05)
class TestSoftplusAPI(unittest.TestCase):
# test paddle.nn.Softplus, paddle.nn.functional.softplus
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册