未验证 提交 495c1fc0 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Remove_trt_weight_map (#49159)

上级 579784e2
......@@ -697,8 +697,11 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix].reset(new phi::DenseTensor());
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
if (weight_tensor.place() == PlaceType::kGPU ||
weight_tensor.dtype() != phi::DataType::FLOAT32) {
weight_map[name_with_suffix].reset(new phi::DenseTensor());
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
}
TensorRTEngine::Weight weight;
weight.SetCount(weight_tensor.numel());
......@@ -735,10 +738,15 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
weight.SetDataType(phi::DataType::INT32);
weight.SetValues(int32_data);
} else {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_map[name_with_suffix]->data());
if (weight_tensor.place() == PlaceType::kGPU) {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_map[name_with_suffix]->data());
} else {
weight.SetDataType(weight_tensor.dtype());
weight.SetValues(weight_tensor.data());
}
}
name_suffix_counter += 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册