提交 3c83a2f7 编写于 作者: N nhzlx

fix comments

上级 d3e140a5
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager) cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager analysis_helper) cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass) cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass)
set(analysis_deps ${analysis_deps} set(analysis_deps ${analysis_deps}
......
...@@ -19,16 +19,6 @@ ...@@ -19,16 +19,6 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace {
bool IsPersistable(const framework::VarDesc *var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST) {
return true;
}
return false;
}
} // namespace
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -47,32 +37,30 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { ...@@ -47,32 +37,30 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
place = platform::CUDAPlace(argument->gpu_device_id()); place = platform::CUDAPlace(argument->gpu_device_id());
auto *scope = argument->scope_ptr(); auto *scope = argument->scope_ptr();
// Get the program which has been processed by several passes. std::vector<std::string> all_vars = scope->LocalVarNames();
analysis_program_.reset(
new framework::ProgramDesc(argument->ir_analyzed_program()));
const auto &global_block = analysis_program_->Block(0);
// sync the params from cpu to gpu. // We get all the vars from local_scope instead of the ProgramDesc.
for (auto &var : global_block.AllVars()) { // Because there exists the case that new parameter variables are not added to
if (IsPersistable(var)) { // the program in the analysis pass.
std::string var_name = var->Name(); for (auto &var_name : all_vars) {
LOG(INFO) << var_name; auto *var = scope->FindLocalVar(var_name);
auto &t = inference::analysis::GetFromScope<framework::LoDTensor>( PADDLE_ENFORCE(var != nullptr);
*scope, var_name); if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto *t = var->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place; platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor; framework::LoDTensor temp_tensor;
temp_tensor.Resize(t.dims()); temp_tensor.Resize(t->dims());
temp_tensor.mutable_data<float>(cpu_place); temp_tensor.mutable_data<float>(cpu_place);
// Copy the parameter data to a tmp tensor. // Copy the parameter data to a tmp tensor.
TensorCopySync(t, cpu_place, &temp_tensor); TensorCopySync(*t, cpu_place, &temp_tensor);
// Reallocation the space on GPU // Reallocation the space on GPU
t.mutable_data<float>(place); t->mutable_data<float>(place);
// Copy parameter data to newly allocated GPU space. // Copy parameter data to newly allocated GPU space.
TensorCopySync(temp_tensor, place, &t); TensorCopySync(temp_tensor, place, t);
} }
} }
} }
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
...@@ -32,9 +32,6 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass { ...@@ -32,9 +32,6 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
public: public:
void RunImpl(Argument *argument) override; void RunImpl(Argument *argument) override;
std::string repr() const override; std::string repr() const override;
private:
std::unique_ptr<framework::ProgramDesc> analysis_program_;
}; };
} // namespace analysis } // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册