未验证 提交 4e3f0b95 编写于 作者: W Wang Bojun 提交者: GitHub

fix mean/variance shape infer bug during loop call of dynamic trt enqueue (#45387)

* fix bug fix
上级 73e41c89
......@@ -177,9 +177,20 @@ int LayerNormPluginDynamic::enqueue(
}
// in dynamic shape
// the batch num should be involved in mean/variance shape
mean_shape_[0] *= input_dims.d[0];
variance_shape_[0] *= input_dims.d[0];
PADDLE_ENFORCE_EQ(1,
mean_shape_.size(),
platform::errors::InvalidArgument(
"Size of mean_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
PADDLE_ENFORCE_EQ(1,
variance_shape_.size(),
platform::errors::InvalidArgument(
"Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0];
int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0];
const auto input_ddim = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis);
int feature_size = static_cast<int>(matrix_dim[1]);
......@@ -206,8 +217,8 @@ int LayerNormPluginDynamic::enqueue(
float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
mean_t.Resize(phi::make_ddim({batched_mean_shape}));
variance_t.Resize(phi::make_ddim({batched_variance_shape}));
float *scale_d =
scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册