From 4e62af80d57a4c7937a047629ccb130fbc070179 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Wed, 8 Sep 2021 11:42:52 +0800 Subject: [PATCH] Add FP16 PRelu (#35532) --- paddle/fluid/operators/math/prelu.cu | 12 +++-- paddle/fluid/operators/prelu_op.cu | 19 +++++--- .../fluid/tests/unittests/test_prelu_op.py | 45 ++++++++++++++++++- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu index 42c4c799c57..7c93d1725e9 100644 --- a/paddle/fluid/operators/math/prelu.cu +++ b/paddle/fluid/operators/math/prelu.cu @@ -33,7 +33,8 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha, size_t channel_index = temp % channel_num; T scale = alpha[channel_index]; T x = input[index]; - output[index] = (x > 0) ? x : scale * x; + T zero = static_cast(0); + output[index] = (x > zero) ? x : scale * x; } } @@ -45,7 +46,8 @@ __global__ void PReluElementWiseKernel(const T *input, const T *alpha, size_t element_index = index % spatial_size; T scale = alpha[element_index]; T x = input[index]; - output[index] = (x > 0) ? x : scale * x; + T zero = static_cast(0); + output[index] = (x > zero) ? x : scale * x; } } @@ -55,7 +57,8 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, T scale = alpha[0]; CUDA_KERNEL_LOOP(index, numel) { T x = input[index]; - output[index] = (x > 0) ? x : scale * x; + T zero = static_cast(0); + output[index] = (x > zero) ? x : scale * x; } } @@ -88,12 +91,15 @@ void PreluScalarDirectCUDAFunctor::operator()(gpuStream_t stream, } template class PreluChannelWiseDirectCUDAFunctor; +template class PreluChannelWiseDirectCUDAFunctor; template class PreluChannelWiseDirectCUDAFunctor; template class PreluElementWiseDirectCUDAFunctor; +template class PreluElementWiseDirectCUDAFunctor; template class PreluElementWiseDirectCUDAFunctor; template class PreluScalarDirectCUDAFunctor; +template class PreluScalarDirectCUDAFunctor; template class PreluScalarDirectCUDAFunctor; } // namespace math diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index ca01487549f..049217f2a9f 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -87,8 +87,9 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, } T x = x_ptr[index]; T dy = dy_ptr[index]; - if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy; - if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy; + T zero = static_cast(0); + if (dx_ptr != nullptr) dx_ptr[index] = (x > zero) ? dy : scale * dy; + if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > zero) ? zero : x * dy; } } @@ -112,9 +113,11 @@ class PreluOpGradFunctor { } }; -template struct IdentityFunctor { - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline T operator()(const T& x) const { + return x; + } }; template @@ -174,9 +177,9 @@ class CUDAPReluGradKernel : public framework::OpKernel { reduce_dims.push_back(i); } - TensorReduce>( + TensorReduce( dalpha_tmp, dalpha, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } }; @@ -184,10 +187,14 @@ class CUDAPReluGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( prelu, ops::CUDAPReluKernel, + ops::CUDAPReluKernel, ops::CUDAPReluKernel); REGISTER_OP_CUDA_KERNEL( prelu_grad, ops::CUDAPReluGradKernel, + ops::CUDAPReluGradKernel, ops::CUDAPReluGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index e0db1bab3ad..04862eba8a9 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -153,11 +153,12 @@ class TestNNPReluAPI(unittest.TestCase): class PReluTest(OpTest): def setUp(self): + self.init_dtype() self.init_input_shape() self.init_attr() self.op_type = "prelu" - x_np = np.random.uniform(-1, 1, self.x_shape) + x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) # Since zero point in prelu is not differentiable, avoid randomize # zero. x_np[np.abs(x_np) < 0.005] = 0.02 @@ -168,6 +169,7 @@ class PReluTest(OpTest): alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1]) else: alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) + alpha_np = alpha_np.astype(self.dtype) self.inputs = {'X': x_np, 'Alpha': alpha_np} @@ -184,6 +186,9 @@ class PReluTest(OpTest): assert out_np is not self.inputs['X'] self.outputs = {'Out': out_np} + def init_dtype(self): + self.dtype = np.float64 + def init_input_shape(self): self.x_shape = [2, 100, 3, 4] @@ -270,6 +275,44 @@ class TestModeElementRank6(PReluTest): self.attrs = {'mode': "element"} +def create_test_fp16_class(parent, + check_grad=True, + atol=1e-3, + max_relative_error=0.05): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestPReluFp16Case(parent): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=atol) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place) and check_grad: + self.check_grad_with_place( + place, ['X', 'Alpha'], + 'Out', + max_relative_error=max_relative_error) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16Op") + TestPReluFp16Case.__name__ = cls_name + globals()[cls_name] = TestPReluFp16Case + + +create_test_fp16_class(TestModeElt) +create_test_fp16_class(TestModeAllRank3) +create_test_fp16_class(TestModeAllRank6) +create_test_fp16_class(TestModeChannelRank3) +create_test_fp16_class(TestModeChannelRank6) +create_test_fp16_class(TestModeElementRank3) +create_test_fp16_class(TestModeElementRank6) + + def prelu_t(x, mode, param_attr=None, name=None): helper = fluid.layer_helper.LayerHelper('prelu', **locals()) alpha_shape = [1, x.shape[1], 1, 1] -- GitLab