From 5970871a647035289eeec12c1abcfdbf6835255a Mon Sep 17 00:00:00 2001 From: Zhaolong Xing Date: Wed, 5 Aug 2020 13:49:09 +0800 Subject: [PATCH] add eltwise clip cuda impl. (#25689) test=develop --- paddle/fluid/operators/clip_op.h | 35 +++++++++++---- .../elementwise/elementwise_op_function.h | 44 +++++++++++++++++++ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 56e65b6f66..a8485a148b 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -25,17 +25,23 @@ namespace operators { using framework::Tensor; using platform::Transform; +#ifdef __NVCC__ +template +__global__ void ClipCudaKernel(const T* input, T* out, int num, + UnaryOperation op) { + int idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx < num) { + out[idx] = op(input[idx]); + } +} +#endif + template class ClipFunctor { public: explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} HOSTDEVICE T operator()(const T& x) const { - if (x < min_) - return min_; - else if (x > max_) - return max_; - else - return x; + return x < min_ ? min_ : x > max_ ? max_ : x; } private: @@ -97,9 +103,20 @@ class ClipKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); const T* x_data = x->data(); int64_t numel = x->numel(); - Transform trans; - trans(context.template device_context(), x_data, - x_data + numel, out_data, ClipFunctor(min, max)); + if (platform::is_gpu_place(context.GetPlace())) { +#ifdef __NVCC__ + int threads = 256; + int blocks = (numel + threads - 1) / threads; + ClipCudaKernel><<< + blocks, threads, 0, + context.template device_context() + .stream()>>>(x_data, out_data, numel, ClipFunctor(min, max)); +#endif + } else { + Transform trans; + trans(context.template device_context(), x_data, + x_data + numel, out_data, ClipFunctor(min, max)); + } } else if (x_var->IsType()) { auto* x = context.Input("X"); auto* out = context.Output("Out"); diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 364fe773c7..206eeea87f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -197,6 +197,40 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, } #ifdef __NVCC__ +template +__global__ void ElementwiseKernel(const T *x, const T *y, OutType *out, int pre, + int n, int post, int total, Functor func) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + int idx = tid / post % n; + if (tid < total) { + out[tid] = func(x[tid], y[idx]); + } +} + +template +void ComputeElementwiseCUDA(const framework::Tensor *x, + const framework::Tensor *y, framework::Tensor *z, + int pre, int n, int post, + const platform::CUDADeviceContext &ctx, + Functor func, const bool is_xsize_larger = true) { + const T *x_data = x->data(); + const T *y_data = y->data(); + OutType *out_data = z->mutable_data(ctx.GetPlace()); + + int numel = pre * n * post; + int threads = 256; + int blocks = (numel + threads - 1) / threads; + if (is_xsize_larger) { + ElementwiseKernel<<>>( + x_data, y_data, out_data, pre, n, post, numel, func); + } else { + ElementwiseKernel<<>>( + y_data, x_data, out_data, pre, n, post, numel, func); + } +} + template __global__ void CommonForwardBroadcastCUDAKernel( const int *x_strides_array, const int *y_strides_array, @@ -1908,6 +1942,16 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); return; } + + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + ComputeElementwiseCUDA( + x, y, z, pre, n, post, + ctx.template device_context(), func, + is_xsize_larger); +#endif + return; + } if (post == 1) { functor.RunRowWise(n, pre); return; -- GitLab