未验证 提交 f293e36c 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] subgraph support device param copy (#51876)

上级 535ddd3d
......@@ -227,13 +227,16 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) {
platform::CPUPlace cpu_place;
platform::Place xpu_place = platform::XPUPlace(argument->xpu_device_id());
auto *scope = argument->scope_ptr();
framework::ir::Graph &graph = argument->main_graph();
framework::ir::Graph &main_graph = argument->main_graph();
for (auto *node : graph.Nodes()) {
for (size_t i = 0; i < main_graph.SubGraphsSize(); i++) {
auto *graph = main_graph.GetSubGraph(i);
for (auto *node : graph->Nodes()) {
if (!node->IsVar() || !node->Var()->Persistable()) continue;
auto *var = scope->FindVar(node->Name());
if (!var->IsType<phi::DenseTensor>()) continue;
auto *tensor = var->GetMutable<phi::DenseTensor>();
if (tensor->place().GetType() == phi::AllocationType::XPU) continue;
phi::DenseTensor temp_tensor;
temp_tensor.Resize(tensor->dims());
......@@ -241,6 +244,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) {
tensor->clear();
paddle::framework::TensorCopySync(temp_tensor, xpu_place, tensor);
}
}
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册