未验证 提交 f293e36c 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] subgraph support device param copy (#51876)

上级 535ddd3d
...@@ -227,19 +227,23 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) { ...@@ -227,19 +227,23 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) {
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
platform::Place xpu_place = platform::XPUPlace(argument->xpu_device_id()); platform::Place xpu_place = platform::XPUPlace(argument->xpu_device_id());
auto *scope = argument->scope_ptr(); auto *scope = argument->scope_ptr();
framework::ir::Graph &graph = argument->main_graph(); framework::ir::Graph &main_graph = argument->main_graph();
for (auto *node : graph.Nodes()) { for (size_t i = 0; i < main_graph.SubGraphsSize(); i++) {
if (!node->IsVar() || !node->Var()->Persistable()) continue; auto *graph = main_graph.GetSubGraph(i);
auto *var = scope->FindVar(node->Name()); for (auto *node : graph->Nodes()) {
if (!var->IsType<phi::DenseTensor>()) continue; if (!node->IsVar() || !node->Var()->Persistable()) continue;
auto *tensor = var->GetMutable<phi::DenseTensor>(); auto *var = scope->FindVar(node->Name());
if (!var->IsType<phi::DenseTensor>()) continue;
phi::DenseTensor temp_tensor; auto *tensor = var->GetMutable<phi::DenseTensor>();
temp_tensor.Resize(tensor->dims()); if (tensor->place().GetType() == phi::AllocationType::XPU) continue;
paddle::framework::TensorCopySync(*tensor, cpu_place, &temp_tensor);
tensor->clear(); phi::DenseTensor temp_tensor;
paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor); temp_tensor.Resize(tensor->dims());
paddle::framework::TensorCopySync(*tensor, cpu_place, &temp_tensor);
tensor->clear();
paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor);
}
} }
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册