未验证 提交 c47f11f5 编写于 作者: W Wang Bojun 提交者: GitHub

fix mutable_data() (#50396)

上级 996c2b70
......@@ -363,11 +363,6 @@ int TransLayerNormPluginDynamic::enqueue(
"but got:%d",
device_id));
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
auto input_type = input_desc[0].type;
paddle::platform::DeviceContextPool &pool =
......@@ -376,6 +371,13 @@ int TransLayerNormPluginDynamic::enqueue(
auto *device_context = static_cast<phi::GPUContext *>(pool.Get(place));
const phi::GPUContext &dev_ctx = *device_context;
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
float *mean_d =
dev_ctx.template Alloc<float>(&mean_t, mean_shape_[0] * sizeof(float));
float *variance_d = dev_ctx.template Alloc<float>(
&variance_t, variance_shape_[0] * sizeof(float));
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. trans_layernorm-->fp32";
VLOG(1) << "TRT Plugin format selected. trans_layernorm-->kLINEAR";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册