From 05163e1d325ae4c3791b80233fceaa7453c79bd9 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 1 Jul 2020 10:40:02 +0800 Subject: [PATCH] fix bug of prelu when rank not equal 4, test=develop (#25067) (#25235) * fix bug of prelu when rank not equal 4, test=develop * fix prelu inference, test=develop * fix api, test=develop * fix shape when mode is chennel, test=develop * remove debug code, test=develop * add unittest, test=develop --- .../tensorrt/plugin/prelu_op_plugin.cu | 27 +++---- paddle/fluid/operators/math/prelu.cu | 36 ++++------ paddle/fluid/operators/math/prelu.h | 7 +- paddle/fluid/operators/prelu_op.cc | 14 ++++ paddle/fluid/operators/prelu_op.cu | 32 +++++---- python/paddle/fluid/dygraph/nn.py | 5 +- python/paddle/fluid/layers/nn.py | 14 +++- .../fluid/tests/unittests/test_prelu_op.py | 72 +++++++++++++++++-- 8 files changed, 146 insertions(+), 61 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index 1bde3c16d0..f1e11b6fba 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -13,8 +13,10 @@ // limitations under the License. #include + #include #include + #include "glog/logging.h" #include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" @@ -55,24 +57,23 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, // const float *alpha = reinterpret_cast(alpha_.get().values); const float *alpha = p_gpu_weight_; float *output = reinterpret_cast(outputs)[0]; - - std::vector input_shape; - input_shape.push_back(batch_size); + int numel = 1; for (int i = 0; i < input_dims.nbDims; i++) { - input_shape.push_back(input_dims.d[i]); + numel *= input_dims.d[i]; } if (mode_ == "channel") { operators::math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; - prelu_channel_wise(stream, input, alpha, output, input_shape); + prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], + input_dims.d[1], numel); } else if (mode_ == "element") { operators::math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; - prelu_element_wise(stream, input, alpha, output, input_shape); + prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel); } else { operators::math::PreluScalarDirectCUDAFunctor prelu_scalar; - prelu_scalar(stream, input, alpha, output, input_shape); + prelu_scalar(stream, input, alpha, output, numel); } return cudaGetLastError() != cudaSuccess; } @@ -133,23 +134,23 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, const float *alpha = p_gpu_weight_; const float *input = static_cast(inputs[0]); float *output = static_cast(outputs[0]); - - std::vector input_shape; + int numel = 1; for (int i = 0; i < input_dims.nbDims; i++) { - input_shape.push_back(input_dims.d[i]); + numel *= input_dims.d[i]; } if (mode_ == "channel") { operators::math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; - prelu_channel_wise(stream, input, alpha, output, input_shape); + prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], + input_dims.d[1], numel); } else if (mode_ == "element") { operators::math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; - prelu_element_wise(stream, input, alpha, output, input_shape); + prelu_element_wise(stream, input, alpha, output, input_dims.d[0], numel); } else { operators::math::PreluScalarDirectCUDAFunctor prelu_scalar; - prelu_scalar(stream, input, alpha, output, input_shape); + prelu_scalar(stream, input, alpha, output, numel); } return cudaGetLastError() != cudaSuccess; } diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu index 7586e8458a..af2996a4ac 100644 --- a/paddle/fluid/operators/math/prelu.cu +++ b/paddle/fluid/operators/math/prelu.cu @@ -21,8 +21,8 @@ namespace math { #define CUDA_NUM_THREADS 1024 // CUDA: grid stride looping -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ +#define CUDA_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ i += blockDim.x * gridDim.x) inline static int PADDLE_GET_BLOCKS(const int N) { @@ -33,7 +33,6 @@ template __global__ void PReluChannelWiseKernel(const T *input, const T *alpha, T *output, size_t channel_num, size_t plane_size, size_t numel) { - size_t index; CUDA_KERNEL_LOOP(index, numel) { size_t temp = index / plane_size; size_t channel_index = temp % channel_num; @@ -47,7 +46,6 @@ template __global__ void PReluElementWiseKernel(const T *input, const T *alpha, T *output, size_t spatial_size, size_t numel) { - size_t index; CUDA_KERNEL_LOOP(index, numel) { size_t element_index = index % spatial_size; T scale = alpha[element_index]; @@ -60,7 +58,6 @@ template __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, size_t numel) { T scale = alpha[0]; - size_t index; CUDA_KERNEL_LOOP(index, numel) { T x = input[index]; output[index] = (x > 0) ? x : scale * x; @@ -70,34 +67,27 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, template void PreluChannelWiseDirectCUDAFunctor::operator()( cudaStream_t stream, const T *input, const T *alpha, T *output, - std::vector input_shape) { - size_t plane_size = input_shape[2] * input_shape[3]; - size_t spatial_size = input_shape[1] * plane_size; - size_t numel = input_shape[0] * spatial_size; + size_t batch_size, size_t channel, size_t numel) { PReluChannelWiseKernel<<>>(input, alpha, output, input_shape[1], - plane_size, numel); + stream>>>(input, alpha, output, channel, + numel / batch_size / channel, numel); } template -void PreluElementWiseDirectCUDAFunctor::operator()( - cudaStream_t stream, const T *input, const T *alpha, T *output, - std::vector input_shape) { - size_t plane_size = input_shape[2] * input_shape[3]; - size_t spatial_size = input_shape[1] * plane_size; - size_t numel = input_shape[0] * spatial_size; +void PreluElementWiseDirectCUDAFunctor::operator()(cudaStream_t stream, + const T *input, + const T *alpha, T *output, + size_t batch_size, + size_t numel) { PReluElementWiseKernel<<>>(input, alpha, output, spatial_size, numel); + stream>>>(input, alpha, output, numel / batch_size, + numel); } template void PreluScalarDirectCUDAFunctor::operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, - std::vector input_shape) { - size_t plane_size = input_shape[2] * input_shape[3]; - size_t spatial_size = input_shape[1] * plane_size; - size_t numel = input_shape[0] * spatial_size; + T *output, size_t numel) { PReluScalarKernel<<>>( input, alpha, output, numel); } diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index c57aa6caab..93c7035d44 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include + #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -26,21 +27,21 @@ template class PreluChannelWiseDirectCUDAFunctor { public: void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, std::vector input_shape); + T *output, size_t batch_size, size_t channel, size_t numel); }; template class PreluElementWiseDirectCUDAFunctor { public: void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, std::vector input_shape); + T *output, size_t batch_size, size_t numel); }; template class PreluScalarDirectCUDAFunctor { public: void operator()(cudaStream_t stream, const T *input, const T *alpha, - T *output, std::vector input_shape); + T *output, size_t numel); }; #endif diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index c822c4b789..8a18843a97 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/prelu_op.h" + #include #include @@ -43,10 +44,23 @@ class PReluOp : public framework::OperatorWithKernel { "equal to the number of channels of input(x). But " "recevied alpha's size: %d, x_dim[1]: %d", product(ctx->GetInputDim("Alpha")), x_dim[1])); + auto x_rank = x_dim.size(); + PADDLE_ENFORCE_GE(x_rank, 2, + platform::errors::InvalidArgument( + "For mode 'channel', rank of input X must be " + "equal or larger than 2. But recevied X's " + "rank: %d", + x_rank)); } else if (mode == "element") { auto alpha_dim = ctx->GetInputDim("Alpha"); auto alpha_rank = alpha_dim.size(); auto x_rank = x_dim.size(); + PADDLE_ENFORCE_GE(x_rank, 1, + platform::errors::InvalidArgument( + "For mode 'element', rank of input X must be " + "equal or larger than 2. But recevied X's " + "rank: %d", + x_rank)); PADDLE_ENFORCE_EQ( alpha_rank, x_rank, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 7e46a23483..2e51b00b98 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -11,6 +11,7 @@ limitations under the License. */ #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/prelu.h" #include "paddle/fluid/operators/prelu_op.h" @@ -49,20 +50,22 @@ class CUDAPReluKernel : public framework::OpKernel { int numel = x->numel(); auto dim = x->dims(); - std::vector input_shape = framework::vectorize(dim); + + VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] + << ", numel:" << numel; if (mode == "channel") { math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, - alpha_ptr, o_ptr, input_shape); + alpha_ptr, o_ptr, dim[0], dim[1], numel); } else if (mode == "element") { math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; prelu_element_wise(context.cuda_device_context().stream(), x_ptr, - alpha_ptr, o_ptr, input_shape); + alpha_ptr, o_ptr, dim[0], numel); } else { math::PreluScalarDirectCUDAFunctor prelu_scalar; prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr, - o_ptr, input_shape); + o_ptr, numel); } } }; @@ -75,7 +78,6 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, size_t channel_num, size_t plane_size, size_t spatial_size, size_t numel, PRELU_MODE mode) { - size_t index; CUDA_KERNEL_LOOP(index, numel) { T scale; if (mode == Element) { @@ -99,14 +101,18 @@ template class PreluOpGradFunctor { public: void operator()(cudaStream_t stream, const T* x, const T* alpha, const T* dy, - T* dx, T* dalpha, std::vector input_shape, + T* dx, T* dalpha, const framework::DDim& input_dims, PRELU_MODE mode) { - size_t plane_size = input_shape[2] * input_shape[3]; - size_t spatial_size = plane_size * input_shape[1]; - size_t numel = spatial_size * input_shape[0]; + size_t numel = 1; + for (size_t i = 0; i < input_dims.size(); ++i) { + numel *= input_dims[i]; + } + size_t plane_size = numel / input_dims[0] / input_dims[1]; + size_t spatial_size = numel / input_dims[0]; + PReluOpGradKernel< T><<>>( - x, alpha, dy, dx, dalpha, input_shape[1], plane_size, spatial_size, + x, alpha, dy, dx, dalpha, input_dims[1], plane_size, spatial_size, numel, mode); } }; @@ -161,13 +167,13 @@ class CUDAPReluGradKernel : public framework::OpKernel { m = Scalar; } PreluOpGradFunctor prelu_grad; - prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, - input_shape, m); + prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, dim, + m); if (dalpha_tmp_ptr == nullptr) return; std::vector reduce_dims; - for (size_t i = 0; i < input_shape.size(); i++) { + for (size_t i = 0; i < dim.size(); i++) { if (mode == "channel" && i == 1) continue; if (mode == "element" && i != 0) continue; reduce_dims.push_back(i); diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 6b79b1564d..a705d2a4b2 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -2262,7 +2262,10 @@ class PRelu(layers.Layer): assert isinstance( channel, int), "channel argument is required when mode is 'channel'." - self._alpha_shape = [1, channel, 1, 1] + #NOTE(zhiqiu): The _alpha_shape should be [1, channel] + [1] * len(input_shape[2:]), not [1, channel, 1, 1]. + # However, the suffix 1 in the list is useless, since the tensor is viewed as one demension array during kernel calculation. + # And, input_shape is not required when mode is 'channel', so it is simplified. + self._alpha_shape = [1, channel] elif mode == 'element': assert isinstance(input_shape, ( list, tuple diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 029dc16099..05cdacb0cd 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10661,10 +10661,20 @@ def prelu(x, mode, param_attr=None, name=None): if mode not in ['all', 'channel', 'element']: raise ValueError('mode should be one of all, channel, element.') alpha_shape = [1] + # NOTE(): The input of this API should be ``N,C,...`` format, + # which means x.shape[0] is batch_size and x.shape[0] is channel. if mode == 'channel': - alpha_shape = [1, x.shape[1], 1, 1] + assert len( + x.shape + ) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" + #NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). + # To be consistent with Prelu, it is simplified. + alpha_shape = [1, x.shape[1]] elif mode == 'element': - alpha_shape = [1, x.shape[1], x.shape[2], x.shape[3]] + assert len( + x.shape + ) >= 1, "The size of input shape should be equal or larger than 1 in prelu() when mode is 'element'" + alpha_shape = [1] + list(x.shape)[1:] dtype = helper.input_dtype(input_param_name='x') alpha = helper.create_parameter( attr=helper.param_attr, diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 2f44fc44b5..398ad9aa69 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -51,12 +51,18 @@ class PReluTest(OpTest): if self.attrs == {'mode': "all"}: alpha_np = np.random.uniform(-1, -0.5, (1)) elif self.attrs == {'mode': "channel"}: - alpha_np = np.random.uniform(-1, -0.5, (1, x_np.shape[1], 1, 1)) + alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1]]) else: - alpha_np = np.random.uniform(-1, -0.5, \ - (1, x_np.shape[1], x_np.shape[2], x_np.shape[3])) + alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) + self.inputs = {'X': x_np, 'Alpha': alpha_np} + # NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100] to [1, 100, 1, 1] since np operands could not be broadcast together with shapes (2,100,3,4) (1,100) + if self.attrs == {'mode': "channel"}: + self.inputs['Alpha'] = np.reshape( + self.inputs['Alpha'], + [1, self.x_shape[1]] + [1] * len(self.x_shape[2:])) + out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.inputs['Alpha'] @@ -64,7 +70,7 @@ class PReluTest(OpTest): self.outputs = {'Out': out_np} def init_input_shape(self): - self.x_shape = (2, 100, 3, 4) + self.x_shape = [2, 100, 3, 4] def init_attr(self): self.attrs = {'mode': "channel"} @@ -81,7 +87,7 @@ class PReluTest(OpTest): ) class TestModeAll(PReluTest): def init_input_shape(self): - self.x_shape = (2, 3, 4, 5) + self.x_shape = [2, 3, 4, 5] def init_attr(self): self.attrs = {'mode': "all"} @@ -89,7 +95,61 @@ class TestModeAll(PReluTest): class TestModeElt(PReluTest): def init_input_shape(self): - self.x_shape = (3, 2, 5, 10) + self.x_shape = [3, 2, 5, 10] + + def init_attr(self): + self.attrs = {'mode': "element"} + + +@skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" +) +class TestModeAllRank3(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 200, 3] + + def init_attr(self): + self.attrs = {'mode': "all"} + + +@skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" +) +class TestModeAllRank6(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 2, 3, 4, 5, 6] + + def init_attr(self): + self.attrs = {'mode': "all"} + + +class TestModeChannelRank3(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 200, 3] + + def init_attr(self): + self.attrs = {'mode': "channel"} + + +class TestModeChannelRank6(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 100, 2, 2, 2, 2] + + def init_attr(self): + self.attrs = {'mode': "channel"} + + +class TestModeElementRank3(PReluTest): + def init_input_shape(self): + self.x_shape = [3, 10, 10] + + def init_attr(self): + self.attrs = {'mode': "element"} + + +class TestModeElementRank6(PReluTest): + def init_input_shape(self): + self.x_shape = [3, 2, 2, 4, 5, 2] def init_attr(self): self.attrs = {'mode': "element"} -- GitLab