diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index d3d8cf176de67bfbd4f8c95b780f6f0284903da7..fde05759a9040efcb7470fda0d6ecc59837cd090 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -80,6 +80,5 @@ class ClipOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker, clip_grad, ops::ClipOpGrad); -REGISTER_OP_CPU_KERNEL(clip, - ops::ClipKernel); +REGISTER_OP_CPU_KERNEL(clip, ops::ClipKernel); REGISTER_OP_CPU_KERNEL(clip_grad, ops::ClipGradKernel); diff --git a/paddle/operators/clip_op.cu b/paddle/operators/clip_op.cu index 7e9c6c23c2c46336a115ec36cdcba003d4667c24..3a61841a56b30729bbf3db7f8e0cd41e009581eb 100644 --- a/paddle/operators/clip_op.cu +++ b/paddle/operators/clip_op.cu @@ -14,60 +14,6 @@ #include "paddle/operators/clip_op.h" -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ - i += blockDim.x * gridDim.x) - -namespace paddle { -namespace operators { - -using framework::LoDTensor; - -template -__global__ void ClipGradientKernel(const int N, const T min, const T max, - const T* Y, const T* dY, T* dX) { - CUDA_1D_KERNEL_LOOP(i, N) { - if (Y[i] > min && Y[i] < max) { - dX[i] = dY[i]; - } else { - dX[i] = 0; - } - } -} - -template -class ClipGradientOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto max = context.Attr("max"); - auto min = context.Attr("min"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); - if (d_x != nullptr) { - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); - auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); - - int N = d_x->dims()[0]; - int D = d_x->dims()[1]; - int block = 512; - int grid = (N * D + block - 1) / block; - ClipGradientKernel<<< - grid, block, 0, reinterpret_cast( - context.device_context()) - .stream()>>>(count, min, max, x_data, d_out_data, - d_x_data); - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(clip, - ops::ClipKernel); -REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradientOpCUDAKernel); +REGISTER_OP_GPU_KERNEL(clip, ops::ClipKernel); +REGISTER_OP_GPU_KERNEL(clip_grad, ops::ClipGradKernel); diff --git a/paddle/operators/clip_op.h b/paddle/operators/clip_op.h index 47bfe1b7f8ef84e4896b6a5fef839d582adc9599..5d05959129eaf1e4fd44e2761e265242750650b2 100644 --- a/paddle/operators/clip_op.h +++ b/paddle/operators/clip_op.h @@ -16,57 +16,61 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" namespace paddle { namespace operators { -using framework::LoDTensor; +using framework::Tensor; +using platform::Transform; -template -using EigenTensor = framework::EigenTensor; +template +class ClipFunctor { + public: + explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x) const { + if (x < min_) + return min_; + else if (x > max_) + return max_; + else + return x; + } + + private: + T min_; + T max_; +}; + +template +class ClipGradFunctor { + public: + explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {} + HOSTDEVICE T operator()(const T& x, const T& y) const { + if (y > min_ && y < max_) + return x; + else + return 0; + } -template -void ClipFunction(const framework::ExecutionContext& context) { - auto max = context.op().Attr("max"); - auto min = context.op().Attr("min"); - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - auto x_tensor = EigenTensor::From(*x); - auto out_tensor = EigenTensor::From(*out); - auto place = context.GetEigenDevice(); - out_tensor.device(place) = x_tensor.cwiseMin(max).cwiseMax(min); -} + private: + T min_; + T max_; +}; -template +template class ClipKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - int rank = context.Input("X")->dims().size(); - switch (rank) { - case 1: - ClipFunction(context); - break; - case 2: - ClipFunction(context); - break; - case 3: - ClipFunction(context); - break; - case 4: - ClipFunction(context); - break; - case 5: - ClipFunction(context); - break; - case 6: - ClipFunction(context); - break; - default: - PADDLE_THROW( - "PadOp only support tensors with no more than 6 dimensions."); - } + auto max = context.Attr("max"); + auto min = context.Attr("min"); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + const T* x_data = x->data(); + int numel = x->numel(); + Transform(context.device_context(), x_data, x_data + numel, out_data, + ClipFunctor(min, max)); } }; @@ -74,24 +78,18 @@ template class ClipGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto max = context.op().Attr("max"); - auto min = context.op().Attr("min"); - auto* d_out = context.Input(framework::GradVarName("Out")); - auto* d_x = context.Output(framework::GradVarName("X")); + auto max = context.Attr("max"); + auto min = context.Attr("min"); + auto* d_out = context.Input(framework::GradVarName("Out")); + auto* d_x = context.Output(framework::GradVarName("X")); if (d_x != nullptr) { - auto* x = context.Input("X"); - auto dims = d_x->dims(); - int64_t count = d_out->numel(); + auto* x = context.Input("X"); + int64_t numel = d_out->numel(); auto d_x_data = d_x->mutable_data(context.GetPlace()); - auto d_out_data = d_out->data(); - auto x_data = x->data(); - for (int i = 0; i < count; ++i) { - if (x_data[i] > min && x_data[i] < max) { - d_x_data[i] = d_out_data[i]; - } else { - d_x_data[i] = 0; - } - } + const T* d_out_data = d_out->data(); + const T* x_data = x->data(); + Transform(context.device_context(), d_out_data, d_out_data + numel, + x_data, d_x_data, ClipGradFunctor(min, max)); } } };