diff --git a/paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu index 467c9737744dbf4d0aae7dd9cd54927ec3ccdaad..a9177ee2d8f6aef086645000bafac9f36be2b10e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/trans_layernorm_op_plugin.cu @@ -363,11 +363,6 @@ int TransLayerNormPluginDynamic::enqueue( "but got:%d", device_id)); - mean_t.Resize(phi::make_ddim(mean_shape_)); - variance_t.Resize(phi::make_ddim(variance_shape_)); - float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); - float *variance_d = - variance_t.mutable_data(platform::CUDAPlace(device_id)); auto input_type = input_desc[0].type; paddle::platform::DeviceContextPool &pool = @@ -376,6 +371,13 @@ int TransLayerNormPluginDynamic::enqueue( auto *device_context = static_cast(pool.Get(place)); const phi::GPUContext &dev_ctx = *device_context; + mean_t.Resize(phi::make_ddim(mean_shape_)); + variance_t.Resize(phi::make_ddim(variance_shape_)); + float *mean_d = + dev_ctx.template Alloc(&mean_t, mean_shape_[0] * sizeof(float)); + float *variance_d = dev_ctx.template Alloc( + &variance_t, variance_shape_[0] * sizeof(float)); + if (input_type == nvinfer1::DataType::kFLOAT) { VLOG(1) << "TRT Plugin DataType selected. trans_layernorm-->fp32"; VLOG(1) << "TRT Plugin format selected. trans_layernorm-->kLINEAR";