From a71d72d921fc861051553c6d44b32bc9037706bc Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 29 Mar 2021 20:30:37 +0800 Subject: [PATCH] relu forward and backward with vectortype (#31869) --- paddle/fluid/operators/activation_op.cu | 286 +++++++++++++++++++++++- 1 file changed, 285 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 2033081af22..c6d2fbccd8e 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -10,8 +10,278 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using float16 = paddle::platform::float16; + +template +struct CudaVecType { + using type = T; + static constexpr int vecsize = 1; +}; + +template <> +struct CudaVecType { + using type = __half2; + static constexpr int vecsize = 2; +}; + +template <> +struct CudaVecType { + using type = float4; + static constexpr int vecsize = 4; +}; + +template +class BaseGPUFunctor { + public: + using ELEMENT_TYPE = T; +}; + +/* ========================================================================== */ + +/* =========================== relu forward ============================ */ +template +class ReluGPUFunctor : public BaseGPUFunctor { + private: + T zero_; + + public: + ReluGPUFunctor() { zero_ = static_cast(0.0f); } + + // for relu forward when T is double + __device__ __forceinline__ typename CudaVecType::type Compute( + const typename CudaVecType::type* x); + + // when num % vecsize != 0 this func will be used + __device__ __forceinline__ T ComputeRemainder(const T x) { + return x > zero_ ? x : zero_; + } +}; + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGPUFunctor::Compute(const CudaVecType::type* x) { +// relu forward : out = max(x, 0) +#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 + return __ldg(x) > zero_ ? __ldg(x) : zero_; +#else + return (*x) > zero_ ? (*x) : zero_; +#endif +} + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGPUFunctor::Compute(const CudaVecType::type* xx) { + // relu forward : out = max(xx, 0) + return make_float4((xx->x > zero_) * (xx->x), (xx->y > zero_) * (xx->y), + (xx->z > zero_) * (xx->z), (xx->w > zero_) * (xx->w)); +} + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGPUFunctor::Compute(const CudaVecType::type* in) { +// relu forward : out = max(in, 0) +#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const half2 kzero = __float2half2_rn(0.0f); + return __hmul2(__hgt2(__ldg(in), kzero), __ldg(in)); +#else + const float2 xx = __half22float2(*in); + return __floats2half2_rn((xx.x > 0.0f) * static_cast(xx.x), + (xx.y > 0.0f) * static_cast(xx.y)); +#endif +} +/* ========================================================================== */ + +/* =========================== relu backward ============================ + */ + +template +class ReluGradGPUFunctor : public BaseGPUFunctor { + private: + T zero_; + + public: + ReluGradGPUFunctor() { zero_ = static_cast(0.0f); } + + // for relu backward when T is double + __device__ __forceinline__ typename CudaVecType::type Compute( + const typename CudaVecType::type* out, + const typename CudaVecType::type* dout); + + // when num % vecsize != 0 this func will be used + __device__ __forceinline__ T ComputeRemainder(const T out, const T dout) { + // relu backward : dx = out > 0 ? dout : 0 + return out > zero_ ? dout : zero_; + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } +}; + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGradGPUFunctor::Compute(const CudaVecType::type* out, + const CudaVecType::type* dout) { +// relu backward : dx = out > 0 ? dout : 0; +#ifdef __HIPCC__ || __CUDA_ARCH__ >= 350 + return __ldg(out) > zero_ ? __ldg(dout) : zero_; +#else + return (*out) > zero_ ? (*dout) : zero_; +#endif +} + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGradGPUFunctor::Compute(const CudaVecType::type* out, + const CudaVecType::type* dout) { + // relu backward : dx = out > 0 ? dout : 0; + return make_float4((out->x > zero_) * (dout->x), (out->y > zero_) * (dout->y), + (out->z > zero_) * (dout->z), + (out->w > zero_) * (dout->w)); +} + +template <> +__device__ __forceinline__ CudaVecType::type +ReluGradGPUFunctor::Compute(const CudaVecType::type* out, + const CudaVecType::type* dout) { +// relu backward : dx = out > 0 ? dout : 0; +#ifdef __HIPCC__ || CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) + const half2 kzero = __float2half2_rn(0.0f); + return __hmul2(__hgt2(__ldg(out), kzero), __ldg(dout)); +#else + const float2 xx = __half22float2(*out); + const float2 yy = __half22float2(*dout); + return __floats2half2_rn((xx.x > 0.0f) * static_cast(yy.x), + (xx.y > 0.0f) * static_cast(yy.y)); +#endif +} + +/* ========================================================================== */ + +template +__global__ void ActivationGradKernelVec(const T* forward_data, const T* dout, + T* dx, int num, Functor functor) { + using VecType = typename CudaVecType::type; + constexpr int vecsize = CudaVecType::vecsize; + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + int loop = num / vecsize; + int tail = num % vecsize; + const VecType* in_forward = reinterpret_cast(forward_data); + const VecType* in_dout = reinterpret_cast(dout); + VecType* out = reinterpret_cast(dx); + + for (int i = idx; i < loop; i += stride) { + out[i] = functor.Compute((in_forward + i), (in_dout + i)); + } + + while (idx == loop && tail) { + dx[num - tail] = + functor.ComputeRemainder(forward_data[num - tail], dout[num - tail]); + --tail; + } +} + +template +__global__ void ActivationkernelVec(const T* src, T* dst, int num, + Functor functor) { + constexpr int vecsize = CudaVecType::vecsize; + using VecType = typename CudaVecType::type; + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + int loop = num / vecsize; + int tail = num % vecsize; + const VecType* in = reinterpret_cast(src); + VecType* out = reinterpret_cast(dst); + + for (int i = idx; i < loop; i += stride) { + out[i] = functor.Compute((in + i)); + } + + while (idx == loop && tail) { + dst[num - tail] = functor.ComputeRemainder(src[num - tail]); + --tail; + } +} + +template +class ActivationGPUKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = nullptr; + framework::Tensor* out = nullptr; + ExtractActivationTensor(context, &in_x, &out); + auto& dev_ctx = context.template device_context(); + + int num = in_x->numel(); + const T* input_data = in_x->data(); + T* output_data = out->mutable_data(dev_ctx.GetPlace(), + static_cast(num * sizeof(T))); + + int block = 512; +#ifdef __HIPCC__ + block = 256; +#endif + Functor functor; + constexpr int vecsize = CudaVecType::vecsize; + int grid = max((num / vecsize + block - 1) / block, 1); + auto stream = context.cuda_device_context().stream(); + ActivationkernelVec<<>>( + input_data, output_data, num, functor); + } +}; + +template +class ActivationGradGPUKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEMENT_TYPE; + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor *x, *out, *d_out; + framework::Tensor* d_x = nullptr; + x = out = d_out = nullptr; + ExtractActivationGradTensor(context, &x, &out, &d_out, + &d_x); + int numel = d_out->numel(); + auto& dev_ctx = context.template device_context(); + auto* dx_data = d_x->mutable_data( + dev_ctx.GetPlace(), static_cast(numel * sizeof(T))); + auto* dout_data = d_out->data(); + + auto* forward_data = dout_data; + if (static_cast(Functor::FwdDeps()) == static_cast(kDepOut)) { + // Only need forward output Out + forward_data = out->data(); + } else if (static_cast(Functor::FwdDeps()) == + static_cast(kDepX)) { + // Only need forward input X + forward_data = x->data(); + } + + int block = 512; +#ifdef __HIPCC__ + block = 256; +#endif + Functor functor; + constexpr int vecsize = CudaVecType::vecsize; + int grid = max((numel / vecsize + block - 1) / block, 1); + auto stream = context.cuda_device_context().stream(); + ActivationGradKernelVec<<>>( + forward_data, dout_data, dx_data, numel, functor); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; namespace plat = paddle::platform; @@ -60,7 +330,21 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ -REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluCUDAFunctor, ReluGradFunctor); +REGISTER_OP_CUDA_KERNEL( + relu, ops::ActivationGPUKernel>, + ops::ActivationGPUKernel>, + ops::ActivationGPUKernel>); + +REGISTER_OP_CUDA_KERNEL( + relu_grad, ops::ActivationGradGPUKernel>, + ops::ActivationGradGPUKernel>, + ops::ActivationGradGPUKernel>); REGISTER_OP_CUDA_KERNEL( relu_grad_grad, -- GitLab