未验证 提交 d3003a16 编写于 作者: Z Zeng Jinle 提交者: GitHub

Feature/buffer_shared_inplace (#17911)

* feature/buffer_shared_inplace, test=develop

* refine code, test=develop

* fix elementwise_add op cpu inplace and sum inplace bug, test=develop

* add unittest and debug log, test=develop

* fix parallel_executor scope bug, polish code, test=develop

* fix sum op, activation op, single_in_place_inference bug, test=develop

* remove kLocalExecScopeName, test=develop

* fix unittest,test=develop

* fix out_var first version bug, test=develop

* follow comments,test=develop
上级 1c10dac4
......@@ -59,7 +59,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()
......
......@@ -99,10 +99,9 @@ void AllReduceOpHandle::RunImpl() {
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &local_scope = local_exec_scopes_[i];
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
local_scope->FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
......@@ -140,9 +139,7 @@ void AllReduceOpHandle::RunImpl() {
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
auto &trg = *this->local_exec_scopes_[0]
->FindVar(out_var_handles[0]->name())
->GetMutable<framework::LoDTensor>();
......@@ -151,10 +148,9 @@ void AllReduceOpHandle::RunImpl() {
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &scope = local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name());
auto *var = scope->FindVar(out_var_handles[i]->name());
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
......
......@@ -49,6 +49,9 @@ class AllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
......
......@@ -24,22 +24,20 @@ namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
Scope *local_scope) {
VLOG(3) << "InitVarsInScope";
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
auto *var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2) << info.name_
<< " has been initialized beforehand in global scope, skipped";
continue;
}
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
InitializeVariable(local_scope->Var(info.name_), info.type_);
}
}
}
......@@ -101,14 +99,17 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
local_exec_scopes_(local_exec_scopes),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
......@@ -118,7 +119,8 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
strategy_, {local_scopes_[i]}, {local_exec_scopes_[i]}, {places_[i]},
graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
......@@ -129,8 +131,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
for (size_t i = 0; i < local_scopes_.size(); ++i) {
InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
......
......@@ -36,6 +36,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
......@@ -50,6 +51,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;
......
......@@ -40,18 +40,13 @@ void BroadcastOpHandle::RunImpl() {
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
}
void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
const std::vector<Scope *> &var_scopes) {
auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
PADDLE_ENFORCE_NOT_NULL(in_var);
......@@ -140,10 +135,7 @@ void BroadcastOpHandle::BroadcastOneVar(
void BroadcastOpHandle::InitOutputValue(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const {
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto &var_scopes = local_exec_scopes_;
auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
......
......@@ -62,9 +62,11 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);
const std::vector<Scope *> &var_scopes);
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
......
......@@ -14,7 +14,9 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
......@@ -92,14 +94,13 @@ struct TestBroadcastOpHandle {
void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out");
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
param_scopes_[input_scope_idx]->Var("input");
......@@ -122,6 +123,8 @@ struct TestBroadcastOpHandle {
#endif
}
op_handle_->SetLocalExecScopes(scope_map);
nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,
......
......@@ -92,16 +92,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_relu_depthwise_conv_pass");
}
// NOTE(dzhwinter): A note for automatical inplace.
// 1. modify program desc passes should put
// before inplace pass.
// 2. manually configured inplace should put
// before inplace_pass
// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(1) << "Add inplace_pass";
AppendPass("inplace_pass");
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass");
}
}
if (strategy_.fuse_elewise_add_act_ops_) {
......@@ -160,9 +158,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) {
VLOG(1) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.memory_optimize_) {
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
}
// runtime_context_cache pass should be the last pass to enable the attr of
......
......@@ -114,7 +114,12 @@ struct BuildStrategy {
// it is not appropriate, because kStaleProgramOpDescs will be removed in the
// near future.
bool memory_optimize_{false};
bool enable_inplace_{false};
// Turn on inplace by default.
bool enable_inplace_{true};
// TODO(zjl): Remove this flag when MemoryOptimizePass is refactored
bool use_legacy_memory_optimize_strategy_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
......
......@@ -31,9 +31,7 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
if (is_lock_and_record_event_free_) {
run_func();
......
......@@ -38,6 +38,8 @@ class ComputationOpHandle : public OpHandleBase {
const Scope *GetScope() const { return scope_; }
Scope *GetScope() { return scope_; }
const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
......@@ -49,6 +51,8 @@ class ComputationOpHandle : public OpHandleBase {
bool NeedWait(VarHandleBase *in_var) override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
......
......@@ -17,6 +17,7 @@
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
......@@ -30,14 +31,13 @@ namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
ir::AtomicReferenceCountMap *ref_cnts)
ir::Node *node, Scope *scope, const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names.begin(), var_names.end()),
gc_(gc),
ref_cnts_(ref_cnts) {
place_(place),
var_infos_(vars.begin(), vars.end()),
gc_(gc) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
......@@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
}
}
#endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
PADDLE_ENFORCE(!vars.empty(), "Var names cannot be empty");
for (auto *var : var_infos_) {
PADDLE_ENFORCE_NOT_NULL(var);
}
}
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
......@@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#endif
}
void EagerDeletionOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
int dev_id =
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
events_[dev_id] = nullptr;
#endif
}
void EagerDeletionOpHandle::CallOnce() {
PADDLE_ENFORCE(vars_.empty(), "vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (auto *var_info : var_infos_) {
auto *var = exec_scope->FindVar(var_info->Name());
PADDLE_ENFORCE_NOT_NULL(var, "Variable %s should not be nullptr",
var_info->Name());
vars_.emplace_back(var);
}
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size()) {
CallOnce();
}
platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
// Reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
for (size_t i = 0; i < var_infos_.size(); ++i) {
auto *var_info = var_infos_[i];
if (var_info->IsSkipped() || !var_info->DecreaseRefCnt()) {
continue;
}
if (!exec_scope) {
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << var_info->Name() << " on " << place_;
VLOG(2) << "Erase variable " << name;
Variable *var = vars_[i];
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
......@@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() {
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
framework::ToTypeName(var->Type()), name);
framework::ToTypeName(var->Type()), var_info->Name());
}
}
......
......@@ -26,15 +26,18 @@ namespace paddle {
namespace framework {
class Scope;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
EagerDeletionOpHandle(ir::Node *node, Scope *scope,
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
ir::AtomicReferenceCountMap *ref_cnts);
const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc);
~EagerDeletionOpHandle();
......@@ -50,13 +53,20 @@ class EagerDeletionOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_;
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own
ir::AtomicReferenceCountMap *ref_cnts_; // not own
void CallOnce();
Scope *scope_;
platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own
std::vector<Variable *> vars_;
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};
......
......@@ -28,9 +28,11 @@ namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(strategy),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places),
graph_(graph),
fetch_ctxs_(places),
......@@ -143,7 +145,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......
......@@ -33,6 +33,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
......@@ -43,6 +44,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_;
ir::Graph *graph_;
......
......@@ -42,9 +42,7 @@ bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
if (is_lock_and_record_event_free_) {
run_func();
......
......@@ -44,6 +44,8 @@ struct FetchBarrierOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
bool NeedWait(VarHandleBase *in_var) override;
private:
......
......@@ -22,11 +22,13 @@ namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes)
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes) {}
FetchOpHandle::~FetchOpHandle() {}
......@@ -49,14 +51,12 @@ void FetchOpHandle::RunImpl() {
tensors_.resize(inputs_.size());
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
auto &scopes = *local_exec_scopes_;
for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_handle->name());
auto *var = scope->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_handle->name());
......
......@@ -29,7 +29,8 @@ namespace details {
struct FetchOpHandle : public OpHandleBase {
public:
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes);
~FetchOpHandle();
......@@ -44,12 +45,15 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void WaitInputVarGenerated(const platform::Place &place) override;
private:
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
std::vector<LoDTensor> tensors_;
};
......
......@@ -185,9 +185,7 @@ void FusedAllReduceOpHandle::RunImpl() {
} else {
// Special handle CPU only Operator's gradient. Like CRF
auto grad_name = grads_tensor.at(0).at(0).first;
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
auto &trg = *this->local_exec_scopes_[0]
->FindVar(grad_name)
->GetMutable<framework::LoDTensor>();
......@@ -195,9 +193,8 @@ void FusedAllReduceOpHandle::RunImpl() {
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
auto &scope = *local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope.FindVar(grad_name);
auto *dev_ctx = dev_ctxes_.at(p);
......@@ -215,8 +212,7 @@ void FusedAllReduceOpHandle::GetGradLoDTensor(
const size_t &scope_idx, const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const {
auto *local_scope =
local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *local_scope = local_exec_scopes_[scope_idx];
size_t place_num = places_.size();
for (size_t j = 0; j < in_var_handles.size(); j += place_num) {
......
......@@ -52,6 +52,8 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private:
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
......
......@@ -31,11 +31,6 @@ void FusedBroadcastOpHandle::RunImpl() {
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
......@@ -44,7 +39,7 @@ void FusedBroadcastOpHandle::RunImpl() {
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
local_exec_scopes_);
}
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/broadcast_op_handle_test.h"
......@@ -27,17 +29,16 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
for (size_t j = 0; j < input_scope_idxes.size(); ++j) {
local_scope.Var("out_var" + std::to_string(j));
if (i == j) local_scope.Var("in_var" + std::to_string(j));
}
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
// create op handle node
......@@ -60,6 +61,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
#endif
}
op_handle_->SetLocalExecScopes(scope_map);
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle
nodes_.emplace_back(ir::CreateNodeForTest("in_node" + std::to_string(i),
......
......@@ -42,10 +42,7 @@ void GatherOpHandle::RunImpl() {
out_var_handle = out_var_handles.front();
}
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto &var_scopes = local_exec_scopes_;
auto in_0_handle = in_var_handles[0];
auto pre_in_var =
......
......@@ -40,6 +40,8 @@ struct GatherOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/framework/details/gather_op_handle.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -72,14 +74,13 @@ struct TestGatherOpHandle {
void InitGatherOp(size_t input_scope_idx) {
nodes_.clear();
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("input");
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
param_scopes_[input_scope_idx]->Var("out");
......@@ -87,6 +88,9 @@ struct TestGatherOpHandle {
ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release());
op_handle_ =
new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_);
op_handle_->SetLocalExecScopes(scope_map);
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
......
......@@ -35,49 +35,55 @@ std::string OpHandleBase::DebugString() const {
OpHandleBase::~OpHandleBase() {
#ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
if (ev.second) {
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
}
}
#endif
}
void OpHandleBase::Run(bool use_cuda) {
void OpHandleBase::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
}
if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) {
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
}
if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) {
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
} else {
PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL,
"%s should have only one dev_ctx.", Name());
auto &place = dev_ctxes_.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
PADDLE_ENFORCE(
platform::is_same_place(place, out_var_handle->place()),
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
} else {
PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL,
"%s should have only one dev_ctx.", Name());
auto &place = dev_ctxes_.begin()->first;
int dev_id = boost::get<platform::CUDAPlace>(place).device;
for (auto &out_var : outputs_) {
auto *out_var_handle = dynamic_cast<VarHandle *>(out_var);
if (out_var_handle) {
PADDLE_ENFORCE(platform::is_same_place(place, out_var_handle->place()),
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
}
#else
#endif
}
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) {
InitCUDA();
}
#else
PADDLE_ENFORCE(!use_cuda);
#endif
......@@ -232,6 +238,17 @@ size_t OpHandleBase::NotReadyInputSize() const {
return res.size();
}
void OpHandleBase::SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map) {
local_exec_scopes_.clear();
auto scopes = GetLocalScopes();
for (auto *scope : scopes) {
auto iter = scope_map.find(scope);
PADDLE_ENFORCE(iter != scope_map.end(), "Local scope not found");
local_exec_scopes_.emplace_back(iter->second);
}
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -25,9 +25,10 @@
namespace paddle {
namespace framework {
namespace details {
constexpr char kLocalExecScopeName[] = "@LOCAL_EXE_SCOPE@";
class Scope;
namespace details {
// Wraps ir::Node and provide helper utilities.
// It's responsible for populating necessary fields of ir::Node.
......@@ -107,7 +108,12 @@ class OpHandleBase {
ir::Node *Node() { return node_; }
void SetLocalExecScopes(
const std::unordered_map<Scope *, Scope *> &scope_map);
protected:
virtual std::vector<Scope *> GetLocalScopes() = 0;
void RunAndRecordEvent(const std::function<void()> &callback);
void RunAndRecordEvent(platform::Place p,
......@@ -115,11 +121,15 @@ class OpHandleBase {
virtual void RunImpl() = 0;
virtual void InitCUDA();
ir::Node *node_;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
std::vector<Scope *> local_exec_scopes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
#endif
......
......@@ -83,6 +83,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) {
ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
......@@ -108,10 +109,20 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::FastThreadedSSAGraphExecutor(
strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get()));
strategy_, local_scopes_, local_exec_scopes, {places_[i]},
graphs_.at(i).get()));
}
}
std::vector<ir::Graph *> ParallelSSAGraphExecutor::Graphs() {
std::vector<ir::Graph *> result;
result.reserve(graphs_.size());
for (auto &g : graphs_) {
result.emplace_back(g.get());
}
return result;
}
FeedFetchList ParallelSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::vector<std::future<FeedFetchList>> run_futures;
......
......@@ -30,12 +30,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor {
public:
ParallelSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
~ParallelSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
std::vector<ir::Graph *> Graphs();
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include <memory>
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
......@@ -160,10 +161,7 @@ void ReduceOpHandle::RunImpl() {
auto in_0_handle = in_var_handles[0];
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto &var_scopes = local_exec_scopes_;
auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
......@@ -250,9 +248,7 @@ void ReduceOpHandle::RunImpl() {
} else {
// We sum lod_tensors to reduce_sum_trg which is in local_scopes_0
// here, but it doesn't mean reduce_sum_trg must be in local_scopes_0.
auto &reduce_sum_trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
auto &reduce_sum_trg = *this->local_exec_scopes_[0]
->FindVar(out_var_handle->name())
->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
......@@ -317,7 +313,7 @@ void ReduceOpHandle::RunImpl() {
template <typename T>
std::vector<const T *> ReduceOpHandle::GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const {
const std::vector<Scope *> &var_scopes) const {
std::vector<const T *> in_selected_rows;
for (auto *in_handle : in_var_handles) {
auto &in_sr = var_scopes.at(in_handle->scope_idx())
......
......@@ -15,6 +15,7 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
......@@ -90,6 +91,8 @@ struct ReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType>
void GatherSelectedRows(
......@@ -106,7 +109,7 @@ struct ReduceOpHandle : public OpHandleBase {
template <typename T>
std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles,
const std::vector<const Scope *> &var_scopes) const;
const std::vector<Scope *> &var_scopes) const;
};
} // namespace details
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -86,14 +87,13 @@ struct TestReduceOpHandle {
void InitReduceOp(size_t out_scope_idx) {
std::vector<std::unique_ptr<ir::Node>> nodes;
// init scope
std::unordered_map<Scope *, Scope *> scope_map;
for (size_t j = 0; j < gpu_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope &local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope *>() = &local_scope;
local_scope.Var("input");
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
param_scopes_[out_scope_idx]->Var("out");
......@@ -115,6 +115,8 @@ struct TestReduceOpHandle {
#endif
}
op_handle_->SetLocalExecScopes(scope_map);
// init op handle
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
......
......@@ -21,7 +21,7 @@ namespace framework {
namespace details {
RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
const Scope *local_scope, const std::string &name,
Scope *local_scope, const std::string &name,
const platform::Place &place)
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(op_desc)),
......@@ -41,10 +41,7 @@ void RPCOpHandle::RunImpl() {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
}
}
this->RunAndRecordEvent([this] {
op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
place_);
});
this->RunAndRecordEvent([this] { op_->Run(*local_exec_scopes_[0], place_); });
}
std::string RPCOpHandle::Name() const { return name_; }
......
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include <string>
#include <vector>
......@@ -29,7 +30,7 @@ namespace details {
struct RPCOpHandle : public OpHandleBase {
RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc,
const Scope* local_scope, const std::string& name,
Scope* local_scope, const std::string& name,
const platform::Place& place);
std::string Name() const override;
......@@ -41,9 +42,11 @@ struct RPCOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope*> GetLocalScopes() override { return {local_scope_}; }
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
Scope* local_scope_;
const std::string name_;
platform::Place place_;
};
......
......@@ -70,9 +70,9 @@ void ScaleLossGradOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
// Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>();
auto *tensor =
local_exec_scopes_[0]->FindVar(var_name)->GetMutable<LoDTensor>();
tensor->Resize(make_ddim({1}));
#ifdef PADDLE_WITH_CUDA
......
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -36,6 +37,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
float coeff_;
Scope *scope_;
......
......@@ -25,19 +25,24 @@ namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::vector<Scope *> local_exec_scopes, std::vector<VariableInfo> var_infos,
std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)),
local_exec_scopes_(std::move(local_exec_scopes)),
var_infos_(std::move(var_infos)),
places_(std::move(places)) {}
places_(std::move(places)) {
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size());
PrepareLocalExeScopes();
}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
platform::RecordEvent e("InitLocalExeScopes");
PrepareLocalExeScopes();
platform::RecordEvent e("InitLocalVars");
InitVariables();
}
std::vector<framework::LoDTensor> fetch_data;
......@@ -59,39 +64,55 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
void ScopeBufferedSSAGraphExecutor::InitVariables() {
for (auto &info : tmp_var_infos_) {
for (auto &pair : info) {
InitializeVariable(pair.first, pair.second);
}
}
}
void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() {
platform::RecordEvent drop_scope_event("DropLocalExeScopes");
drop_scope_counter_ = 0;
for (auto p : places_) {
for (auto &p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName);
if (local_scope_var != nullptr) {
auto &local_scope = *local_scope_var->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
scope->EraseVars({std::string(details::kLocalExecScopeName)});
VLOG(3) << "Drop local execution scope: " << local_scope;
for (size_t i = 0; i < local_exec_scopes_.size(); ++i) {
local_exec_scopes_[i]->EraseVarsExcept(preserve_vars_[i]);
local_exec_scopes_[i]->DropKids();
for (auto &preserve_var : preserve_vars_[i]) {
preserve_var->Clear();
}
VLOG(3) << "Drop local execution scope: " << local_scopes_[i];
}
}
void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() {
// Create local scopes.
preserve_vars_.resize(local_scopes_.size());
tmp_var_infos_.resize(local_scopes_.size());
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(kLocalExecScopeName)->GetMutable<Scope *>() = &local_scope;
size_t idx = local_scopes_.size() - 1 - (it - local_scopes_.rbegin());
auto *scope = local_scopes_[idx];
auto *local_scope = local_exec_scopes_[idx];
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
auto var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2)
<< info.name_
<< " has been initialized beforehand in global scope, skipped";
continue;
}
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
Variable *tmp_var = local_scope->Var(info.name_);
preserve_vars_[idx].emplace(tmp_var);
tmp_var_infos_[idx].emplace_back(tmp_var, info.type_);
}
}
}
......
......@@ -17,6 +17,8 @@
#include <list>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/var_handle.h"
......@@ -39,6 +41,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public:
ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<Scope*> local_exec_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
......@@ -55,10 +58,18 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
void PrepareLocalExeScopes();
private:
void InitVariables();
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
std::vector<Scope*> local_exec_scopes_;
std::vector<std::unordered_set<Variable*>> preserve_vars_;
std::vector<std::vector<std::pair<Variable*, proto::VarType::Type>>>
tmp_var_infos_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace details {
// TODO(zjl): support SelectedRows
static inline const Tensor &GetTensorFromVar(const Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
static inline Tensor *GetMutableTensorFromVar(Variable *var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable must be type of LoDTensor");
}
}
ShareTensorBufferOpHandle::ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_var_infos,
const std::vector<std::string> &out_var_names)
: OpHandleBase(node),
scope_(scope),
scope_idx_(scope_idx),
op_type_(op_type),
in_var_infos_(in_var_infos),
out_var_names_(out_var_names) {
PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size());
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
Add(in_var_infos_[i], out_var_names_[i]);
}
}
std::unordered_set<std::string> ShareTensorBufferOpHandle::ReusedVarSet()
const {
std::unordered_set<std::string> result;
for (auto &in_var_info : in_var_infos_) {
result.insert(in_var_info->Name());
}
return result;
}
void ShareTensorBufferOpHandle::Add(ir::MemOptVarInfo *in_var_info,
const std::string &out_var_name) {
PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr");
PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name,
"in/out cannot have same name: %s", out_var_name);
in_var_infos_.emplace_back(in_var_info);
out_var_names_.emplace_back(out_var_name);
}
void ShareTensorBufferOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
int dev_id =
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
events_[dev_id] = nullptr;
#endif
}
void ShareTensorBufferOpHandle::CallOnce() {
PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
auto *in_var = exec_scope->FindVar(in_var_infos_[i]->Name());
auto *out_var = exec_scope->FindVar(out_var_names_[i]);
PADDLE_ENFORCE_NOT_NULL(in_var);
PADDLE_ENFORCE_NOT_NULL(out_var);
PADDLE_ENFORCE_NE(in_var, out_var);
in_out_vars_.emplace_back(in_var, out_var);
}
}
void ShareTensorBufferOpHandle::RunImpl() {
if (in_var_infos_.size() != in_out_vars_.size()) {
CallOnce();
}
for (size_t i = 0; i < in_var_infos_.size(); ++i) {
const auto &in_tensor = GetTensorFromVar(in_out_vars_[i].first);
auto *out_tensor = GetMutableTensorFromVar(in_out_vars_[i].second);
auto *in_var_info = in_var_infos_[i];
if (UNLIKELY(in_var_info->IsSkipped())) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()
<< " in previous batch: " << in_var_info->Name() << " -> "
<< out_var_names_[i];
out_tensor->clear();
}
} else {
out_tensor->ShareBufferWith(in_tensor);
VLOG(2) << "Share tensor buffer when running " << op_type_ << " : "
<< in_var_info->Name() << " -> " << out_var_names_[i];
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
namespace paddle {
namespace framework {
class Variable;
class Scope;
class Tensor;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details {
class ShareTensorBufferOpHandle : public OpHandleBase {
public:
ShareTensorBufferOpHandle(
ir::Node *node, Scope *scope, size_t scope_idx,
const std::string &op_type,
const std::vector<ir::MemOptVarInfo *> &in_vars_infos,
const std::vector<std::string> &out_var_names);
std::unordered_set<std::string> ReusedVarSet() const;
Priority GetPriority() const override { return Priority::kHighest; }
size_t GetScopeIdx() const { return scope_idx_; }
void Add(ir::MemOptVarInfo *in_var_info, const std::string &ou_var_name);
protected:
std::string Name() const override { return "buffer_share"; }
void RunImpl() final;
void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
void CallOnce();
Scope *scope_;
size_t scope_idx_;
std::string op_type_;
std::vector<ir::MemOptVarInfo *> in_var_infos_;
std::vector<std::string> out_var_names_;
std::vector<std::pair<const Variable *, Variable *>> in_out_vars_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -58,8 +58,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
std::vector<LoDTensor *> outs;
int k = -1;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &local_scope =
local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *local_scope = local_exec_scopes_[i];
auto original_name =
paddle::framework::GradOriginalVarName(in_var_handles[i]->name());
auto encode_var_name = original_name + g_dgc_encoded;
......@@ -135,9 +134,8 @@ int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
auto var_name = original_name + g_dgc_k;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto var = local_scope->FindVar(var_name);
auto *scope = local_exec_scopes_[0];
auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto tensor = var->Get<LoDTensor>().data<float>();
return *tensor;
......@@ -151,8 +149,7 @@ bool SparseAllReduceOpHandle::IsEncoded() {
auto step_name = g_dgc_rampup_begin_step;
PADDLE_ENFORCE(local_scopes_.size() > 0);
auto *scope = local_scopes_[0];
auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *local_scope = local_exec_scopes_[0];
auto count_var = local_scope->FindVar(counter_name);
auto step_var = local_scope->FindVar(step_name);
if (count_var == nullptr || step_var == nullptr) {
......
......@@ -22,9 +22,11 @@ namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: graph_(graph),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places),
fetch_ctxs_(places),
strategy_(strategy),
......@@ -176,7 +178,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_);
auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_,
&local_exec_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......
......@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
......@@ -71,6 +72,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
// be destroyed first.
ir::Graph *graph_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_ctxs_;
ExceptionHolder exception_holder_;
......
......@@ -48,35 +48,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
PADDLE_ENFORCE(!op_desc.InputNames().empty(),
"Op inputs must not be empty");
PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
"Op outputs must not be empty");
PADDLE_ENFORCE_EQ(op_desc.InputNames().size(), 1,
"Op inputs must be unique");
PADDLE_ENFORCE_EQ(op_desc.OutputNames().size(), 1,
"Op outputs must be unique");
auto x_name = op_desc.InputNames().at(0);
auto out_name = op_desc.OutputNames().at(0);
return std::unordered_map<std::string, std::string>{{x_name, out_name}};
}
};
/*
Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy.
*/
class GradOpInplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, bool use_cuda) const override {
std::unordered_map<std::string, std::string> ret;
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
op_desc.OutputNames().end());
for (auto& input_name : op_desc.InputNames()) {
if (output_names.count(GradVarName(input_name))) {
ret.insert({input_name, GradVarName(input_name)});
}
}
return ret;
}
};
} // namespace framework
} // namespace paddle
......@@ -16,3 +16,7 @@ cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_o
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass)
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class BufferSharedInplaceOpPass : public MemoryReusePass {
protected:
std::string ReuseType() const override { return "inplace"; }
void Run(Graph *graph) const override;
};
void BufferSharedInplaceOpPass::Run(Graph *graph) const {
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
bool use_cuda = Get<bool>(kUseCuda);
// Step 1: Build a reverse map of last_live_ops
// i.e.: op -> vars
std::unordered_map<details::ComputationOpHandle *,
std::unordered_map<std::string, ir::Node *>>
candidate_ops;
for (auto &each_scope_ops : last_live_ops) {
for (auto &pair : each_scope_ops) {
// If variable has more than 1 last lived ops, this variable cannot
// be inplaced.
if (pair.second.size() != 1) {
continue;
}
auto *op = *(pair.second.begin());
const std::string &op_type = op->GetOp()->Type();
const framework::OpDesc *op_desc = op->Node()->Op();
PADDLE_ENFORCE_NOT_NULL(op_desc);
auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_;
if (!infer_inplace) {
continue;
}
const std::string &var_name = pair.first;
auto in_nodes = this->FindNodesByName(var_name, op->Node()->inputs);
if (in_nodes.size() == 1) {
candidate_ops[op][var_name] = *in_nodes.begin();
}
}
}
// Step 2: Check which vars can be inplaced indeed
for (auto &op_vars_pair : candidate_ops) {
auto *op = op_vars_pair.first;
auto &vars = op_vars_pair.second;
const std::string &op_type = op->GetOp()->Type();
auto *op_desc = op->Node()->Op();
auto in_to_outs =
OpInfoMap::Instance().Get(op_type).infer_inplace_(*op_desc, use_cuda);
for (auto &pair : in_to_outs) {
auto &in_param = pair.first;
auto &in_args = op_desc->Input(in_param);
if (in_args.empty()) {
VLOG(4) << "Cannot inplace because Input(" << in_param
<< ") is empty in " << op_type;
continue;
}
auto &in_arg = in_args[0];
auto iter = vars.find(in_arg);
if (iter == vars.end()) {
VLOG(4) << "Cannot inplace maybe because Input(" << in_param
<< ")=" << in_arg << " is not lastly used in op " << op_type
<< ", or it occurs multiple times in input or occurs in output";
continue;
}
ir::Node *in_node = iter->second;
auto &out_param = pair.second;
auto &out_args = op_desc->Output(out_param);
if (out_args.empty()) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ") is empty in " << op_type;
continue;
}
auto &out_arg = out_args[0];
auto out_nodes = this->FindNodesByName(out_arg, op->Node()->outputs);
if (out_nodes.size() != 1) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " occurs " << out_nodes.size()
<< " time(s) in output of op " << op_type;
continue;
}
auto *out_node = *out_nodes.begin();
auto &in_var_handle = in_node->Wrapper<details::VarHandleBase>();
auto &out_var_handle = out_node->Wrapper<details::VarHandleBase>();
auto *in_var_handle_ptr =
dynamic_cast<details::VarHandle *>(&in_var_handle);
auto *out_var_handle_ptr =
dynamic_cast<details::VarHandle *>(&out_var_handle);
if (in_var_handle_ptr == nullptr || out_var_handle_ptr == nullptr) {
continue;
}
bool success = this->TryReuseVar(in_var_handle_ptr, out_var_handle_ptr);
if (success) {
VLOG(4) << "Inplace performed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name()
<< ". Debug String is: " << op->GetOp()->DebugString();
} else {
VLOG(4) << "Inplace failed in op " << op_type << ": "
<< in_var_handle_ptr->Name() << " -> "
<< out_var_handle_ptr->Name();
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(buffer_shared_inplace_pass,
paddle::framework::ir::BufferSharedInplaceOpPass)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kUseCuda);
......@@ -24,6 +24,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
namespace paddle {
namespace framework {
......@@ -189,13 +190,9 @@ class EagerDeletionPass : public ir::Pass {
};
void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts =
Get<std::vector<AtomicReferenceCountMap>>(kRuntimeReferenceCount);
PADDLE_ENFORCE(ref_cnts.empty(),
"kRuntimeReferenceCount should be initialized here!");
auto &var_infos = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
ref_cnts.resize(vars.size());
const auto &last_live_ops =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
......@@ -224,10 +221,15 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
auto *eager_deletion_node =
graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation);
std::unordered_set<MemOptVarInfo *> var_info;
for (auto &var_name : var_names) {
var_info.insert(var_infos[op->GetScopeIdx()].at(var_name).get());
}
auto *eager_deletion_op = new details::EagerDeletionOpHandle(
eager_deletion_node, op->GetScope(), op->GetPlace(), var_names,
gcs.at(places[op->GetScopeIdx()]).get(),
&(ref_cnts[op->GetScopeIdx()]));
eager_deletion_node, op->GetScope(), op->GetPlace(),
std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get());
auto it = std::find_if(
op->Outputs().begin(), op->Outputs().end(),
......@@ -250,6 +252,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
graph->Get<details::GraphDepVars>(details::kGraphDepVars)
.emplace(dummy_leaf);
eager_deletion_op->AddOutput(dummy_leaf);
eager_deletion_op->SetDeviceContext(
places[op->GetScopeIdx()],
platform::DeviceContextPool::Instance().Get(places[op->GetScopeIdx()]));
}
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
......@@ -273,7 +279,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
} // namespace paddle
REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
.RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars)
.RequirePassAttr(paddle::framework::ir::kAllPlaces)
.RequirePassAttr(paddle::framework::ir::kGarbageCollector);
......
......@@ -106,6 +106,9 @@ class InplacePass : public ir::Pass {
// Check whether var is the last version one in SSA graph
bool IsLastVersionVar(ir::Node *var) const;
// Check whether var is the first version one in SSA graph
bool IsFirstVersionVar(ir::Node *var) const;
// Check whether all `ops` is the preceding ops of `op`
bool CheckOpDeps(ir::Node *op, const std::vector<ir::Node *> &ops) const;
......@@ -155,6 +158,10 @@ bool InplacePass::IsSkipVar(const std::string &var_name) const {
return skip_vars_.count(var_name) > 0;
}
bool InplacePass::IsFirstVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->front() == var;
}
bool InplacePass::IsLastVersionVar(ir::Node *var) const {
return AllVersionVars(var->Name())->back() == var;
}
......@@ -429,13 +436,19 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const {
}
if (!FindNodesByName(out_arg, op_node->inputs).empty()) {
VLOG(4) << "Cannot inplace because Output(" << in_param
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " occurs in input of op " << op_type;
continue;
}
auto *out_node = *out_nodes.begin();
if (!IsFirstVersionVar(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " does not occur first in op " << op_type;
continue;
}
if (!NodeCanReused(out_node)) {
VLOG(4) << "Cannot inplace because Output(" << out_param
<< ")=" << out_arg << " is not reusable in " << op_type;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class MemOptVarInfo {
public:
MemOptVarInfo(const std::string &name, size_t ref_cnt) : name_(name) {
SetRefCnt(ref_cnt);
}
bool DecreaseRefCnt() {
return ref_cnt_ == 1 || (runtime_ref_cnt_.fetch_sub(1) == 1);
}
void ResetRuntimeRefCnt() { runtime_ref_cnt_ = ref_cnt_; }
void SetRefCnt(size_t ref_cnt) {
PADDLE_ENFORCE_GE(ref_cnt, 1,
"Reference count must be larger than or equal to 1");
ref_cnt_ = ref_cnt;
runtime_ref_cnt_ = ref_cnt;
}
bool IsSkipped() const { return skipped_; }
void SetSkip(bool skipped) { skipped_ = skipped; }
const std::string &Name() const { return name_; }
private:
std::string name_;
size_t ref_cnt_;
std::atomic<size_t> runtime_ref_cnt_;
bool skipped_{false};
};
using MemOptVarInfoMapList = std::vector<
std::unordered_map<std::string, std::unique_ptr<MemOptVarInfo>>>;
class SkipMemOptVarsGuard {
public:
SkipMemOptVarsGuard(MemOptVarInfoMapList *list,
const std::vector<std::string> &vars,
bool need_reset_ref_cnt)
: list_(list), need_reset_ref_cnt_(need_reset_ref_cnt) {
if (!list_) return;
skip_vars_.reserve(vars.size() * list->size());
for (auto &var : vars) {
for (auto &map : *list_) {
auto iter = map.find(var);
if (iter != map.end() && !iter->second->IsSkipped()) {
iter->second->SetSkip(true);
skip_vars_.emplace_back(iter->second.get());
}
}
}
}
~SkipMemOptVarsGuard() {
for (auto *var : skip_vars_) {
var->SetSkip(false);
}
if (list_ && need_reset_ref_cnt_) {
for (auto &map : *list_) {
for (auto &pair : map) {
pair.second->ResetRuntimeRefCnt();
}
}
}
}
private:
MemOptVarInfoMapList *list_;
bool need_reset_ref_cnt_;
std::vector<MemOptVarInfo *> skip_vars_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h"
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
// Each ShareTensorBufferOpHandle should only have one pending
// ComputationOpHandle
static details::ComputationOpHandle *GetUniquePendingComputationOpHandle(
details::ShareTensorBufferOpHandle *share_tensor_op) {
details::ComputationOpHandle *result_op = nullptr;
for (Node *out_var : share_tensor_op->Node()->outputs) {
for (Node *pending_op : out_var->outputs) {
auto &op = pending_op->Wrapper<details::OpHandleBase>();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(&op);
PADDLE_ENFORCE_NOT_NULL(compute_op);
if (result_op == nullptr) {
result_op = compute_op;
} else {
PADDLE_ENFORCE_EQ(result_op, compute_op);
}
}
}
PADDLE_ENFORCE_NOT_NULL(result_op);
return result_op;
}
void MemoryReusePass::ApplyImpl(Graph *graph) const {
graph_ = graph;
all_vars_ = &(graph_->Get<details::GraphVars>(details::kGraphVars));
var_infos_ = &(Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList));
last_live_ops_of_vars_ =
&(Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars));
reused_var_names_.resize(all_vars_->size());
var_descs_.resize(all_vars_->size());
// Collect the existing ShareTensorBufferOpHandles.
// This is because (1) we want to reuse the existing
// ShareTensorBufferOpHandles to avoid inserting too many ops;
// (2) what is more important, a variable cannot be reused
// by two different variables, which may cause wrong calculation
// results. We have to know which variables have been reused.
CollectShareTensorBufferOpHandles();
CollectReusedVars();
Run(graph);
std::map<size_t, size_t> op_num;
for (auto &pair : ops_) {
++op_num[pair.first->GetScopeIdx()];
}
for (auto &pair : op_num) {
VLOG(2) << "Create " << pair.second
<< " ShareTensorBufferOpHandles in Scope " << pair.first;
}
}
bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const {
auto *op =
dynamic_cast<details::ComputationOpHandle *>(out_var->GeneratedOp());
PADDLE_ENFORCE_NOT_NULL(op);
if (IsVarsReusable(in_var, out_var)) {
AddReuseVar(op, in_var, out_var);
return true;
} else {
return false;
}
}
std::unordered_set<Node *> MemoryReusePass::FindNodesByName(
const std::string &name, const std::vector<Node *> &nodes) const {
std::unordered_set<ir::Node *> ret;
for (auto *node : nodes) {
if (node->Name() == name) {
ret.insert(node);
}
}
return ret;
}
VarDesc *MemoryReusePass::GetVarDesc(details::VarHandle *var) const {
auto iter = var_descs_[var->scope_idx()].find(var->Name());
if (iter == var_descs_[var->scope_idx()].end()) {
PADDLE_ENFORCE((*all_vars_)[var->scope_idx()].count(var->Name()),
"Variable %s not found", var->Name());
auto *desc =
TryGetLatestVarDesc((*all_vars_)[var->scope_idx()].at(var->Name()));
PADDLE_ENFORCE_NOT_NULL(desc);
var_descs_[var->scope_idx()].emplace(var->Name(), desc);
return desc;
} else {
return iter->second;
}
}
void MemoryReusePass::CollectShareTensorBufferOpHandles() const {
auto all_ops = FilterByNodeWrapper<details::OpHandleBase>(*graph_);
for (auto *op : all_ops) {
auto *share_buffer_op =
dynamic_cast<details::ShareTensorBufferOpHandle *>(op);
if (share_buffer_op != nullptr) {
auto *compute_op = GetUniquePendingComputationOpHandle(share_buffer_op);
PADDLE_ENFORCE(ops_.count(compute_op) == 0);
ops_.emplace(compute_op, share_buffer_op);
}
}
}
void MemoryReusePass::CollectReusedVars() const {
for (auto &pair : ops_) {
auto reused_vars = pair.second->ReusedVarSet();
reused_var_names_[pair.first->GetScopeIdx()].insert(reused_vars.begin(),
reused_vars.end());
}
}
bool MemoryReusePass::IsVarAlreadyReused(details::VarHandle *var) const {
return reused_var_names_[var->scope_idx()].count(var->Name()) > 0;
}
details::ShareTensorBufferOpHandle *
MemoryReusePass::InsertShareTensorBufferOpHandleToGraph(
details::ComputationOpHandle *op) const {
auto *buffer_share_node =
graph_->CreateEmptyNode("buffer_share", ir::Node::Type::kOperation);
auto *buffer_share_op = new details::ShareTensorBufferOpHandle(
buffer_share_node, op->GetScope(), op->GetScopeIdx(), op->GetOp()->Type(),
{}, {});
buffer_share_op->SetDeviceContext(
op->GetPlace(),
platform::DeviceContextPool::Instance().Get(op->GetPlace()));
// Inputs of `buffer_share_op` should be all inputs of `op`
for (auto *in_var : op->Inputs()) {
buffer_share_op->AddInput(in_var);
}
// Add a dep_var to resolve write-after-write data hazard between
// `buffer_share_op` and `op`.
auto *dep_var = new details::DummyVarHandle(graph_->CreateControlDepVar());
graph_->Get<details::GraphDepVars>(details::kGraphDepVars).emplace(dep_var);
op->AddInput(dep_var);
buffer_share_op->AddOutput(dep_var);
ops_.emplace(op, buffer_share_op);
return buffer_share_op;
}
bool MemoryReusePass::IsVarsReusable(details::VarHandle *in_var,
details::VarHandle *out_var) const {
const auto in_name = in_var->Name();
const auto out_name = out_var->Name();
if (in_name == out_name) {
return false;
}
if (in_name == kEmptyVarName || out_name == kEmptyVarName) {
return false;
}
if (IsVarAlreadyReused(in_var)) {
return false;
}
// out_var must be the first version!!!
auto out_var_iter = (*all_vars_)[out_var->scope_idx()].find(out_name);
PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var->scope_idx()].end() &&
!out_var_iter->second.empty(),
"Cannot find variable %s", out_name);
if (out_var_iter->second[0] != out_var) {
return false;
}
const VarDesc *in_var_desc = GetVarDesc(in_var);
const VarDesc *out_var_desc = GetVarDesc(out_var);
if (in_var_desc->Persistable() || out_var_desc->Persistable()) {
return false;
}
if (in_var_desc->GetType() != proto::VarType::LOD_TENSOR ||
out_var_desc->GetType() != proto::VarType::LOD_TENSOR) {
return false;
}
if (!FindNodesByName(in_name, out_var->GeneratedOp()->Node()->outputs)
.empty()) {
return false;
}
if (!FindNodesByName(out_name, out_var->GeneratedOp()->Node()->inputs)
.empty()) {
return false;
}
auto all_input_args =
out_var->GeneratedOp()->Node()->Op()->InputArgumentNames();
if (std::count(all_input_args.begin(), all_input_args.end(), in_name) > 1) {
return false;
}
return true;
}
void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const {
PADDLE_ENFORCE((*var_infos_)[op->GetScopeIdx()].count(in_var->Name()) > 0,
"%s does not in mem-opt var infos", in_var->Name());
if (ops_.count(op) == 0) {
InsertShareTensorBufferOpHandleToGraph(op);
}
auto *share_buffer_op = ops_[op];
auto &all_input_vars = share_buffer_op->Inputs();
bool has_input = std::find(all_input_vars.begin(), all_input_vars.end(),
in_var) != all_input_vars.end();
if (!has_input) {
share_buffer_op->AddInput(in_var);
}
share_buffer_op->Add(
(*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(),
out_var->Name());
reused_var_names_[op->GetScopeIdx()].insert(in_var->Name());
UpdateLastLiveOpOfVar(op, in_var, out_var);
}
// 1. Set last living op of in_var to be any last living op of out_var
// 2. Set reference count of in_var to be 1
void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const {
size_t scope_idx = op->GetScopeIdx();
auto out_var_op_iter =
(*last_live_ops_of_vars_)[scope_idx].find(out_var->Name());
PADDLE_ENFORCE(out_var_op_iter != (*last_live_ops_of_vars_)[scope_idx].end(),
"Cannot find variable %s", out_var->Name());
PADDLE_ENFORCE(!out_var_op_iter->second.empty());
auto &last_live_ops_of_in_var =
(*last_live_ops_of_vars_)[scope_idx][in_var->Name()];
last_live_ops_of_in_var.clear();
last_live_ops_of_in_var.insert(*(out_var_op_iter->second.begin()));
auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name());
PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(),
"Cannot find variable %s", in_var->Name());
in_var_info_iter->second->SetRefCnt(1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* MemoryReusePass is the base class of InplacePass and MemoryOptimizePass.
*
* Unlike the legacy Python API fluid.memory_optimize() which changes
* variable names in the program/graph, MemoryReusePass inserts
* ShareTensorBufferOpHandle into the graph. It is because if we use the
* way of changing variable names:
*
* 1. There are so many corner cases we should skip. For example, (1) variables
* that relates to send/recv ops cannot be renamed (otherwise, pserver
* and trainer cannot find the matching variables), (2) ins/outs of ops
* containing sub-blocks cannot be optimized, (3) variables inside
* op_role_vars cannot be renamed.
*
* 2. It is very difficult to avoid reusing variables that users want to fetch.
* This is because the memory-optimize passes/transpiler runs before users
* fetch, i.e., exe.run(...). We cannot know what users want to fetch in the
* future. As a result, we have to set var.persistable = True before
* applying memory-optimize passes/transpiler, which is rather ugly and not
* friendly to users.
*
* 3. Dim and LoD of the reused variable would be changed, which may result
* in potential errors in InferShape stage of the following ops. What's
* more, it makes that we cannot use the information from
* NoNeedBufferVarsInference.
*
* Considering the drawbacks of the former renaming strategy, we design a
* novel memory-optimize pass to fix these issues. Whether in-place is
* performed can be decided during run-time. ShareTensorBufferOpHandle
* would only share tensor buffers between in/out, never rename variable,
* and not change dim and LoD of variable. If users want to fetch a certain
* variable, we can skip in-place during run-time.
*
* The only concern on speed performance may be: there are too many
* ShareTensorBufferOpHandles in the graph. This can be avoided by moving
* tensor buffer sharing in each ComputationOpHandle::Run() method. We need
* a pass to clean all ShareTensorBufferOpHandles and move sharing to
* ComputationOpHandle::Run() in the future.
*/
class MemoryReusePass : public Pass {
protected:
void ApplyImpl(Graph *graph) const final;
virtual void Run(Graph *graph) const = 0;
virtual std::string ReuseType() const = 0;
bool TryReuseVar(details::VarHandle *in_var,
details::VarHandle *out_var) const;
std::unordered_set<ir::Node *> FindNodesByName(
const std::string &name, const std::vector<ir::Node *> &nodes) const;
size_t ScopeNum() const { return all_vars_->size(); }
private:
VarDesc *GetVarDesc(details::VarHandle *var) const;
bool IsVarsReusable(details::VarHandle *in_var,
details::VarHandle *out_var) const;
bool IsVarAlreadyReused(details::VarHandle *var) const;
details::ShareTensorBufferOpHandle *InsertShareTensorBufferOpHandleToGraph(
details::ComputationOpHandle *op) const;
void CollectShareTensorBufferOpHandles() const;
void CollectReusedVars() const;
void AddReuseVar(details::ComputationOpHandle *op, details::VarHandle *in_var,
details::VarHandle *out_var) const;
void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op,
details::VarHandle *in_var,
details::VarHandle *out_var) const;
private:
mutable Graph *graph_;
mutable details::GraphVars *all_vars_;
mutable MemOptVarInfoMapList *var_infos_;
mutable std::vector<LastLiveOpsOfVars> *last_live_ops_of_vars_;
mutable std::unordered_map<details::ComputationOpHandle *,
details::ShareTensorBufferOpHandle *>
ops_;
mutable std::vector<std::unordered_set<std::string>> reused_var_names_;
mutable std::vector<std::unordered_map<std::string, VarDesc *>> var_descs_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -26,6 +26,7 @@
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -295,18 +296,18 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx,
}
void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
auto &ref_cnts = Get<std::vector<ReferenceCountMap>>(kGlobalReferenceCount);
auto &var_infos = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
auto &last_live_ops_of_vars =
Get<std::vector<LastLiveOpsOfVars>>(kLastLiveOpsOfVars);
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(),
PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(),
"Last Live Ops and Reference Counts of vars should be "
"initialized at here.");
const auto &vars = graph->Get<details::GraphVars>(details::kGraphVars);
last_live_ops_of_vars.resize(vars.size());
ref_cnts.resize(vars.size());
var_infos.resize(vars.size());
ShrinkDepsOpFunctor shrink_func(
ir::FilterByNodeWrapper<details::OpHandleBase>(*graph));
......@@ -359,7 +360,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
var_name);
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
ref_cnts[i].emplace(var_name, result.size());
var_infos[i][var_name].reset(
new MemOptVarInfo(var_name, result.size()));
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
break;
}
......@@ -375,5 +377,5 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const {
} // namespace paddle
REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass)
.RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount)
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList)
.RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars);
......@@ -33,16 +33,10 @@ class VarDesc;
namespace ir {
using ReferenceCountMap = std::unordered_map<std::string, size_t>;
using AtomicReferenceCountMap =
std::unordered_map<std::string, std::atomic<size_t>>;
using GarbageCollectorMap =
std::map<platform::Place, std::unique_ptr<GarbageCollector>>;
const char kGlobalReferenceCount[] = "global_reference_count";
const char kRuntimeReferenceCount[] = "runtime_reference_count";
const char kMemOptVarInfoMapList[] = "mem_opt_var_info_map_list";
const char kGarbageCollector[] = "garbage_collector";
const char kAllPlaces[] = "all_places";
......
......@@ -89,7 +89,12 @@ class Node {
// Return a reference to the `wrapper`.
template <typename T>
T& Wrapper() {
return *boost::any_cast<T*>(wrapper_);
try {
return *boost::any_cast<T*>(wrapper_);
} catch (boost::bad_any_cast&) {
PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s",
typeid(T).name(), wrapper_type_.name());
}
}
// Test if the Node is wrapped by type T.
......
......@@ -22,11 +22,13 @@ limitations under the License. */
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -76,24 +78,10 @@ class ParallelExecutorPrivate {
}
}
ir::Graph *PrepareGCAndRefCnts(ir::Graph *graph, size_t max_memory_size);
ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph);
inline bool HasGarbageCollectors() const { return !gcs_.empty(); }
void ResetRuntimeReferenceCount(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) {
for (auto &pair : global_ref_cnts_[i]) {
runtime_ref_cnts_[i][pair.first] = pair.second;
}
for (auto &fetch_name : fetch_tensors) {
runtime_ref_cnts_[i].erase(fetch_name);
}
runtime_ref_cnts_[i].erase(fetched_var_name);
}
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) {
VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_
......@@ -201,12 +189,20 @@ class ParallelExecutorPrivate {
}
#endif
inline bool IsPersistable(const std::string &name) const {
auto iter = is_persistable_.find(name);
return iter != is_persistable_.end() && iter->second;
}
BuildStrategy build_strategy_;
std::vector<platform::Place> places_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
Scope *global_scope_; // not owned
std::unique_ptr<details::SSAGraphExecutor> executor_;
std::unordered_map<std::string, bool> is_persistable_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
platform::NCCLCommunicator *nccl_ctxs_{nullptr};
#endif
......@@ -215,16 +211,37 @@ class ParallelExecutorPrivate {
bool use_all_reduce_;
size_t nranks_;
// global_ref_cnts_ is only initialized when ParallelExecutor constructs, and
// then keeps unchanged
// Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_
std::vector<ir::ReferenceCountMap> global_ref_cnts_;
std::vector<ir::AtomicReferenceCountMap> runtime_ref_cnts_;
ir::MemOptVarInfoMapList mem_opt_var_infos_;
ir::GarbageCollectorMap gcs_;
};
ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
ir::Graph *graph, size_t max_memory_size) {
ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
if (build_strategy_.enable_inplace_) {
auto inplace_pass =
ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass");
inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_);
inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
inplace_pass->SetNotOwned(ir::kUseCuda, &use_cuda_);
VLOG(10) << "Start to apply buffer_shared_inplace_pass";
graph = inplace_pass->Apply(graph);
VLOG(10) << "buffer_shared_inplace_pass Applied";
}
// TODO(zjl): refactor MemoryOptimizePass as well!!!
if (GetEagerDeletionThreshold() < 0) {
return graph;
}
size_t max_memory_size = static_cast<size_t>(GetEagerDeletionThreshold());
for (size_t i = 0; i < places_.size(); ++i) {
auto &place = places_[i];
if (gcs_.count(place) > 0) {
......@@ -258,19 +275,10 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts(
}
if (!gcs_.empty()) {
std::vector<ir::LastLiveOpsOfVars> last_live_ops_of_vars;
auto ref_cnt_pass =
ir::PassRegistry::Instance().Get("reference_count_pass");
ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_);
ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars);
graph = ref_cnt_pass->Apply(graph);
VLOG(10) << "ReferenceCountPass Applied";
auto eager_deletion_pass =
ir::PassRegistry::Instance().Get("eager_deletion_pass");
eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount,
&runtime_ref_cnts_);
eager_deletion_pass->SetNotOwned(ir::kMemOptVarInfoMapList,
&mem_opt_var_infos_);
eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_);
eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars,
&last_live_ops_of_vars);
......@@ -386,9 +394,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
// same communicators.
auto *nccl_ctxs =
member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_);
auto &pool = platform::DeviceContextPool::Instance();
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id]));
auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]);
......@@ -456,13 +463,7 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
}
#endif
auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(graph,
static_cast<size_t>(max_memory_size));
}
graph = member_->ApplyMemoryOptimizePass(graph);
async_graphs[0] = graph;
......@@ -475,6 +476,9 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
member_->is_persistable_.emplace(node->Var()->Name(),
node->Var()->Persistable());
}
}
......@@ -493,17 +497,34 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
}
}
std::unordered_map<Scope *, Scope *> scope_map;
for (auto *scope : member_->local_scopes_) {
auto &local_exec_scope = scope->NewScope();
member_->local_exec_scopes_.emplace_back(&local_exec_scope);
scope_map.emplace(scope, &local_exec_scope);
}
PADDLE_ENFORCE_EQ(member_->local_scopes_.size(),
member_->local_exec_scopes_.size());
std::vector<ir::Graph *> final_graphs;
if (member_->build_strategy_.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, async_graphs));
final_graphs = async_graphs;
} else if (member_->build_strategy_.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr.
member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph));
auto *pg_exe = new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph);
final_graphs = pg_exe->Graphs();
member_->executor_.reset(pg_exe);
#else
PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution.");
......@@ -512,19 +533,29 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph));
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph));
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
}
final_graphs.emplace_back(graph);
}
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!member_->build_strategy_.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
std::move(var_infos), member_->places_, std::move(member_->executor_)));
}
for (auto *g : final_graphs) {
auto ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*g);
for (auto *op : ops) {
op->SetLocalExecScopes(scope_map);
}
}
}
......@@ -616,10 +647,9 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
#endif
platform::RecordBlock b(0);
if (member_->HasGarbageCollectors()) {
platform::RecordEvent event("PrepareGarbageCollectors");
member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name);
}
ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors,
member_->HasGarbageCollectors());
VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run";
auto fetch_data = member_->executor_->Run(fetch_tensors);
......@@ -633,9 +663,13 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
for (size_t i = 0; i < tensors.size(); ++i) {
auto &map = tensors[i];
auto *scope = member_->local_scopes_[i];
for (auto &pair : map) {
auto *trg = scope->Var(pair.first)->GetMutable<LoDTensor>();
bool is_persistable = member_->IsPersistable(pair.first);
auto *feed_scope = is_persistable ? member_->local_scopes_[i]
: member_->local_exec_scopes_[i];
auto *feed_var = feed_scope->Var(pair.first);
auto *trg = feed_var->GetMutable<LoDTensor>();
trg->ShareDataWith(pair.second);
trg->set_lod(pair.second.lod());
}
......@@ -644,7 +678,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes(
void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
const std::unordered_map<std::string, LoDTensor> &tensors) {
for (auto pair : tensors) {
for (auto &pair : tensors) {
auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
if (member_->places_.size() != lod_tensors.size()) {
bool is_cpu_place = platform::is_cpu_place(member_->places_.front());
......@@ -661,10 +695,14 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
}
PADDLE_THROW(error_info);
}
bool is_persistable = member_->IsPersistable(pair.first);
for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var?
auto t =
member_->local_scopes_[j]->Var(pair.first)->GetMutable<LoDTensor>();
auto *feed_scope = is_persistable ? member_->local_scopes_[j]
: member_->local_exec_scopes_[j];
auto *feed_var = feed_scope->Var(pair.first);
auto t = feed_var->GetMutable<LoDTensor>();
t->ShareDataWith(lod_tensors[j]);
t->set_lod(lod_tensors[j].lod());
}
......@@ -724,3 +762,4 @@ bool ParallelExecutor::EnableParallelGraphExecution(
USE_PASS(reference_count_pass);
USE_PASS(eager_deletion_pass);
USE_PASS(buffer_shared_inplace_pass);
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
......
......@@ -200,6 +200,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr;
}
void Scope::EraseVarsExcept(const std::unordered_set<Variable*>& vars) {
SCOPE_VARS_WRITER_LOCK
for (auto iter = vars_.begin(); iter != vars_.end();) {
if (vars.count(iter->second.get()) != 0) {
++iter;
} else {
vars_.erase(iter++);
}
}
}
std::string GenScopeTreeDebugInfo(Scope* root) {
std::stringstream os;
......
......@@ -22,6 +22,7 @@ extern "C" {
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -66,6 +67,9 @@ class Scope {
void EraseVars(const std::vector<std::string>& var_names);
// Erase all variables except the given `vars`
void EraseVarsExcept(const std::unordered_set<Variable*>& vars);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
......
......@@ -149,7 +149,15 @@ class Tensor {
void set_layout(const DataLayout layout) { layout_ = layout; }
void clear() { holder_ = nullptr; }
void clear() {
holder_ = nullptr;
offset_ = 0;
}
void ShareBufferWith(const Tensor& tensor) {
holder_ = tensor.holder_;
offset_ = tensor.offset_;
}
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; }
......
......@@ -751,6 +751,14 @@ class SquareDoubleGradMaker
}
};
class ActivationGradOpInplaceInference : public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc, bool use_cuda) const override {
return {{framework::GradVarName("Out"), framework::GradVarName("X")}};
}
};
} // namespace operators
} // namespace paddle
......@@ -765,11 +773,8 @@ namespace plat = paddle::platform;
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
REGISTER_OPERATOR( \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
......@@ -794,7 +799,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
......@@ -819,7 +824,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
......@@ -843,7 +848,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
sqrt_grad_grad,
......@@ -865,7 +870,7 @@ REGISTER_OPERATOR(
ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
paddle::framework::SingleOpInplaceInToOut,
ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
square_grad_grad,
......
......@@ -115,8 +115,15 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
auto *out = context.Output<LoDTensor>("Out");
bool in_place = in_vars[0] == context.OutputVar("Out");
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
auto *out_ptr = out->mutable_data<T>(context.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
}
// Sum of two tensors
......
......@@ -128,10 +128,15 @@ class SumKernel : public framework::OpKernel<T> {
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<LoDTensor>("Out");
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
auto *out = out_var->GetMutable<framework::LoDTensor>();
auto *out_ptr = out->mutable_data<T>(context.GetPlace());
if (in_num >= 1 && in_vars[0]->IsType<framework::LoDTensor>()) {
auto &in_0_tensor = in_vars[0]->Get<framework::LoDTensor>();
if (in_0_tensor.numel() > 0) {
in_place = (in_0_tensor.data<T>() == out_ptr);
}
}
auto result = EigenVector<T>::Flatten(*out);
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
......
......@@ -1549,6 +1549,13 @@ All parameter, weight, gradient are variables in Paddle.
"enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; },
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def_property("_use_legacy_memory_optimize_strategy",
[](const BuildStrategy &self) {
return self.use_legacy_memory_optimize_strategy_;
},
[](BuildStrategy &self, bool b) {
self.use_legacy_memory_optimize_strategy_ = b;
})
.def_property(
"fuse_all_reduce_ops",
[](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; },
......
......@@ -211,34 +211,9 @@ class CompiledProgram(object):
if self._program:
if self._program._is_mem_optimized:
self._build_strategy.memory_optimize = False
self._build_strategy.enable_inplace = False
elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace:
# remind the user to try our memmory optimize strategy
six.print_(
"""
You can try our memory optimize feature to save your memory usage:
# create a build_strategy variable to set memory optimize option
build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.memory_optimize = True
# pass the build_strategy to with_data_parallel API
compiled_prog = compiler.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
!!! Memory optimize is our experimental feature !!!
some variables may be removed/reused internal to save memory usage,
in order to fetch the right value of the fetch_list, please set the
persistable property to true for each variable in fetch_list
# Sample
conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None)
# if you need to fetch conv1, then:
conv1.persistable = True
""",
file=sys.stderr)
if self._build_strategy.memory_optimize:
self._build_strategy._use_legacy_memory_optimize_strategy = True
return self
def with_inference_optimize(self, config):
......
......@@ -551,7 +551,7 @@ class Executor(object):
if not persistable:
logging.warn("""
Detect that memory optimize or inplace is enabled, but the some variables in the fetch
Detect that build_strategy.memory_optimize = True, but the some variables in the fetch
list is not persistable, you may get wrong fetched value, or an exeception may be thrown
about cannot find variable of the fetch list.
......@@ -668,9 +668,8 @@ class Executor(object):
return_numpy=return_numpy,
use_program_cache=use_program_cache)
else:
if fetch_list and program._is_data_parallel and program._program and (
program._build_strategy.memory_optimize or
program._build_strategy.enable_inplace):
if fetch_list and program._is_data_parallel and program._program and \
program._build_strategy._use_legacy_memory_optimize_strategy:
self._check_fetch_vars_persistable(program._program, fetch_list)
program._compile(scope, self.place)
......
......@@ -256,4 +256,4 @@ endif()
set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist
test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op
PROPERTIES LABELS "RUN_TYPE=DIST")
test_buffer_shared_inplace_pass PROPERTIES LABELS "RUN_TYPE=DIST")
......@@ -33,7 +33,7 @@ class TestParallelExecutorBase(unittest.TestCase):
def check_network_convergence(cls,
method,
use_cuda=True,
memory_opt=True,
memory_opt=False,
iter=50,
batch_size=None,
allow_op_delay=False,
......@@ -41,7 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase):
seed=None,
use_parallel_executor=True,
use_reduce=False,
use_ir_memory_optimize=True,
use_ir_memory_optimize=False,
enable_inplace=True,
fuse_elewise_add_act_ops=False,
fuse_all_optimizer_ops=False,
......@@ -65,7 +65,8 @@ class TestParallelExecutorBase(unittest.TestCase):
main.random_seed = seed
loss = method(use_feed=feed_dict is not None)
loss.persistable = True
if memory_opt or use_ir_memory_optimize:
loss.persistable = True
if optimizer:
optimizer().minimize(loss)
......@@ -88,9 +89,8 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize
build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops
build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops
# python memory optimization is conflict with inplace pass.
# Use ir graph memory optimization after inplace pass is the correct way.
build_strategy.enable_inplace = False if memory_opt else enable_inplace
build_strategy.enable_inplace = enable_inplace
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.framework import Parameter
import numpy as np
from simple_nets import simple_fc_net
import random
import unittest
import os
batch_size = 32
feed_dict = {
'image': np.random.random([batch_size, 784]).astype('float32'),
'label': np.random.random_integers(
low=0, high=9, size=[batch_size, 1]).astype('int64')
}
class InplaceTestBase(unittest.TestCase):
def initParameter(self):
self.use_cuda = True
def setUp(self):
self.initParameter()
if self.use_cuda and fluid.core.is_compiled_with_cuda():
self.device_count = fluid.core.get_cuda_device_count()
else:
self.device_count = 4
assert batch_size % self.device_count == 0
def build_program_and_scope(self):
self.place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
startup_program = fluid.Program()
main_program = fluid.Program()
startup_program.random_seed = 1
main_program.random_seed = 1
scope = fluid.Scope()
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
loss = simple_fc_net()
adam = fluid.optimizer.Adam(learning_rate=1e-3)
adam.minimize(loss)
with fluid.scope_guard(scope):
exe = fluid.Executor(
fluid.CUDAPlace(0)
if self.use_cuda else fluid.CPUPlace())
exe.run(startup_program)
return main_program, scope, exe, loss
def is_invalid_test(self):
return self.use_cuda and not fluid.core.is_compiled_with_cuda()
def get_all_vars(self, program):
all_vars = program.global_block().vars
all_vars_name = []
for name, var in all_vars.items():
if 0 not in var.shape and not var.persistable:
all_vars_name.append(name)
return all_vars_name
def test_single_card_fetch_var(self):
if self.is_invalid_test():
return
prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope()
prog3, scope3, _, loss3 = self.build_program_and_scope()
build_strategy2 = fluid.BuildStrategy()
build_strategy2.memory_optimize = False
build_strategy2.enable_inplace = True
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel(
loss_name=loss2.name,
build_strategy=build_strategy2,
places=self.place)
build_strategy3 = fluid.BuildStrategy()
build_strategy3.memory_optimize = False
build_strategy3.enable_inplace = False
compiled_prog3 = fluid.CompiledProgram(prog3).with_data_parallel(
loss_name=loss2.name,
build_strategy=build_strategy3,
places=self.place)
all_vars_name = self.get_all_vars(prog1)
repeated_var_names = all_vars_name * 4
random.shuffle(repeated_var_names) # add some random
for fetch_var in repeated_var_names:
for _ in range(4):
with fluid.scope_guard(scope1):
fetch_val1, = exe.run(prog1,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope2):
fetch_val2, = exe.run(compiled_prog2,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope3):
fetch_val3, = exe.run(compiled_prog3,
feed=feed_dict,
fetch_list=[fetch_var])
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
self.assertTrue(np.array_equal(fetch_val1, fetch_val3))
def test_multi_card_fetch_var(self):
if self.is_invalid_test():
return
prog1, scope1, exe, loss1 = self.build_program_and_scope()
prog2, scope2, _, loss2 = self.build_program_and_scope()
build_strategy1 = fluid.BuildStrategy()
build_strategy1.memory_optimize = False
build_strategy1.enable_inplace = True
build_strategy2 = fluid.BuildStrategy()
build_strategy2.memory_optimize = False
build_strategy2.enable_inplace = False
if self.use_cuda:
places = fluid.cuda_places()
else:
places = fluid.cpu_places(self.device_count)
compiled_prog1 = fluid.CompiledProgram(prog1).with_data_parallel(
loss_name=loss1.name, build_strategy=build_strategy1, places=places)
compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel(
loss_name=loss2.name, build_strategy=build_strategy2, places=places)
repeated_var_names = self.get_all_vars(prog1) * 4
random.shuffle(repeated_var_names) # add some random
for fetch_var in repeated_var_names:
for _ in range(4):
with fluid.scope_guard(scope1):
fetch_val1, = exe.run(compiled_prog1,
feed=feed_dict,
fetch_list=[fetch_var])
with fluid.scope_guard(scope2):
fetch_val2, = exe.run(compiled_prog2,
feed=feed_dict,
fetch_list=[fetch_var])
self.assertTrue(np.array_equal(fetch_val1, fetch_val2))
class CPUInplaceTest(InplaceTestBase):
def initParameter(self):
self.use_cuda = False
if __name__ == '__main__':
unittest.main()
......@@ -61,6 +61,8 @@ class TestSoftmaxWithXe(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = inplace
if inplace:
build_strategy._use_legacy_memory_optimize_strategy = True
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
build_strategy=build_strategy, places=place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册