未验证 提交 4cdeab7b 编写于 作者: Y Yuanle Liu 提交者: GitHub

fix get trt weight (#49197)

上级 7f0eb2e3
......@@ -582,8 +582,6 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
TensorRTEngine::Weight weight;
weight.SetCount(weight_tensor.numel());
weight.SetDataType(nvinfer1::DataType::kHALF);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to fp16.
if (weight_tensor.dtype() == phi::DataType::BFLOAT16) {
......@@ -593,13 +591,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
weight_tensor, platform::CPUPlace(), &bf16_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT16);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp16_data = weight_map[name_with_suffix]->mutable_data<float16>(
platform::CPUPlace());
auto *bf16_data = bf16_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp16_data[i] = static_cast<float16>(bf16_data[i]);
}
weight.SetDataType(phi::DataType::FLOAT16);
weight.SetValues(fp16_data);
} else if (weight_tensor.dtype() == phi::DataType::FLOAT32) {
phi::DenseTensor fp32_tensor;
fp32_tensor.clear();
......@@ -607,18 +606,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
weight_tensor, platform::CPUPlace(), &fp32_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT16);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp16_data = weight_map[name_with_suffix]->mutable_data<float16>(
platform::CPUPlace());
auto *fp32_data = fp32_tensor.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp16_data[i] = static_cast<float16>(fp32_data[i]);
}
weight.SetDataType(phi::DataType::FLOAT16);
weight.SetValues(fp16_data);
} else if (weight_tensor.dtype() == phi::DataType::INT64) {
phi::DenseTensor int64_tensor;
int64_tensor.clear();
paddle::framework::TensorCopySync(
weight_tensor, platform::CPUPlace(), &int64_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::INT32);
auto *int32_data = weight_map[name_with_suffix]->mutable_data<int32_t>(
platform::CPUPlace());
auto *int64_data = int64_tensor.mutable_data<int64_t>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
int32_data[i] = int64_data[i];
}
weight.SetDataType(phi::DataType::INT32);
weight.SetValues(int32_data);
} else {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
}
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_map[name_with_suffix]->data());
}
name_suffix_counter += 1;
return weight;
}
......@@ -642,8 +658,6 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
TensorRTEngine::Weight weight;
weight.SetCount(weight_tensor.numel());
weight.SetDataType(nvinfer1::DataType::kFLOAT);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to fp32.
if (weight_tensor.dtype() == phi::DataType::BFLOAT16) {
......@@ -653,13 +667,14 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
weight_tensor, platform::CPUPlace(), &bf16_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT32);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp32_data =
weight_map[name_with_suffix]->mutable_data<float>(platform::CPUPlace());
auto *bf16_data = bf16_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp32_data[i] = static_cast<float>(bf16_data[i]);
}
weight.SetDataType(phi::DataType::FLOAT32);
weight.SetValues(fp32_data);
} else if (weight_tensor.dtype() == phi::DataType::FLOAT16) {
phi::DenseTensor fp16_tensor;
fp16_tensor.clear();
......@@ -667,18 +682,35 @@ TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
weight_tensor, platform::CPUPlace(), &fp16_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT32);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp32_data =
weight_map[name_with_suffix]->mutable_data<float>(platform::CPUPlace());
auto *fp16_data = fp16_tensor.mutable_data<float16>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp32_data[i] = static_cast<float>(fp16_data[i]);
}
weight.SetDataType(phi::DataType::FLOAT32);
weight.SetValues(fp32_data);
} else if (weight_tensor.dtype() == phi::DataType::INT64) {
phi::DenseTensor int64_tensor;
int64_tensor.clear();
paddle::framework::TensorCopySync(
weight_tensor, platform::CPUPlace(), &int64_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::INT32);
auto *int32_data = weight_map[name_with_suffix]->mutable_data<int32_t>(
platform::CPUPlace());
auto *int64_data = int64_tensor.mutable_data<int64_t>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
int32_data[i] = int64_data[i];
}
weight.SetDataType(phi::DataType::INT32);
weight.SetValues(int32_data);
} else {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
}
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_map[name_with_suffix]->data());
}
name_suffix_counter += 1;
return weight;
}
......@@ -729,8 +761,8 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
weight_tensor, platform::CPUPlace(), &int64_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::INT32);
auto *int32_data =
weight_map[name_with_suffix]->mutable_data<int>(platform::CPUPlace());
auto *int32_data = weight_map[name_with_suffix]->mutable_data<int32_t>(
platform::CPUPlace());
auto *int64_data = int64_tensor.mutable_data<int64_t>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
int32_data[i] = int64_data[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册