diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 9b0fa45c337733028367bfb1b155b769a5702029..48b9d3229c38c0599124134af221b0dcc41d6d1b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -177,9 +177,20 @@ int LayerNormPluginDynamic::enqueue( } // in dynamic shape // the batch num should be involved in mean/variance shape - mean_shape_[0] *= input_dims.d[0]; - variance_shape_[0] *= input_dims.d[0]; - + PADDLE_ENFORCE_EQ(1, + mean_shape_.size(), + platform::errors::InvalidArgument( + "Size of mean_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + PADDLE_ENFORCE_EQ(1, + variance_shape_.size(), + platform::errors::InvalidArgument( + "Size of variance_shape vector should be equal to 1," + "but got Size of mean_shape vector:%d", + mean_shape_.size())); + int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0]; + int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0]; const auto input_ddim = phi::make_ddim(input_shape); auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis); int feature_size = static_cast(matrix_dim[1]); @@ -206,8 +217,8 @@ int LayerNormPluginDynamic::enqueue( float *output = static_cast(outputs[0]); scale_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size})); - mean_t.Resize(phi::make_ddim(mean_shape_)); - variance_t.Resize(phi::make_ddim(variance_shape_)); + mean_t.Resize(phi::make_ddim({batched_mean_shape})); + variance_t.Resize(phi::make_ddim({batched_variance_shape})); float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id));