提交 3c83a2f7 编写于 作者: N nhzlx

fix comments

上级 d3e140a5
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager analysis_helper)
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass)
set(analysis_deps ${analysis_deps}
......
......@@ -19,16 +19,6 @@
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace {
bool IsPersistable(const framework::VarDesc *var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST) {
return true;
}
return false;
}
} // namespace
namespace inference {
namespace analysis {
......@@ -47,32 +37,30 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
place = platform::CUDAPlace(argument->gpu_device_id());
auto *scope = argument->scope_ptr();
// Get the program which has been processed by several passes.
analysis_program_.reset(
new framework::ProgramDesc(argument->ir_analyzed_program()));
const auto &global_block = analysis_program_->Block(0);
std::vector<std::string> all_vars = scope->LocalVarNames();
// sync the params from cpu to gpu.
for (auto &var : global_block.AllVars()) {
if (IsPersistable(var)) {
std::string var_name = var->Name();
LOG(INFO) << var_name;
auto &t = inference::analysis::GetFromScope<framework::LoDTensor>(
*scope, var_name);
// We get all the vars from local_scope instead of the ProgramDesc.
// Because there exists the case that new parameter variables are not added to
// the program in the analysis pass.
for (auto &var_name : all_vars) {
auto *var = scope->FindLocalVar(var_name);
PADDLE_ENFORCE(var != nullptr);
if (var->IsType<framework::LoDTensor>() ||
var->IsType<framework::Tensor>()) {
auto *t = var->GetMutable<framework::LoDTensor>();
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t.dims());
temp_tensor.Resize(t->dims());
temp_tensor.mutable_data<float>(cpu_place);
// Copy the parameter data to a tmp tensor.
TensorCopySync(t, cpu_place, &temp_tensor);
TensorCopySync(*t, cpu_place, &temp_tensor);
// Reallocation the space on GPU
t.mutable_data<float>(place);
t->mutable_data<float>(place);
// Copy parameter data to newly allocated GPU space.
TensorCopySync(temp_tensor, place, &t);
TensorCopySync(temp_tensor, place, t);
}
}
}
......
......@@ -15,10 +15,10 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
......@@ -32,9 +32,6 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override;
private:
std::unique_ptr<framework::ProgramDesc> analysis_program_;
};
} // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册