From 4490e8af4e29cc2cf8933226e5b2c2b577d155b9 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Fri, 2 Apr 2021 12:13:04 +0800 Subject: [PATCH] add leaky_relu forward and backward in activation_op.cu (#31841) * add leaky_relu forward and backward in activation_op.cu --- paddle/fluid/operators/activation_op.cu | 250 +++++++++++++++++------- 1 file changed, 181 insertions(+), 69 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index c6d2fbccd8..04f329088f 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -42,6 +42,10 @@ template class BaseGPUFunctor { public: using ELEMENT_TYPE = T; + + using AttrPair = std::vector>; + + AttrPair GetAttrs() { return AttrPair(); } }; /* ========================================================================== */ @@ -57,42 +61,35 @@ class ReluGPUFunctor : public BaseGPUFunctor { // for relu forward when T is double __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type* x); + const typename CudaVecType::type in) { + // relu forward : out = max(x, 0) + return in > zero_ ? in : zero_; + } // when num % vecsize != 0 this func will be used - __device__ __forceinline__ T ComputeRemainder(const T x) { - return x > zero_ ? x : zero_; + __device__ __forceinline__ T ComputeRemainder(const T in) { + // relu forward : out = max(x, 0) + return in > zero_ ? in : zero_; } }; -template <> -__device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type* x) { -// relu forward : out = max(x, 0) -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - return __ldg(x) > zero_ ? __ldg(x) : zero_; -#else - return (*x) > zero_ ? (*x) : zero_; -#endif -} - template <> __device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type* xx) { - // relu forward : out = max(xx, 0) - return make_float4((xx->x > zero_) * (xx->x), (xx->y > zero_) * (xx->y), - (xx->z > zero_) * (xx->z), (xx->w > zero_) * (xx->w)); +ReluGPUFunctor::Compute(const CudaVecType::type in) { + // relu forward : out = max(in, 0) + return make_float4((in.x > zero_) * (in.x), (in.y > zero_) * (in.y), + (in.z > zero_) * (in.z), (in.w > zero_) * (in.w)); } template <> __device__ __forceinline__ CudaVecType::type -ReluGPUFunctor::Compute(const CudaVecType::type* in) { +ReluGPUFunctor::Compute(const CudaVecType::type in) { // relu forward : out = max(in, 0) #ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(__ldg(in), kzero), __ldg(in)); + return __hmul2(__hgt2(in, kzero), in); #else - const float2 xx = __half22float2(*in); + const float2 xx = __half22float2(in); return __floats2half2_rn((xx.x > 0.0f) * static_cast(xx.x), (xx.y > 0.0f) * static_cast(xx.y)); #endif @@ -112,8 +109,10 @@ class ReluGradGPUFunctor : public BaseGPUFunctor { // for relu backward when T is double __device__ __forceinline__ typename CudaVecType::type Compute( - const typename CudaVecType::type* out, - const typename CudaVecType::type* dout); + const typename CudaVecType::type out, + const typename CudaVecType::type dout) { + return out > zero_ ? dout : zero_; + } // when num % vecsize != 0 this func will be used __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) { @@ -124,44 +123,132 @@ class ReluGradGPUFunctor : public BaseGPUFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; -template <> -__device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type* out, - const CudaVecType::type* dout) { -// relu backward : dx = out > 0 ? dout : 0; -#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 - return __ldg(out) > zero_ ? __ldg(dout) : zero_; -#else - return (*out) > zero_ ? (*dout) : zero_; -#endif -} - template <> __device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type* out, - const CudaVecType::type* dout) { +ReluGradGPUFunctor::Compute(const CudaVecType::type out, + const CudaVecType::type dout) { // relu backward : dx = out > 0 ? dout : 0; - return make_float4((out->x > zero_) * (dout->x), (out->y > zero_) * (dout->y), - (out->z > zero_) * (dout->z), - (out->w > zero_) * (dout->w)); + return make_float4((out.x > zero_) * (dout.x), (out.y > zero_) * (dout.y), + (out.z > zero_) * (dout.z), (out.w > zero_) * (dout.w)); } template <> __device__ __forceinline__ CudaVecType::type -ReluGradGPUFunctor::Compute(const CudaVecType::type* out, - const CudaVecType::type* dout) { +ReluGradGPUFunctor::Compute(const CudaVecType::type out, + const CudaVecType::type dout) { // relu backward : dx = out > 0 ? dout : 0; #ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) const half2 kzero = __float2half2_rn(0.0f); - return __hmul2(__hgt2(__ldg(out), kzero), __ldg(dout)); + return __hmul2(__hgt2(out, kzero), dout); #else - const float2 xx = __half22float2(*out); - const float2 yy = __half22float2(*dout); + const float2 xx = __half22float2(out); + const float2 yy = __half22float2(dout); return __floats2half2_rn((xx.x > 0.0f) * static_cast(yy.x), (xx.y > 0.0f) * static_cast(yy.y)); #endif } +/* ========================================================================== */ +/* ======================== leaky relu forward ======================== + */ +template +class LeakyReluGPUFunctor : public BaseGPUFunctor { + private: + T zero_; + float alpha_; + + public: + LeakyReluGPUFunctor() { zero_ = static_cast(0.0f); } + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha_}}; + } + // leakyrelu forward : out = x > 0 ? x : x * alpha + __device__ __forceinline__ typename CudaVecType::type Compute( + const typename CudaVecType::type in) { + return in > zero_ ? in : static_cast(alpha_) * in; + } + + __device__ __forceinline__ T ComputeRemainder(const T in) { + // leakyrelu forward : out = x > 0 ? x : x * alpha + return in > zero_ ? in : static_cast(alpha_) * in; + } +}; + +template <> +__device__ __forceinline__ CudaVecType::type +LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { + // leakyrelu forward : out = x > 0 ? x : x * alpha + return make_float4((in.x > zero_) ? (in.x) : (in.x) * alpha_, + (in.y > zero_) ? (in.y) : (in.y) * alpha_, + (in.z > zero_) ? (in.z) : (in.z) * alpha_, + (in.w > zero_) ? (in.w) : (in.w) * alpha_); +} + +template <> +__device__ __forceinline__ CudaVecType::type +LeakyReluGPUFunctor::Compute(const CudaVecType::type in) { + // leakyrelu forward : out = x > 0 ? x : x * alpha + const float2 xx = __half22float2(in); + return __floats2half2_rn((xx.x > 0.0f) ? xx.x : xx.x * alpha_, + (xx.y > 0.0f) ? xx.y : xx.y * alpha_); +} +/* ========================================================================== */ + +/* =========================== leaky relu backward ======================= + */ +template +class LeakyReluGradGPUFunctor : public BaseGPUFunctor { + private: + T zero_; + float alpha_; + + public: + LeakyReluGradGPUFunctor() { zero_ = static_cast(0.0f); } + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha_}}; + } + + // for leaky relu backward when T is double + __device__ __forceinline__ typename CudaVecType::type Compute( + const typename CudaVecType::type in, + const typename CudaVecType::type dout) { + // leakyrelu backward : dx = x > 0 ? dout : alpha * dout + return in > zero_ ? dout : static_cast(alpha_) * dout; + } + + // when num % vecsize != 0 this func will be used + __device__ __forceinline__ T ComputeRemainder(const T in, const T dout) { + // leakyrelu backward : dx = x > 0 ? dout : alpha * dout + return in > zero_ ? dout : static_cast(alpha_) * dout; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } +}; + +template <> +__device__ __forceinline__ CudaVecType::type +LeakyReluGradGPUFunctor::Compute(const CudaVecType::type in, + const CudaVecType::type dout) { + // leakyrelu backward : dx = x > 0 ? dout : alpha * dout + return make_float4((in.x > zero_) ? (dout.x) : alpha_ * (dout.x), + (in.y > zero_) ? (dout.y) : alpha_ * (dout.y), + (in.z > zero_) ? (dout.z) : alpha_ * (dout.z), + (in.w > zero_) ? (dout.w) : alpha_ * (dout.w)); +} + +template <> +__device__ __forceinline__ CudaVecType::type LeakyReluGradGPUFunctor< + float16>::Compute(const CudaVecType::type in, + const CudaVecType::type dout) { + // leakyrelu backward : dx = x > 0 ? dout : alpha * dout + const float2 xx = __half22float2(in); + const float2 yy = __half22float2(dout); + return __floats2half2_rn((xx.x > 0.0f) ? yy.x : alpha_ * yy.x, + (xx.y > 0.0f) ? yy.y : alpha_ * yy.y); +} + /* ========================================================================== */ template @@ -176,14 +263,23 @@ __global__ void ActivationGradKernelVec(const T* forward_data, const T* dout, const VecType* in_forward = reinterpret_cast(forward_data); const VecType* in_dout = reinterpret_cast(dout); VecType* out = reinterpret_cast(dx); - + VecType forward_vec, dout_vec; + T in_data, dout_data; for (int i = idx; i < loop; i += stride) { - out[i] = functor.Compute((in_forward + i), (in_dout + i)); +#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 + forward_vec = __ldg(in_forward + i); + dout_vec = __ldg(in_dout + i); +#else + forward_vec = in_forward[i]; + dout_vec = in_dout[i]; +#endif + out[i] = functor.Compute(forward_vec, dout_vec); } while (idx == loop && tail) { - dx[num - tail] = - functor.ComputeRemainder(forward_data[num - tail], dout[num - tail]); + in_data = forward_data[num - tail]; + dout_data = dout[num - tail]; + dx[num - tail] = functor.ComputeRemainder(in_data, dout_data); --tail; } } @@ -199,9 +295,14 @@ __global__ void ActivationkernelVec(const T* src, T* dst, int num, int tail = num % vecsize; const VecType* in = reinterpret_cast(src); VecType* out = reinterpret_cast(dst); - + VecType x_vec; for (int i = idx; i < loop; i += stride) { - out[i] = functor.Compute((in + i)); +#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 + x_vec = __ldg(in + i); +#else + x_vec = in[i]; +#endif + out[i] = functor.Compute(x_vec); } while (idx == loop && tail) { @@ -231,6 +332,10 @@ class ActivationGPUKernel block = 256; #endif Functor functor; + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } constexpr int vecsize = CudaVecType::vecsize; int grid = max((num / vecsize + block - 1) / block, 1); auto stream = context.cuda_device_context().stream(); @@ -270,7 +375,12 @@ class ActivationGradGPUKernel #ifdef __HIPCC__ block = 256; #endif + Functor functor; + auto attrs = functor.GetAttrs(); + for (auto& attr : attrs) { + *attr.second = context.Attr(attr.first); + } constexpr int vecsize = CudaVecType::vecsize; int grid = max((numel / vecsize + block - 1) / block, 1); auto stream = context.cuda_device_context().stream(); @@ -300,12 +410,28 @@ namespace plat = paddle::platform; ops::grad_functor>, \ ops::ActivationGradKernel>); - FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); +#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, op_name, functor, \ + grad_functor) \ + REGISTER_OP_CUDA_KERNEL( \ + act_type, ops::ActivationGPUKernel>, \ + ops::ActivationGPUKernel>, \ + ops::ActivationGPUKernel>); \ + REGISTER_OP_CUDA_KERNEL( \ + act_type##_grad, ops::ActivationGradGPUKernel>, \ + ops::ActivationGradGPUKernel>, \ + ops::ActivationGradGPUKernel>); + /* ======================== leaky relu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor, - LeakyReluGradFunctor); +REGISTER_ACTIVATION_GPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluGPUFunctor, + LeakyReluGradGPUFunctor); REGISTER_OP_CUDA_KERNEL( leaky_relu_grad_grad, @@ -330,21 +456,7 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ -REGISTER_OP_CUDA_KERNEL( - relu, ops::ActivationGPUKernel>, - ops::ActivationGPUKernel>, - ops::ActivationGPUKernel>); - -REGISTER_OP_CUDA_KERNEL( - relu_grad, ops::ActivationGradGPUKernel>, - ops::ActivationGradGPUKernel>, - ops::ActivationGradGPUKernel>); +REGISTER_ACTIVATION_GPU_KERNEL(relu, Relu, ReluGPUFunctor, ReluGradGPUFunctor); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, -- GitLab