From 4cdeab7b5e0cd177a0b7175a0640f0a84714acd6 Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Wed, 21 Dec 2022 14:29:53 +0800 Subject: [PATCH] fix get trt weight (#49197) --- paddle/fluid/inference/tensorrt/engine.cc | 60 +++++++++++++++++------ 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 15ed7261bf..f480f791f9 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -582,10 +582,8 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight( TensorRTEngine::Weight weight; weight.SetCount(weight_tensor.numel()); - weight.SetDataType(nvinfer1::DataType::kHALF); - // weight_tensor.dims().; - // if trt not support dtype, we need to cast to fp16. + // if trt not support dtype, we need to cast to fp16. if (weight_tensor.dtype() == phi::DataType::BFLOAT16) { phi::DenseTensor bf16_tensor; bf16_tensor.clear(); @@ -593,13 +591,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight( weight_tensor, platform::CPUPlace(), &bf16_tensor); weight_map[name_with_suffix]->set_type( paddle::experimental::DataType::FLOAT16); - weight_map[name_with_suffix]->Resize(weight_tensor.dims()); auto *fp16_data = weight_map[name_with_suffix]->mutable_data( platform::CPUPlace()); auto *bf16_data = bf16_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor.numel(); i++) { fp16_data[i] = static_cast(bf16_data[i]); } + weight.SetDataType(phi::DataType::FLOAT16); + weight.SetValues(fp16_data); } else if (weight_tensor.dtype() == phi::DataType::FLOAT32) { phi::DenseTensor fp32_tensor; fp32_tensor.clear(); @@ -607,18 +606,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight( weight_tensor, platform::CPUPlace(), &fp32_tensor); weight_map[name_with_suffix]->set_type( paddle::experimental::DataType::FLOAT16); - weight_map[name_with_suffix]->Resize(weight_tensor.dims()); auto *fp16_data = weight_map[name_with_suffix]->mutable_data( platform::CPUPlace()); auto *fp32_data = fp32_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor.numel(); i++) { fp16_data[i] = static_cast(fp32_data[i]); } + weight.SetDataType(phi::DataType::FLOAT16); + weight.SetValues(fp16_data); + } else if (weight_tensor.dtype() == phi::DataType::INT64) { + phi::DenseTensor int64_tensor; + int64_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &int64_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::INT32); + auto *int32_data = weight_map[name_with_suffix]->mutable_data( + platform::CPUPlace()); + auto *int64_data = int64_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + int32_data[i] = int64_data[i]; + } + weight.SetDataType(phi::DataType::INT32); + weight.SetValues(int32_data); } else { paddle::framework::TensorCopySync( weight_tensor, cpu_place, weight_map[name_with_suffix].get()); + weight.SetDataType(weight_tensor.dtype()); + weight.SetValues(weight_map[name_with_suffix]->data()); } - weight.SetValues(weight_map[name_with_suffix]->data()); name_suffix_counter += 1; return weight; } @@ -642,10 +658,8 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight( TensorRTEngine::Weight weight; weight.SetCount(weight_tensor.numel()); - weight.SetDataType(nvinfer1::DataType::kFLOAT); - // weight_tensor.dims().; - // if trt not support dtype, we need to cast to fp32. + // if trt not support dtype, we need to cast to fp32. if (weight_tensor.dtype() == phi::DataType::BFLOAT16) { phi::DenseTensor bf16_tensor; bf16_tensor.clear(); @@ -653,13 +667,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight( weight_tensor, platform::CPUPlace(), &bf16_tensor); weight_map[name_with_suffix]->set_type( paddle::experimental::DataType::FLOAT32); - weight_map[name_with_suffix]->Resize(weight_tensor.dims()); auto *fp32_data = weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); auto *bf16_data = bf16_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor.numel(); i++) { fp32_data[i] = static_cast(bf16_data[i]); } + weight.SetDataType(phi::DataType::FLOAT32); + weight.SetValues(fp32_data); } else if (weight_tensor.dtype() == phi::DataType::FLOAT16) { phi::DenseTensor fp16_tensor; fp16_tensor.clear(); @@ -667,18 +682,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight( weight_tensor, platform::CPUPlace(), &fp16_tensor); weight_map[name_with_suffix]->set_type( paddle::experimental::DataType::FLOAT32); - weight_map[name_with_suffix]->Resize(weight_tensor.dims()); auto *fp32_data = weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); auto *fp16_data = fp16_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor.numel(); i++) { fp32_data[i] = static_cast(fp16_data[i]); } + weight.SetDataType(phi::DataType::FLOAT32); + weight.SetValues(fp32_data); + } else if (weight_tensor.dtype() == phi::DataType::INT64) { + phi::DenseTensor int64_tensor; + int64_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &int64_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::INT32); + auto *int32_data = weight_map[name_with_suffix]->mutable_data( + platform::CPUPlace()); + auto *int64_data = int64_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + int32_data[i] = int64_data[i]; + } + weight.SetDataType(phi::DataType::INT32); + weight.SetValues(int32_data); } else { paddle::framework::TensorCopySync( weight_tensor, cpu_place, weight_map[name_with_suffix].get()); + weight.SetDataType(weight_tensor.dtype()); + weight.SetValues(weight_map[name_with_suffix]->data()); } - weight.SetValues(weight_map[name_with_suffix]->data()); name_suffix_counter += 1; return weight; } @@ -729,8 +761,8 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( weight_tensor, platform::CPUPlace(), &int64_tensor); weight_map[name_with_suffix]->set_type( paddle::experimental::DataType::INT32); - auto *int32_data = - weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); + auto *int32_data = weight_map[name_with_suffix]->mutable_data( + platform::CPUPlace()); auto *int64_data = int64_tensor.mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor.numel(); i++) { int32_data[i] = int64_data[i]; -- GitLab