diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 62c55c4f5578ac6e620c0a4ac7846a14209dd2a1..ccb08b245a4696865b46f555b1ef2500bd39aadd 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -79,10 +79,10 @@ x, \qquad \text{if} \ x >= 0 $$ The input `X` can carry the LoD (Level of Details) information, or not. And the output shares the LoD information with input `X`. -There are modes: +There are modes: all: all elements share same weight channel: elements in a channel share same weight - element: each element has a weight + element: each element has a weight )DOC"); AddAttr("mode", "The mode for inputs to share weights.") .SetDefault("all"); @@ -113,7 +113,7 @@ class PReluGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType(ctx.Input("X")->type(), - platform::CPUPlace()); + ctx.device_context()); } }; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 36b5259ae5106914f5668625cad535ebc8aa72ec..998768db0c0bbb41e5f7871c21376ec9680dc8d2 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -14,11 +14,15 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/prelu_op.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/platform/cuda_primitives.h" namespace paddle { namespace operators { +static const int CUDA_NUM_THREADS = 1024; +static const int CUDA_MAX_NUM_BLOCKS = 65535; + using Tensor = framework::Tensor; template @@ -55,6 +59,145 @@ class CUDAPReluKernel : public framework::OpKernel { } }; +namespace prelu { +struct ElementWiseMode {}; +struct ChannelMode {}; +struct ScalarMode {}; +} /* namespace prelu */ + +template +struct AlphaFunctor { + HOSTDEVICE inline T operator()(const T* alpha, size_t channel, + size_t spatial_size, size_t idx) const {} +}; + +template +struct AlphaFunctor { + HOSTDEVICE inline T operator()(const T* alpha, size_t channel, + size_t spatial_size, size_t idx) const { + return alpha[blockIdx.x * spatial_size + idx]; + } +}; + +template +struct AlphaFunctor { + HOSTDEVICE inline T operator()(const T* alpha, size_t channel, + size_t spatial_size, size_t idx) const { + return alpha[blockIdx.x % channel]; + } +}; + +template +struct AlphaFunctor { + HOSTDEVICE inline T operator()(const T* alpha, size_t channel, + size_t spatial_size, size_t idx) const { + return alpha[0]; + } +}; + +template +__global__ void PReluGradElementWiseKernel(const T* x_ptr, const T* y_ptr, + const T* alpha_ptr, const T* dy_ptr, + T* dx_ptr, T* dalpha_ptr, + size_t channel, + size_t spatial_size) { + size_t offset = blockIdx.x * spatial_size; + AlphaFunctor alpha_func; + + for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { + T y = y_ptr[offset + i]; + T x = x_ptr[offset + i]; + T dy = dy_ptr[offset + i]; + T alpha = alpha_func(alpha_ptr, channel, spatial_size, i); + if (dx_ptr != nullptr) dx_ptr[offset + i] = (y > 0) ? dy : alpha * dy; + if (dalpha_ptr != nullptr) dalpha_ptr[offset + i] = (x > 0) ? 0 : x * dy; + } +} + +template +class PreluGradElementwiseFunctor { + public: + void operator()(cudaStream_t stream, const T* x, const T* y, const T* alpha, + const T* dy, T* dx, T* dalpha, std::vector input_shape) { + size_t unroll = input_shape[0] * input_shape[1]; + size_t spatial_size = input_shape[2] * input_shape[3]; + CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); + PReluGradElementWiseKernel<<>>( + x, y, alpha, dy, dx, dalpha, input_shape[1], spatial_size); + } +}; + +template +struct IdentityFunctor { + HOSTDEVICE inline T operator()(const T& x) const { return x; } +}; + +template +class CUDAPReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Input("Out"); + auto* alpha = context.Input("Alpha"); + auto* dx = context.Output(framework::GradVarName("X")); + auto* dy = context.Input(framework::GradVarName("Out")); + auto* dalpha = context.Output(framework::GradVarName("Alpha")); + + const T* x_ptr = x->data(); + const T* y_ptr = y->data(); + const T* alpha_ptr = alpha->data(); + const T* dy_ptr = dy->data(); + T* dx_ptr = dx ? dx->mutable_data(context.GetPlace()) : nullptr; + T* dalpha_ptr = + dalpha ? dalpha->mutable_data(context.GetPlace()) : nullptr; + + if (!dx && !dalpha) return; + + auto& mode = context.Attr("mode"); + + int numel = x->numel(); + auto dim = x->dims(); + std::vector input_shape = framework::vectorize2int(dim); + auto stream = context.cuda_device_context().stream(); + + T* dalpha_tmp_ptr; + Tensor dalpha_tmp; + if (mode == "element" || dalpha_ptr == nullptr) { + dalpha_tmp_ptr = dalpha_ptr; + } else { + auto& dev_ctx = context.template device_context(); + dalpha_tmp = context.AllocateTmpTensor(dim, dev_ctx); + dalpha_tmp_ptr = dalpha_tmp.mutable_data(context.GetPlace()); + } + + if (mode == "element") { + PreluGradElementwiseFunctor prelu_grad; + prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, + dalpha_tmp_ptr, input_shape); + } else if (mode == "channel") { + PreluGradElementwiseFunctor prelu_grad; + prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, + dalpha_tmp_ptr, input_shape); + } else { + PreluGradElementwiseFunctor prelu_grad; + prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, + dalpha_tmp_ptr, input_shape); + } + + if (mode == "element" || dalpha_tmp_ptr == nullptr) return; + + std::vector reduce_dims; + for (size_t i = 0; i < input_shape.size(); i++) { + if (mode == "channel" && i == 1) continue; + reduce_dims.push_back(i); + } + + TensorReduce>( + dalpha_tmp, dalpha, reduce_dims, static_cast(0), cub::Sum(), + IdentityFunctor(), stream); + } +}; + } // namespace operators } // namespace paddle @@ -62,3 +205,7 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( prelu, ops::CUDAPReluKernel, ops::CUDAPReluKernel); +REGISTER_OP_CUDA_KERNEL( + prelu_grad, + ops::CUDAPReluGradKernel, + ops::CUDAPReluGradKernel);