From 495c1fc067389c7eff533cdadfca15b74e343fc8 Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Tue, 20 Dec 2022 11:10:00 +0800 Subject: [PATCH] [Paddle Inference] Remove_trt_weight_map (#49159) --- paddle/fluid/inference/tensorrt/engine.cc | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 301136d353..15ed7261bf 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -697,8 +697,11 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( "twice in TRT OP converter.", name_with_suffix)); - weight_map[name_with_suffix].reset(new phi::DenseTensor()); - weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + if (weight_tensor.place() == PlaceType::kGPU || + weight_tensor.dtype() != phi::DataType::FLOAT32) { + weight_map[name_with_suffix].reset(new phi::DenseTensor()); + weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + } TensorRTEngine::Weight weight; weight.SetCount(weight_tensor.numel()); @@ -735,10 +738,15 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( weight.SetDataType(phi::DataType::INT32); weight.SetValues(int32_data); } else { - paddle::framework::TensorCopySync( - weight_tensor, cpu_place, weight_map[name_with_suffix].get()); - weight.SetDataType(weight_tensor.dtype()); - weight.SetValues(weight_map[name_with_suffix]->data()); + if (weight_tensor.place() == PlaceType::kGPU) { + paddle::framework::TensorCopySync( + weight_tensor, cpu_place, weight_map[name_with_suffix].get()); + weight.SetDataType(weight_tensor.dtype()); + weight.SetValues(weight_map[name_with_suffix]->data()); + } else { + weight.SetDataType(weight_tensor.dtype()); + weight.SetValues(weight_tensor.data()); + } } name_suffix_counter += 1; -- GitLab