diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu index 701a802080f65ea32b95402682dc46362ccf0966..7586e8458a69274d2e3d06f6e9bc0be1bbb51ddc 100644 --- a/paddle/fluid/operators/math/prelu.cu +++ b/paddle/fluid/operators/math/prelu.cu @@ -18,108 +18,76 @@ namespace paddle { namespace operators { namespace math { -static const int CUDA_NUM_THREADS = 1024; -static const int CUDA_MAX_NUM_BLOCKS = 65535; -inline static int GET_NUM_BLOCKS(const int N) { +#define CUDA_NUM_THREADS 1024 + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline static int PADDLE_GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } template __global__ void PReluChannelWiseKernel(const T *input, const T *alpha, - T *output, int channel, - size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const T *in = input + offset; - T *out = output + offset; - T scale = alpha[blockIdx.x % channel]; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - T x = in[i]; - out[i] = (x > 0) ? x : scale * x; + T *output, size_t channel_num, + size_t plane_size, size_t numel) { + size_t index; + CUDA_KERNEL_LOOP(index, numel) { + size_t temp = index / plane_size; + size_t channel_index = temp % channel_num; + T scale = alpha[channel_index]; + T x = input[index]; + output[index] = (x > 0) ? x : scale * x; } } template __global__ void PReluElementWiseKernel(const T *input, const T *alpha, - T *output, size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const T *in = input + offset; - const T *scale = alpha + offset; - T *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - T x = in[i]; - out[i] = (x > 0) ? x : scale[i] * x; + T *output, size_t spatial_size, + size_t numel) { + size_t index; + CUDA_KERNEL_LOOP(index, numel) { + size_t element_index = index % spatial_size; + T scale = alpha[element_index]; + T x = input[index]; + output[index] = (x > 0) ? x : scale * x; } } template __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, - size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - const T *in = input + offset; - T scale = *alpha; - T *out = output + offset; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - T x = in[i]; - out[i] = (x > 0) ? x : scale * x; + size_t numel) { + T scale = alpha[0]; + size_t index; + CUDA_KERNEL_LOOP(index, numel) { + T x = input[index]; + output[index] = (x > 0) ? x : scale * x; } } -template -static inline void PReluChannelWise(cudaStream_t stream, const T *input, - const T *alpha, T *output, - std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluChannelWiseKernel<<>>( - input, alpha, output, input_shape[1], spatial_size); -} - -template -static inline void PReluElementWise(cudaStream_t stream, const T *input, - const T *alpha, T *output, - std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluElementWiseKernel<<>>( - input, alpha, output, spatial_size); -} - -template -static inline void PReluScalar(cudaStream_t stream, const T *input, - const T *alpha, T *output, - std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluScalarKernel<<>>( - input, alpha, output, spatial_size); -} - template void PreluChannelWiseDirectCUDAFunctor::operator()( cudaStream_t stream, const T *input, const T *alpha, T *output, std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluChannelWiseKernel<<>>( - input, alpha, output, input_shape[1], spatial_size); + size_t plane_size = input_shape[2] * input_shape[3]; + size_t spatial_size = input_shape[1] * plane_size; + size_t numel = input_shape[0] * spatial_size; + PReluChannelWiseKernel<<>>(input, alpha, output, input_shape[1], + plane_size, numel); } template void PreluElementWiseDirectCUDAFunctor::operator()( cudaStream_t stream, const T *input, const T *alpha, T *output, std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluElementWiseKernel<<>>( - input, alpha, output, spatial_size); + size_t plane_size = input_shape[2] * input_shape[3]; + size_t spatial_size = input_shape[1] * plane_size; + size_t numel = input_shape[0] * spatial_size; + PReluElementWiseKernel<<>>(input, alpha, output, spatial_size, numel); } template @@ -127,11 +95,11 @@ void PreluScalarDirectCUDAFunctor::operator()(cudaStream_t stream, const T *input, const T *alpha, T *output, std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluScalarKernel<<>>( - input, alpha, output, spatial_size); + size_t plane_size = input_shape[2] * input_shape[3]; + size_t spatial_size = input_shape[1] * plane_size; + size_t numel = input_shape[0] * spatial_size; + PReluScalarKernel<<>>( + input, alpha, output, numel); } template class PreluChannelWiseDirectCUDAFunctor; diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index 3237c6d4cbf956aafb4046ea2ffa42efe62e7b28..c57aa6caab0cdaac975eb58a51c6604a35ead903 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -42,6 +42,7 @@ class PreluScalarDirectCUDAFunctor { void operator()(cudaStream_t stream, const T *input, const T *alpha, T *output, std::vector input_shape); }; + #endif } // namespace math diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 5408e7bf0d3abe28ab7368c9b43a471685d792f2..0d63558d97f0027f2197cc28b65aa3c8178359a4 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -42,10 +42,21 @@ class PReluOp : public framework::OperatorWithKernel { "equal to the number of channels, should be %d", x_dim[1]); } else if (mode == "element") { - PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == product(x_dim), - "For element-wise mode, size of weight Alpha must be " - "equal to the number of input, should be %d", - product(x_dim)); + auto alpha_dim = ctx->GetInputDim("Alpha"); + auto alpha_rank = alpha_dim.size(); + auto x_rank = x_dim.size(); + size_t x_product = 1; + size_t alpha_product = 1; + PADDLE_ENFORCE_EQ(alpha_rank, x_rank, + "For element-wise mode, rank of weight Alpha must be ", + "equal to the rank of input."); + for (int64_t i = x_rank - 1; i > 0; i--) { + x_product *= x_dim[i]; + alpha_product *= alpha_dim[i]; + } + PADDLE_ENFORCE_EQ(x_product, alpha_product, + "For element-wise mode, size of weight Alpha must be " + "equal to the number of input."); } else { PADDLE_THROW("Unkown mode %s", mode); } diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 4a26c98af8814a500e35cb2168097a43b16cef44..6d721e797ed39d9d781713e5c64c7881cb686075 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -20,11 +20,19 @@ limitations under the License. */ namespace paddle { namespace operators { -static const int CUDA_NUM_THREADS = 1024; -static const int CUDA_MAX_NUM_BLOCKS = 65535; - using Tensor = framework::Tensor; +#define CUDA_NUM_THREADS 1024 + +// CUDA: grid stride looping +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +inline static int PADDLE_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + template class CUDAPReluKernel : public framework::OpKernel { public: @@ -59,71 +67,47 @@ class CUDAPReluKernel : public framework::OpKernel { } }; -namespace prelu { -struct ElementWiseMode {}; -struct ChannelMode {}; -struct ScalarMode {}; -} /* namespace prelu */ - -template -struct AlphaFunctor { - HOSTDEVICE inline T operator()(const T* alpha, size_t channel, - size_t spatial_size, size_t idx) const {} -}; - -template -struct AlphaFunctor { - HOSTDEVICE inline T operator()(const T* alpha, size_t channel, - size_t spatial_size, size_t idx) const { - return alpha[blockIdx.x * spatial_size + idx]; - } -}; +enum PRELU_MODE { Element, Channel, Scalar }; template -struct AlphaFunctor { - HOSTDEVICE inline T operator()(const T* alpha, size_t channel, - size_t spatial_size, size_t idx) const { - return alpha[blockIdx.x % channel]; - } -}; - -template -struct AlphaFunctor { - HOSTDEVICE inline T operator()(const T* alpha, size_t channel, - size_t spatial_size, size_t idx) const { - return alpha[0]; - } -}; - -template -__global__ void PReluGradElementWiseKernel(const T* x_ptr, const T* y_ptr, - const T* alpha_ptr, const T* dy_ptr, - T* dx_ptr, T* dalpha_ptr, - size_t channel, - size_t spatial_size) { - size_t offset = blockIdx.x * spatial_size; - AlphaFunctor alpha_func; - - for (size_t i = threadIdx.x; i < spatial_size; i += blockDim.x) { - T y = y_ptr[offset + i]; - T x = x_ptr[offset + i]; - T dy = dy_ptr[offset + i]; - T alpha = alpha_func(alpha_ptr, channel, spatial_size, i); - if (dx_ptr != nullptr) dx_ptr[offset + i] = (y > 0) ? dy : alpha * dy; - if (dalpha_ptr != nullptr) dalpha_ptr[offset + i] = (x > 0) ? 0 : x * dy; +__global__ void PReluOpGradKernel(const T* x_ptr, const T* y_ptr, + const T* alpha_ptr, const T* dy_ptr, + T* dx_ptr, T* dalpha_ptr, size_t channel_num, + size_t plane_size, size_t spatial_size, + size_t numel, PRELU_MODE mode) { + size_t index; + CUDA_KERNEL_LOOP(index, numel) { + T scale; + if (mode == Element) { + size_t element_index = index % spatial_size; + scale = alpha_ptr[element_index]; + } else if (mode == Channel) { + size_t temp = index / plane_size; + size_t channel_index = temp % channel_num; + scale = alpha_ptr[channel_index]; + } else { + scale = alpha_ptr[0]; + } + T x = x_ptr[index]; + T dy = dy_ptr[index]; + if (dx_ptr != nullptr) dx_ptr[index] = (x > 0) ? dy : scale * dy; + if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > 0) ? 0 : x * dy; } } -template -class PreluGradElementwiseFunctor { +template +class PreluOpGradFunctor { public: void operator()(cudaStream_t stream, const T* x, const T* y, const T* alpha, - const T* dy, T* dx, T* dalpha, std::vector input_shape) { - size_t unroll = input_shape[0] * input_shape[1]; - size_t spatial_size = input_shape[2] * input_shape[3]; - CHECK_LT(unroll, CUDA_MAX_NUM_BLOCKS); - PReluGradElementWiseKernel<<>>( - x, y, alpha, dy, dx, dalpha, input_shape[1], spatial_size); + const T* dy, T* dx, T* dalpha, std::vector input_shape, + PRELU_MODE mode) { + size_t plane_size = input_shape[2] * input_shape[3]; + size_t spatial_size = plane_size * input_shape[1]; + size_t numel = spatial_size * input_shape[0]; + PReluOpGradKernel< + T><<>>( + x, y, alpha, dy, dx, dalpha, input_shape[1], plane_size, spatial_size, + numel, mode); } }; @@ -162,7 +146,7 @@ class CUDAPReluGradKernel : public framework::OpKernel { T* dalpha_tmp_ptr; Tensor dalpha_tmp; - if (mode == "element" || dalpha_ptr == nullptr) { + if (dalpha_ptr == nullptr) { dalpha_tmp_ptr = dalpha_ptr; } else { auto& dev_ctx = context.template device_context(); @@ -170,25 +154,24 @@ class CUDAPReluGradKernel : public framework::OpKernel { dalpha_tmp_ptr = dalpha_tmp.mutable_data(context.GetPlace()); } + PRELU_MODE m; if (mode == "element") { - PreluGradElementwiseFunctor prelu_grad; - prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, - dalpha_tmp_ptr, input_shape); + m = Element; } else if (mode == "channel") { - PreluGradElementwiseFunctor prelu_grad; - prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, - dalpha_tmp_ptr, input_shape); + m = Channel; } else { - PreluGradElementwiseFunctor prelu_grad; - prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, - dalpha_tmp_ptr, input_shape); + m = Scalar; } + PreluOpGradFunctor prelu_grad; + prelu_grad(stream, x_ptr, y_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, + input_shape, m); - if (mode == "element" || dalpha_tmp_ptr == nullptr) return; + if (dalpha_tmp_ptr == nullptr) return; std::vector reduce_dims; for (size_t i = 0; i < input_shape.size(); i++) { if (mode == "channel" && i == 1) continue; + if (mode == "element" && i != 0) continue; reduce_dims.push_back(i); } diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index 594f1cb3abe49c61ad7c490ebcd100a5c9ea6fb9..cfc0a2b6fb1128ee4460cbc669772c6257aad8ab 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -45,8 +45,10 @@ class PReluKernel : public framework::OpKernel { o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; } } else if (mode == "element") { + int temp = numel / dim[0]; for (i = 0; i < numel; i++) { - o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[i] * x_ptr[i]; + index = i % temp; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; } } else { for (i = 0; i < numel; i++) { @@ -64,12 +66,10 @@ class PReluGradKernel : public framework::OpKernel { auto* dx = context.Output(framework::GradVarName("X")); auto* dout = context.Input(framework::GradVarName("Out")); auto* dalpha = context.Output(framework::GradVarName("Alpha")); - auto* out = context.Input("Out"); auto* alpha = context.Input("Alpha"); const T* alpha_ptr = alpha->data(); const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); - const T* out_ptr = out->data(); std::string mode = context.Attr("mode"); int numel = x->numel(); auto dim = x->dims(); @@ -83,15 +83,18 @@ class PReluGradKernel : public framework::OpKernel { temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; dx_ptr[i] = - out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; } } else if (mode == "element") { + temp = numel / dim[0]; for (i = 0; i < numel; i++) { - dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[i] * dout_ptr[i]; + index = i % temp; + dx_ptr[i] = + x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; } } else { for (i = 0; i < numel; i++) { - dx_ptr[i] = out_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; + dx_ptr[i] = x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[0] * dout_ptr[i]; } } } @@ -105,15 +108,17 @@ class PReluGradKernel : public framework::OpKernel { for (i = 0; i < numel; i++) { temp = numel / (dim[0] * dim[1]); index = (i / temp) % dim[1]; - dalpha_ptr[index] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; } } else if (mode == "element") { + temp = numel / dim[0]; for (i = 0; i < numel; i++) { - dalpha_ptr[i] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + index = i % temp; + dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; } } else { for (i = 0; i < numel; i++) { - dalpha_ptr[0] += out_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + dalpha_ptr[0] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; } } } diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 876118245f1ab63de41f7d87db8d3ce4eeea57ba..7dc78270c21866866851a396c234bd6564064be4 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -66,6 +66,7 @@ __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, Ty reduce_var = init; for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) reduce_var = reducer(reduce_var, transformer(x[idx_x + idx_y])); + __syncthreads(); reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); @@ -113,6 +114,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); reduce_var = static_cast(reducer(reduce_var, transformer(x[idx_x]))); } + __syncthreads(); reduce_var = cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ebc60c27c45ab15f205a19db4e8fac2d723d9b53..dad3ec305543f8ac3496916a9d45edf28b3ab049 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10034,14 +10034,14 @@ def prelu(x, mode, param_attr=None, name=None): if mode == 'channel': alpha_shape = [1, x.shape[1], 1, 1] elif mode == 'element': - alpha_shape = x.shape + alpha_shape = x.shape[1:] dtype = helper.input_dtype(input_param_name='x') alpha = helper.create_parameter( attr=helper.param_attr, shape=alpha_shape, dtype='float32', is_bias=False, - default_initializer=Constant(1.0)) + default_initializer=Constant(0.25)) out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="prelu", diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 48a6b0577b6787d2e1231fdcbe6d2c1bb46414ed..190fa0f42aef47f5ed67ecdf3a6553d6ba35334d 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -37,7 +37,8 @@ class PReluTest(OpTest): alpha_np = np.random.rand(1, x_np.shape[1], 1, 1).astype("float32") self.inputs = {'X': x_np, 'Alpha': alpha_np} else: - alpha_np = np.random.rand(*x_np.shape).astype("float32") + alpha_np = np.random.rand(1, x_np.shape[1], x_np.shape[2], \ + x_np.shape[3]).astype("float32") self.inputs = {'X': x_np, 'Alpha': alpha_np} out_np = np.maximum(self.inputs['X'], 0.)