From 4e3f0b9560938753d266a1185589532e2b3b72cf Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Wed, 24 Aug 2022 23:28:09 +0800 Subject: [PATCH] fix mean/variance shape infer bug during loop call of dynamic trt enqueue (#45387) * fix bug fix --- .../tensorrt/plugin/layer_norm_op_plugin.cu | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 9b0fa45c337..48b9d3229c3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -177,9 +177,20 @@ int LayerNormPluginDynamic::enqueue( } // in dynamic shape // the batch num should be involved in mean/variance shape - mean_shape_[0] *= input_dims.d[0]; - variance_shape_[0] *= input_dims.d[0]; - + PADDLE_ENFORCE_EQ(1, + mean_shape_.size(), + platform::errors::InvalidArgument( + "Size of mean_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + PADDLE_ENFORCE_EQ(1, + variance_shape_.size(), + platform::errors::InvalidArgument( + "Size of variance_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0]; + int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0]; const auto input_ddim = phi::make_ddim(input_shape); auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis); int feature_size = static_cast(matrix_dim[1]); @@ -206,8 +217,8 @@ int LayerNormPluginDynamic::enqueue( float *output = static_cast(outputs[0]); scale_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size})); - mean_t.Resize(phi::make_ddim(mean_shape_)); - variance_t.Resize(phi::make_ddim(variance_shape_)); + mean_t.Resize(phi::make_ddim({batched_mean_shape})); + variance_t.Resize(phi::make_ddim({batched_variance_shape})); float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); -- GitLab