From 1aa6adb1fb4c6bfb65f3a8af2bca4a675743f8b3 Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Fri, 19 Aug 2022 16:55:14 +0800 Subject: [PATCH] Trt groupnorm dynamic plugin (#44911) * add group_norm dyanmic plugin --- .../tensorrt/convert/group_norm_op.cc | 86 +++--- paddle/fluid/inference/tensorrt/op_teller.cc | 26 +- .../inference/tensorrt/plugin/CMakeLists.txt | 1 + .../tensorrt/plugin/group_norm_op_plugin.cu | 263 ++++++++++++++++++ .../tensorrt/plugin/group_norm_op_plugin.h | 255 +++++++++++++++++ paddle/phi/kernels/gpu/group_norm_kernel.cu | 92 ++++++ paddle/phi/kernels/group_norm_kernel.h | 21 ++ .../inference/test_trt_convert_group_norm.py | 128 ++++----- .../ir/inference/test_trt_group_norm_op.py | 68 ----- 9 files changed, 749 insertions(+), 191 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h delete mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_group_norm_op.py diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index 275837ea6a7..fd6ce2658c3 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -9,11 +9,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include + namespace paddle { namespace framework { class Scope; @@ -59,52 +61,44 @@ class GroupNormOpConverter : public OpConverter { framework::DDim bias_dims; auto scale_weights = GetWeight(scale_name, &scale_dims); auto bias_weights = GetWeight(bias_name, &bias_dims); - - nvinfer1::Dims scale_nv_dims; - nvinfer1::Dims bias_nv_dims; - scale_nv_dims.nbDims = scale_dims.size(); - bias_nv_dims.nbDims = bias_dims.size(); - for (int i = 0; i < scale_dims.size(); i++) { - scale_nv_dims.d[i] = scale_dims.at(i); + if (engine_->with_dynamic_shape()) { + int gn_num = groups; + std::vector mean_shape({gn_num}); + std::vector variance_shape({gn_num}); + plugin::GroupNormPluginDynamic* plugin = + new plugin::GroupNormPluginDynamic( + static_cast(scale_weights.get().values), + scale_weights.get().count, + static_cast(bias_weights.get().values), + bias_weights.get().count, + epsilon, + groups, + mean_shape, + variance_shape); + nvinfer1::ILayer* groupnorm_layer = + engine_->AddDynamicPlugin(&input_itensor, 1, plugin); + auto output_name = op_desc.Output("Y")[0]; + RreplenishLayerAndOutput( + groupnorm_layer, "group_norm", {output_name}, test_mode); + } else { + int gn_num = input_itensor->getDimensions().d[0] * groups; + std::vector mean_shape({gn_num}); + std::vector variance_shape({gn_num}); + plugin::GroupNormPlugin* plugin = new plugin::GroupNormPlugin( + static_cast(scale_weights.get().values), + scale_weights.get().count, + static_cast(bias_weights.get().values), + bias_weights.get().count, + epsilon, + groups, + mean_shape, + variance_shape); + nvinfer1::ILayer* groupnorm_layer = + engine_->AddPlugin(&input_itensor, 1, plugin); + auto output_name = op_desc.Output("Y")[0]; + RreplenishLayerAndOutput( + groupnorm_layer, "group_norm", {output_name}, test_mode); } - for (int i = 0; i < bias_dims.size(); i++) { - bias_nv_dims.d[i] = bias_dims.at(i); - } - - auto* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Constant, scale_nv_dims, scale_weights.get()); - auto* bias_layer = TRT_ENGINE_ADD_LAYER( - engine_, Constant, bias_nv_dims, bias_weights.get()); - - std::vector plugin_inputs; - plugin_inputs.emplace_back(input_itensor); - plugin_inputs.emplace_back(scale_layer->getOutput(0)); - plugin_inputs.emplace_back(bias_layer->getOutput(0)); - - const std::vector fields{ - {"eps", &epsilon, nvinfer1::PluginFieldType::kFLOAT32, 1}, - {"num_groups", &groups, nvinfer1::PluginFieldType::kINT32, 1}, - }; - - nvinfer1::PluginFieldCollection* plugin_collections = - static_cast( - malloc(sizeof(*plugin_collections) + - fields.size() * sizeof(nvinfer1::PluginField))); - plugin_collections->nbFields = static_cast(fields.size()); - plugin_collections->fields = fields.data(); - - auto creator = - GetPluginRegistry()->getPluginCreator("GroupNormalizationPlugin", "1"); - auto group_norm_plugin = - creator->createPlugin("GroupNormalizationPlugin", plugin_collections); - free(plugin_collections); - - auto group_norm_plugin_layer = engine_->network()->addPluginV2( - plugin_inputs.data(), plugin_inputs.size(), *group_norm_plugin); - - auto output_name = op_desc.Output("Y")[0]; - RreplenishLayerAndOutput( - group_norm_plugin_layer, "group_norm", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 910d0393167..36e4f2cbc9d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -32,11 +32,9 @@ namespace tensorrt { // Just tell by the op_types. struct SimpleOpTypeSetTeller : public Teller { SimpleOpTypeSetTeller() { -// TODO(baoachun) The group_norm trt plugin will check input's dim -// not -1 failed when dynamic shape mode. -// #if IS_TRT_VERSION_GE(7130) -// teller_set.insert("group_norm"); -// #endif +#if IS_TRT_VERSION_GE(7130) + teller_set.insert("group_norm"); +#endif #if IS_TRT_VERSION_GE(7000) teller_set.insert("tile"); teller_set.insert("flatten_contiguous_range"); @@ -583,12 +581,26 @@ bool OpTeller::Tell(const framework::ir::Node* node, const auto x_shape = x_var_desc->GetShape(); } if (op_type == "group_norm") { - if (!with_dynamic_shape) return false; bool has_attrs = (desc.HasAttr("epsilon") && desc.HasAttr("groups")); if (has_attrs == false) return false; - auto registry = GetPluginRegistry(); if (registry == nullptr) return false; + std::string layout_str = + PADDLE_GET_CONST(std::string, desc.GetAttr("data_layout")); + if (layout_str != "NCHW") { + VLOG(3) << "Group norm trt plugin only support NCHW layout, but got " + << layout_str; + return false; + } + auto* block = desc.Block(); + if (block == nullptr) return false; + auto x_var_name = desc.Input("X")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + auto dtype = x_var_desc->GetDataType(); + if (dtype != 5) { + VLOG(3) << "Group norm trt plugin only support float32"; + return false; + } } if (op_type == "concat") { if (!desc.HasAttr("axis")) { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 90344fc0ada..b41823d9186 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -8,6 +8,7 @@ list( gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu + group_norm_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu new file mode 100644 index 00000000000..294677e6ac5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.cu @@ -0,0 +1,263 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h" +#include "paddle/phi/kernels/group_norm_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/layout.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +using DataLayout = framework::DataLayout; + +int GroupNormPlugin::initialize() TRT_NOEXCEPT { return 0; } + +nvinfer1::Dims GroupNormPlugin::getOutputDimensions( + int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT { + return inputDims[0]; +} + +int GroupNormPlugin::enqueue(int batch_size, + const void *const *inputs, +#if IS_TRT_VERSION_LT(8000) + void **outputs, + void *workspace, +#else + void *const *outputs, + void *workspace, +#endif + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = this->getInputDims(0); + int groups = groups_; + float eps = eps_; + std::vector input_shape; + input_shape.push_back(batch_size); + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + const auto input_ddim = phi::make_ddim(input_shape); + + int C = input_shape[1]; + + PADDLE_ENFORCE_EQ( + C, + scale_.size(), + platform::errors::InvalidArgument( + "scale's size should be equal to the channel number in groupnorm," + "but got channel number:%d, scale's size:%d.", + C, + scale_.size())); + PADDLE_ENFORCE_EQ( + C, + bias_.size(), + platform::errors::InvalidArgument( + "bias's size should be equal to the channel number in groupnorm," + "but got channel number:%d, bias's size:%d.", + C, + bias_.size())); + + int device_id; + cudaGetDevice(&device_id); + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + + scale_t.Resize(phi::make_ddim({C})); + bias_t.Resize(phi::make_ddim({C})); + + mean_t.Resize(phi::make_ddim(mean_shape_)); + variance_t.Resize(phi::make_ddim(variance_shape_)); + float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); + float *variance_d = + variance_t.mutable_data(platform::CUDAPlace(device_id)); + + framework::Tensor temp_variance_t; + temp_variance_t.Resize(phi::make_ddim(variance_shape_)); + float *temp_variance_d = + temp_variance_t.mutable_data(platform::CUDAPlace(device_id)); + cudaMemcpyAsync(scale_d, + scale_.data(), + sizeof(float) * C, + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync( + bias_d, bias_.data(), sizeof(float) * C, cudaMemcpyHostToDevice, stream); + phi::GroupNormDirectCUDAFunctor group_norm; + group_norm(stream, + input, + input_shape, + bias_d, + scale_d, + mean_d, + temp_variance_d, + groups_, + eps_, + output, + mean_d, + variance_d, + DataLayout::kNCHW); + return cudaGetLastError() != cudaSuccess; +} +nvinfer1::DimsExprs GroupNormPluginDynamic::getOutputDimensions( + int output_index, + const nvinfer1::DimsExprs *inputDims, + int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + return inputDims[0]; +} + +bool GroupNormPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc *in_out, + int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, + platform::errors::InvalidArgument( + "The input of groupnorm plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, + nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, + nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType GroupNormPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, + 0, + platform::errors::InvalidArgument( + "The groupnorm Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +int GroupNormPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, + const void *const *inputs, + void *const *outputs, + void *workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto &input_dims = input_desc[0].dims; + int groups = groups_; + float eps = eps_; + + std::vector input_shape; + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + + const auto input_ddim = phi::make_ddim(input_shape); + + int C = input_shape[1]; + int image_size = input_shape[2] * input_shape[3]; + int batchSize = input_shape[0]; + std::vector batched_mean_shape = {batchSize * mean_shape_[0]}; + std::vector batched_variance_shape = {batchSize * + variance_shape_[0]}; + PADDLE_ENFORCE_EQ( + C, + scale_.size(), + platform::errors::InvalidArgument( + "scale's size should be equal to the channel number in groupnorm," + "but got feature_size:%d, scale's size:%d.", + C, + scale_.size())); + PADDLE_ENFORCE_EQ( + C, + bias_.size(), + platform::errors::InvalidArgument( + "bias's size should be equal to the channel number in groupnorm," + "but got feature_size:%d, bias's size:%d.", + C, + bias_.size())); + + int device_id; + cudaGetDevice(&device_id); + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + scale_t.Resize(phi::make_ddim({C})); + bias_t.Resize(phi::make_ddim({C})); + + mean_t.Resize(phi::make_ddim(batched_mean_shape)); + variance_t.Resize(phi::make_ddim(batched_variance_shape)); + float *scale_d = + scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); + float *variance_d = + variance_t.mutable_data(platform::CUDAPlace(device_id)); + + framework::Tensor temp_variance_t; + temp_variance_t.Resize(phi::make_ddim(batched_variance_shape)); + float *temp_variance_d = + temp_variance_t.mutable_data(platform::CUDAPlace(device_id)); + cudaMemcpyAsync(scale_d, + scale_.data(), + sizeof(float) * C, + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(bias_d, + bias_.data(), + sizeof(float) * C, + cudaMemcpyHostToDevice, + stream); + + phi::GroupNormDirectCUDAFunctor group_norm; + group_norm(stream, + input, + input_shape, + bias_d, + scale_d, + mean_d, + temp_variance_d, + groups, + eps, + output, + mean_d, + variance_d, + DataLayout::kNCHW); + } else { + // input not float + PADDLE_THROW(platform::errors::Fatal( + "The Groupnorm TRT Plugin's only support fp32 input")); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h new file mode 100644 index 00000000000..fdcb93e29f0 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/group_norm_op_plugin.h @@ -0,0 +1,255 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +class GroupNormPlugin : public PluginTensorRT { + public: + size_t getSerializationSize() const TRT_NOEXCEPT override { + return getBaseSerializationSize() + SerializedSize(scale_) + + SerializedSize(bias_) + SerializedSize(eps_) + + SerializedSize(groups_) + SerializedSize(mean_shape_) + + SerializedSize(variance_shape_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + serializeBase(buffer); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, groups_); + SerializeValue(&buffer, mean_shape_); + SerializeValue(&buffer, variance_shape_); + } + + GroupNormPlugin(const float* scale, + const int scale_num, + const float* bias, + const int bias_num, + float eps, + int groups, + std::vector mean_shape, + std::vector variance_shape) + : groups_(groups), + eps_(eps), + mean_shape_(mean_shape), + variance_shape_(variance_shape) { + scale_.resize(scale_num); + bias_.resize(bias_num); + std::copy(scale, scale + scale_num, scale_.data()); + std::copy(bias, bias + bias_num, bias_.data()); + } + GroupNormPlugin(void const* serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &groups_); + DeserializeValue(&serialData, &serialLength, &mean_shape_); + DeserializeValue(&serialData, &serialLength, &variance_shape_); + } + ~GroupNormPlugin() {} + int initialize() TRT_NOEXCEPT override; + GroupNormPlugin* clone() const TRT_NOEXCEPT override { + return new GroupNormPlugin(scale_.data(), + scale_.size(), + bias_.data(), + bias_.size(), + eps_, + groups_, + mean_shape_, + variance_shape_); + } + const char* getPluginType() const TRT_NOEXCEPT override { + return "groupnorm_plugin"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims* inputs, + int nbInputDims) TRT_NOEXCEPT override; + +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batchSize, + const void* const* inputs, + void** outputs, +#else + int enqueue(int batchSize, + const void* const* inputs, + void* const* outputs, +#endif + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + private: + std::vector scale_; + std::vector bias_; + framework::Tensor scale_t; + framework::Tensor bias_t; + framework::Tensor mean_t; + framework::Tensor variance_t; + int groups_; + float eps_; + std::vector mean_shape_; + std::vector variance_shape_; +}; +class GroupNormPluginCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "groupnorm_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new GroupNormPlugin(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(GroupNormPluginCreator); + +class GroupNormPluginDynamic : public DynamicPluginTensorRT { + public: + GroupNormPluginDynamic(const float* scale, + const int scale_num, + const float* bias, + const int bias_num, + float eps, + int groups, + std::vector mean_shape, + std::vector variance_shape) + : groups_(groups), + eps_(eps), + mean_shape_(mean_shape), + variance_shape_(variance_shape) { + scale_.resize(scale_num); + bias_.resize(bias_num); + std::copy(scale, scale + scale_num, scale_.data()); + std::copy(bias, bias + bias_num, bias_.data()); + } + + GroupNormPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &groups_); + DeserializeValue(&serialData, &serialLength, &mean_shape_); + DeserializeValue(&serialData, &serialLength, &variance_shape_); + } + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new GroupNormPluginDynamic(scale_.data(), + scale_.size(), + bias_.data(), + bias_.size(), + eps_, + groups_, + mean_shape_, + variance_shape_); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "groupnorm_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(scale_) + SerializedSize(bias_) + + SerializedSize(eps_) + SerializedSize(groups_) + + SerializedSize(mean_shape_) + SerializedSize(variance_shape_); + } + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, groups_); + SerializeValue(&buffer, mean_shape_); + SerializeValue(&buffer, variance_shape_); + } + nvinfer1::DimsExprs getOutputDimensions(int output_index, + const nvinfer1::DimsExprs* inputs, + int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) + TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const + TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + // void terminate() TRT_NOEXCEPT override; + + private: + std::vector scale_; + std::vector bias_; + framework::Tensor scale_t; + framework::Tensor bias_t; + framework::Tensor mean_t; + framework::Tensor variance_t; + int groups_; + float eps_; + std::vector mean_shape_; + std::vector variance_shape_; +}; +class GroupNormPluginDynamicCreator : public TensorRTPluginCreator { + public: + const char* getPluginName() const TRT_NOEXCEPT override { + return "groupnorm_plugin_dynamic"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) + TRT_NOEXCEPT override { + return new GroupNormPluginDynamic(serial_data, serial_length); + } +}; +REGISTER_TRT_PLUGIN_V2(GroupNormPluginDynamicCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/phi/kernels/gpu/group_norm_kernel.cu b/paddle/phi/kernels/gpu/group_norm_kernel.cu index 012f224f044..7c518b72eb3 100644 --- a/paddle/phi/kernels/gpu/group_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/group_norm_kernel.cu @@ -228,6 +228,98 @@ void GroupNormKernel(const Context& dev_ctx, data_layout); } +template +void GroupNormDirectCUDAFunctor::operator()(gpuStream_t stream, + const T* input, + std::vector input_shape, + const T* bias, + const T* scale, + T* temp_mean, + T* temp_variance, + int groups, + float eps, + T* output, + T* mean, + T* variance, + const DataLayout data_layout) { + const auto input_ddim = phi::make_ddim(input_shape); + const int C = + (data_layout == DataLayout::kNCHW ? input_ddim[1] + : input_ddim[input_ddim.size() - 1]); + const int group_size = C / groups; + const int W = + (data_layout == DataLayout::kNCHW ? input_ddim[input_ddim.size() - 1] + : input_ddim[input_ddim.size() - 2]); + + int image_size = 1; + if (data_layout == DataLayout::kNCHW) { + for (int i = 2; i < input_ddim.size(); ++i) { + image_size *= input_ddim[i]; + } + } else { + for (int i = 1; i < input_ddim.size() - 1; ++i) { + image_size *= input_ddim[i]; + } + } +#ifdef __HIPCC__ + int block_size = std::max(std::min(256, image_size), 64); +#else + int block_size = std::min(1024, image_size); +#endif + dim3 grid(group_size, groups, input_ddim[0]); + dim3 threads(block_size, 1, 1); + if (data_layout == DataLayout::kNCHW) { + using AccT = typename phi::kps::details::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(float); + int size = group_size * image_size; // group element size + const int max_num_threads = 1024; + int max_block_size = std::min(size / vec_size, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + + block_size_nchw = std::max(block_size_nchw, phi::kps::details::kWarpSize); + dim3 grids(input_ddim[0] * groups); + dim3 blocks(block_size_nchw); + + if (size < vec_size * block_size_nchw) { + phi::ScalarGetMeanAndVarNCHW + <<>>(input, temp_mean, temp_variance, size); + } else { + phi::VectorizedGetMeanAndVarNCHW + <<>>(input, temp_mean, temp_variance, size); + } + } else { + phi::GroupNormForwardGetMeanAndVar + <<>>(input, + input_ddim[0], + C, + W, + image_size, + groups, + group_size, + temp_mean, + temp_variance); + } + GroupNormForward<<>>( + input, + temp_mean, + temp_variance, + scale, + bias, + input_ddim[0], + C, + W, + image_size, + groups, + group_size, + eps, + output, + variance, + data_layout); // for now, we only support nchw for group norm +} +template class GroupNormDirectCUDAFunctor; } // namespace phi PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/group_norm_kernel.h b/paddle/phi/kernels/group_norm_kernel.h index 36bf7125ec1..8a8812d2a17 100644 --- a/paddle/phi/kernels/group_norm_kernel.h +++ b/paddle/phi/kernels/group_norm_kernel.h @@ -16,6 +16,7 @@ #include +#include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -32,4 +33,24 @@ void GroupNormKernel(const Context& dev_ctx, DenseTensor* mean, DenseTensor* variance); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +class GroupNormDirectCUDAFunctor { + public: + void operator()(gpuStream_t stream, + const T* input, + std::vector input_shape, + const T* bias, + const T* scale, + T* temp_mean, + T* temp_variance, + int groups, + float eps, + T* output, + T* mean, + T* variance, + const DataLayout data_layout); +}; +#endif + } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py index da65c3d2198..6115ae60eff 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_group_norm.py @@ -24,6 +24,15 @@ import unittest class TrtConvertGroupNormTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if attrs[0]['epsilon'] < 0 or attrs[0]['epsilon'] > 0.001: + return False + if attrs[0]['groups'] <= 0: + return False return True def sample_program_configs(self): @@ -41,62 +50,56 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): return np.random.randn(32).astype(np.float32) for batch in [1, 2, 4]: - for group in [1, 4, 32]: - for epsilon in [0.1, 0.7]: - for data_layout in ['NCHW', 'NHWC']: - for i in [0, 1]: - dics = [{ - "epsilon": epsilon, - "groups": group, - "data_layout": data_layout - }, { - "groups": group, - "data_layout": data_layout - }] - ops_config = [{ - "op_type": "group_norm", - "op_inputs": { - "X": ["input_data"], - "Scale": ["scale_weight"], - "Bias": ["bias_weight"] - }, - "op_outputs": { - "Y": ["y_output"], - "Mean": ["mean_output"], - "Variance": ["variance_output"] - }, - "op_attrs": dics[i] - }] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={ - "scale_weight": - TensorConfig( - data_gen=partial(generate_scale)), - "bias_weight": - TensorConfig( - data_gen=partial(generate_bias)) - }, - inputs={ - "input_data": - TensorConfig(data_gen=partial( - generate_input, dics, batch)) - }, - outputs=["y_output"]) - - yield program_config + for group in [1, 4, 32, -1]: + for epsilon in [0.0001, 0.0007, -1, 1]: + for data_layout in ['NCHW']: + dics = [{ + "epsilon": epsilon, + "groups": group, + "data_layout": data_layout + }] + ops_config = [{ + "op_type": "group_norm", + "op_inputs": { + "X": ["input_data"], + "Scale": ["scale_weight"], + "Bias": ["bias_weight"] + }, + "op_outputs": { + "Y": ["y_output"], + "Mean": ["mean_output"], + "Variance": ["variance_output"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "scale_weight": + TensorConfig(data_gen=partial(generate_scale)), + "bias_weight": + TensorConfig(data_gen=partial(generate_bias)) + }, + inputs={ + "input_data": + TensorConfig(data_gen=partial( + generate_input, dics, batch)) + }, + outputs=["y_output"]) + + yield program_config def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): - self.dynamic_shape.min_input_shape = {"input_data": [1, 16, 32, 32]} + self.dynamic_shape.min_input_shape = {"input_data": [1, 16, 16, 16]} self.dynamic_shape.max_input_shape = { - "input_data": [4, 64, 128, 64] + "input_data": [4, 64, 128, 128] } - self.dynamic_shape.opt_input_shape = {"input_data": [2, 32, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 32, 64, 64]} def clear_dynamic_shape(): self.dynamic_shape.max_input_shape = {} @@ -104,13 +107,7 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - if len(attrs[0]) == 3: - if dynamic_shape: - return 1, 2 - else: - return 0, 3 - else: - return 0, 3 + return 1, 2 attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) @@ -120,31 +117,22 @@ class TrtConvertGroupNormTest(TrtLayerAutoScanTest): clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False), (1e-5, 1e-5) + attrs, False), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( - attrs, False), (1e-5, 1e-5) + attrs, False), 1e-5 # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True), (1e-5, 1e-5) + attrs, True), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True), (1e-5, 1e-5) + attrs, True), 1e-5 def add_skip_trt_case(self): - - def teller1(program_config, predictor_config): - if len(self.dynamic_shape.min_input_shape) != 0: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "The goup_norm plugin will check dim not -1 failed when dynamic fp16 mode." - ) + pass def test(self): self.add_skip_trt_case() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_group_norm_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_group_norm_op.py deleted file mode 100644 index de59753d976..00000000000 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_group_norm_op.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.core import PassVersionChecker -from paddle.fluid.core import AnalysisConfig - - -class TRTGroupNormTest(InferencePassTest): - - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data(name="data", - shape=[-1, 512, 12, 12], - dtype="float32") - out = self.append_group_norm(data) - - self.feeds = { - "data": np.random.random([1, 512, 12, 12]).astype("float32"), - } - self.enable_trt = True - self.trt_parameters = TRTGroupNormTest.TensorRTParam( - 1 << 30, 1, 1, AnalysisConfig.Precision.Float32, False, False) - self.dynamic_shape_params = TRTGroupNormTest.DynamicShapeParam( - {'data': [1, 512, 12, 12]}, {'data': [1, 512, 12, 12]}, - {'data': [1, 512, 12, 12]}, False) - self.fetch_list = [out] - - def append_group_norm(self, data): - param_attr = fluid.ParamAttr( - name='group_norm_scale', - initializer=fluid.initializer.Constant(value=1.0)) - bias_attr = fluid.ParamAttr( - name='group_norm_bias', - initializer=fluid.initializer.Constant(value=0.0)) - return fluid.layers.group_norm(data, - groups=32, - epsilon=0.000009999999747378752, - param_attr=param_attr, - bias_attr=bias_attr) - - def test_check_output(self): - if core.is_compiled_with_cuda(): - use_gpu = True - self.check_output_with_option(use_gpu) - self.assertTrue( - PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) - - -if __name__ == "__main__": - unittest.main() -- GitLab