From 3f2a665a0a23a0b0e0d472ef0699322e5e80a9a8 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 30 Nov 2021 13:15:43 +0800 Subject: [PATCH] support data_format='NHWC' for prelu channel mode (#37019) * support data_format='NHWC' for prelu channel mode --- .../inference/tensorrt/convert/prelu_op.cc | 11 +- .../tensorrt/plugin/prelu_op_plugin.cu | 6 +- .../tensorrt/plugin/prelu_op_plugin.h | 22 ++- paddle/fluid/operators/math/prelu.cu | 33 +++- paddle/fluid/operators/math/prelu.h | 3 +- .../fluid/operators/mkldnn/prelu_mkldnn_op.cc | 19 ++- paddle/fluid/operators/prelu_op.cc | 36 ++++- paddle/fluid/operators/prelu_op.cu | 33 ++-- paddle/fluid/operators/prelu_op.h | 68 +++++--- python/paddle/fluid/layers/nn.py | 29 +++- .../ir/inference/test_mkldnn_prelu_op.py | 14 +- .../ir/inference/test_trt_convert_prelu.py | 83 ++++++---- .../tests/unittests/test_imperative_layers.py | 5 +- .../fluid/tests/unittests/test_prelu_op.py | 145 ++++++++++++++++-- python/paddle/nn/functional/activation.py | 32 +++- python/paddle/nn/layer/activation.py | 16 +- 16 files changed, 425 insertions(+), 130 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index 94f5708e03..a883d2b5bb 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -34,6 +34,11 @@ class PReluOpConverter : public OpConverter { auto* input = engine_->GetITensor(op_desc.Input("X")[0]); // Get attrs std::string mode = BOOST_GET_CONST(std::string, op_desc.GetAttr("mode")); + std::string data_format = "NCHW"; + if (op_desc.HasAttr("data_format")) { + data_format = + BOOST_GET_CONST(std::string, op_desc.GetAttr("data_format")); + } auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]); auto* alpha_tensor = alpha_var->GetMutable(); @@ -47,7 +52,7 @@ class PReluOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic( - alpha_data, alpha_tensor_temp->numel(), mode); + alpha_data, alpha_tensor_temp->numel(), mode, data_format); layer = engine_->AddDynamicPlugin(&input, input_num, plugin); } else { #if IS_TRT_VERSION_GE(7000) @@ -74,8 +79,8 @@ class PReluOpConverter : public OpConverter { layer = TRT_ENGINE_ADD_LAYER(engine_, ParametricReLU, *input, *alpha_layer_output); #else - plugin::PReluPlugin* plugin = - new plugin::PReluPlugin(alpha_data, alpha_tensor_temp->numel(), mode); + plugin::PReluPlugin* plugin = new plugin::PReluPlugin( + alpha_data, alpha_tensor_temp->numel(), mode, data_format); layer = engine_->AddPlugin(&input, input_num, plugin); #endif } diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index 5533fb0af3..1ea2b8b5f6 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -69,10 +69,11 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, } if (mode_ == "channel") { + bool channel_last = data_format_ == "NHWC"; operators::math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], - input_dims.d[1], numel); + input_dims.d[1], channel_last, numel); } else if (mode_ == "element") { operators::math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; @@ -168,10 +169,11 @@ int PReluPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, } if (mode_ == "channel") { + bool channel_last = data_format_ == "NHWC"; operators::math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; prelu_channel_wise(stream, input, alpha, output, input_dims.d[0], - input_dims.d[1], numel); + input_dims.d[1], channel_last, numel); } else if (mode_ == "element") { operators::math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index c61b07e22d..e0a77de6f5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -32,11 +32,12 @@ class PReluPlugin : public PluginTensorRT { std::vector weight_; float* p_gpu_weight_; std::string mode_; + std::string data_format_; public: size_t getSerializationSize() const TRT_NOEXCEPT override { return getBaseSerializationSize() + SerializedSize(mode_.c_str()) + - SerializedSize(weight_); + SerializedSize(data_format_.c_str()) + SerializedSize(weight_); } // TRT will call this func when we need to serialize the configuration of @@ -46,11 +47,12 @@ class PReluPlugin : public PluginTensorRT { serializeBase(buffer); SerializeValue(&buffer, weight_); SerializeValue(&buffer, mode_.c_str()); + SerializeValue(&buffer, data_format_.c_str()); } PReluPlugin(const float* weight, const int weight_num, - std::string const& mode) - : mode_(mode) { + std::string const& mode, std::string const& data_format) + : mode_(mode), data_format_(data_format) { weight_.resize(weight_num); std::copy(weight, weight + weight_num, weight_.data()); } @@ -63,13 +65,17 @@ class PReluPlugin : public PluginTensorRT { const char* prelu_mode; DeserializeValue(&serialData, &serialLength, &prelu_mode); mode_ = std::string(prelu_mode); + const char* prelu_data_format; + DeserializeValue(&serialData, &serialLength, &prelu_data_format); + data_format_ = std::string(prelu_data_format); } ~PReluPlugin() {} int initialize() TRT_NOEXCEPT override; void terminate() TRT_NOEXCEPT override; PReluPlugin* clone() const TRT_NOEXCEPT override { - auto* ptr = new PReluPlugin(weight_.data(), weight_.size(), mode_); + auto* ptr = + new PReluPlugin(weight_.data(), weight_.size(), mode_, data_format_); ptr->p_gpu_weight_ = p_gpu_weight_; return ptr; } @@ -108,8 +114,8 @@ REGISTER_TRT_PLUGIN_V2(PReluPluginCreator); class PReluPluginDynamic : public DynamicPluginTensorRT { public: PReluPluginDynamic(const float* weight, const int weight_num, - std::string const& mode) - : mode_(mode) { + std::string const& mode, std::string const& data_format) + : mode_(mode), data_format_(data_format) { weight_.resize(weight_num); std::copy(weight, weight + weight_num, weight_.data()); } @@ -117,7 +123,8 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { PReluPluginDynamic(void const* serialData, size_t serialLength); ~PReluPluginDynamic() {} nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { - auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_, + data_format_); ptr->p_gpu_weight_ = p_gpu_weight_; return ptr; } @@ -167,6 +174,7 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { std::vector weight_; float* p_gpu_weight_; std::string mode_; + std::string data_format_; }; #endif diff --git a/paddle/fluid/operators/math/prelu.cu b/paddle/fluid/operators/math/prelu.cu index 7c93d1725e..d06490ee57 100644 --- a/paddle/fluid/operators/math/prelu.cu +++ b/paddle/fluid/operators/math/prelu.cu @@ -25,9 +25,9 @@ inline static int PADDLE_GET_BLOCKS(const int N) { } template -__global__ void PReluChannelWiseKernel(const T *input, const T *alpha, - T *output, size_t channel_num, - size_t plane_size, size_t numel) { +__global__ void PReluChannelFirstWiseKernel(const T *input, const T *alpha, + T *output, size_t channel_num, + size_t plane_size, size_t numel) { CUDA_KERNEL_LOOP(index, numel) { size_t temp = index / plane_size; size_t channel_index = temp % channel_num; @@ -38,6 +38,19 @@ __global__ void PReluChannelWiseKernel(const T *input, const T *alpha, } } +template +__global__ void PReluChannelLastWiseKernel(const T *input, const T *alpha, + T *output, size_t channel_num, + size_t numel) { + CUDA_KERNEL_LOOP(index, numel) { + size_t channel_index = index % channel_num; + T scale = alpha[channel_index]; + T x = input[index]; + T zero = static_cast(0); + output[index] = (x > zero) ? x : scale * x; + } +} + template __global__ void PReluElementWiseKernel(const T *input, const T *alpha, T *output, size_t spatial_size, @@ -65,10 +78,16 @@ __global__ void PReluScalarKernel(const T *input, const T *alpha, T *output, template void PreluChannelWiseDirectCUDAFunctor::operator()( gpuStream_t stream, const T *input, const T *alpha, T *output, - size_t batch_size, size_t channel, size_t numel) { - PReluChannelWiseKernel<<>>(input, alpha, output, channel, - numel / batch_size / channel, numel); + size_t batch_size, size_t channel, bool channel_last, size_t numel) { + if (channel_last) { + PReluChannelLastWiseKernel<<>>(input, alpha, output, channel, + numel); + } else { + PReluChannelFirstWiseKernel<<>>( + input, alpha, output, channel, numel / batch_size / channel, numel); + } } template diff --git a/paddle/fluid/operators/math/prelu.h b/paddle/fluid/operators/math/prelu.h index efa493a06c..dc1e3c1c3d 100644 --- a/paddle/fluid/operators/math/prelu.h +++ b/paddle/fluid/operators/math/prelu.h @@ -31,7 +31,8 @@ template class PreluChannelWiseDirectCUDAFunctor { public: void operator()(gpuStream_t stream, const T *input, const T *alpha, T *output, - size_t batch_size, size_t channel, size_t numel); + size_t batch_size, size_t channel, bool channel_last, + size_t numel); }; template diff --git a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc index 8296b4739d..8c7113d963 100644 --- a/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/prelu_mkldnn_op.cc @@ -34,7 +34,7 @@ class PReluMKLDNNHandler const dnnl::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* weights, const std::string& uniq_name, const std::string& mode, - bool is_test = false) + const std::string& data_format, bool is_test = false) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -49,8 +49,13 @@ class PReluMKLDNNHandler if (weights->dims().size() != x->dims().size()) { auto new_weights_dims = std::vector(x->dims().size(), 1); if (mode == "channel") { - new_weights_dims[1] = - *std::max_element(weights_dims.begin(), weights_dims.end()); + if (data_format == "NHWC") { + new_weights_dims[x->dims().size() - 1] = + *std::max_element(weights_dims.begin(), weights_dims.end()); + } else { + new_weights_dims[1] = + *std::max_element(weights_dims.begin(), weights_dims.end()); + } } weights_dims = std::move(new_weights_dims); } @@ -110,9 +115,11 @@ class PReluMKLDNNKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); const bool is_test = ctx.Attr("is_test"); const auto mode = ctx.Attr("mode"); + const auto data_format = ctx.Attr("data_format"); PReluMKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, - alpha, ctx.InputName("X"), mode, is_test); + alpha, ctx.InputName("X"), mode, data_format, + is_test); auto src_memory_p = handler.AcquireSrcMemory(x); auto weights_memory_p = @@ -149,9 +156,11 @@ class PReluGradMKLDNNKernel : public framework::OpKernel { auto* alpha = ctx.Input("Alpha"); const bool is_test = ctx.Attr("is_test"); const auto mode = ctx.Attr("mode"); + const auto data_format = ctx.Attr("data_format"); PReluMKLDNNHandler handler(dev_ctx, onednn_engine, ctx.GetPlace(), x, - alpha, framework::GradVarName("X"), mode); + alpha, framework::GradVarName("X"), mode, + data_format); auto src_memory_p = handler.AcquireSrcMemory(x); auto weights_memory_p = diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 63d12f790f..bf8651cd5f 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -38,12 +38,6 @@ class PReluOp : public framework::OperatorWithKernel { "But recevied alpha's size: %d.", product(ctx->GetInputDim("Alpha")))); } else if (mode == "channel") { - PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1], - platform::errors::InvalidArgument( - "For mode 'channel', size of weight Alpha must be " - "equal to the number of channels of input(x). But " - "recevied alpha's size: %d, x_dim[1]: %d", - product(ctx->GetInputDim("Alpha")), x_dim[1])); auto x_rank = x_dim.size(); PADDLE_ENFORCE_GE(x_rank, 2, platform::errors::InvalidArgument( @@ -51,6 +45,33 @@ class PReluOp : public framework::OperatorWithKernel { "equal or larger than 2. But recevied X's " "rank: %d", x_rank)); + const std::string data_format_str = + ctx->Attrs().Get("data_format"); + PADDLE_ENFORCE_EQ(data_format_str == "NCHW" || data_format_str == "NHWC", + true, + platform::errors::InvalidArgument( + "For mode 'channel', data_format must be one of " + "NCHW and NHWC. But recevied data_format: %s", + data_format_str)); + if (data_format_str == "NCHW") { + PADDLE_ENFORCE_EQ( + product(ctx->GetInputDim("Alpha")) == x_dim[1], true, + platform::errors::InvalidArgument( + "For mode 'channel', size of weight Alpha must be " + "equal to the number of channels of input(x). But " + "recevied alpha's size: %d, x_dim[1]: %d", + product(ctx->GetInputDim("Alpha")), x_dim[1])); + } else { + PADDLE_ENFORCE_EQ( + product(ctx->GetInputDim("Alpha")) == x_dim[x_rank - 1], true, + platform::errors::InvalidArgument( + "For mode 'channel', size of weight Alpha must be " + "equal to the number of channels of input(x). But " + "recevied alpha's size: %d, x_dim[%d]: %d", + product(ctx->GetInputDim("Alpha")), x_rank - 1, + x_dim[x_rank - 1])); + } + } else if (mode == "element") { auto alpha_dim = ctx->GetInputDim("Alpha"); auto alpha_rank = alpha_dim.size(); @@ -134,6 +155,9 @@ There are modes: )DOC"); AddAttr("mode", "The mode for inputs to share weights.") .SetDefault("all"); + AddAttr("data_format", + "Data format that specifies the layout of input") + .SetDefault("NCHW"); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false) diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 049217f2a9..ce3f5969ce 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -42,17 +42,22 @@ class CUDAPReluKernel : public framework::OpKernel { const T* alpha_ptr = alpha->data(); auto& mode = context.Attr("mode"); + auto& data_format = context.Attr("data_format"); int numel = x->numel(); auto dim = x->dims(); + auto x_rank = dim.size(); - VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] - << ", numel:" << numel; + VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim[" + << x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel; if (mode == "channel") { + bool channel_last = data_format == "NHWC"; + size_t channel = channel_last ? dim[x_rank - 1] : dim[1]; math::PreluChannelWiseDirectCUDAFunctor prelu_channel_wise; prelu_channel_wise(context.cuda_device_context().stream(), x_ptr, - alpha_ptr, o_ptr, dim[0], dim[1], numel); + alpha_ptr, o_ptr, dim[0], channel, channel_last, + numel); } else if (mode == "element") { math::PreluElementWiseDirectCUDAFunctor prelu_element_wise; prelu_element_wise(context.cuda_device_context().stream(), x_ptr, @@ -65,7 +70,7 @@ class CUDAPReluKernel : public framework::OpKernel { } }; -enum PRELU_MODE { Element, Channel, Scalar }; +enum PRELU_MODE { Element, ChannelFirst, ChannelLast, Scalar }; template __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, @@ -78,10 +83,13 @@ __global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr, if (mode == Element) { size_t element_index = index % spatial_size; scale = alpha_ptr[element_index]; - } else if (mode == Channel) { + } else if (mode == ChannelFirst) { size_t temp = index / plane_size; size_t channel_index = temp % channel_num; scale = alpha_ptr[channel_index]; + } else if (mode == ChannelLast) { + size_t channel_index = index % channel_num; + scale = alpha_ptr[channel_index]; } else { scale = alpha_ptr[0]; } @@ -105,11 +113,13 @@ class PreluOpGradFunctor { } size_t plane_size = numel / input_dims[0] / input_dims[1]; size_t spatial_size = numel / input_dims[0]; + size_t channel = + mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1]; PReluOpGradKernel< T><<>>( - x, alpha, dy, dx, dalpha, input_dims[1], plane_size, spatial_size, - numel, mode); + x, alpha, dy, dx, dalpha, channel, plane_size, spatial_size, numel, + mode); } }; @@ -140,9 +150,11 @@ class CUDAPReluGradKernel : public framework::OpKernel { if (!dx && !dalpha) return; auto& mode = context.Attr("mode"); + auto& data_format = context.Attr("data_format"); int numel = x->numel(); auto dim = x->dims(); + auto x_rank = dim.size(); std::vector input_shape = framework::vectorize(dim); auto stream = context.cuda_device_context().stream(); @@ -157,10 +169,12 @@ class CUDAPReluGradKernel : public framework::OpKernel { } PRELU_MODE m; + bool channel_last = false; if (mode == "element") { m = Element; } else if (mode == "channel") { - m = Channel; + channel_last = data_format == "NHWC"; + m = channel_last ? ChannelLast : ChannelFirst; } else { m = Scalar; } @@ -172,7 +186,8 @@ class CUDAPReluGradKernel : public framework::OpKernel { std::vector reduce_dims; for (size_t i = 0; i < dim.size(); i++) { - if (mode == "channel" && i == 1) continue; + if (mode == "channel" && !channel_last && i == 1) continue; + if (mode == "channel" && channel_last && i == dim.size() - 1) continue; if (mode == "element" && i != 0) continue; reduce_dims.push_back(i); } diff --git a/paddle/fluid/operators/prelu_op.h b/paddle/fluid/operators/prelu_op.h index 60fd75ce3c..384994eb37 100644 --- a/paddle/fluid/operators/prelu_op.h +++ b/paddle/fluid/operators/prelu_op.h @@ -33,19 +33,27 @@ class PReluKernel : public framework::OpKernel { const T* alpha_ptr = alpha->data(); auto& mode = context.Attr("mode"); + auto& data_format = context.Attr("data_format"); int numel = x->numel(); auto dim = x->dims(); int index = 0; int i = 0; if (mode == "channel") { - int temp = 1; - for (int j = 2; j < dim.size(); j++) { - temp *= dim[j]; - } - for (i = 0; i < numel; i++) { - index = (i / temp) % dim[1]; - o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + if (data_format == "NCHW") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } + for (i = 0; i < numel; i++) { + index = (i / temp) % dim[1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + index = i % dim[dim.size() - 1]; + o_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : alpha_ptr[index] * x_ptr[i]; + } } } else if (mode == "element") { int temp = 1; @@ -77,6 +85,7 @@ class PReluGradKernel : public framework::OpKernel { const T* x_ptr = x->data(); const T* dout_ptr = dout->data(); std::string mode = context.Attr("mode"); + auto& data_format = context.Attr("data_format"); int numel = x->numel(); auto dim = x->dims(); int index = 0; @@ -84,14 +93,22 @@ class PReluGradKernel : public framework::OpKernel { if (dx) { T* dx_ptr = dx->mutable_data(context.GetPlace()); if (mode == "channel") { - int temp = 1; - for (int j = 2; j < dim.size(); j++) { - temp *= dim[j]; - } - for (i = 0; i < numel; i++) { - index = (i / temp) % dim[1]; - dx_ptr[i] = - x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + if (data_format == "NCHW") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } + for (i = 0; i < numel; i++) { + index = (i / temp) % dim[1]; + dx_ptr[i] = + x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + index = i % dim[dim.size() - 1]; + dx_ptr[i] = + x_ptr[i] > 0 ? dout_ptr[i] : alpha_ptr[index] * dout_ptr[i]; + } } } else if (mode == "element") { int temp = 1; @@ -116,13 +133,20 @@ class PReluGradKernel : public framework::OpKernel { memset(dalpha_ptr, 0, sizeof(T) * dalpha->numel()); if (mode == "channel") { - int temp = 1; - for (int j = 2; j < dim.size(); j++) { - temp *= dim[j]; - } - for (i = 0; i < numel; i++) { - index = (i / temp) % dim[1]; - dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + if (data_format == "NCHW") { + int temp = 1; + for (int j = 2; j < dim.size(); j++) { + temp *= dim[j]; + } + for (i = 0; i < numel; i++) { + index = (i / temp) % dim[1]; + dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } + } else { + for (i = 0; i < numel; i++) { + index = i % dim[dim.size() - 1]; + dalpha_ptr[index] += x_ptr[i] > 0 ? 0 : x_ptr[i] * dout_ptr[i]; + } } } else if (mode == "element") { int temp = 1; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 366d873504..c7fb75387a 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9791,7 +9791,7 @@ def swish(x, beta=1.0, name=None): @deprecated(since="2.0.0", update_to="paddle.static.nn.prelu") -def prelu(x, mode, param_attr=None, name=None): +def prelu(x, mode, param_attr=None, data_format="NCHW", name=None): r""" prelu activation. @@ -9818,6 +9818,9 @@ def prelu(x, mode, param_attr=None, name=None): name (str, optional): Name for the operation (optional, default is None). \ For more information, please refer to :ref:`api_guide_Name`. + + data_format(str, optional): Data format that specifies the layout of input. + It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". Returns: Tensor: A tensor with the same shape and data type as x. @@ -9839,17 +9842,32 @@ def prelu(x, mode, param_attr=None, name=None): helper = LayerHelper('prelu', **locals()) if mode not in ['all', 'channel', 'element']: raise ValueError('mode should be one of all, channel, element.') + alpha_shape = [1] - # NOTE(): The input of this API should be ``N,C,...`` format, - # which means x.shape[0] is batch_size and x.shape[0] is channel. if mode == 'channel': + + true_data_format = [ + 'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC' + ] + if data_format not in true_data_format: + raise ValueError( + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " + "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)) + + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' + assert len( x.shape ) >= 2, "The size of input shape should be equal or larger than 2 in prelu() when mode is 'channel'" #NOTE(zhiqiu): The alpha_shape should be [1, channel] + [1] * len(x.shape[2:]). # To be consistent with Prelu, it is simplified. #NOTE(zhiqiu): Revert shape to [1, channel, 1, 1] for compatibility with saved model of old version. - alpha_shape = [1, x.shape[1], 1, 1] + #NOTE(GuoxiaWang): support NHWC data format + if data_format == 'NHWC': + alpha_shape = [1, 1, 1, x.shape[1]] + else: + alpha_shape = [1, x.shape[1], 1, 1] + elif mode == 'element': assert len( x.shape @@ -9867,7 +9885,8 @@ def prelu(x, mode, param_attr=None, name=None): type="prelu", inputs={"X": x, 'Alpha': alpha}, - attrs={"mode": mode}, + attrs={"mode": mode, + "data_format": data_format}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py index 98c34a669c..3839c22ca2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_prelu_op.py @@ -44,8 +44,12 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): if len(kwargs['in_shape']) <= 1: # not valid case, just return 0 return np.zeros((1)).astype(np.float32) - return np.random.random(kwargs['in_shape'][1]).astype( - np.float32) + if kwargs['data_format'] == 'NCHW': + return np.random.random(kwargs['in_shape'][1]).astype( + np.float32) + else: + return np.random.random(kwargs['in_shape'][-1]).astype( + np.float32) else: if len(kwargs['in_shape']) <= 1: # not valid case, just return 0 @@ -57,7 +61,10 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): inputs={"X": ["input_data"], "Alpha": ["alpha_weight"]}, outputs={"Out": ["output_data"]}, - attrs={"mode": kwargs['mode']}) + attrs={ + "mode": kwargs['mode'], + "data_format": kwargs['data_format'] + }) program_config = ProgramConfig( ops=[prelu_op], @@ -82,6 +89,7 @@ class TestMkldnnPreluOp(MkldnnAutoScanTest): @given( mode=st.sampled_from(['all', 'channel', 'element']), + data_format=st.sampled_from(['NCHW', 'NHWC']), in_shape=st.lists( st.integers( min_value=1, max_value=32), min_size=1, max_size=4)) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py index 0bcbffb367..5153476ae1 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py @@ -39,7 +39,8 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3): if attrs[0]["mode"] == "all": return np.random.random(size=(1)).astype(np.float32) - elif attrs[0]["mode"] == "channel": + elif attrs[0]["mode"] == "channel" and attrs[0][ + "data_format"] == "NCHW": shape = [1] if dim1 != 0: shape.append(dim1) @@ -48,6 +49,16 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): if dim3 != 0: shape.append(1) return np.random.random(size=shape).astype(np.float32) + elif attrs[0]["mode"] == "channel" and attrs[0][ + "data_format"] == "NHWC": + shape = [1] + if dim1 != 0: + shape.append(1) + if dim2 != 0: + shape.append(1) + if dim3 != 0: + shape.append(dim3) + return np.random.random(size=shape).astype(np.float32) elif attrs[0]["mode"] == "element": shape = [1] if dim1 != 0: @@ -72,37 +83,45 @@ class TrtConvertPreluTest(TrtLayerAutoScanTest): continue for mode in ["all", "channel", "element"]: - if mode == "channel" and dim1 == 0: - continue - dics = [{"mode": mode}] - ops_config = [{ - "op_type": "prelu", - "op_inputs": { - "X": ["input_data"], - "Alpha": ["alpha_weight"] - }, - "op_outputs": { - "Out": ["output_data"] - }, - "op_attrs": dics[0] - }] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={ - "alpha_weight": TensorConfig( - data_gen=partial(generate_alpha, dics, - dim1, dim2, dim3)) - }, - inputs={ - "input_data": TensorConfig( - data_gen=partial(generate_input, batch, - dim1, dim2, dim3)), - }, - outputs=["output_data"]) - - yield program_config + for data_format in ['NCHW', 'NHWC']: + if mode == "channel" and dim1 == 0 and data_format == "NCHW": + continue + if mode == "channel" and dim3 == 0 and data_format == "NHWC": + continue + dics = [{ + "mode": mode, + "data_format": data_format + }] + ops_config = [{ + "op_type": "prelu", + "op_inputs": { + "X": ["input_data"], + "Alpha": ["alpha_weight"] + }, + "op_outputs": { + "Out": ["output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "alpha_weight": TensorConfig( + data_gen=partial(generate_alpha, + dics, dim1, dim2, + dim3)) + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, + batch, dim1, dim2, + dim3)), + }, + outputs=["output_data"]) + + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_layers.py b/python/paddle/fluid/tests/unittests/test_imperative_layers.py index 8bb87198d8..f69ed7a817 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_layers.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_layers.py @@ -41,10 +41,11 @@ class TestLayerPrint(unittest.TestCase): self.assertEqual( str(module), 'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)') - module = nn.PReLU(1, 0.25, name="PReLU") + module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW") self.assertEqual( str(module), - 'PReLU(num_parameters=1, init=0.25, dtype=float32, name=PReLU)') + 'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)' + ) module = nn.ReLU() self.assertEqual(str(module), 'ReLU()') diff --git a/python/paddle/fluid/tests/unittests/test_prelu_op.py b/python/paddle/fluid/tests/unittests/test_prelu_op.py index 04862eba8a..6afc462322 100644 --- a/python/paddle/fluid/tests/unittests/test_prelu_op.py +++ b/python/paddle/fluid/tests/unittests/test_prelu_op.py @@ -163,10 +163,18 @@ class PReluTest(OpTest): # zero. x_np[np.abs(x_np) < 0.005] = 0.02 - if self.attrs == {'mode': "all"}: + if self.attrs == { + 'mode': "all", + "data_format": "NCHW" + } or self.attrs == { + 'mode': "all", + "data_format": "NHWC" + }: alpha_np = np.random.uniform(-1, -0.5, (1)) - elif self.attrs == {'mode': "channel"}: + elif self.attrs == {'mode': "channel", "data_format": "NCHW"}: alpha_np = np.random.uniform(-1, -0.5, [1, self.x_shape[1], 1, 1]) + elif self.attrs == {'mode': "channel", "data_format": "NHWC"}: + alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]]) else: alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:]) alpha_np = alpha_np.astype(self.dtype) @@ -176,11 +184,14 @@ class PReluTest(OpTest): # NOTE(zhiqu): reshape inputs['Alpha'] from [1, 100, 1, 1] to [1, 100] + [1]*len(x.shape[2:]) # since np operands could not be broadcast together with shapes (1,100,2,2,2,3) (1,100,1,1) reshaped_alpha = self.inputs['Alpha'] - if self.attrs == {'mode': "channel"}: + if self.attrs == {'mode': "channel", "data_format": "NCHW"}: reshaped_alpha = np.reshape( self.inputs['Alpha'], [1, self.x_shape[1]] + [1] * len(self.x_shape[2:])) - + elif self.attrs == {'mode': "channel", "data_format": "NHWC"}: + reshaped_alpha = np.reshape( + self.inputs['Alpha'], + [1] + [1] * len(self.x_shape[1:-1]) + [self.x_shape[-1]]) out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * reshaped_alpha assert out_np is not self.inputs['X'] @@ -193,7 +204,7 @@ class PReluTest(OpTest): self.x_shape = [2, 100, 3, 4] def init_attr(self): - self.attrs = {'mode': "channel"} + self.attrs = {'mode': "channel", "data_format": "NCHW"} def test_check_output(self): self.check_output() @@ -210,7 +221,18 @@ class TestModeAll(PReluTest): self.x_shape = [2, 3, 4, 5] def init_attr(self): - self.attrs = {'mode': "all"} + self.attrs = {'mode': "all", "data_format": "NCHW"} + + +@skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" +) +class TestModeAllNHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [2, 3, 4, 50] + + def init_attr(self): + self.attrs = {'mode': "all", "data_format": "NHWC"} class TestModeElt(PReluTest): @@ -218,7 +240,15 @@ class TestModeElt(PReluTest): self.x_shape = [3, 2, 5, 10] def init_attr(self): - self.attrs = {'mode': "element"} + self.attrs = {'mode': "element", "data_format": "NCHW"} + + +class TestModeEltNHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [3, 2, 5, 10] + + def init_attr(self): + self.attrs = {'mode': "element", "data_format": "NHWC"} @skip_check_grad_ci( @@ -229,7 +259,18 @@ class TestModeAllRank3(PReluTest): self.x_shape = [1, 200, 3] def init_attr(self): - self.attrs = {'mode': "all"} + self.attrs = {'mode': "all", "data_format": "NCHW"} + + +@skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" +) +class TestModeAllRank3NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 200, 3] + + def init_attr(self): + self.attrs = {'mode': "all", "data_format": "NHWC"} @skip_check_grad_ci( @@ -240,7 +281,18 @@ class TestModeAllRank6(PReluTest): self.x_shape = [1, 2, 3, 4, 5, 6] def init_attr(self): - self.attrs = {'mode': "all"} + self.attrs = {'mode': "all", "data_format": "NCHW"} + + +@skip_check_grad_ci( + reason="[skip shape check] Input(Alpha) must be 1-D and only has one data in 'all' mode" +) +class TestModeAllRank6NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 2, 3, 4, 5, 6] + + def init_attr(self): + self.attrs = {'mode': "all", "data_format": "NHWC"} class TestModeChannelRank3(PReluTest): @@ -248,7 +300,15 @@ class TestModeChannelRank3(PReluTest): self.x_shape = [1, 200, 3] def init_attr(self): - self.attrs = {'mode': "channel"} + self.attrs = {'mode': "channel", "data_format": "NCHW"} + + +class TestModeChannelRank3NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 3, 100] + + def init_attr(self): + self.attrs = {'mode': "channel", "data_format": "NHWC"} class TestModeChannelRank6(PReluTest): @@ -256,7 +316,15 @@ class TestModeChannelRank6(PReluTest): self.x_shape = [1, 100, 2, 2, 2, 2] def init_attr(self): - self.attrs = {'mode': "channel"} + self.attrs = {'mode': "channel", "data_format": "NCHW"} + + +class TestModeChannelRank6NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [1, 2, 2, 2, 2, 100] + + def init_attr(self): + self.attrs = {'mode': "channel", "data_format": "NHWC"} class TestModeElementRank3(PReluTest): @@ -264,7 +332,15 @@ class TestModeElementRank3(PReluTest): self.x_shape = [3, 10, 10] def init_attr(self): - self.attrs = {'mode': "element"} + self.attrs = {'mode': "element", "data_format": "NCHW"} + + +class TestModeElementRank3NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [3, 10, 10] + + def init_attr(self): + self.attrs = {'mode': "element", "data_format": "NHWC"} class TestModeElementRank6(PReluTest): @@ -272,7 +348,15 @@ class TestModeElementRank6(PReluTest): self.x_shape = [3, 2, 2, 4, 5, 2] def init_attr(self): - self.attrs = {'mode': "element"} + self.attrs = {'mode': "element", "data_format": "NCHW"} + + +class TestModeElementRank6NHWC(PReluTest): + def init_input_shape(self): + self.x_shape = [3, 2, 2, 4, 5, 2] + + def init_attr(self): + self.attrs = {'mode': "element", "data_format": "NHWC"} def create_test_fp16_class(parent, @@ -311,9 +395,16 @@ create_test_fp16_class(TestModeChannelRank3) create_test_fp16_class(TestModeChannelRank6) create_test_fp16_class(TestModeElementRank3) create_test_fp16_class(TestModeElementRank6) +create_test_fp16_class(TestModeEltNHWC) +create_test_fp16_class(TestModeAllRank3NHWC) +create_test_fp16_class(TestModeAllRank6NHWC) +create_test_fp16_class(TestModeChannelRank3NHWC) +create_test_fp16_class(TestModeChannelRank6NHWC) +create_test_fp16_class(TestModeElementRank3NHWC) +create_test_fp16_class(TestModeElementRank6NHWC) -def prelu_t(x, mode, param_attr=None, name=None): +def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'): helper = fluid.layer_helper.LayerHelper('prelu', **locals()) alpha_shape = [1, x.shape[1], 1, 1] dtype = helper.input_dtype(input_param_name='x') @@ -328,13 +419,19 @@ def prelu_t(x, mode, param_attr=None, name=None): type="prelu", inputs={"X": x, 'Alpha': alpha}, - attrs={"mode": mode}, + attrs={"mode": mode, + 'data_format': data_format}, outputs={"Out": out}) return out # error message test if mode is not one of 'all', 'channel', 'element' class TestModeError(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + self.x_np = np.ones([1, 2, 3, 4]).astype('float32') + def test_mode_error(self): main_program = Program() with fluid.program_guard(main_program, Program()): @@ -344,6 +441,24 @@ class TestModeError(unittest.TestCase): except Exception as e: assert (e.args[0].find('InvalidArgument') != -1) + def test_data_format_error1(self): + main_program = Program() + with fluid.program_guard(main_program, Program()): + x = fluid.data(name='x', shape=[2, 3, 4, 5]) + try: + y = prelu_t(x, 'channel', data_format='N') + except Exception as e: + assert (e.args[0].find('InvalidArgument') != -1) + + def test_data_format_error2(self): + main_program = Program() + with fluid.program_guard(main_program, Program()): + x = fluid.data(name='x', shape=[2, 3, 4, 5]) + try: + y = paddle.static.nn.prelu(x, 'channel', data_format='N') + except ValueError as e: + pass + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 6b4d171d39..4a071c2fe7 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -442,7 +442,7 @@ def leaky_relu(x, negative_slope=0.01, name=None): return out -def prelu(x, weight, name=None): +def prelu(x, weight, data_format="NCHW", name=None): """ prelu activation. @@ -456,6 +456,8 @@ def prelu(x, weight, name=None): The weight shape is [1] or [in], where `in` is the input channel of ``x``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + data_format(str, optional): Data format that specifies the layout of input. + It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". Returns: A Tensor with the same data type and shape as ``x`` . @@ -490,19 +492,34 @@ def prelu(x, weight, name=None): assert len(weight.shape ) == 1, "The dim count of weight shape should be 1 in prelu()." - # NOTE(): The input of this API should be ``N,C,...`` format, - # which means x.shape[0] is batch_size and x.shape[0] is channel. mode = 'all' if weight.shape[0] > 1: + + true_data_format = [ + 'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC' + ] + if data_format not in true_data_format: + raise ValueError( + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', " + "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format)) + + data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC' + assert len( x.shape ) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]." - assert weight.shape[0] == x.shape[ - 1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." + + #NOTE(GuoxiaWang): support NHWC data format + if data_format == 'NHWC': + assert weight.shape[0] == x.shape[ + -1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." + else: + assert weight.shape[0] == x.shape[ + 1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]." mode = 'channel' if in_dygraph_mode(): - return _C_ops.prelu(x, weight, 'mode', mode) + return _C_ops.prelu(x, weight, 'mode', mode, 'data_format', data_format) helper = LayerHelper('prelu', **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -511,7 +528,8 @@ def prelu(x, weight, name=None): inputs={"X": x, "Alpha": weight}, outputs={"Out": out}, - attrs={"mode": mode}) + attrs={"mode": mode, + "data_format": data_format}) return out diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 6d31389103..45308f15f4 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -376,6 +376,8 @@ class PReLU(Layer): Default is None. For more information, please refer to :ref:`api_paddle_ParamAttr`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + data_format(str, optional): Data format that specifies the layout of input. + It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". Shape: - input: Tensor with any shape. Default dtype is float32. @@ -406,13 +408,18 @@ class PReLU(Layer): # [ 6. , 7. , 8. , 9. ]]]] """ - def __init__(self, num_parameters=1, init=0.25, weight_attr=None, + def __init__(self, + num_parameters=1, + init=0.25, + weight_attr=None, + data_format="NCHW", name=None): super(PReLU, self).__init__() self._num_parameters = num_parameters self._init = init self._weight_attr = weight_attr self._name = name + self._data_format = data_format self._weight = self.create_parameter( attr=self._weight_attr, @@ -422,12 +429,13 @@ class PReLU(Layer): default_initializer=Constant(self._init)) def forward(self, x): - return F.prelu(x, self._weight) + return F.prelu(x, self._weight, data_format=self._data_format) def extra_repr(self): name_str = ', name={}'.format(self._name) if self._name else '' - return 'num_parameters={}, init={}, dtype={}{}'.format( - self._num_parameters, self._init, self._dtype, name_str) + return 'num_parameters={}, data_format={}, init={}, dtype={}{}'.format( + self._num_parameters, self._data_format, self._init, self._dtype, + name_str) class ReLU(Layer): -- GitLab