diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc index 54017666a77d2faf34781f07690e93ed2bedba74..0eed1a4f5e71f690c2a426164667422b0628bcae 100644 --- a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -56,14 +56,10 @@ class LayerNormOpConverter : public OpConverter { nvinfer1::ILayer* layernorm_layer = nullptr; if (engine_->with_dynamic_shape()) { - int statis_num = 1; // For dynamic shape, - // the batch num will be taken into account in plugin runtime. - for (int i = 1; i < begin_norm_axis; i++) { - statis_num *= X->getDimensions().d[i]; - } - std::vector mean_shape{statis_num}; - std::vector variance_shape{statis_num}; + // the shape of mean and variance will be determine in configuPlugin. + std::vector mean_shape{1}; + std::vector variance_shape{1}; plugin::LayerNormPluginDynamic* plugin = new plugin::LayerNormPluginDynamic( static_cast(bias_weight.get().values), @@ -77,7 +73,7 @@ class LayerNormOpConverter : public OpConverter { layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); } else { int statis_num = 1; - for (int i = 0; i < begin_norm_axis; i++) { + for (int i = 1; i < begin_norm_axis; i++) { statis_num *= X->getDimensions().d[i]; } std::vector mean_shape{statis_num}; diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 48b9d3229c38c0599124134af221b0dcc41d6d1b..da4ebdc6cb6e81c6df97918fdc987db7ee05a769 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -53,6 +53,22 @@ int LayerNormPlugin::enqueue(int batch_size, int begin_norm_axis = begin_norm_axis_; float eps = eps_; + PADDLE_ENFORCE_EQ(1, + mean_shape_.size(), + platform::errors::InvalidArgument( + "Size of mean_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + PADDLE_ENFORCE_EQ(1, + variance_shape_.size(), + platform::errors::InvalidArgument( + "Size of variance_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + + int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0]; + int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0]; + std::vector input_shape; input_shape.push_back(batch_size); for (int i = 0; i < input_dims.nbDims; i++) { @@ -78,8 +94,8 @@ int LayerNormPlugin::enqueue(int batch_size, scale_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size})); - mean_t.Resize(phi::make_ddim(mean_shape_)); - variance_t.Resize(phi::make_ddim(variance_shape_)); + mean_t.Resize(phi::make_ddim({batched_mean_shape})); + variance_t.Resize(phi::make_ddim({batched_variance_shape})); int device_id; cudaGetDevice(&device_id); float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); @@ -147,6 +163,20 @@ bool LayerNormPluginDynamic::supportsFormatCombination( return in.type == prev.type && in.format == prev.format; } +void LayerNormPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT { + const auto &input_dims = in[0].desc.dims; + int statis_num = 1; + for (int i = 0; i < begin_norm_axis_; i++) { + statis_num *= input_dims.d[i]; + } + mean_shape_[0] = statis_num; + variance_shape_[0] = statis_num; +} + nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType *input_types, @@ -189,8 +219,19 @@ int LayerNormPluginDynamic::enqueue( "Size of variance_shape vector should be equal to 1," "but got Size of mean_shape vector:%d", mean_shape_.size())); - int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0]; - int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0]; + PADDLE_ENFORCE_GE(mean_shape_[0], + 0, + platform::errors::InvalidArgument( + "The size of mean vector should be positive," + "but got:%d", + mean_shape_[0])); + PADDLE_ENFORCE_GE(variance_shape_[0], + 0, + platform::errors::InvalidArgument( + "The size of mean vector should be positive," + "but got:%d", + variance_shape_[0])); + const auto input_ddim = phi::make_ddim(input_shape); auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis); int feature_size = static_cast(matrix_dim[1]); @@ -217,8 +258,8 @@ int LayerNormPluginDynamic::enqueue( float *output = static_cast(outputs[0]); scale_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size})); - mean_t.Resize(phi::make_ddim({batched_mean_shape})); - variance_t.Resize(phi::make_ddim({batched_variance_shape})); + mean_t.Resize(phi::make_ddim(mean_shape_)); + variance_t.Resize(phi::make_ddim(variance_shape_)); float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h index 64cfde8e4a76b696d6c00222e9862d2e753c4385..a8ccabb3cff597f37f49237b3043b0b9273318b2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h @@ -215,7 +215,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT override {} + int nbOutputs) TRT_NOEXCEPT override; size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs,