未验证 提交 d3003a16 编写于 作者: Z Zeng Jinle 提交者: GitHub

Feature/buffer_shared_inplace (#17911)

* feature/buffer_shared_inplace, test=develop

* refine code, test=develop

* fix elementwise_add op cpu inplace and sum inplace bug, test=develop

* add unittest and debug log, test=develop

* fix parallel_executor scope bug, polish code, test=develop

* fix sum op, activation op, single_in_place_inference bug, test=develop

* remove kLocalExecScopeName, test=develop

* fix unittest,test=develop

* fix out_var first version bug, test=develop

* follow comments,test=develop
上级 1c10dac4
...@@ -59,7 +59,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d ...@@ -59,7 +59,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass) cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass)
if (WITH_GPU) if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif() endif()
......
...@@ -99,10 +99,9 @@ void AllReduceOpHandle::RunImpl() { ...@@ -99,10 +99,9 @@ void AllReduceOpHandle::RunImpl() {
std::vector<const LoDTensor *> lod_tensors; std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i]; auto &local_scope = local_exec_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor = auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>(); local_scope->FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor); lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name() VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name(); << ", out_name:" << out_var_handles[i]->name();
...@@ -140,9 +139,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -140,9 +139,7 @@ void AllReduceOpHandle::RunImpl() {
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
} else { // Special handle CPU only Operator's gradient. Like CRF } else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0] auto &trg = *this->local_exec_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(out_var_handles[0]->name()) ->FindVar(out_var_handles[0]->name())
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
...@@ -151,10 +148,9 @@ void AllReduceOpHandle::RunImpl() { ...@@ -151,10 +148,9 @@ void AllReduceOpHandle::RunImpl() {
VisitDataType(lod_tensors[0]->type(), func); VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope = auto &scope = local_exec_scopes_[i];
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name()); auto *var = scope->FindVar(out_var_handles[i]->name());
auto *dev_ctx = dev_ctxes_.at(p); auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
......
...@@ -49,6 +49,9 @@ class AllReduceOpHandle : public OpHandleBase { ...@@ -49,6 +49,9 @@ class AllReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32)) #if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
......
...@@ -24,22 +24,20 @@ namespace paddle { ...@@ -24,22 +24,20 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos, inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
Scope *scope) { Scope *local_scope) {
VLOG(3) << "NewTempScopeAndInitVars"; VLOG(3) << "InitVarsInScope";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) { for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable if (info.persistable_) { // Persistable
auto *var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2) << info.name_
<< " has been initialized beforehand in global scope, skipped";
continue;
}
InitializeVariable(scope->Var(info.name_), info.type_); InitializeVariable(scope->Var(info.name_), info.type_);
} else { } else {
InitializeVariable(local_scope.Var(info.name_), info.type_); InitializeVariable(local_scope->Var(info.name_), info.type_);
} }
} }
} }
...@@ -101,14 +99,17 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -101,14 +99,17 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs) const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
local_exec_scopes_(local_exec_scopes),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graphs_(std::move(graphs)) { graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size());
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
...@@ -118,7 +119,8 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -118,7 +119,8 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); strategy_, {local_scopes_[i]}, {local_exec_scopes_[i]}, {places_[i]},
graphs_[i]));
} }
for (auto &node : graphs_[0]->Nodes()) { for (auto &node : graphs_[0]->Nodes()) {
...@@ -129,8 +131,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -129,8 +131,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
var_infos_.back().persistable_ = node->Var()->Persistable(); var_infos_.back().persistable_ = node->Var()->Persistable();
} }
} }
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope); for (size_t i = 0; i < local_scopes_.size(); ++i) {
InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]);
} }
ProcessGraph(graphs_, local_scopes_[0]); ProcessGraph(graphs_, local_scopes_[0]);
} }
......
...@@ -36,6 +36,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -36,6 +36,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public: public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs); std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default; ~AsyncSSAGraphExecutor() final = default;
...@@ -50,6 +51,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -50,6 +51,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
private: private:
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_; std::vector<ir::Graph *> graphs_;
......
...@@ -40,18 +40,13 @@ void BroadcastOpHandle::RunImpl() { ...@@ -40,18 +40,13 @@ void BroadcastOpHandle::RunImpl() {
WaitInputVarGenerated(); WaitInputVarGenerated();
std::vector<const Scope *> var_scopes; BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
} }
void BroadcastOpHandle::BroadcastOneVar( void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle, const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles, const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) { const std::vector<Scope *> &var_scopes) {
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
...@@ -140,10 +135,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -140,10 +135,7 @@ void BroadcastOpHandle::BroadcastOneVar(
void BroadcastOpHandle::InitOutputValue( void BroadcastOpHandle::InitOutputValue(
const VarHandle &in_var_handle, const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const { const std::vector<VarHandle *> &out_var_handles) const {
std::vector<const Scope *> var_scopes; auto &var_scopes = local_exec_scopes_;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
......
...@@ -62,9 +62,11 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -62,9 +62,11 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
void BroadcastOneVar(const VarHandle &in_var_handle, void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles, const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes); const std::vector<Scope *> &var_scopes);
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -92,14 +94,13 @@ struct TestBroadcastOpHandle { ...@@ -92,14 +94,13 @@ struct TestBroadcastOpHandle {
void InitBroadcastOp(size_t input_scope_idx) { void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear(); nodes_.clear();
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t j = 0; j < place_list_.size(); ++j) { for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out"); local_scope.Var("out");
param_scopes_.emplace_back(&local_scope); param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
} }
param_scopes_[input_scope_idx]->Var("input"); param_scopes_[input_scope_idx]->Var("input");
...@@ -122,6 +123,8 @@ struct TestBroadcastOpHandle { ...@@ -122,6 +123,8 @@ struct TestBroadcastOpHandle {
#endif #endif
} }
op_handle_->SetLocalExecScopes(scope_map);
nodes_.emplace_back( nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable)); ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx, auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
......
...@@ -92,16 +92,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -92,16 +92,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
} }
// NOTE(dzhwinter): A note for automatical inplace. // TODO(zjl): refactor MemoryOptimizePass to fit
// 1. modify program desc passes should put // new strategy, which does not need to set
// before inplace pass. // var.persistable = True
// 2. manually configured inplace should put if (strategy_.use_legacy_memory_optimize_strategy_) {
// before inplace_pass if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass";
// Add automatically inplace. AppendPass("inplace_pass");
if (strategy_.enable_inplace_) { }
VLOG(1) << "Add inplace_pass";
AppendPass("inplace_pass");
} }
if (strategy_.fuse_elewise_add_act_ops_) { if (strategy_.fuse_elewise_add_act_ops_) {
...@@ -160,9 +158,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -160,9 +158,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless. // the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars // A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface. // , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) { if (strategy_.use_legacy_memory_optimize_strategy_) {
VLOG(1) << "Add memory_optimize_pass"; if (strategy_.memory_optimize_) {
AppendPass("memory_optimize_pass"); VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
} }
// runtime_context_cache pass should be the last pass to enable the attr of // runtime_context_cache pass should be the last pass to enable the attr of
......
...@@ -114,7 +114,12 @@ struct BuildStrategy { ...@@ -114,7 +114,12 @@ struct BuildStrategy {
// it is not appropriate, because kStaleProgramOpDescs will be removed in the // it is not appropriate, because kStaleProgramOpDescs will be removed in the
// near future. // near future.
bool memory_optimize_{false}; bool memory_optimize_{false};
bool enable_inplace_{false};
// Turn on inplace by default.
bool enable_inplace_{true};
// TODO(zjl): Remove this flag when MemoryOptimizePass is refactored
bool use_legacy_memory_optimize_strategy_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if // num_trainers is 1, so the current fields of build_strategy doesn't tell if
......
...@@ -31,9 +31,7 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -31,9 +31,7 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
auto run_func = [this]() { auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) { if (is_lock_and_record_event_free_) {
run_func(); run_func();
......
...@@ -38,6 +38,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -38,6 +38,8 @@ class ComputationOpHandle : public OpHandleBase {
const Scope *GetScope() const { return scope_; } const Scope *GetScope() const { return scope_; }
Scope *GetScope() { return scope_; }
const platform::Place &GetPlace() const { return place_; } const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
...@@ -49,6 +51,8 @@ class ComputationOpHandle : public OpHandleBase { ...@@ -49,6 +51,8 @@ class ComputationOpHandle : public OpHandleBase {
bool NeedWait(VarHandleBase *in_var) override; bool NeedWait(VarHandleBase *in_var) override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
Scope *scope_; Scope *scope_;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" #include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
...@@ -30,14 +31,13 @@ namespace framework { ...@@ -30,14 +31,13 @@ namespace framework {
namespace details { namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle( EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place, ir::Node *node, Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc, const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
ir::AtomicReferenceCountMap *ref_cnts)
: OpHandleBase(node), : OpHandleBase(node),
scope_(scope), scope_(scope),
var_names_(var_names.begin(), var_names.end()), place_(place),
gc_(gc), var_infos_(vars.begin(), vars.end()),
ref_cnts_(ref_cnts) { gc_(gc) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>( dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
...@@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( ...@@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
} }
} }
#endif #endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty"); PADDLE_ENFORCE(!vars.empty(), "Var names cannot be empty");
for (auto *var : var_infos_) {
PADDLE_ENFORCE_NOT_NULL(var);
}
} }
EagerDeletionOpHandle::~EagerDeletionOpHandle() { EagerDeletionOpHandle::~EagerDeletionOpHandle() {
...@@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { ...@@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#endif #endif
} }
void EagerDeletionOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
int dev_id =
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
events_[dev_id] = nullptr;
#endif
}
void EagerDeletionOpHandle::CallOnce() {
PADDLE_ENFORCE(vars_.empty(), "vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (auto *var_info : var_infos_) {
auto *var = exec_scope->FindVar(var_info->Name());
PADDLE_ENFORCE_NOT_NULL(var, "Variable %s should not be nullptr",
var_info->Name());
vars_.emplace_back(var);
}
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() { void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size()) {
CallOnce();
}
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages; std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) { for (size_t i = 0; i < var_infos_.size(); ++i) {
auto it = ref_cnts_->find(name); auto *var_info = var_infos_[i];
// Reference count has not decreased to 0 if (var_info->IsSkipped() || !var_info->DecreaseRefCnt()) {
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
continue; continue;
} }
if (!exec_scope) { VLOG(2) << "Erase variable " << var_info->Name() << " on " << place_;
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << name; Variable *var = vars_[i];
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder()); garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
...@@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() { ...@@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() {
} }
} else { } else {
PADDLE_THROW("Type %s of %s is not supported eager deletion", PADDLE_THROW("Type %s of %s is not supported eager deletion",
framework::ToTypeName(var->Type()), name); framework::ToTypeName(var->Type()), var_info->Name());
} }
} }
......
...@@ -26,15 +26,18 @@ namespace paddle { ...@@ -26,15 +26,18 @@ namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details { namespace details {
class EagerDeletionOpHandle : public OpHandleBase { class EagerDeletionOpHandle : public OpHandleBase {
public: public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope, EagerDeletionOpHandle(ir::Node *node, Scope *scope,
const platform::Place &place, const platform::Place &place,
const std::unordered_set<std::string> &var_names, const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc, GarbageCollector *gc);
ir::AtomicReferenceCountMap *ref_cnts);
~EagerDeletionOpHandle(); ~EagerDeletionOpHandle();
...@@ -50,13 +53,20 @@ class EagerDeletionOpHandle : public OpHandleBase { ...@@ -50,13 +53,20 @@ class EagerDeletionOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private: private:
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages); void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_; void CallOnce();
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own Scope *scope_;
ir::AtomicReferenceCountMap *ref_cnts_; // not own platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own
std::vector<Variable *> vars_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr}; platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr}; cudaEvent_t event_{nullptr};
......
...@@ -28,9 +28,11 @@ namespace details { ...@@ -28,9 +28,11 @@ namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(strategy), : strategy_(strategy),
local_scopes_(local_scopes), local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places), places_(places),
graph_(graph), graph_(graph),
fetch_ctxs_(places), fetch_ctxs_(places),
...@@ -143,7 +145,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -143,7 +145,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_); auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
......
...@@ -33,6 +33,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -33,6 +33,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
...@@ -43,6 +44,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -43,6 +44,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
// be destroyed first. // be destroyed first.
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
ir::Graph *graph_; ir::Graph *graph_;
......
...@@ -42,9 +42,7 @@ bool FetchBarrierOpHandle::IsMultiDeviceTransfer() { ...@@ -42,9 +42,7 @@ bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
void FetchBarrierOpHandle::RunImpl() { void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
auto run_func = [this]() { auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) { if (is_lock_and_record_event_free_) {
run_func(); run_func();
......
...@@ -44,6 +44,8 @@ struct FetchBarrierOpHandle : public OpHandleBase { ...@@ -44,6 +44,8 @@ struct FetchBarrierOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
bool NeedWait(VarHandleBase *in_var) override; bool NeedWait(VarHandleBase *in_var) override;
private: private:
......
...@@ -22,11 +22,13 @@ namespace framework { ...@@ -22,11 +22,13 @@ namespace framework {
namespace details { namespace details {
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes) std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes)
: OpHandleBase(node), : OpHandleBase(node),
data_(data), data_(data),
offset_(offset), offset_(offset),
local_scopes_(local_scopes) {} local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes) {}
FetchOpHandle::~FetchOpHandle() {} FetchOpHandle::~FetchOpHandle() {}
...@@ -49,14 +51,12 @@ void FetchOpHandle::RunImpl() { ...@@ -49,14 +51,12 @@ void FetchOpHandle::RunImpl() {
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
platform::CPUPlace cpu; platform::CPUPlace cpu;
auto &scopes = *local_scopes_; auto &scopes = *local_exec_scopes_;
for (size_t i = 0; i < inputs_.size(); ++i) { for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]); auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx()); auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(kLocalExecScopeName) auto *var = scope->FindVar(var_handle->name());
->Get<Scope *>()
->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_handle->name()); var_handle->name());
......
...@@ -29,7 +29,8 @@ namespace details { ...@@ -29,7 +29,8 @@ namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
public: public:
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes); std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes);
~FetchOpHandle(); ~FetchOpHandle();
...@@ -44,12 +45,15 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -44,12 +45,15 @@ struct FetchOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void WaitInputVarGenerated(const platform::Place &place) override; void WaitInputVarGenerated(const platform::Place &place) override;
private: private:
FeedFetchList *data_; FeedFetchList *data_;
size_t offset_; size_t offset_;
std::vector<Scope *> *local_scopes_; std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
std::vector<LoDTensor> tensors_; std::vector<LoDTensor> tensors_;
}; };
......
...@@ -185,9 +185,7 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -185,9 +185,7 @@ void FusedAllReduceOpHandle::RunImpl() {
} else { } else {
// Special handle CPU only Operator's gradient. Like CRF // Special handle CPU only Operator's gradient. Like CRF
auto grad_name = grads_tensor.at(0).at(0).first; auto grad_name = grads_tensor.at(0).at(0).first;
auto &trg = *this->local_scopes_[0] auto &trg = *this->local_exec_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(grad_name) ->FindVar(grad_name)
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
...@@ -195,9 +193,8 @@ void FusedAllReduceOpHandle::RunImpl() { ...@@ -195,9 +193,8 @@ void FusedAllReduceOpHandle::RunImpl() {
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel); ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func); VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) { for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
auto &scope = auto &scope = *local_exec_scopes_[i];
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(grad_name); auto *var = scope.FindVar(grad_name);
auto *dev_ctx = dev_ctxes_.at(p); auto *dev_ctx = dev_ctxes_.at(p);
...@@ -215,8 +212,7 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( ...@@ -215,8 +212,7 @@ void FusedAllReduceOpHandle::GetGradLoDTensor(
const size_t &scope_idx, const std::vector<VarHandle *> &in_var_handles, const size_t &scope_idx, const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles, const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const { std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const {
auto *local_scope = auto *local_scope = local_exec_scopes_[scope_idx];
local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get<Scope *>();
size_t place_num = places_.size(); size_t place_num = places_.size();
for (size_t j = 0; j < in_var_handles.size(); j += place_num) { for (size_t j = 0; j < in_var_handles.size(); j += place_num) {
......
...@@ -52,6 +52,8 @@ struct FusedAllReduceOpHandle : public OpHandleBase { ...@@ -52,6 +52,8 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private: private:
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32)) #if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
......
...@@ -31,11 +31,6 @@ void FusedBroadcastOpHandle::RunImpl() { ...@@ -31,11 +31,6 @@ void FusedBroadcastOpHandle::RunImpl() {
WaitInputVarGenerated(); WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size(); size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size()); PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
...@@ -44,7 +39,7 @@ void FusedBroadcastOpHandle::RunImpl() { ...@@ -44,7 +39,7 @@ void FusedBroadcastOpHandle::RunImpl() {
*in_var_handles[i], *in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num, std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num), out_var_handles.begin() + (i + 1) * place_num),
var_scopes); local_exec_scopes_);
} }
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/details/broadcast_op_handle_test.h" #include "paddle/fluid/framework/details/broadcast_op_handle_test.h"
...@@ -27,17 +29,16 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -27,17 +29,16 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) { void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear(); nodes_.clear();
// initialize scope and var // initialize scope and var
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t i = 0; i < place_list_.size(); ++i) { for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
for (size_t j = 0; j < input_scope_idxes.size(); ++j) { for (size_t j = 0; j < input_scope_idxes.size(); ++j) {
local_scope.Var("out_var" + std::to_string(j)); local_scope.Var("out_var" + std::to_string(j));
if (i == j) local_scope.Var("in_var" + std::to_string(j)); if (i == j) local_scope.Var("in_var" + std::to_string(j));
} }
param_scopes_.emplace_back(&local_scope); param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
} }
// create op handle node // create op handle node
...@@ -60,6 +61,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { ...@@ -60,6 +61,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
#endif #endif
} }
op_handle_->SetLocalExecScopes(scope_map);
for (size_t i = 0; i < input_scope_idxes.size(); ++i) { for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle // add input var handle
nodes_.emplace_back(ir::CreateNodeForTest("in_node" + std::to_string(i), nodes_.emplace_back(ir::CreateNodeForTest("in_node" + std::to_string(i),
......
...@@ -42,10 +42,7 @@ void GatherOpHandle::RunImpl() { ...@@ -42,10 +42,7 @@ void GatherOpHandle::RunImpl() {
out_var_handle = out_var_handles.front(); out_var_handle = out_var_handles.front();
} }
std::vector<const Scope *> var_scopes; auto &var_scopes = local_exec_scopes_;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto in_0_handle = in_var_handles[0]; auto in_0_handle = in_var_handles[0];
auto pre_in_var = auto pre_in_var =
......
...@@ -40,6 +40,8 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -40,6 +40,8 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private: private:
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/gather_op_handle.h" #include "paddle/fluid/framework/details/gather_op_handle.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -72,14 +74,13 @@ struct TestGatherOpHandle { ...@@ -72,14 +74,13 @@ struct TestGatherOpHandle {
void InitGatherOp(size_t input_scope_idx) { void InitGatherOp(size_t input_scope_idx) {
nodes_.clear(); nodes_.clear();
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope(); Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("input"); local_scope.Var("input");
param_scopes_.emplace_back(&local_scope); param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
} }
param_scopes_[input_scope_idx]->Var("out"); param_scopes_[input_scope_idx]->Var("out");
...@@ -87,6 +88,9 @@ struct TestGatherOpHandle { ...@@ -87,6 +88,9 @@ struct TestGatherOpHandle {
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_ = op_handle_ =
new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_); new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
op_handle_->SetLocalExecScopes(scope_map);
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
......
...@@ -35,49 +35,55 @@ std::string OpHandleBase::DebugString() const { ...@@ -35,49 +35,55 @@ std::string OpHandleBase::DebugString() const {
OpHandleBase::~OpHandleBase() { OpHandleBase::~OpHandleBase() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second)); if (ev.second) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
} }
#endif #endif
} }
void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::InitCUDA() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) { for (auto &p : dev_ctxes_) {
for (auto &p : dev_ctxes_) { int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(cudaSetDevice(dev_id)); PADDLE_ENFORCE(
PADDLE_ENFORCE( cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); }
} if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) {
if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) { for (auto &out_var : outputs_) {
for (auto &out_var : outputs_) { auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var); if (out_var_handle) {
if (out_var_handle) { int dev_id =
int dev_id = boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
boost::get<platform::CUDAPlace>(out_var_handle->place()).device; out_var_handle->SetGenerateEvent(events_.at(dev_id));
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
} }
} else { }
PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL, } else {
"%s should have only one dev_ctx.", Name()); PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL,
auto &place = dev_ctxes_.begin()->first; "%s should have only one dev_ctx.", Name());
int dev_id = boost::get<platform::CUDAPlace>(place).device; auto &place = dev_ctxes_.begin()->first;
for (auto &out_var : outputs_) { int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var); for (auto &out_var : outputs_) {
if (out_var_handle) { auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
PADDLE_ENFORCE( if (out_var_handle) {
platform::is_same_place(place, out_var_handle->place()), PADDLE_ENFORCE(platform::is_same_place(place, out_var_handle->place()),
"The place of output(%s) is not consistent with the " "The place of output(%s) is not consistent with the "
"place of current op(%s).", "place of current op(%s).",
out_var_handle->Name(), Name()); out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_.at(dev_id)); out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
} }
} }
} }
#else #endif
}
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) {
InitCUDA();
}
#else
PADDLE_ENFORCE(!use_cuda); PADDLE_ENFORCE(!use_cuda);
#endif #endif
...@@ -232,6 +238,17 @@ size_t OpHandleBase::NotReadyInputSize() const { ...@@ -232,6 +238,17 @@ size_t OpHandleBase::NotReadyInputSize() const {
return res.size(); return res.size();
} }
void OpHandleBase::SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map) {
local_exec_scopes_.clear();
auto scopes = GetLocalScopes();
for (auto *scope : scopes) {
auto iter = scope_map.find(scope);
PADDLE_ENFORCE(iter != scope_map.end(), "Local scope not found");
local_exec_scopes_.emplace_back(iter->second);
}
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,9 +25,10 @@ ...@@ -25,9 +25,10 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LOCAL_EXE_SCOPE@"; class Scope;
namespace details {
// Wraps ir::Node and provide helper utilities. // Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node. // It's responsible for populating necessary fields of ir::Node.
...@@ -107,7 +108,12 @@ class OpHandleBase { ...@@ -107,7 +108,12 @@ class OpHandleBase {
ir::Node *Node() { return node_; } ir::Node *Node() { return node_; }
void SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map);
protected: protected:
virtual std::vector<Scope *> GetLocalScopes() = 0;
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
void RunAndRecordEvent(platform::Place p, void RunAndRecordEvent(platform::Place p,
...@@ -115,11 +121,15 @@ class OpHandleBase { ...@@ -115,11 +121,15 @@ class OpHandleBase {
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
virtual void InitCUDA();
ir::Node *node_; ir::Node *node_;
std::vector<VarHandleBase *> inputs_; std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_; std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_; std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
std::vector<Scope *> local_exec_scopes_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_; std::unordered_map<int, cudaEvent_t> events_;
#endif #endif
......
...@@ -83,6 +83,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ...@@ -83,6 +83,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
...@@ -108,10 +109,20 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( ...@@ -108,10 +109,20 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::FastThreadedSSAGraphExecutor( executors_.emplace_back(new details::FastThreadedSSAGraphExecutor(
strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get())); strategy_, local_scopes_, local_exec_scopes, {places_[i]},
graphs_.at(i).get()));
} }
} }
std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
std::vector<ir::Graph *> result;
result.reserve(graphs_.size());
for (auto &g : graphs_) {
result.emplace_back(g.get());
}
return result;
}
FeedFetchList ParallelSSAGraphExecutor::Run( FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<FeedFetchList>> run_futures; std::vector<std::future<FeedFetchList>> run_futures;
......
...@@ -30,12 +30,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { ...@@ -30,12 +30,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
~ParallelSSAGraphExecutor() final = default; ~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graphs_[0]; }
std::vector<ir::Graph *> Graphs();
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private: private:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include <memory>
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
...@@ -160,10 +161,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -160,10 +161,7 @@ void ReduceOpHandle::RunImpl() {
auto in_0_handle = in_var_handles[0]; auto in_0_handle = in_var_handles[0];
std::vector<const Scope *> var_scopes; auto &var_scopes = local_exec_scopes_;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name()); var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
...@@ -250,9 +248,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -250,9 +248,7 @@ void ReduceOpHandle::RunImpl() {
} else { } else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
auto &reduce_sum_trg = *this->local_scopes_[0] auto &reduce_sum_trg = *this->local_exec_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(out_var_handle->name()) ->FindVar(out_var_handle->name())
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
...@@ -317,7 +313,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -317,7 +313,7 @@ void ReduceOpHandle::RunImpl() {
template <typename T> template <typename T>
std::vector<const T *> ReduceOpHandle::GetInputValues( std::vector<const T *> ReduceOpHandle::GetInputValues(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const { const std::vector<Scope *> &var_scopes) const {
std::vector<const T *> in_selected_rows; std::vector<const T *> in_selected_rows;
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto &in_sr = var_scopes.at(in_handle->scope_idx()) auto &in_sr = var_scopes.at(in_handle->scope_idx())
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -90,6 +91,8 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -90,6 +91,8 @@ struct ReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType> template <typename DevCtx, typename DataType>
void GatherSelectedRows( void GatherSelectedRows(
...@@ -106,7 +109,7 @@ struct ReduceOpHandle : public OpHandleBase { ...@@ -106,7 +109,7 @@ struct ReduceOpHandle : public OpHandleBase {
template <typename T> template <typename T>
std::vector<const T *> GetInputValues( std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const; const std::vector<Scope *> &var_scopes) const;
}; };
} // namespace details } // namespace details
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include <unordered_map>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -86,14 +87,13 @@ struct TestReduceOpHandle { ...@@ -86,14 +87,13 @@ struct TestReduceOpHandle {
void InitReduceOp(size_t out_scope_idx) { void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope // init scope
std::unordered_map<Scope *, Scope *> scope_map;
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope())); local_scopes_.push_back(&(g_scope_.NewScope()));
Scope &local_scope = local_scopes_.back()->NewScope(); Scope &local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope *>() = &local_scope;
local_scope.Var("input"); local_scope.Var("input");
param_scopes_.emplace_back(&local_scope); param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
} }
param_scopes_[out_scope_idx]->Var("out"); param_scopes_[out_scope_idx]->Var("out");
...@@ -115,6 +115,8 @@ struct TestReduceOpHandle { ...@@ -115,6 +115,8 @@ struct TestReduceOpHandle {
#endif #endif
} }
op_handle_->SetLocalExecScopes(scope_map);
// init op handle // init op handle
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
......
...@@ -21,7 +21,7 @@ namespace framework { ...@@ -21,7 +21,7 @@ namespace framework {
namespace details { namespace details {
RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name, Scope *local_scope, const std::string &name,
const platform::Place &place) const platform::Place &place)
: OpHandleBase(node), : OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)), op_(framework::OpRegistry::CreateOp(op_desc)),
...@@ -41,10 +41,7 @@ void RPCOpHandle::RunImpl() { ...@@ -41,10 +41,7 @@ void RPCOpHandle::RunImpl() {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p)); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
} }
} }
this->RunAndRecordEvent([this] { this->RunAndRecordEvent([this] { op_->Run(*local_exec_scopes_[0], place_); });
op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
place_);
});
} }
std::string RPCOpHandle::Name() const { return name_; } std::string RPCOpHandle::Name() const { return name_; }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -29,7 +30,7 @@ namespace details { ...@@ -29,7 +30,7 @@ namespace details {
struct RPCOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc, RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const Scope* local_scope, const std::string& name, Scope* local_scope, const std::string& name,
const platform::Place& place); const platform::Place& place);
std::string Name() const override; std::string Name() const override;
...@@ -41,9 +42,11 @@ struct RPCOpHandle : public OpHandleBase { ...@@ -41,9 +42,11 @@ struct RPCOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope*> GetLocalScopes() override { return {local_scope_}; }
private: private:
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; Scope* local_scope_;
const std::string name_; const std::string name_;
platform::Place place_; platform::Place place_;
}; };
......
...@@ -70,9 +70,9 @@ void ScaleLossGradOpHandle::RunImpl() { ...@@ -70,9 +70,9 @@ void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name()); platform::RecordEvent record_event(Name());
// Doesn't wait any event // Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name(); std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>(); auto *tensor =
local_exec_scopes_[0]->FindVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1})); tensor->Resize(make_ddim({1}));
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -36,6 +37,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -36,6 +37,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private: private:
float coeff_; float coeff_;
Scope *scope_; Scope *scope_;
......
...@@ -25,19 +25,24 @@ namespace framework { ...@@ -25,19 +25,24 @@ namespace framework {
namespace details { namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes, ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<Scope *> local_exec_scopes, std::vector<VariableInfo> var_infos,
std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor) std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)), underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
local_exec_scopes_(std::move(local_exec_scopes)),
var_infos_(std::move(var_infos)), var_infos_(std::move(var_infos)),
places_(std::move(places)) {} places_(std::move(places)) {
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size());
PrepareLocalExeScopes();
}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalExeScopes"); platform::RecordEvent e("InitLocalVars");
PrepareLocalExeScopes(); InitVariables();
} }
std::vector<framework::LoDTensor> fetch_data; std::vector<framework::LoDTensor> fetch_data;
...@@ -59,39 +64,55 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( ...@@ -59,39 +64,55 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
} }
} }
void ScopeBufferedSSAGraphExecutor::InitVariables() {
for (auto &info : tmp_var_infos_) {
for (auto &pair : info) {
InitializeVariable(pair.first, pair.second);
}
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes"); platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0; drop_scope_counter_ = 0;
for (auto p : places_) { for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait(); platform::DeviceContextPool::Instance().Get(p)->Wait();
} }
for (auto &scope : local_scopes_) { for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName); local_exec_scopes_[i]->EraseVarsExcept(preserve_vars_[i]);
if (local_scope_var != nullptr) { local_exec_scopes_[i]->DropKids();
auto &local_scope = *local_scope_var->GetMutable<Scope *>(); for (auto &preserve_var : preserve_vars_[i]) {
scope->DeleteScope(local_scope); preserve_var->Clear();
scope->EraseVars({std::string(details::kLocalExecScopeName)});
VLOG(3) << "Drop local execution scope: " << local_scope;
} }
VLOG(3) << "Drop local execution scope: " << local_scopes_[i];
} }
} }
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() { void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes. // Create local scopes.
preserve_vars_.resize(local_scopes_.size());
tmp_var_infos_.resize(local_scopes_.size());
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) { for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it; size_t idx = local_scopes_.size() - 1 - (it - local_scopes_.rbegin());
Scope &local_scope = scope->NewScope(); auto *scope = local_scopes_[idx];
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope; auto *local_scope = local_exec_scopes_[idx];
for (auto &info : var_infos_) { for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable if (info.persistable_) { // Persistable
auto var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2)
<< info.name_
<< " has been initialized beforehand in global scope, skipped";
continue;
}
InitializeVariable(scope->Var(info.name_), info.type_); InitializeVariable(scope->Var(info.name_), info.type_);
} else { } else {
InitializeVariable(local_scope.Var(info.name_), info.type_); Variable *tmp_var = local_scope->Var(info.name_);
preserve_vars_[idx].emplace(tmp_var);
tmp_var_infos_[idx].emplace_back(tmp_var, info.type_);
} }
} }
} }
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <list> #include <list>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
...@@ -39,6 +41,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -39,6 +41,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes, ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<Scope*> local_exec_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor); std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
...@@ -55,10 +58,18 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -55,10 +58,18 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareLocalExeScopes(); void PrepareLocalExeScopes();
private: private:
void InitVariables();
size_t drop_scope_counter_{0}; size_t drop_scope_counter_{0};
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_; std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<Scope*> local_exec_scopes_;
std::vector<std::unordered_set<Variable*>> preserve_vars_;
std::vector<std::vector<std::pair<Variable*, proto::VarType::Type>>>
tmp_var_infos_;
std::vector<VariableInfo> var_infos_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
}; };
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// TODO(zjl): support SelectedRows
static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names)
: OpHandleBase(node),
scope_(scope),
scope_idx_(scope_idx),
op_type_(op_type),
in_var_infos_(in_var_infos),
out_var_names_(out_var_names) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size());
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
Add(in_var_infos_[i], out_var_names_[i]);
}
}
std::unordered_set<std::string> ShareTensorBufferOpHandle::ReusedVarSet()
const {
std::unordered_set<std::string> result;
for (auto &in_var_info : in_var_infos_) {
result.insert(in_var_info->Name());
}
return result;
}
void ShareTensorBufferOpHandle::Add(ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name) {
PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr");
PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name,
"in/out cannot have same name: %s", out_var_name);
in_var_infos_.emplace_back(in_var_info);
out_var_names_.emplace_back(out_var_name);
}
void ShareTensorBufferOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
int dev_id =
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
events_[dev_id] = nullptr;
#endif
}
void ShareTensorBufferOpHandle::CallOnce() {
PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
auto *in_var = exec_scope->FindVar(in_var_infos_[i]->Name());
auto *out_var = exec_scope->FindVar(out_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(in_var);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_NE(in_var, out_var);
in_out_vars_.emplace_back(in_var, out_var);
}
}
void ShareTensorBufferOpHandle::RunImpl() {
if (in_var_infos_.size() != in_out_vars_.size()) {
CallOnce();
}
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
const auto &in_tensor = GetTensorFromVar(in_out_vars_[i].first);
auto *out_tensor = GetMutableTensorFromVar(in_out_vars_[i].second);
auto *in_var_info = in_var_infos_[i];
if (UNLIKELY(in_var_info->IsSkipped())) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()
<< " in previous batch: " << in_var_info->Name() << " -> "
<< out_var_names_[i];
out_tensor->clear();
}
} else {
out_tensor->ShareBufferWith(in_tensor);
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i];
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
class Variable;
class Scope;
class Tensor;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details {
class ShareTensorBufferOpHandle : public OpHandleBase {
public:
ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx,
const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_vars_infos,
const std::vector<std::string> &out_var_names);
std::unordered_set<std::string> ReusedVarSet() const;
Priority GetPriority() const override { return Priority::kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
void Add(ir::MemOptVarInfo *in_var_info, const std::string &ou_var_name);
protected:
std::string Name() const override { return "buffer_share"; }
void RunImpl() final;
void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
void CallOnce();
Scope *scope_;
size_t scope_idx_;
std::string op_type_;
std::vector<ir::MemOptVarInfo *> in_var_infos_;
std::vector<std::string> out_var_names_;
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -58,8 +58,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { ...@@ -58,8 +58,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
std::vector<LoDTensor *> outs; std::vector<LoDTensor *> outs;
int k = -1; int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope = auto *local_scope = local_exec_scopes_[i];
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto original_name = auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded; auto encode_var_name = original_name + g_dgc_encoded;
...@@ -135,9 +134,8 @@ int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) { ...@@ -135,9 +134,8 @@ int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto var_name = original_name + g_dgc_k; auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0); PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0]; auto *scope = local_exec_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto var = scope->FindVar(var_name);
auto var = local_scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>(); auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor; return *tensor;
...@@ -151,8 +149,7 @@ bool SparseAllReduceOpHandle::IsEncoded() { ...@@ -151,8 +149,7 @@ bool SparseAllReduceOpHandle::IsEncoded() {
auto step_name = g_dgc_rampup_begin_step; auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0); PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0]; auto *local_scope = local_exec_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto count_var = local_scope->FindVar(counter_name); auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name); auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) { if (count_var == nullptr || step_var == nullptr) {
......
...@@ -22,9 +22,11 @@ namespace framework { ...@@ -22,9 +22,11 @@ namespace framework {
namespace details { namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph) const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph), : graph_(graph),
local_scopes_(local_scopes), local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places), places_(places),
fetch_ctxs_(places), fetch_ctxs_(places),
strategy_(strategy), strategy_(strategy),
...@@ -176,7 +178,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ...@@ -176,7 +178,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node = ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_); auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_,
&local_exec_scopes_);
fetch_ops->emplace_back(op); fetch_ops->emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
......
...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
ir::Graph *graph); ir::Graph *graph);
...@@ -71,6 +72,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -71,6 +72,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
// be destroyed first. // be destroyed first.
ir::Graph *graph_; ir::Graph *graph_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
......
...@@ -48,35 +48,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference { ...@@ -48,35 +48,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
public: public:
std::unordered_map<std::string, std::string> operator()( std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override { const OpDesc& op_desc, bool use_cuda) const override {
PADDLE_ENFORCE(!op_desc.InputNames().empty(), PADDLE_ENFORCE_EQ(op_desc.InputNames().size(), 1,
"Op inputs must not be empty"); "Op inputs must be unique");
PADDLE_ENFORCE(!op_desc.OutputNames().empty(), PADDLE_ENFORCE_EQ(op_desc.OutputNames().size(), 1,
"Op outputs must not be empty"); "Op outputs must be unique");
auto x_name = op_desc.InputNames().at(0); auto x_name = op_desc.InputNames().at(0);
auto out_name = op_desc.OutputNames().at(0); auto out_name = op_desc.OutputNames().at(0);
return std::unordered_map<std::string, std::string>{{x_name, out_name}}; return std::unordered_map<std::string, std::string>{{x_name, out_name}};
} }
}; };
/*
Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy.
*/
class GradOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
std::unordered_map<std::string, std::string> ret;
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
op_desc.OutputNames().end());
for (auto& input_name : op_desc.InputNames()) {
if (output_names.count(GradVarName(input_name))) {
ret.insert({input_name, GradVarName(input_name)});
}
}
return ret;
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,3 +16,7 @@ cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_o ...@@ -16,3 +16,7 @@ cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_o
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper) cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class BufferSharedInplaceOpPass : public MemoryReusePass {
protected:
std::string ReuseType() const override { return "inplace"; }
void Run(Graph *graph) const override;
};
void BufferSharedInplaceOpPass::Run(Graph *graph) const {
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
bool use_cuda = Get<bool>(kUseCuda);
// Step 1: Build a reverse map of last_live_ops
// i.e.: op -> vars
std::unordered_map<details::ComputationOpHandle *,
std::unordered_map<std::string, ir::Node *>>
candidate_ops;
for (auto &each_scope_ops : last_live_ops) {
for (auto &pair : each_scope_ops) {
// If variable has more than 1 last lived ops, this variable cannot
// be inplaced.
if (pair.second.size() != 1) {
continue;
}
auto *op = *(pair.second.begin());
const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(op_desc);
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (!infer_inplace) {
continue;
}
const std::string &var_name = pair.first;
auto in_nodes = this->FindNodesByName(var_name, op->Node()->inputs);
if (in_nodes.size() == 1) {
candidate_ops[op][var_name] = *in_nodes.begin();
}
}
}
// Step 2: Check which vars can be inplaced indeed
for (auto &op_vars_pair : candidate_ops) {
auto *op = op_vars_pair.first;
auto &vars = op_vars_pair.second;
const std::string &op_type = op->GetOp()->Type();
auto *op_desc = op->Node()->Op();
auto in_to_outs =
OpInfoMap::Instance().Get(op_type).infer_inplace_(*op_desc, use_cuda);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &in_args = op_desc->Input(in_param);
if (in_args.empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param
<< ") is empty in " << op_type;
continue;
}
auto &in_arg = in_args[0];
auto iter = vars.find(in_arg);
if (iter == vars.end()) {
VLOG(4) << "Cannot inplace maybe because Input(" << in_param
<< ")=" << in_arg << " is not lastly used in op " << op_type
<< ", or it occurs multiple times in input or occurs in output";
continue;
}
ir::Node *in_node = iter->second;
auto &out_param = pair.second;
auto &out_args = op_desc->Output(out_param);
if (out_args.empty()) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ") is empty in " << op_type;
continue;
}
auto &out_arg = out_args[0];
auto out_nodes = this->FindNodesByName(out_arg, op->Node()->outputs);
if (out_nodes.size() != 1) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " occurs " << out_nodes.size()
<< " time(s) in output of op " << op_type;
continue;
}
auto *out_node = *out_nodes.begin();
auto &in_var_handle = in_node->Wrapper<details::VarHandleBase>();
auto &out_var_handle = out_node->Wrapper<details::VarHandleBase>();
auto *in_var_handle_ptr =
dynamic_cast<details::VarHandle *>(&in_var_handle);
auto *out_var_handle_ptr =
dynamic_cast<details::VarHandle *>(&out_var_handle);
if (in_var_handle_ptr == nullptr || out_var_handle_ptr == nullptr) {
continue;
}
bool success = this->TryReuseVar(in_var_handle_ptr, out_var_handle_ptr);
if (success) {
VLOG(4) << "Inplace performed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name()
<< ". Debug String is: " << op->GetOp()->DebugString();
} else {
VLOG(4) << "Inplace failed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name();
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(buffer_shared_inplace_pass,
paddle::framework::ir::BufferSharedInplaceOpPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -189,13 +190,9 @@ class EagerDeletionPass : public ir::Pass { ...@@ -189,13 +190,9 @@ class EagerDeletionPass : public ir::Pass {
}; };
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = auto &var_infos = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops = const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
...@@ -224,10 +221,15 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -224,10 +221,15 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node = auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
std::unordered_set<MemOptVarInfo *> var_info;
for (auto &var_name : var_names) {
var_info.insert(var_infos[op->GetScopeIdx()].at(var_name).get());
}
auto *eager_deletion_op = new details::EagerDeletionOpHandle( auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names, eager_deletion_node, op->GetScope(), op->GetPlace(),
gcs.at(places[op->GetScopeIdx()]).get(), std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
&(ref_cnts[op->GetScopeIdx()]));
auto it = std::find_if( auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(), op->Outputs().begin(), op->Outputs().end(),
...@@ -250,6 +252,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -250,6 +252,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
graph->Get<details::GraphDepVars>(details::kGraphDepVars) graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf); .emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf);
eager_deletion_op->SetDeviceContext(
places[op->GetScopeIdx()],
platform::DeviceContextPool::Instance().Get(places[op->GetScopeIdx()]));
} }
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction; VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
...@@ -273,7 +279,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { ...@@ -273,7 +279,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount) .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kAllPlaces) .RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector); .RequirePassAttr(paddle::framework::ir::kGarbageCollector);
......
...@@ -106,6 +106,9 @@ class InplacePass : public ir::Pass { ...@@ -106,6 +106,9 @@ class InplacePass : public ir::Pass {
// Check whether var is the last version one in SSA graph // Check whether var is the last version one in SSA graph
bool IsLastVersionVar(ir::Node *var) const; bool IsLastVersionVar(ir::Node *var) const;
// Check whether var is the first version one in SSA graph
bool IsFirstVersionVar(ir::Node *var) const;
// Check whether all `ops` is the preceding ops of `op` // Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const; bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
...@@ -155,6 +158,10 @@ bool InplacePass::IsSkipVar(const std::string &var_name) const { ...@@ -155,6 +158,10 @@ bool InplacePass::IsSkipVar(const std::string &var_name) const {
return skip_vars_.count(var_name) > 0; return skip_vars_.count(var_name) > 0;
} }
bool InplacePass::IsFirstVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->front() == var;
}
bool InplacePass::IsLastVersionVar(ir::Node *var) const { bool InplacePass::IsLastVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->back() == var; return AllVersionVars(var->Name())->back() == var;
} }
...@@ -429,13 +436,19 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { ...@@ -429,13 +436,19 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
} }
if (!FindNodesByName(out_arg, op_node->inputs).empty()) { if (!FindNodesByName(out_arg, op_node->inputs).empty()) {
VLOG(4) << "Cannot inplace because Output(" << in_param VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " occurs in input of op " << op_type; << ")=" << out_arg << " occurs in input of op " << op_type;
continue; continue;
} }
auto *out_node = *out_nodes.begin(); auto *out_node = *out_nodes.begin();
if (!IsFirstVersionVar(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " does not occur first in op " << op_type;
continue;
}
if (!NodeCanReused(out_node)) { if (!NodeCanReused(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is not reusable in " << op_type; << ")=" << out_arg << " is not reusable in " << op_type;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class MemOptVarInfo {
public:
MemOptVarInfo(const std::string &name, size_t ref_cnt) : name_(name) {
SetRefCnt(ref_cnt);
}
bool DecreaseRefCnt() {
return ref_cnt_ == 1 || (runtime_ref_cnt_.fetch_sub(1) == 1);
}
void ResetRuntimeRefCnt() { runtime_ref_cnt_ = ref_cnt_; }
void SetRefCnt(size_t ref_cnt) {
PADDLE_ENFORCE_GE(ref_cnt, 1,
"Reference count must be larger than or equal to 1");
ref_cnt_ = ref_cnt;
runtime_ref_cnt_ = ref_cnt;
}
bool IsSkipped() const { return skipped_; }
void SetSkip(bool skipped) { skipped_ = skipped; }
const std::string &Name() const { return name_; }
private:
std::string name_;
size_t ref_cnt_;
std::atomic<size_t> runtime_ref_cnt_;
bool skipped_{false};
};
using MemOptVarInfoMapList = std::vector<
std::unordered_map<std::string, std::unique_ptr<MemOptVarInfo>>>;
class SkipMemOptVarsGuard {
public:
SkipMemOptVarsGuard(MemOptVarInfoMapList *list,
const std::vector<std::string> &vars,
bool need_reset_ref_cnt)
: list_(list), need_reset_ref_cnt_(need_reset_ref_cnt) {
if (!list_) return;
skip_vars_.reserve(vars.size() * list->size());
for (auto &var : vars) {
for (auto &map : *list_) {
auto iter = map.find(var);
if (iter != map.end() && !iter->second->IsSkipped()) {
iter->second->SetSkip(true);
skip_vars_.emplace_back(iter->second.get());
}
}
}
}
~SkipMemOptVarsGuard() {
for (auto *var : skip_vars_) {
var->SetSkip(false);
}
if (list_ && need_reset_ref_cnt_) {
for (auto &map : *list_) {
for (auto &pair : map) {
pair.second->ResetRuntimeRefCnt();
}
}
}
}
private:
MemOptVarInfoMapList *list_;
bool need_reset_ref_cnt_;
std::vector<MemOptVarInfo *> skip_vars_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
// Each ShareTensorBufferOpHandle should only have one pending
// ComputationOpHandle
static details::ComputationOpHandle *GetUniquePendingComputationOpHandle(
details::ShareTensorBufferOpHandle *share_tensor_op) {
details::ComputationOpHandle *result_op = nullptr;
for (Node *out_var : share_tensor_op->Node()->outputs) {
for (Node *pending_op : out_var->outputs) {
auto &op = pending_op->Wrapper<details::OpHandleBase>();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(&op);
PADDLE_ENFORCE_NOT_NULL(compute_op);
if (result_op == nullptr) {
result_op = compute_op;
} else {
PADDLE_ENFORCE_EQ(result_op, compute_op);
}
}
}
PADDLE_ENFORCE_NOT_NULL(result_op);
return result_op;
}
void MemoryReusePass::ApplyImpl(Graph *graph) const {
graph_ = graph;
all_vars_ = &(graph_->Get<details::GraphVars>(details::kGraphVars));
var_infos_ = &(Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList));
last_live_ops_of_vars_ =
&(Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars));
reused_var_names_.resize(all_vars_->size());
var_descs_.resize(all_vars_->size());
// Collect the existing ShareTensorBufferOpHandles.
// This is because (1) we want to reuse the existing
// ShareTensorBufferOpHandles to avoid inserting too many ops;
// (2) what is more important, a variable cannot be reused
// by two different variables, which may cause wrong calculation
// results. We have to know which variables have been reused.
CollectShareTensorBufferOpHandles();
CollectReusedVars();
Run(graph);
std::map<size_t, size_t> op_num;
for (auto &pair : ops_) {
++op_num[pair.first->GetScopeIdx()];
}
for (auto &pair : op_num) {
VLOG(2) << "Create " << pair.second
<< " ShareTensorBufferOpHandles in Scope " << pair.first;
}
}
bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const {
auto *op =
dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op);
if (IsVarsReusable(in_var, out_var)) {
AddReuseVar(op, in_var, out_var);
return true;
} else {
return false;
}
}
std::unordered_set<Node *> MemoryReusePass::FindNodesByName(
const std::string &name, const std::vector<Node *> &nodes) const {
std::unordered_set<ir::Node *> ret;
for (auto *node : nodes) {
if (node->Name() == name) {
ret.insert(node);
}
}
return ret;
}
VarDesc *MemoryReusePass::GetVarDesc(details::VarHandle *var) const {
auto iter = var_descs_[var->scope_idx()].find(var->Name());
if (iter == var_descs_[var->scope_idx()].end()) {
PADDLE_ENFORCE((*all_vars_)[var->scope_idx()].count(var->Name()),
"Variable %s not found", var->Name());
auto *desc =
TryGetLatestVarDesc((*all_vars_)[var->scope_idx()].at(var->Name()));
PADDLE_ENFORCE_NOT_NULL(desc);
var_descs_[var->scope_idx()].emplace(var->Name(), desc);
return desc;
} else {
return iter->second;
}
}
void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
auto all_ops = FilterByNodeWrapper<details::OpHandleBase>(*graph_);
for (auto *op : all_ops) {
auto *share_buffer_op =
dynamic_cast<details::ShareTensorBufferOpHandle *>(op);
if (share_buffer_op != nullptr) {
auto *compute_op = GetUniquePendingComputationOpHandle(share_buffer_op);
PADDLE_ENFORCE(ops_.count(compute_op) == 0);
ops_.emplace(compute_op, share_buffer_op);
}
}
}
void MemoryReusePass::CollectReusedVars() const {
for (auto &pair : ops_) {
auto reused_vars = pair.second->ReusedVarSet();
reused_var_names_[pair.first->GetScopeIdx()].insert(reused_vars.begin(),
reused_vars.end());
}
}
bool MemoryReusePass::IsVarAlreadyReused(details::VarHandle *var) const {
return reused_var_names_[var->scope_idx()].count(var->Name()) > 0;
}
details::ShareTensorBufferOpHandle *
MemoryReusePass::InsertShareTensorBufferOpHandleToGraph(
details::ComputationOpHandle *op) const {
auto *buffer_share_node =
graph_->CreateEmptyNode("buffer_share", ir::Node::Type::kOperation);
auto *buffer_share_op = new details::ShareTensorBufferOpHandle(
buffer_share_node, op->GetScope(), op->GetScopeIdx(), op->GetOp()->Type(),
{}, {});
buffer_share_op->SetDeviceContext(
op->GetPlace(),
platform::DeviceContextPool::Instance().Get(op->GetPlace()));
// Inputs of `buffer_share_op` should be all inputs of `op`
for (auto *in_var : op->Inputs()) {
buffer_share_op->AddInput(in_var);
}
// Add a dep_var to resolve write-after-write data hazard between
// `buffer_share_op` and `op`.
auto *dep_var = new details::DummyVarHandle(graph_->CreateControlDepVar());
graph_->Get<details::GraphDepVars>(details::kGraphDepVars).emplace(dep_var);
op->AddInput(dep_var);
buffer_share_op->AddOutput(dep_var);
ops_.emplace(op, buffer_share_op);
return buffer_share_op;
}
bool MemoryReusePass::IsVarsReusable(details::VarHandle *in_var,
details::VarHandle *out_var) const {
const auto in_name = in_var->Name();
const auto out_name = out_var->Name();
if (in_name == out_name) {
return false;
}
if (in_name == kEmptyVarName || out_name == kEmptyVarName) {
return false;
}
if (IsVarAlreadyReused(in_var)) {
return false;
}
// out_var must be the first version!!!
auto out_var_iter = (*all_vars_)[out_var->scope_idx()].find(out_name);
PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var->scope_idx()].end() &&
!out_var_iter->second.empty(),
"Cannot find variable %s", out_name);
if (out_var_iter->second[0] != out_var) {
return false;
}
const VarDesc *in_var_desc = GetVarDesc(in_var);
const VarDesc *out_var_desc = GetVarDesc(out_var);
if (in_var_desc->Persistable() || out_var_desc->Persistable()) {
return false;
}
if (in_var_desc->GetType() != proto::VarType::LOD_TENSOR ||
out_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (!FindNodesByName(in_name, out_var->GeneratedOp()->Node()->outputs)
.empty()) {
return false;
}
if (!FindNodesByName(out_name, out_var->GeneratedOp()->Node()->inputs)
.empty()) {
return false;
}
auto all_input_args =
out_var->GeneratedOp()->Node()->Op()->InputArgumentNames();
if (std::count(all_input_args.begin(), all_input_args.end(), in_name) > 1) {
return false;
}
return true;
}
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const {
PADDLE_ENFORCE((*var_infos_)[op->GetScopeIdx()].count(in_var->Name()) > 0,
"%s does not in mem-opt var infos", in_var->Name());
if (ops_.count(op) == 0) {
InsertShareTensorBufferOpHandleToGraph(op);
}
auto *share_buffer_op = ops_[op];
auto &all_input_vars = share_buffer_op->Inputs();
bool has_input = std::find(all_input_vars.begin(), all_input_vars.end(),
in_var) != all_input_vars.end();
if (!has_input) {
share_buffer_op->AddInput(in_var);
}
share_buffer_op->Add(
(*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(),
out_var->Name());
reused_var_names_[op->GetScopeIdx()].insert(in_var->Name());
UpdateLastLiveOpOfVar(op, in_var, out_var);
}
// 1. Set last living op of in_var to be any last living op of out_var
// 2. Set reference count of in_var to be 1
void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const {
size_t scope_idx = op->GetScopeIdx();
auto out_var_op_iter =
(*last_live_ops_of_vars_)[scope_idx].find(out_var->Name());
PADDLE_ENFORCE(out_var_op_iter != (*last_live_ops_of_vars_)[scope_idx].end(),
"Cannot find variable %s", out_var->Name());
PADDLE_ENFORCE(!out_var_op_iter->second.empty());
auto &last_live_ops_of_in_var =
(*last_live_ops_of_vars_)[scope_idx][in_var->Name()];
last_live_ops_of_in_var.clear();
last_live_ops_of_in_var.insert(*(out_var_op_iter->second.begin()));
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(),
"Cannot find variable %s", in_var->Name());
in_var_info_iter->second->SetRefCnt(1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* MemoryReusePass is the base class of InplacePass and MemoryOptimizePass.
*
* Unlike the legacy Python API fluid.memory_optimize() which changes
* variable names in the program/graph, MemoryReusePass inserts
* ShareTensorBufferOpHandle into the graph. It is because if we use the
* way of changing variable names:
*
* 1. There are so many corner cases we should skip. For example, (1) variables
* that relates to send/recv ops cannot be renamed (otherwise, pserver
* and trainer cannot find the matching variables), (2) ins/outs of ops
* containing sub-blocks cannot be optimized, (3) variables inside
* op_role_vars cannot be renamed.
*
* 2. It is very difficult to avoid reusing variables that users want to fetch.
* This is because the memory-optimize passes/transpiler runs before users
* fetch, i.e., exe.run(...). We cannot know what users want to fetch in the
* future. As a result, we have to set var.persistable = True before
* applying memory-optimize passes/transpiler, which is rather ugly and not
* friendly to users.
*
* 3. Dim and LoD of the reused variable would be changed, which may result
* in potential errors in InferShape stage of the following ops. What's
* more, it makes that we cannot use the information from
* NoNeedBufferVarsInference.
*
* Considering the drawbacks of the former renaming strategy, we design a
* novel memory-optimize pass to fix these issues. Whether in-place is
* performed can be decided during run-time. ShareTensorBufferOpHandle
* would only share tensor buffers between in/out, never rename variable,
* and not change dim and LoD of variable. If users want to fetch a certain
* variable, we can skip in-place during run-time.
*
* The only concern on speed performance may be: there are too many
* ShareTensorBufferOpHandles in the graph. This can be avoided by moving
* tensor buffer sharing in each ComputationOpHandle::Run() method. We need
* a pass to clean all ShareTensorBufferOpHandles and move sharing to
* ComputationOpHandle::Run() in the future.
*/
class MemoryReusePass : public Pass {
protected:
void ApplyImpl(Graph *graph) const final;
virtual void Run(Graph *graph) const = 0;
virtual std::string ReuseType() const = 0;
bool TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const;
std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes) const;
size_t ScopeNum() const { return all_vars_->size(); }
private:
VarDesc *GetVarDesc(details::VarHandle *var) const;
bool IsVarsReusable(details::VarHandle *in_var,
details::VarHandle *out_var) const;
bool IsVarAlreadyReused(details::VarHandle *var) const;
details::ShareTensorBufferOpHandle *InsertShareTensorBufferOpHandleToGraph(
details::ComputationOpHandle *op) const;
void CollectShareTensorBufferOpHandles() const;
void CollectReusedVars() const;
void AddReuseVar(details::ComputationOpHandle *op, details::VarHandle *in_var,
details::VarHandle *out_var) const;
void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const;
private:
mutable Graph *graph_;
mutable details::GraphVars *all_vars_;
mutable MemOptVarInfoMapList *var_infos_;
mutable std::vector<LastLiveOpsOfVars> *last_live_ops_of_vars_;
mutable std::unordered_map<details::ComputationOpHandle *,
details::ShareTensorBufferOpHandle *>
ops_;
mutable std::vector<std::unordered_set<std::string>> reused_var_names_;
mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -295,18 +296,18 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, ...@@ -295,18 +296,18 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
} }
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount); auto &var_infos = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
auto &last_live_ops_of_vars = auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars); Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(), PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(),
"Last Live Ops and Reference Counts of vars should be " "Last Live Ops and Reference Counts of vars should be "
"initialized at here."); "initialized at here.");
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars); const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size()); last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size()); var_infos.resize(vars.size());
ShrinkDepsOpFunctor shrink_func( ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph)); ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
...@@ -359,7 +360,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -359,7 +360,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
var_name); var_name);
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size()); var_infos[i][var_name].reset(
new MemOptVarInfo(var_name, result.size()));
last_live_ops_of_vars[i].emplace(var_name, std::move(result)); last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break; break;
} }
...@@ -375,5 +377,5 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { ...@@ -375,5 +377,5 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass) REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
.RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount) .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars); .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
...@@ -33,16 +33,10 @@ class VarDesc; ...@@ -33,16 +33,10 @@ class VarDesc;
namespace ir { namespace ir {
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorMap = using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector>>; std::map<platform::Place, std::unique_ptr<GarbageCollector>>;
const char kGlobalReferenceCount[] = "global_reference_count"; const char kMemOptVarInfoMapList[] = "mem_opt_var_info_map_list";
const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kGarbageCollector[] = "garbage_collector"; const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places"; const char kAllPlaces[] = "all_places";
......
...@@ -89,7 +89,12 @@ class Node { ...@@ -89,7 +89,12 @@ class Node {
// Return a reference to the `wrapper`. // Return a reference to the `wrapper`.
template <typename T> template <typename T>
T& Wrapper() { T& Wrapper() {
return *boost::any_cast<T*>(wrapper_); try {
return *boost::any_cast<T*>(wrapper_);
} catch (boost::bad_any_cast&) {
PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s",
typeid(T).name(), wrapper_type_.name());
}
} }
// Test if the Node is wrapped by type T. // Test if the Node is wrapped by type T.
......
...@@ -22,11 +22,13 @@ limitations under the License. */ ...@@ -22,11 +22,13 @@ limitations under the License. */
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -76,24 +78,10 @@ class ParallelExecutorPrivate { ...@@ -76,24 +78,10 @@ class ParallelExecutorPrivate {
} }
} }
ir::Graph *PrepareGCAndRefCnts(ir::Graph *graph, size_t max_memory_size); ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); } inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
for (auto &pair : global_ref_cnts_[i]) {
runtime_ref_cnts_[i][pair.first] = pair.second;
}
for (auto &fetch_name : fetch_tensors) {
runtime_ref_cnts_[i].erase(fetch_name);
}
runtime_ref_cnts_[i].erase(fetched_var_name);
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) { void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) {
VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_ VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_
...@@ -201,12 +189,20 @@ class ParallelExecutorPrivate { ...@@ -201,12 +189,20 @@ class ParallelExecutorPrivate {
} }
#endif #endif
inline bool IsPersistable(const std::string &name) const {
auto iter = is_persistable_.find(name);
return iter != is_persistable_.end() && iter->second;
}
BuildStrategy build_strategy_; BuildStrategy build_strategy_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
Scope *global_scope_; // not owned Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_; std::unique_ptr<details::SSAGraphExecutor> executor_;
std::unordered_map<std::string, bool> is_persistable_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nccl_ctxs_{nullptr}; platform::NCCLCommunicator *nccl_ctxs_{nullptr};
#endif #endif
...@@ -215,16 +211,37 @@ class ParallelExecutorPrivate { ...@@ -215,16 +211,37 @@ class ParallelExecutorPrivate {
bool use_all_reduce_; bool use_all_reduce_;
size_t nranks_; size_t nranks_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and ir::MemOptVarInfoMapList mem_opt_var_infos_;
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
ir::GarbageCollectorMap gcs_; ir::GarbageCollectorMap gcs_;
}; };
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
ir::Graph *graph, size_t max_memory_size) { std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
if (build_strategy_.enable_inplace_) {
auto inplace_pass =
ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass");
inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
inplace_pass->SetNotOwned(ir::kUseCuda, &use_cuda_);
VLOG(10) << "Start to apply buffer_shared_inplace_pass";
graph = inplace_pass->Apply(graph);
VLOG(10) << "buffer_shared_inplace_pass Applied";
}
// TODO(zjl): refactor MemoryOptimizePass as well!!!
if (GetEagerDeletionThreshold() < 0) {
return graph;
}
size_t max_memory_size = static_cast<size_t>(GetEagerDeletionThreshold());
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i]; auto &place = places_[i];
if (gcs_.count(place) > 0) { if (gcs_.count(place) > 0) {
...@@ -258,19 +275,10 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( ...@@ -258,19 +275,10 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
} }
if (!gcs_.empty()) { if (!gcs_.empty()) {
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass = auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass"); ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount, eager_deletion_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
&runtime_ref_cnts_); &mem_opt_var_infos_);
eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_); eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars, eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars); &last_live_ops_of_vars);
...@@ -386,9 +394,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -386,9 +394,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// same communicators. // same communicators.
auto *nccl_ctxs = auto *nccl_ctxs =
member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_); member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_);
auto &pool = platform::DeviceContextPool::Instance();
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>( auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id])); pool.Get(member_->places_[dev_id]));
auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]); auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]);
...@@ -456,13 +463,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -456,13 +463,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); graph = member_->ApplyMemoryOptimizePass(graph);
VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(graph,
static_cast<size_t>(max_memory_size));
}
async_graphs[0] = graph; async_graphs[0] = graph;
...@@ -475,6 +476,9 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -475,6 +476,9 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
var_infos.back().name_ = node->Var()->Name(); var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType(); var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
member_->is_persistable_.emplace(node->Var()->Name(),
node->Var()->Persistable());
} }
} }
...@@ -493,17 +497,34 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -493,17 +497,34 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
} }
} }
std::unordered_map<Scope *, Scope *> scope_map;
for (auto *scope : member_->local_scopes_) {
auto &local_exec_scope = scope->NewScope();
member_->local_exec_scopes_.emplace_back(&local_exec_scope);
scope_map.emplace(scope, &local_exec_scope);
}
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(),
member_->local_exec_scopes_.size());
std::vector<ir::Graph *> final_graphs;
if (member_->build_strategy_.async_mode_) { if (member_->build_strategy_.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor"; VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor( member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, async_graphs));
final_graphs = async_graphs;
} else if (member_->build_strategy_.enable_parallel_graph_) { } else if (member_->build_strategy_.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor"; VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when // TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
member_->executor_.reset(new details::ParallelSSAGraphExecutor( auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph);
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
#else #else
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
...@@ -512,19 +533,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -512,19 +533,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor"; VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor"; VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph)); exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} }
final_graphs.emplace_back(graph);
} }
VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!member_->build_strategy_.async_mode_) { if (!member_->build_strategy_.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos), exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, std::move(member_->executor_))); std::move(var_infos), member_->places_, std::move(member_->executor_)));
}
for (auto *g : final_graphs) {
auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*g);
for (auto *op : ops) {
op->SetLocalExecScopes(scope_map);
}
} }
} }
...@@ -616,10 +647,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -616,10 +647,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
#endif #endif
platform::RecordBlock b(0); platform::RecordBlock b(0);
if (member_->HasGarbageCollectors()) {
platform::RecordEvent event("PrepareGarbageCollectors"); ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name); member_->HasGarbageCollectors());
}
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
...@@ -633,9 +663,13 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( ...@@ -633,9 +663,13 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i]; auto &map = tensors[i];
auto *scope = member_->local_scopes_[i];
for (auto &pair : map) { for (auto &pair : map) {
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>(); bool is_persistable = member_->IsPersistable(pair.first);
auto *feed_scope = is_persistable ? member_->local_scopes_[i]
: member_->local_exec_scopes_[i];
auto *feed_var = feed_scope->Var(pair.first);
auto *trg = feed_var->GetMutable<LoDTensor>();
trg->ShareDataWith(pair.second); trg->ShareDataWith(pair.second);
trg->set_lod(pair.second.lod()); trg->set_lod(pair.second.lod());
} }
...@@ -644,7 +678,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( ...@@ -644,7 +678,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) { const std::unordered_map<std::string, LoDTensor> &tensors) {
for (auto pair : tensors) { for (auto &pair : tensors) {
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_); auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
if (member_->places_.size() != lod_tensors.size()) { if (member_->places_.size() != lod_tensors.size()) {
bool is_cpu_place = platform::is_cpu_place(member_->places_.front()); bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
...@@ -661,10 +695,14 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( ...@@ -661,10 +695,14 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
} }
PADDLE_THROW(error_info); PADDLE_THROW(error_info);
} }
bool is_persistable = member_->IsPersistable(pair.first);
for (size_t j = 0; j < member_->places_.size(); ++j) { for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var? auto *feed_scope = is_persistable ? member_->local_scopes_[j]
auto t = : member_->local_exec_scopes_[j];
member_->local_scopes_[j]->Var(pair.first)->GetMutable<LoDTensor>(); auto *feed_var = feed_scope->Var(pair.first);
auto t = feed_var->GetMutable<LoDTensor>();
t->ShareDataWith(lod_tensors[j]); t->ShareDataWith(lod_tensors[j]);
t->set_lod(lod_tensors[j].lod()); t->set_lod(lod_tensors[j].lod());
} }
...@@ -724,3 +762,4 @@ bool ParallelExecutor::EnableParallelGraphExecution( ...@@ -724,3 +762,4 @@ bool ParallelExecutor::EnableParallelGraphExecution(
USE_PASS(reference_count_pass); USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass); USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
......
...@@ -200,6 +200,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const { ...@@ -200,6 +200,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr; return nullptr;
} }
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) {
if (vars.count(iter->second.get()) != 0) {
++iter;
} else {
vars_.erase(iter++);
}
}
}
std::string GenScopeTreeDebugInfo(Scope* root) { std::string GenScopeTreeDebugInfo(Scope* root) {
std::stringstream os; std::stringstream os;
......
...@@ -22,6 +22,7 @@ extern "C" { ...@@ -22,6 +22,7 @@ extern "C" {
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -66,6 +67,9 @@ class Scope { ...@@ -66,6 +67,9 @@ class Scope {
void EraseVars(const std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string>& var_names);
// Erase all variables except the given `vars`
void EraseVarsExcept(const std::unordered_set<Variable*>& vars);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
......
...@@ -149,7 +149,15 @@ class Tensor { ...@@ -149,7 +149,15 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; } void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; } void clear() {
holder_ = nullptr;
offset_ = 0;
}
void ShareBufferWith(const Tensor& tensor) {
holder_ = tensor.holder_;
offset_ = tensor.offset_;
}
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; } const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; } size_t offset() const { return offset_; }
......
...@@ -751,6 +751,14 @@ class SquareDoubleGradMaker ...@@ -751,6 +751,14 @@ class SquareDoubleGradMaker
} }
}; };
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -765,11 +773,8 @@ namespace plat = paddle::platform; ...@@ -765,11 +773,8 @@ namespace plat = paddle::platform;
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \ ::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \ void>::type); \
REGISTER_OPERATOR( \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ ops::ActivationGradOpInplaceInference);
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \ grad_functor) \
...@@ -794,7 +799,7 @@ REGISTER_OPERATOR( ...@@ -794,7 +799,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker); ops::ReluDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
relu_grad_grad, relu_grad_grad,
...@@ -819,7 +824,7 @@ REGISTER_OPERATOR( ...@@ -819,7 +824,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker); ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
leaky_relu_grad_grad, leaky_relu_grad_grad,
...@@ -843,7 +848,7 @@ REGISTER_OPERATOR( ...@@ -843,7 +848,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker); ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
sqrt_grad_grad, sqrt_grad_grad,
...@@ -865,7 +870,7 @@ REGISTER_OPERATOR( ...@@ -865,7 +870,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>, ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut); paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad, REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut, ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker); ops::SquareDoubleGradMaker);
REGISTER_OPERATOR( REGISTER_OPERATOR(
square_grad_grad, square_grad_grad,
......
...@@ -115,8 +115,15 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { ...@@ -115,8 +115,15 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
auto *out = context.Output<LoDTensor>("Out"); auto *out = context.Output<LoDTensor>("Out");
bool in_place = in_vars[0] == context.OutputVar("Out"); bool in_place = in_vars[0] == context.OutputVar("Out");
if (!in_place) { if (!in_place) {
out->mutable_data<T>(context.GetPlace()); auto *out_ptr = out->mutable_data<T>(context.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
} }
// Sum of two tensors // Sum of two tensors
......
...@@ -128,10 +128,15 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -128,10 +128,15 @@ class SumKernel : public framework::OpKernel<T> {
bool in_place = out_var == in_vars[0]; bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<LoDTensor>("Out"); auto *out = out_var->GetMutable<framework::LoDTensor>();
if (!in_place) { auto *out_ptr = out->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace()); if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
} }
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = auto &place =
*context.template device_context<DeviceContext>().eigen_device(); *context.template device_context<DeviceContext>().eigen_device();
......
...@@ -1549,6 +1549,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1549,6 +1549,13 @@ All parameter, weight, gradient are variables in Paddle.
"enable_inplace", "enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; }, [](const BuildStrategy &self) { return self.enable_inplace_; },
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def_property("_use_legacy_memory_optimize_strategy",
[](const BuildStrategy &self) {
return self.use_legacy_memory_optimize_strategy_;
},
[](BuildStrategy &self, bool b) {
self.use_legacy_memory_optimize_strategy_ = b;
})
.def_property( .def_property(
"fuse_all_reduce_ops", "fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; },
......
...@@ -211,34 +211,9 @@ class CompiledProgram(object): ...@@ -211,34 +211,9 @@ class CompiledProgram(object):
if self._program: if self._program:
if self._program._is_mem_optimized: if self._program._is_mem_optimized:
self._build_strategy.memory_optimize = False self._build_strategy.memory_optimize = False
self._build_strategy.enable_inplace = False
elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace:
# remind the user to try our memmory optimize strategy
six.print_(
"""
You can try our memory optimize feature to save your memory usage:
# create a build_strategy variable to set memory optimize option
build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
# pass the build_strategy to with_data_parallel API
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
!!! Memory optimize is our experimental feature !!!
some variables may be removed/reused internal to save memory usage,
in order to fetch the right value of the fetch_list, please set the
persistable property to true for each variable in fetch_list
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
""",
file=sys.stderr)
if self._build_strategy.memory_optimize:
self._build_strategy._use_legacy_memory_optimize_strategy = True
return self return self
def with_inference_optimize(self, config): def with_inference_optimize(self, config):
......
...@@ -551,7 +551,7 @@ class Executor(object): ...@@ -551,7 +551,7 @@ class Executor(object):
if not persistable: if not persistable:
logging.warn(""" logging.warn("""
Detect that memory optimize or inplace is enabled, but the some variables in the fetch Detect that build_strategy.memory_optimize = True, but the some variables in the fetch
list is not persistable, you may get wrong fetched value, or an exeception may be thrown list is not persistable, you may get wrong fetched value, or an exeception may be thrown
about cannot find variable of the fetch list. about cannot find variable of the fetch list.
...@@ -668,9 +668,8 @@ class Executor(object): ...@@ -668,9 +668,8 @@ class Executor(object):
return_numpy=return_numpy, return_numpy=return_numpy,
use_program_cache=use_program_cache) use_program_cache=use_program_cache)
else: else:
if fetch_list and program._is_data_parallel and program._program and ( if fetch_list and program._is_data_parallel and program._program and \
program._build_strategy.memory_optimize or program._build_strategy._use_legacy_memory_optimize_strategy:
program._build_strategy.enable_inplace):
self._check_fetch_vars_persistable(program._program, fetch_list) self._check_fetch_vars_persistable(program._program, fetch_list)
program._compile(scope, self.place) program._compile(scope, self.place)
......
...@@ -256,4 +256,4 @@ endif() ...@@ -256,4 +256,4 @@ endif()
set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op
PROPERTIES LABELS "RUN_TYPE=DIST") test_buffer_shared_inplace_pass PROPERTIES LABELS "RUN_TYPE=DIST")
...@@ -33,7 +33,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(cls, def check_network_convergence(cls,
method, method,
use_cuda=True, use_cuda=True,
memory_opt=True, memory_opt=False,
iter=50, iter=50,
batch_size=None, batch_size=None,
allow_op_delay=False, allow_op_delay=False,
...@@ -41,7 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None, seed=None,
use_parallel_executor=True, use_parallel_executor=True,
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=True, use_ir_memory_optimize=False,
enable_inplace=True, enable_inplace=True,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False, fuse_all_optimizer_ops=False,
...@@ -65,7 +65,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -65,7 +65,8 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed main.random_seed = seed
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
loss.persistable = True if memory_opt or use_ir_memory_optimize:
loss.persistable = True
if optimizer: if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
...@@ -88,9 +89,8 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -88,9 +89,8 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way. build_strategy.enable_inplace = enable_inplace
build_strategy.enable_inplace = False if memory_opt else enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.framework import Parameter
import numpy as np
from simple_nets import simple_fc_net
import random
import unittest
import os
batch_size = 32
feed_dict = {
'image': np.random.random([batch_size, 784]).astype('float32'),
'label': np.random.random_integers(
low=0, high=9, size=[batch_size, 1]).astype('int64')
}
class InplaceTestBase(unittest.TestCase):
def initParameter(self):
self.use_cuda = True
def setUp(self):
self.initParameter()
if self.use_cuda and fluid.core.is_compiled_with_cuda():
self.device_count = fluid.core.get_cuda_device_count()
else:
self.device_count = 4
assert batch_size % self.device_count == 0
def build_program_and_scope(self):
self.place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
startup_program = fluid.Program()
main_program = fluid.Program()
startup_program.random_seed = 1
main_program.random_seed = 1
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
loss = simple_fc_net()
adam = fluid.optimizer.Adam(learning_rate=1e-3)
adam.minimize(loss)
with fluid.scope_guard(scope):
exe = fluid.Executor(
fluid.CUDAPlace(0)
if self.use_cuda else fluid.CPUPlace())
exe.run(startup_program)
return main_program, scope, exe, loss
def is_invalid_test(self):
return self.use_cuda and not fluid.core.is_compiled_with_cuda()
def get_all_vars(self, program):
all_vars = program.global_block().vars
all_vars_name = []
for name, var in all_vars.items():
if 0 not in var.shape and not var.persistable:
all_vars_name.append(name)
return all_vars_name
def test_single_card_fetch_var(self):
if self.is_invalid_test():
return
prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope()
prog3, scope3, _, loss3 = self.build_program_and_scope()
build_strategy2 = fluid.BuildStrategy()
build_strategy2.memory_optimize = False
build_strategy2.enable_inplace = True
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel(
loss_name=loss2.name,
build_strategy=build_strategy2,
places=self.place)
build_strategy3 = fluid.BuildStrategy()
build_strategy3.memory_optimize = False
build_strategy3.enable_inplace = False
compiled_prog3 = fluid.CompiledProgram(prog3).with_data_parallel(
loss_name=loss2.name,
build_strategy=build_strategy3,
places=self.place)
all_vars_name = self.get_all_vars(prog1)
repeated_var_names = all_vars_name * 4
random.shuffle(repeated_var_names) # add some random
for fetch_var in repeated_var_names:
for _ in range(4):
with fluid.scope_guard(scope1):
fetch_val1, = exe.run(prog1,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope2):
fetch_val2, = exe.run(compiled_prog2,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope3):
fetch_val3, = exe.run(compiled_prog3,
feed=feed_dict,
fetch_list=[fetch_var])
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
self.assertTrue(np.array_equal(fetch_val1, fetch_val3))
def test_multi_card_fetch_var(self):
if self.is_invalid_test():
return
prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope()
build_strategy1 = fluid.BuildStrategy()
build_strategy1.memory_optimize = False
build_strategy1.enable_inplace = True
build_strategy2 = fluid.BuildStrategy()
build_strategy2.memory_optimize = False
build_strategy2.enable_inplace = False
if self.use_cuda:
places = fluid.cuda_places()
else:
places = fluid.cpu_places(self.device_count)
compiled_prog1 = fluid.CompiledProgram(prog1).with_data_parallel(
loss_name=loss1.name, build_strategy=build_strategy1, places=places)
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel(
loss_name=loss2.name, build_strategy=build_strategy2, places=places)
repeated_var_names = self.get_all_vars(prog1) * 4
random.shuffle(repeated_var_names) # add some random
for fetch_var in repeated_var_names:
for _ in range(4):
with fluid.scope_guard(scope1):
fetch_val1, = exe.run(compiled_prog1,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope2):
fetch_val2, = exe.run(compiled_prog2,
feed=feed_dict,
fetch_list=[fetch_var])
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
class CPUInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
if __name__ == '__main__':
unittest.main()
...@@ -61,6 +61,8 @@ class TestSoftmaxWithXe(unittest.TestCase): ...@@ -61,6 +61,8 @@ class TestSoftmaxWithXe(unittest.TestCase):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = inplace build_strategy.enable_inplace = inplace
if inplace:
build_strategy._use_legacy_memory_optimize_strategy = True
prog = fluid.CompiledProgram(fluid.default_main_program( prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel( )).with_data_parallel(
build_strategy=build_strategy, places=place) build_strategy=build_strategy, places=place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册