diff --git a/paddle/fluid/inference/analysis/passes/CMakeLists.txt b/paddle/fluid/inference/analysis/passes/CMakeLists.txt index 98334760a694fab995a9322f1b725caa7307c28d..d3ea511d8f4d8cbec1be57633391f00e29a3e6e9 100644 --- a/paddle/fluid/inference/analysis/passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/passes/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager) -cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager analysis_helper) +cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass) set(analysis_deps ${analysis_deps} diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index e42f1350525962b1b7509a9feb029f571ca05e26..8be2d3ac0b105e50fe619a720929dedaacb75537 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -19,16 +19,6 @@ #include "paddle/fluid/platform/enforce.h" namespace paddle { -namespace { -bool IsPersistable(const framework::VarDesc *var) { - if (var->Persistable() && - var->GetType() != framework::proto::VarType::FEED_MINIBATCH && - var->GetType() != framework::proto::VarType::FETCH_LIST) { - return true; - } - return false; -} -} // namespace namespace inference { namespace analysis { @@ -47,32 +37,30 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { place = platform::CUDAPlace(argument->gpu_device_id()); auto *scope = argument->scope_ptr(); - // Get the program which has been processed by several passes. - analysis_program_.reset( - new framework::ProgramDesc(argument->ir_analyzed_program())); - - const auto &global_block = analysis_program_->Block(0); + std::vector all_vars = scope->LocalVarNames(); - // sync the params from cpu to gpu. - for (auto &var : global_block.AllVars()) { - if (IsPersistable(var)) { - std::string var_name = var->Name(); - LOG(INFO) << var_name; - auto &t = inference::analysis::GetFromScope( - *scope, var_name); + // We get all the vars from local_scope instead of the ProgramDesc. + // Because there exists the case that new parameter variables are not added to + // the program in the analysis pass. + for (auto &var_name : all_vars) { + auto *var = scope->FindLocalVar(var_name); + PADDLE_ENFORCE(var != nullptr); + if (var->IsType() || + var->IsType()) { + auto *t = var->GetMutable(); platform::CPUPlace cpu_place; framework::LoDTensor temp_tensor; - temp_tensor.Resize(t.dims()); + temp_tensor.Resize(t->dims()); temp_tensor.mutable_data(cpu_place); // Copy the parameter data to a tmp tensor. - TensorCopySync(t, cpu_place, &temp_tensor); + TensorCopySync(*t, cpu_place, &temp_tensor); // Reallocation the space on GPU - t.mutable_data(place); + t->mutable_data(place); // Copy parameter data to newly allocated GPU space. - TensorCopySync(temp_tensor, place, &t); + TensorCopySync(temp_tensor, place, t); } } } diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h index 6818887b96c246f1c05962531cb639de1cf7a1b1..a95f460df6f9636fc17a5cf76920f5f459385120 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h @@ -15,10 +15,10 @@ #pragma once #include +#include #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/analysis_pass.h" -#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -32,9 +32,6 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass { public: void RunImpl(Argument *argument) override; std::string repr() const override; - - private: - std::unique_ptr analysis_program_; }; } // namespace analysis