未验证 提交 20d38664 编写于 作者: W Wilber 提交者: GitHub

fix params sync multi times problem (#45406)

上级 9ac27ac3
...@@ -368,6 +368,7 @@ void ProcessInputNode( ...@@ -368,6 +368,7 @@ void ProcessInputNode(
in_var_type == framework::proto::VarType::FP32) { in_var_type == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) return; if (WeightsShouldNotConvert(in_node)) return;
in_var->SetDataType(to_type); in_var->SetDataType(to_type);
in_var_type = to_type;
} else if (!in_var->Persistable() && IsFloatVarType(in_var_type) && } else if (!in_var->Persistable() && IsFloatVarType(in_var_type) &&
in_var_type != to_type) { in_var_type != to_type) {
AddCastOp(graph, AddCastOp(graph,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include <string>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
...@@ -113,6 +114,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { ...@@ -113,6 +114,7 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
reserve_cpu_weights = true; reserve_cpu_weights = true;
} }
std::unordered_set<std::string> visited;
for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) { for (auto *node : paddle::framework::ir::TopologySortOperations(graph)) {
if (!node->IsOp()) continue; if (!node->IsOp()) continue;
if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue; if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") continue;
...@@ -126,6 +128,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) { ...@@ -126,6 +128,8 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
} }
continue; continue;
} }
if (visited.count(var_name)) continue;
visited.insert(var_name);
auto *var = scope->FindLocalVar(var_name); auto *var = scope->FindLocalVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var, PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册