未验证 提交 495c1fc0 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Remove_trt_weight_map (#49159)

上级 579784e2
...@@ -697,8 +697,11 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( ...@@ -697,8 +697,11 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
"twice in TRT OP converter.", "twice in TRT OP converter.",
name_with_suffix)); name_with_suffix));
weight_map[name_with_suffix].reset(new phi::DenseTensor()); if (weight_tensor.place() == PlaceType::kGPU ||
weight_map[name_with_suffix]->Resize(weight_tensor.dims()); weight_tensor.dtype() != phi::DataType::FLOAT32) {
weight_map[name_with_suffix].reset(new phi::DenseTensor());
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
}
TensorRTEngine::Weight weight; TensorRTEngine::Weight weight;
weight.SetCount(weight_tensor.numel()); weight.SetCount(weight_tensor.numel());
...@@ -735,10 +738,15 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( ...@@ -735,10 +738,15 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
weight.SetDataType(phi::DataType::INT32); weight.SetDataType(phi::DataType::INT32);
weight.SetValues(int32_data); weight.SetValues(int32_data);
} else { } else {
paddle::framework::TensorCopySync( if (weight_tensor.place() == PlaceType::kGPU) {
weight_tensor, cpu_place, weight_map[name_with_suffix].get()); paddle::framework::TensorCopySync(
weight.SetDataType(weight_tensor.dtype()); weight_tensor, cpu_place, weight_map[name_with_suffix].get());
weight.SetValues(weight_map[name_with_suffix]->data()); weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_map[name_with_suffix]->data());
} else {
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_tensor.data());
}
} }
name_suffix_counter += 1; name_suffix_counter += 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册