diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 30f61ad1a4c5bf80b560cbfa1d05414bfcc4988d..f1a54e914fe7c37543aae0814ba9b32b9e7c9af4 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -59,7 +59,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper) -set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass) +cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope) + +set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) endif() diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 2f001e54d4f668537953bbaeb14aa21e6745009f..f806a4fa84775c8814846a4f3f33eee3f7034d9d 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -99,10 +99,9 @@ void AllReduceOpHandle::RunImpl() { std::vector lod_tensors; for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto *s = local_scopes_[i]; - auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); + auto &local_scope = local_exec_scopes_[i]; auto &lod_tensor = - local_scope.FindVar(in_var_handles[i]->name())->Get(); + local_scope->FindVar(in_var_handles[i]->name())->Get(); lod_tensors.emplace_back(&lod_tensor); VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name() << ", out_name:" << out_var_handles[i]->name(); @@ -140,9 +139,7 @@ void AllReduceOpHandle::RunImpl() { PADDLE_THROW("Not compiled with CUDA"); #endif } else { // Special handle CPU only Operator's gradient. Like CRF - auto &trg = *this->local_scopes_[0] - ->FindVar(kLocalExecScopeName) - ->Get() + auto &trg = *this->local_exec_scopes_[0] ->FindVar(out_var_handles[0]->name()) ->GetMutable(); @@ -151,10 +148,9 @@ void AllReduceOpHandle::RunImpl() { VisitDataType(lod_tensors[0]->type(), func); for (size_t i = 1; i < local_scopes_.size(); ++i) { - auto &scope = - *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto &scope = local_exec_scopes_[i]; auto &p = places_[i]; - auto *var = scope.FindVar(out_var_handles[i]->name()); + auto *var = scope->FindVar(out_var_handles[i]->name()); auto *dev_ctx = dev_ctxes_.at(p); RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index f206f5fea5c41536a07143e707c53f135b287035..ed5e475a8d8f6018eea6d42149a092cd6ac41214 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -49,6 +49,9 @@ class AllReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; + + std::vector GetLocalScopes() override { return local_scopes_; } + std::vector local_scopes_; #if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32)) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index da9721ea73d28f48ddfc12672fc6249a2a23c9df..2e247075395f6603922c96bbe69f598265ec7c75 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -24,22 +24,20 @@ namespace paddle { namespace framework { namespace details { -inline void NewTempScopeAndInitVars(const std::vector &var_infos, - Scope *scope) { - VLOG(3) << "NewTempScopeAndInitVars"; - Scope &local_scope = scope->NewScope(); - *scope->Var(details::kLocalExecScopeName)->GetMutable() = - &local_scope; - +inline void InitVarsInScope(const std::vector &var_infos, Scope *scope, + Scope *local_scope) { + VLOG(3) << "InitVarsInScope"; for (auto &info : var_infos) { - if (scope->FindVar(info.name_) != nullptr) { - continue; - } - if (info.persistable_) { // Persistable + auto *var = scope->FindVar(info.name_); + if (var != nullptr) { + VLOG(2) << info.name_ + << " has been initialized beforehand in global scope, skipped"; + continue; + } InitializeVariable(scope->Var(info.name_), info.type_); } else { - InitializeVariable(local_scope.Var(info.name_), info.type_); + InitializeVariable(local_scope->Var(info.name_), info.type_); } } } @@ -101,14 +99,17 @@ void ProcessGraph(std::vector graphs, Scope *scope) { AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, std::vector graphs) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), + local_exec_scopes_(local_exec_scopes), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), graphs_(std::move(graphs)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); + PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size()); // set the correct size of thread pool to each device. strategy_.num_threads_ = strategy_.num_threads_ < places_.size() @@ -118,7 +119,8 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::ThreadedSSAGraphExecutor( - strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); + strategy_, {local_scopes_[i]}, {local_exec_scopes_[i]}, {places_[i]}, + graphs_[i])); } for (auto &node : graphs_[0]->Nodes()) { @@ -129,8 +131,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( var_infos_.back().persistable_ = node->Var()->Persistable(); } } - for (auto *scope : local_scopes_) { - NewTempScopeAndInitVars(var_infos_, scope); + + for (size_t i = 0; i < local_scopes_.size(); ++i) { + InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]); } ProcessGraph(graphs_, local_scopes_[0]); } diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index 6aaf8f9a165f2eae3a64874e60084e4d9bdbc182..97472674fada8cc1c531b54be49816e76ebde3f8 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -36,6 +36,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { public: AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, std::vector graphs); ~AsyncSSAGraphExecutor() final = default; @@ -50,6 +51,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { private: ExecutionStrategy strategy_; std::vector local_scopes_; + std::vector local_exec_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; std::vector graphs_; diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 752c932a215bad53f47f19f143a8008b66617a51..75143b9a1a0c85a24de337ad02afeea1112ca85c 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -40,18 +40,13 @@ void BroadcastOpHandle::RunImpl() { WaitInputVarGenerated(); - std::vector var_scopes; - for (auto *s : local_scopes_) { - var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); - } - - BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes); + BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_); } void BroadcastOpHandle::BroadcastOneVar( const VarHandle &in_var_handle, const std::vector &out_var_handles, - const std::vector &var_scopes) { + const std::vector &var_scopes) { auto *in_var = var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); PADDLE_ENFORCE_NOT_NULL(in_var); @@ -140,10 +135,7 @@ void BroadcastOpHandle::BroadcastOneVar( void BroadcastOpHandle::InitOutputValue( const VarHandle &in_var_handle, const std::vector &out_var_handles) const { - std::vector var_scopes; - for (auto *s : local_scopes_) { - var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); - } + auto &var_scopes = local_exec_scopes_; auto *in_var = var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name()); diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index 0b4d33513506d41a63db8316abaa5cd0458ff352..45ccbb41e0b0efca495f1db8d05285b07ecff910 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -62,9 +62,11 @@ struct BroadcastOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return local_scopes_; } + void BroadcastOneVar(const VarHandle &in_var_handle, const std::vector &out_var_handles, - const std::vector &var_scopes); + const std::vector &var_scopes); std::vector local_scopes_; std::vector places_; diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.h b/paddle/fluid/framework/details/broadcast_op_handle_test.h index df3b3cc9ca012eabc428a7fb4c3af9be5b1c5bd5..abc3f39e6867482dfa1d2c01cd97e96293acc9e5 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle_test.h +++ b/paddle/fluid/framework/details/broadcast_op_handle_test.h @@ -14,7 +14,9 @@ #pragma once +#include #include +#include #include #include "gtest/gtest.h" @@ -92,14 +94,13 @@ struct TestBroadcastOpHandle { void InitBroadcastOp(size_t input_scope_idx) { nodes_.clear(); + std::unordered_map scope_map; for (size_t j = 0; j < place_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); - *local_scopes_.back() - ->Var(details::kLocalExecScopeName) - ->GetMutable() = &local_scope; local_scope.Var("out"); param_scopes_.emplace_back(&local_scope); + scope_map.emplace(local_scopes_.back(), param_scopes_.back()); } param_scopes_[input_scope_idx]->Var("input"); @@ -122,6 +123,8 @@ struct TestBroadcastOpHandle { #endif } + op_handle_->SetLocalExecScopes(scope_map); + nodes_.emplace_back( ir::CreateNodeForTest("node1", ir::Node::Type::kVariable)); auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx, diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 3b57a099c8afeeca05f9fa45eda78e20197dc798..c164996f5eae14767c9a232655a9365d7f7a9c9c 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -92,16 +92,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("fuse_relu_depthwise_conv_pass"); } - // NOTE(dzhwinter): A note for automatical inplace. - // 1. modify program desc passes should put - // before inplace pass. - // 2. manually configured inplace should put - // before inplace_pass - - // Add automatically inplace. - if (strategy_.enable_inplace_) { - VLOG(1) << "Add inplace_pass"; - AppendPass("inplace_pass"); + // TODO(zjl): refactor MemoryOptimizePass to fit + // new strategy, which does not need to set + // var.persistable = True + if (strategy_.use_legacy_memory_optimize_strategy_) { + if (strategy_.enable_inplace_) { + VLOG(5) << "Add inplace_pass"; + AppendPass("inplace_pass"); + } } if (strategy_.fuse_elewise_add_act_ops_) { @@ -160,9 +158,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // the de-fact IR, any reuse on Graph is meaningless. // A side-effect of that, memory optimize cannot forsee the fetched vars // , so fetchlist should be set persistable before call the Run interface. - if (strategy_.memory_optimize_) { - VLOG(1) << "Add memory_optimize_pass"; - AppendPass("memory_optimize_pass"); + if (strategy_.use_legacy_memory_optimize_strategy_) { + if (strategy_.memory_optimize_) { + VLOG(5) << "Add memory_optimize_pass"; + AppendPass("memory_optimize_pass"); + } } // runtime_context_cache pass should be the last pass to enable the attr of diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index ae28a2cc6f9a4979eabff69a36b5fac87c096f87..09e7ca5f21df1c2b8ab3a1c319918cd8085cd1eb 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -114,7 +114,12 @@ struct BuildStrategy { // it is not appropriate, because kStaleProgramOpDescs will be removed in the // near future. bool memory_optimize_{false}; - bool enable_inplace_{false}; + + // Turn on inplace by default. + bool enable_inplace_{true}; + + // TODO(zjl): Remove this flag when MemoryOptimizePass is refactored + bool use_legacy_memory_optimize_strategy_{false}; // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // num_trainers is 1, so the current fields of build_strategy doesn't tell if diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc index 7beb8c8de9fc49aebc66ca44de8736240aabbc30..0b653e57f6d48f9d919ee4f09db5b6ab6b2451b7 100644 --- a/paddle/fluid/framework/details/computation_op_handle.cc +++ b/paddle/fluid/framework/details/computation_op_handle.cc @@ -31,9 +31,7 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, void ComputationOpHandle::RunImpl() { WaitInputVarGenerated(place_); - auto run_func = [this]() { - op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get(), place_); - }; + auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); }; if (is_lock_and_record_event_free_) { run_func(); diff --git a/paddle/fluid/framework/details/computation_op_handle.h b/paddle/fluid/framework/details/computation_op_handle.h index e98b16e6b3a07bfa0994295306e3bfa9e4174834..5a65aaf0d2012f6a42f57b47e9e1c7b0167c8b35 100644 --- a/paddle/fluid/framework/details/computation_op_handle.h +++ b/paddle/fluid/framework/details/computation_op_handle.h @@ -38,6 +38,8 @@ class ComputationOpHandle : public OpHandleBase { const Scope *GetScope() const { return scope_; } + Scope *GetScope() { return scope_; } + const platform::Place &GetPlace() const { return place_; } void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; } @@ -49,6 +51,8 @@ class ComputationOpHandle : public OpHandleBase { bool NeedWait(VarHandleBase *in_var) override; + std::vector GetLocalScopes() override { return {scope_}; } + private: std::unique_ptr op_; Scope *scope_; diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.cc b/paddle/fluid/framework/details/eager_deletion_op_handle.cc index f8723fe75f8f0304e149ab2195f29bc4c7223bc4..817fe03cf5ce06f723e06f9b523c2a98017982c6 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.cc +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/details/eager_deletion_op_handle.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" @@ -30,14 +31,13 @@ namespace framework { namespace details { EagerDeletionOpHandle::EagerDeletionOpHandle( - ir::Node *node, const Scope *scope, const platform::Place &place, - const std::unordered_set &var_names, GarbageCollector *gc, - ir::AtomicReferenceCountMap *ref_cnts) + ir::Node *node, Scope *scope, const platform::Place &place, + const std::unordered_set &vars, GarbageCollector *gc) : OpHandleBase(node), scope_(scope), - var_names_(var_names.begin(), var_names.end()), - gc_(gc), - ref_cnts_(ref_cnts) { + place_(place), + var_infos_(vars.begin(), vars.end()), + gc_(gc) { #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place)) { dev_ctx_ = reinterpret_cast( @@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle( } } #endif - PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty"); + PADDLE_ENFORCE(!vars.empty(), "Var names cannot be empty"); + for (auto *var : var_infos_) { + PADDLE_ENFORCE_NOT_NULL(var); + } } EagerDeletionOpHandle::~EagerDeletionOpHandle() { @@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() { #endif } +void EagerDeletionOpHandle::InitCUDA() { +#ifdef PADDLE_WITH_CUDA + int dev_id = + boost::get(dev_ctxes_.begin()->first).device; + events_[dev_id] = nullptr; +#endif +} + +void EagerDeletionOpHandle::CallOnce() { + PADDLE_ENFORCE(vars_.empty(), "vars_ must be initialized here"); + Scope *exec_scope = local_exec_scopes_[0]; + for (auto *var_info : var_infos_) { + auto *var = exec_scope->FindVar(var_info->Name()); + PADDLE_ENFORCE_NOT_NULL(var, "Variable %s should not be nullptr", + var_info->Name()); + vars_.emplace_back(var); + } +} + std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; } void EagerDeletionOpHandle::RunImpl() { + if (vars_.size() != var_infos_.size()) { + CallOnce(); + } + platform::RecordEvent record_event(Name()); - Scope *exec_scope = nullptr; std::deque> garbages; - for (auto &name : var_names_) { - auto it = ref_cnts_->find(name); - // Reference count has not decreased to 0 - if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) { + for (size_t i = 0; i < var_infos_.size(); ++i) { + auto *var_info = var_infos_[i]; + if (var_info->IsSkipped() || !var_info->DecreaseRefCnt()) { continue; } - if (!exec_scope) { - exec_scope = scope_->FindVar(kLocalExecScopeName)->Get(); - } - - // Var not found - auto *var = exec_scope->FindVar(name); - if (var == nullptr) { - continue; - } + VLOG(2) << "Erase variable " << var_info->Name() << " on " << place_; - VLOG(2) << "Erase variable " << name; + Variable *var = vars_[i]; if (var->IsType()) { garbages.emplace_back(var->GetMutable()->MoveMemoryHolder()); @@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() { } } else { PADDLE_THROW("Type %s of %s is not supported eager deletion", - framework::ToTypeName(var->Type()), name); + framework::ToTypeName(var->Type()), var_info->Name()); } } diff --git a/paddle/fluid/framework/details/eager_deletion_op_handle.h b/paddle/fluid/framework/details/eager_deletion_op_handle.h index fe723922ca711a6348fddcaabbdf635cb7d2983d..4b2d4a83a6a88b05c8b7710f3d7a114c73a4f1d4 100644 --- a/paddle/fluid/framework/details/eager_deletion_op_handle.h +++ b/paddle/fluid/framework/details/eager_deletion_op_handle.h @@ -26,15 +26,18 @@ namespace paddle { namespace framework { class Scope; +namespace ir { +class MemOptVarInfo; +} // namespace ir + namespace details { class EagerDeletionOpHandle : public OpHandleBase { public: - EagerDeletionOpHandle(ir::Node *node, const Scope *scope, + EagerDeletionOpHandle(ir::Node *node, Scope *scope, const platform::Place &place, - const std::unordered_set &var_names, - GarbageCollector *gc, - ir::AtomicReferenceCountMap *ref_cnts); + const std::unordered_set &vars, + GarbageCollector *gc); ~EagerDeletionOpHandle(); @@ -50,13 +53,20 @@ class EagerDeletionOpHandle : public OpHandleBase { protected: void RunImpl() override; + void InitCUDA() override; + + std::vector GetLocalScopes() override { return {scope_}; } + private: void ClearGarbages(std::deque> *garbages); - const Scope *scope_; - std::vector var_names_; - GarbageCollector *gc_; // not own - ir::AtomicReferenceCountMap *ref_cnts_; // not own + void CallOnce(); + + Scope *scope_; + platform::Place place_; + std::vector var_infos_; // not own + GarbageCollector *gc_; // not own + std::vector vars_; #ifdef PADDLE_WITH_CUDA platform::CUDADeviceContext *dev_ctx_{nullptr}; cudaEvent_t event_{nullptr}; diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index b33162edd2b69ca0703f27041e71fe72da9779e3..11052273d2849b4b8836c55466e205b8fd0789de 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -28,9 +28,11 @@ namespace details { FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph) : strategy_(strategy), local_scopes_(local_scopes), + local_exec_scopes_(local_exec_scopes), places_(places), graph_(graph), fetch_ctxs_(places), @@ -143,7 +145,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps( ir::Node *fetch_node = graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); - auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_); + auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_, + &local_exec_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h index d88e5bbaa97419c6e5229deaa16fbcfa922432d0..5d11c2cfd9ed6a8b49aa6ee01c89969dc75c21a6 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h @@ -33,6 +33,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { public: FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph); FeedFetchList Run(const std::vector &fetch_tensors) override; @@ -43,6 +44,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { // be destroyed first. ExecutionStrategy strategy_; std::vector local_scopes_; + std::vector local_exec_scopes_; std::vector places_; ir::Graph *graph_; diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.cc b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc index 019ecfbb61028537692c8fdeb874c6c490f75430..127183a32e938de57ce4f7cb5aed4e72f8f09682 100644 --- a/paddle/fluid/framework/details/fetch_barrier_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.cc @@ -42,9 +42,7 @@ bool FetchBarrierOpHandle::IsMultiDeviceTransfer() { void FetchBarrierOpHandle::RunImpl() { WaitInputVarGenerated(place_); - auto run_func = [this]() { - op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get(), place_); - }; + auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); }; if (is_lock_and_record_event_free_) { run_func(); diff --git a/paddle/fluid/framework/details/fetch_barrier_op_handle.h b/paddle/fluid/framework/details/fetch_barrier_op_handle.h index b4d12785e0345c887f179bc53c8446dc1438f889..d1f7e08b28e7d8291c11bd61588c978f591060c2 100644 --- a/paddle/fluid/framework/details/fetch_barrier_op_handle.h +++ b/paddle/fluid/framework/details/fetch_barrier_op_handle.h @@ -44,6 +44,8 @@ struct FetchBarrierOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return local_scopes_; } + bool NeedWait(VarHandleBase *in_var) override; private: diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index fe14e3e91da34e5993a68d10a13b275bab576ce6..1ac32ca975d4d8ac2599714bac45ed211f0adc72 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -22,11 +22,13 @@ namespace framework { namespace details { FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, - std::vector *local_scopes) + std::vector *local_scopes, + std::vector *local_exec_scopes) : OpHandleBase(node), data_(data), offset_(offset), - local_scopes_(local_scopes) {} + local_scopes_(local_scopes), + local_exec_scopes_(local_exec_scopes) {} FetchOpHandle::~FetchOpHandle() {} @@ -49,14 +51,12 @@ void FetchOpHandle::RunImpl() { tensors_.resize(inputs_.size()); platform::CPUPlace cpu; - auto &scopes = *local_scopes_; + auto &scopes = *local_exec_scopes_; for (size_t i = 0; i < inputs_.size(); ++i) { auto *var_handle = static_cast(inputs_[i]); auto &scope = scopes.at(var_handle->scope_idx()); - auto *var = scope->FindVar(kLocalExecScopeName) - ->Get() - ->FindVar(var_handle->name()); + auto *var = scope->FindVar(var_handle->name()); PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", var_handle->name()); diff --git a/paddle/fluid/framework/details/fetch_op_handle.h b/paddle/fluid/framework/details/fetch_op_handle.h index dbb7f4f6582f6e0f0b9b5702533852d12da1051c..f3af4e61e2ba7664275eaed5f34c05940d0ec582 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.h +++ b/paddle/fluid/framework/details/fetch_op_handle.h @@ -29,7 +29,8 @@ namespace details { struct FetchOpHandle : public OpHandleBase { public: FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset, - std::vector *local_scopes); + std::vector *local_scopes, + std::vector *local_exec_scopes); ~FetchOpHandle(); @@ -44,12 +45,15 @@ struct FetchOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return *local_scopes_; } + void WaitInputVarGenerated(const platform::Place &place) override; private: FeedFetchList *data_; size_t offset_; std::vector *local_scopes_; + std::vector *local_exec_scopes_; std::vector tensors_; }; diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc index 4d96d820a1d161e76945a1c87e1832d95a8a802e..8066ca7813ce84cf2e2c9700aa8b86689457cfc5 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -185,9 +185,7 @@ void FusedAllReduceOpHandle::RunImpl() { } else { // Special handle CPU only Operator's gradient. Like CRF auto grad_name = grads_tensor.at(0).at(0).first; - auto &trg = *this->local_scopes_[0] - ->FindVar(kLocalExecScopeName) - ->Get() + auto &trg = *this->local_exec_scopes_[0] ->FindVar(grad_name) ->GetMutable(); @@ -195,9 +193,8 @@ void FusedAllReduceOpHandle::RunImpl() { ReduceBufferData func(lod_tensor_data, trg.data(), numel); VisitDataType(trg.type(), func); - for (size_t i = 1; i < local_scopes_.size(); ++i) { - auto &scope = - *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + for (size_t i = 1; i < local_exec_scopes_.size(); ++i) { + auto &scope = *local_exec_scopes_[i]; auto &p = places_[i]; auto *var = scope.FindVar(grad_name); auto *dev_ctx = dev_ctxes_.at(p); @@ -215,8 +212,7 @@ void FusedAllReduceOpHandle::GetGradLoDTensor( const size_t &scope_idx, const std::vector &in_var_handles, const std::vector &out_var_handles, std::vector> *grad_tensor) const { - auto *local_scope = - local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get(); + auto *local_scope = local_exec_scopes_[scope_idx]; size_t place_num = places_.size(); for (size_t j = 0; j < in_var_handles.size(); j += place_num) { diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h index e0b9123c5b7e40f7d96ef3ea4061c2822aca7eef..fccbd77208b887ae05f6d22038f3ef0f012329f1 100644 --- a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h @@ -52,6 +52,8 @@ struct FusedAllReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return local_scopes_; } + private: std::vector local_scopes_; #if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32)) diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc index f48561ea32e6a3bbc7e9f2a8326b080ad21c6d61..59c5da0de8c114823a1cad3e6d65c92081b5a2b6 100644 --- a/paddle/fluid/framework/details/fused_broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.cc @@ -31,11 +31,6 @@ void FusedBroadcastOpHandle::RunImpl() { WaitInputVarGenerated(); - std::vector var_scopes; - for (auto *s : local_scopes_) { - var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); - } - size_t place_num = places_.size(); PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size()); @@ -44,7 +39,7 @@ void FusedBroadcastOpHandle::RunImpl() { *in_var_handles[i], std::vector(out_var_handles.begin() + i * place_num, out_var_handles.begin() + (i + 1) * place_num), - var_scopes); + local_exec_scopes_); } } diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc b/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc index 6d53dac5c0a20b4340e71274a00a7f3c0cd08ff6..49404509a6fba0a6568c5db39a7bc744418f07a4 100644 --- a/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle_test.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" +#include +#include #include "gtest/gtest.h" #include "paddle/fluid/framework/details/broadcast_op_handle_test.h" @@ -27,17 +29,16 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { void InitFusedBroadcastOp(std::vector input_scope_idxes) { nodes_.clear(); // initialize scope and var + std::unordered_map scope_map; for (size_t i = 0; i < place_list_.size(); ++i) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); - *local_scopes_.back() - ->Var(details::kLocalExecScopeName) - ->GetMutable() = &local_scope; for (size_t j = 0; j < input_scope_idxes.size(); ++j) { local_scope.Var("out_var" + std::to_string(j)); if (i == j) local_scope.Var("in_var" + std::to_string(j)); } param_scopes_.emplace_back(&local_scope); + scope_map.emplace(local_scopes_.back(), param_scopes_.back()); } // create op handle node @@ -60,6 +61,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle { #endif } + op_handle_->SetLocalExecScopes(scope_map); + for (size_t i = 0; i < input_scope_idxes.size(); ++i) { // add input var handle nodes_.emplace_back(ir::CreateNodeForTest("in_node" + std::to_string(i), diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 179cca44cb1871bb9667074f6c6b32edee42be09..a039c6200e394eebf6c44846ce2b0bf5d773e764 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -42,10 +42,7 @@ void GatherOpHandle::RunImpl() { out_var_handle = out_var_handles.front(); } - std::vector var_scopes; - for (auto *s : local_scopes_) { - var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); - } + auto &var_scopes = local_exec_scopes_; auto in_0_handle = in_var_handles[0]; auto pre_in_var = diff --git a/paddle/fluid/framework/details/gather_op_handle.h b/paddle/fluid/framework/details/gather_op_handle.h index d9afbc6547e18e8886c414ff150e332cfaf9b0c3..ac87b246b50f8e0df1d0cc082087d4128a79384b 100644 --- a/paddle/fluid/framework/details/gather_op_handle.h +++ b/paddle/fluid/framework/details/gather_op_handle.h @@ -40,6 +40,8 @@ struct GatherOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return local_scopes_; } + private: const std::vector &local_scopes_; const std::vector &places_; diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc index e8cb7feb8bea92a7486b8a9d84ba4b9e2b93dbfb..5d8562e7046fd2f1609ba34ce2dd71b9fa28be77 100644 --- a/paddle/fluid/framework/details/gather_op_handle_test.cc +++ b/paddle/fluid/framework/details/gather_op_handle_test.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/framework/details/gather_op_handle.h" +#include +#include #include "gtest/gtest.h" #include "paddle/fluid/platform/device_context.h" @@ -72,14 +74,13 @@ struct TestGatherOpHandle { void InitGatherOp(size_t input_scope_idx) { nodes_.clear(); + std::unordered_map scope_map; for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope& local_scope = local_scopes_.back()->NewScope(); - *local_scopes_.back() - ->Var(details::kLocalExecScopeName) - ->GetMutable() = &local_scope; local_scope.Var("input"); param_scopes_.emplace_back(&local_scope); + scope_map.emplace(local_scopes_.back(), param_scopes_.back()); } param_scopes_[input_scope_idx]->Var("out"); @@ -87,6 +88,9 @@ struct TestGatherOpHandle { ir::CreateNodeForTest("node", ir::Node::Type::kOperation).release()); op_handle_ = new GatherOpHandle(nodes_.back().get(), local_scopes_, gpu_list_); + + op_handle_->SetLocalExecScopes(scope_map); + // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get()); diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index b0e6a87bddeecda4f13e1081efeabb1c70be76cf..b2fa31f73b9d96ef7fe56dd59ca9b4b18f114c95 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -35,49 +35,55 @@ std::string OpHandleBase::DebugString() const { OpHandleBase::~OpHandleBase() { #ifdef PADDLE_WITH_CUDA for (auto &ev : events_) { - PADDLE_ENFORCE(cudaEventDestroy(ev.second)); + if (ev.second) { + PADDLE_ENFORCE(cudaEventDestroy(ev.second)); + } } #endif } -void OpHandleBase::Run(bool use_cuda) { +void OpHandleBase::InitCUDA() { #ifdef PADDLE_WITH_CUDA - if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) { - for (auto &p : dev_ctxes_) { - int dev_id = boost::get(p.first).device; - PADDLE_ENFORCE(cudaSetDevice(dev_id)); - PADDLE_ENFORCE( - cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); - } - if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) { - for (auto &out_var : outputs_) { - auto *out_var_handle = dynamic_cast(out_var); - if (out_var_handle) { - int dev_id = - boost::get(out_var_handle->place()).device; - out_var_handle->SetGenerateEvent(events_.at(dev_id)); - } + for (auto &p : dev_ctxes_) { + int dev_id = boost::get(p.first).device; + PADDLE_ENFORCE(cudaSetDevice(dev_id)); + PADDLE_ENFORCE( + cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming)); + } + if (IsMultiDeviceTransfer() && dev_ctxes_.size() > 0) { + for (auto &out_var : outputs_) { + auto *out_var_handle = dynamic_cast(out_var); + if (out_var_handle) { + int dev_id = + boost::get(out_var_handle->place()).device; + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } - } else { - PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL, - "%s should have only one dev_ctx.", Name()); - auto &place = dev_ctxes_.begin()->first; - int dev_id = boost::get(place).device; - for (auto &out_var : outputs_) { - auto *out_var_handle = dynamic_cast(out_var); - if (out_var_handle) { - PADDLE_ENFORCE( - platform::is_same_place(place, out_var_handle->place()), - "The place of output(%s) is not consistent with the " - "place of current op(%s).", - out_var_handle->Name(), Name()); - out_var_handle->SetGenerateEvent(events_.at(dev_id)); - } + } + } else { + PADDLE_ENFORCE_EQ(dev_ctxes_.size(), 1UL, + "%s should have only one dev_ctx.", Name()); + auto &place = dev_ctxes_.begin()->first; + int dev_id = boost::get(place).device; + for (auto &out_var : outputs_) { + auto *out_var_handle = dynamic_cast(out_var); + if (out_var_handle) { + PADDLE_ENFORCE(platform::is_same_place(place, out_var_handle->place()), + "The place of output(%s) is not consistent with the " + "place of current op(%s).", + out_var_handle->Name(), Name()); + out_var_handle->SetGenerateEvent(events_.at(dev_id)); } } } -#else +#endif +} +void OpHandleBase::Run(bool use_cuda) { +#ifdef PADDLE_WITH_CUDA + if (events_.empty() && use_cuda && dev_ctxes_.size() > 0) { + InitCUDA(); + } +#else PADDLE_ENFORCE(!use_cuda); #endif @@ -232,6 +238,17 @@ size_t OpHandleBase::NotReadyInputSize() const { return res.size(); } +void OpHandleBase::SetLocalExecScopes( + const std::unordered_map &scope_map) { + local_exec_scopes_.clear(); + auto scopes = GetLocalScopes(); + for (auto *scope : scopes) { + auto iter = scope_map.find(scope); + PADDLE_ENFORCE(iter != scope_map.end(), "Local scope not found"); + local_exec_scopes_.emplace_back(iter->second); + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 3412fa0bb76fafbef7d1abbee72bf46c361152f9..4c7086918c7dad27c2251076f3cbccc50e34a6b7 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -25,9 +25,10 @@ namespace paddle { namespace framework { -namespace details { -constexpr char kLocalExecScopeName[] = "@LOCAL_EXE_SCOPE@"; +class Scope; + +namespace details { // Wraps ir::Node and provide helper utilities. // It's responsible for populating necessary fields of ir::Node. @@ -107,7 +108,12 @@ class OpHandleBase { ir::Node *Node() { return node_; } + void SetLocalExecScopes( + const std::unordered_map &scope_map); + protected: + virtual std::vector GetLocalScopes() = 0; + void RunAndRecordEvent(const std::function &callback); void RunAndRecordEvent(platform::Place p, @@ -115,11 +121,15 @@ class OpHandleBase { virtual void RunImpl() = 0; + virtual void InitCUDA(); + ir::Node *node_; std::vector inputs_; std::vector outputs_; std::map dev_ctxes_; + std::vector local_exec_scopes_; + #ifdef PADDLE_WITH_CUDA std::unordered_map events_; #endif diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc index 68be353e3464c94e5eb991acc4c3dd6e3de5267a..1a3c753e7d2b075eba9af98f7b206e42b51b650c 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.cc @@ -83,6 +83,7 @@ ParallelSSAGraphExecutor::SeparateMultiDevicesGraph(ir::Graph *graph) { ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), @@ -108,10 +109,20 @@ ParallelSSAGraphExecutor::ParallelSSAGraphExecutor( << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::FastThreadedSSAGraphExecutor( - strategy_, local_scopes_, {places_[i]}, graphs_.at(i).get())); + strategy_, local_scopes_, local_exec_scopes, {places_[i]}, + graphs_.at(i).get())); } } +std::vector ParallelSSAGraphExecutor::Graphs() { + std::vector result; + result.reserve(graphs_.size()); + for (auto &g : graphs_) { + result.emplace_back(g.get()); + } + return result; +} + FeedFetchList ParallelSSAGraphExecutor::Run( const std::vector &fetch_tensors) { std::vector> run_futures; diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index faf071b05306a49c0049421bc72e4981c0bfc84c..6889c54dd4c6906b179036386f8d38dad04f5c9f 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -30,12 +30,15 @@ class ParallelSSAGraphExecutor : public SSAGraphExecutor { public: ParallelSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph); ~ParallelSSAGraphExecutor() final = default; const ir::Graph &Graph() const override { return *graphs_[0]; } + std::vector Graphs(); + FeedFetchList Run(const std::vector &fetch_tensors) override; private: diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 4e2477c205db5966aa0b2d0c7a608be94a69eb82..26153b7dd9cadd3021623de5c2acbf1780b3476c 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include #include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/variable_visitor.h" @@ -160,10 +161,7 @@ void ReduceOpHandle::RunImpl() { auto in_0_handle = in_var_handles[0]; - std::vector var_scopes; - for (auto *s : local_scopes_) { - var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get()); - } + auto &var_scopes = local_exec_scopes_; auto pre_in_var = var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name()); @@ -250,9 +248,7 @@ void ReduceOpHandle::RunImpl() { } else { // We sum lod_tensors to reduce_sum_trg which is in local_scopes_0 // here, but it doesn't mean reduce_sum_trg must be in local_scopes_0. - auto &reduce_sum_trg = *this->local_scopes_[0] - ->FindVar(kLocalExecScopeName) - ->Get() + auto &reduce_sum_trg = *this->local_exec_scopes_[0] ->FindVar(out_var_handle->name()) ->GetMutable(); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); @@ -317,7 +313,7 @@ void ReduceOpHandle::RunImpl() { template std::vector ReduceOpHandle::GetInputValues( const std::vector &in_var_handles, - const std::vector &var_scopes) const { + const std::vector &var_scopes) const { std::vector in_selected_rows; for (auto *in_handle : in_var_handles) { auto &in_sr = var_scopes.at(in_handle->scope_idx()) diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 5491f00f45e9d48c5eb7455396ac51801f2c40ab..15064a108e79fe5ed307e46b03f90b1d74742203 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include @@ -90,6 +91,8 @@ struct ReduceOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return local_scopes_; } + #if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE template void GatherSelectedRows( @@ -106,7 +109,7 @@ struct ReduceOpHandle : public OpHandleBase { template std::vector GetInputValues( const std::vector &in_var_handles, - const std::vector &var_scopes) const; + const std::vector &var_scopes) const; }; } // namespace details diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc index 6cee4770e64354cf8a719b0e11b1816b345dd8bd..664bd00fb68fc37c6d4e7624ed42a2a905f1bd25 100644 --- a/paddle/fluid/framework/details/reduce_op_handle_test.cc +++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/details/reduce_op_handle.h" +#include #include "gtest/gtest.h" #include "paddle/fluid/platform/device_context.h" @@ -86,14 +87,13 @@ struct TestReduceOpHandle { void InitReduceOp(size_t out_scope_idx) { std::vector> nodes; // init scope + std::unordered_map scope_map; for (size_t j = 0; j < gpu_list_.size(); ++j) { local_scopes_.push_back(&(g_scope_.NewScope())); Scope &local_scope = local_scopes_.back()->NewScope(); - *local_scopes_.back() - ->Var(details::kLocalExecScopeName) - ->GetMutable() = &local_scope; local_scope.Var("input"); param_scopes_.emplace_back(&local_scope); + scope_map.emplace(local_scopes_.back(), param_scopes_.back()); } param_scopes_[out_scope_idx]->Var("out"); @@ -115,6 +115,8 @@ struct TestReduceOpHandle { #endif } + op_handle_->SetLocalExecScopes(scope_map); + // init op handle // add input for (size_t j = 0; j < gpu_list_.size(); ++j) { diff --git a/paddle/fluid/framework/details/rpc_op_handle.cc b/paddle/fluid/framework/details/rpc_op_handle.cc index a87b03451bb00643ecb9d9e2339141fe7f25d2e3..8d61a103f98be81309d890f25b8ab6f41d5c3f02 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.cc +++ b/paddle/fluid/framework/details/rpc_op_handle.cc @@ -21,7 +21,7 @@ namespace framework { namespace details { RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, - const Scope *local_scope, const std::string &name, + Scope *local_scope, const std::string &name, const platform::Place &place) : OpHandleBase(node), op_(framework::OpRegistry::CreateOp(op_desc)), @@ -41,10 +41,7 @@ void RPCOpHandle::RunImpl() { in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p)); } } - this->RunAndRecordEvent([this] { - op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get(), - place_); - }); + this->RunAndRecordEvent([this] { op_->Run(*local_exec_scopes_[0], place_); }); } std::string RPCOpHandle::Name() const { return name_; } diff --git a/paddle/fluid/framework/details/rpc_op_handle.h b/paddle/fluid/framework/details/rpc_op_handle.h index 7f99cdeacf618a9496eaef98520685d6d1621ae1..d86d33dd676ca066b8772ac2afbab05cf0d98b38 100644 --- a/paddle/fluid/framework/details/rpc_op_handle.h +++ b/paddle/fluid/framework/details/rpc_op_handle.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -29,7 +30,7 @@ namespace details { struct RPCOpHandle : public OpHandleBase { RPCOpHandle(ir::Node* node, const framework::OpDesc& op_desc, - const Scope* local_scope, const std::string& name, + Scope* local_scope, const std::string& name, const platform::Place& place); std::string Name() const override; @@ -41,9 +42,11 @@ struct RPCOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return {local_scope_}; } + private: std::unique_ptr op_; - const Scope* local_scope_; + Scope* local_scope_; const std::string name_; platform::Place place_; }; diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 67b4fed0d3083b105eae4838cf264bba7f7a44c3..7ab216095cade0bef3f188708dcee5d49f26c36f 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -70,9 +70,9 @@ void ScaleLossGradOpHandle::RunImpl() { platform::RecordEvent record_event(Name()); // Doesn't wait any event std::string var_name = static_cast(this->outputs_[0])->name(); - auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get(); - auto *tensor = local_scope.FindVar(var_name)->GetMutable(); + auto *tensor = + local_exec_scopes_[0]->FindVar(var_name)->GetMutable(); tensor->Resize(make_ddim({1})); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 8bedd1643eb9c5e591fa3c40995fcba08980b9fa..d4f28dbe2b261be9c5d48aa50e38edfe36bfcfd3 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -36,6 +37,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { protected: void RunImpl() override; + std::vector GetLocalScopes() override { return {scope_}; } + private: float coeff_; Scope *scope_; diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 5bbbf07e6d9fb8845d3f93d1d8124d3f557dba3c..070d59517b2a2b7bde7d10c6bf0ed03513a66fe8 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -25,19 +25,24 @@ namespace framework { namespace details { ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ExecutionStrategy strategy, std::vector local_scopes, - std::vector var_infos, std::vector places, + std::vector local_exec_scopes, std::vector var_infos, + std::vector places, std::unique_ptr &&underlying_executor) : strategy_(std::move(strategy)), underlying_executor_(std::move(underlying_executor)), local_scopes_(std::move(local_scopes)), + local_exec_scopes_(std::move(local_exec_scopes)), var_infos_(std::move(var_infos)), - places_(std::move(places)) {} + places_(std::move(places)) { + PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size()); + PrepareLocalExeScopes(); +} FeedFetchList ScopeBufferedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { if (drop_scope_counter_ == 0) { - platform::RecordEvent e("InitLocalExeScopes"); - PrepareLocalExeScopes(); + platform::RecordEvent e("InitLocalVars"); + InitVariables(); } std::vector fetch_data; @@ -59,39 +64,55 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( } } +void ScopeBufferedSSAGraphExecutor::InitVariables() { + for (auto &info : tmp_var_infos_) { + for (auto &pair : info) { + InitializeVariable(pair.first, pair.second); + } + } +} + void ScopeBufferedSSAGraphExecutor::DropLocalExeScopes() { platform::RecordEvent drop_scope_event("DropLocalExeScopes"); drop_scope_counter_ = 0; - for (auto p : places_) { + for (auto &p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } - for (auto &scope : local_scopes_) { - auto *local_scope_var = scope->FindLocalVar(details::kLocalExecScopeName); - if (local_scope_var != nullptr) { - auto &local_scope = *local_scope_var->GetMutable(); - scope->DeleteScope(local_scope); - scope->EraseVars({std::string(details::kLocalExecScopeName)}); - VLOG(3) << "Drop local execution scope: " << local_scope; + for (size_t i = 0; i < local_exec_scopes_.size(); ++i) { + local_exec_scopes_[i]->EraseVarsExcept(preserve_vars_[i]); + local_exec_scopes_[i]->DropKids(); + for (auto &preserve_var : preserve_vars_[i]) { + preserve_var->Clear(); } + VLOG(3) << "Drop local execution scope: " << local_scopes_[i]; } } void ScopeBufferedSSAGraphExecutor::PrepareLocalExeScopes() { // Create local scopes. + preserve_vars_.resize(local_scopes_.size()); + tmp_var_infos_.resize(local_scopes_.size()); + for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) { - auto &scope = *it; - Scope &local_scope = scope->NewScope(); - *scope->Var(kLocalExecScopeName)->GetMutable() = &local_scope; + size_t idx = local_scopes_.size() - 1 - (it - local_scopes_.rbegin()); + auto *scope = local_scopes_[idx]; + auto *local_scope = local_exec_scopes_[idx]; for (auto &info : var_infos_) { - if (scope->FindVar(info.name_) != nullptr) { - continue; - } if (info.persistable_) { // Persistable + auto var = scope->FindVar(info.name_); + if (var != nullptr) { + VLOG(2) + << info.name_ + << " has been initialized beforehand in global scope, skipped"; + continue; + } InitializeVariable(scope->Var(info.name_), info.type_); } else { - InitializeVariable(local_scope.Var(info.name_), info.type_); + Variable *tmp_var = local_scope->Var(info.name_); + preserve_vars_[idx].emplace(tmp_var); + tmp_var_infos_[idx].emplace_back(tmp_var, info.type_); } } } diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index e0388be305f2285b941bc7193a8d97e52ce765c9..988882e65dba818989455a4af608aee85c150eae 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/var_handle.h" @@ -39,6 +41,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { public: ScopeBufferedSSAGraphExecutor( ExecutionStrategy strategy, std::vector local_scopes, + std::vector local_exec_scopes, std::vector var_infos, std::vector places, std::unique_ptr&& underlying_executor); @@ -55,10 +58,18 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { void PrepareLocalExeScopes(); private: + void InitVariables(); + size_t drop_scope_counter_{0}; ExecutionStrategy strategy_; std::unique_ptr underlying_executor_; std::vector local_scopes_; + + std::vector local_exec_scopes_; + std::vector> preserve_vars_; + std::vector>> + tmp_var_infos_; + std::vector var_infos_; std::vector places_; }; diff --git a/paddle/fluid/framework/details/share_tensor_buffer_op_handle.cc b/paddle/fluid/framework/details/share_tensor_buffer_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..8539eb9dae90b6558a12e51842c9c98c04ccb925 --- /dev/null +++ b/paddle/fluid/framework/details/share_tensor_buffer_op_handle.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" +#include +#include +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace details { + +// TODO(zjl): support SelectedRows +static inline const Tensor &GetTensorFromVar(const Variable *var) { + if (var->IsType()) { + return var->Get(); + } else { + PADDLE_THROW("Variable must be type of LoDTensor"); + } +} + +static inline Tensor *GetMutableTensorFromVar(Variable *var) { + if (var->IsType()) { + return var->GetMutable(); + } else { + PADDLE_THROW("Variable must be type of LoDTensor"); + } +} + +ShareTensorBufferOpHandle::ShareTensorBufferOpHandle( + ir::Node *node, Scope *scope, size_t scope_idx, const std::string &op_type, + const std::vector &in_var_infos, + const std::vector &out_var_names) + : OpHandleBase(node), + scope_(scope), + scope_idx_(scope_idx), + op_type_(op_type), + in_var_infos_(in_var_infos), + out_var_names_(out_var_names) { + PADDLE_ENFORCE_EQ(in_var_infos_.size(), out_var_names_.size()); + for (size_t i = 0; i < in_var_infos_.size(); ++i) { + Add(in_var_infos_[i], out_var_names_[i]); + } +} + +std::unordered_set ShareTensorBufferOpHandle::ReusedVarSet() + const { + std::unordered_set result; + for (auto &in_var_info : in_var_infos_) { + result.insert(in_var_info->Name()); + } + return result; +} + +void ShareTensorBufferOpHandle::Add(ir::MemOptVarInfo *in_var_info, + const std::string &out_var_name) { + PADDLE_ENFORCE_NOT_NULL(in_var_info, "in_var_info cannot be nullptr"); + PADDLE_ENFORCE_NE(in_var_info->Name(), out_var_name, + "in/out cannot have same name: %s", out_var_name); + in_var_infos_.emplace_back(in_var_info); + out_var_names_.emplace_back(out_var_name); +} + +void ShareTensorBufferOpHandle::InitCUDA() { +#ifdef PADDLE_WITH_CUDA + int dev_id = + boost::get(dev_ctxes_.begin()->first).device; + events_[dev_id] = nullptr; +#endif +} + +void ShareTensorBufferOpHandle::CallOnce() { + PADDLE_ENFORCE(in_out_vars_.empty(), "in_out_vars_ must be initialized here"); + Scope *exec_scope = local_exec_scopes_[0]; + for (size_t i = 0; i < in_var_infos_.size(); ++i) { + auto *in_var = exec_scope->FindVar(in_var_infos_[i]->Name()); + auto *out_var = exec_scope->FindVar(out_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL(in_var); + PADDLE_ENFORCE_NOT_NULL(out_var); + PADDLE_ENFORCE_NE(in_var, out_var); + in_out_vars_.emplace_back(in_var, out_var); + } +} + +void ShareTensorBufferOpHandle::RunImpl() { + if (in_var_infos_.size() != in_out_vars_.size()) { + CallOnce(); + } + + for (size_t i = 0; i < in_var_infos_.size(); ++i) { + const auto &in_tensor = GetTensorFromVar(in_out_vars_[i].first); + auto *out_tensor = GetMutableTensorFromVar(in_out_vars_[i].second); + auto *in_var_info = in_var_infos_[i]; + + if (UNLIKELY(in_var_info->IsSkipped())) { + // If in_var is inplaced in the previous batch and we want to fetch + // in_var in the current batch, we have to reset memory of out_var + // to avoid wrong calculation result. + if (in_tensor.Holder() == out_tensor->Holder()) { + VLOG(1) << "Clear " << out_var_names_[i] + << " because you may want to fetch an inplaced variable " + << in_var_info->Name() + << " in previous batch: " << in_var_info->Name() << " -> " + << out_var_names_[i]; + out_tensor->clear(); + } + } else { + out_tensor->ShareBufferWith(in_tensor); + + VLOG(2) << "Share tensor buffer when running " << op_type_ << " : " + << in_var_info->Name() << " -> " << out_var_names_[i]; + } + } +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/share_tensor_buffer_op_handle.h b/paddle/fluid/framework/details/share_tensor_buffer_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..87e971baf044d23b8c9a73a6b03a489ef6641009 --- /dev/null +++ b/paddle/fluid/framework/details/share_tensor_buffer_op_handle.h @@ -0,0 +1,74 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" + +namespace paddle { +namespace framework { + +class Variable; +class Scope; +class Tensor; + +namespace ir { +class MemOptVarInfo; +} // namespace ir + +namespace details { + +class ShareTensorBufferOpHandle : public OpHandleBase { + public: + ShareTensorBufferOpHandle( + ir::Node *node, Scope *scope, size_t scope_idx, + const std::string &op_type, + const std::vector &in_vars_infos, + const std::vector &out_var_names); + + std::unordered_set ReusedVarSet() const; + + Priority GetPriority() const override { return Priority::kHighest; } + + size_t GetScopeIdx() const { return scope_idx_; } + + void Add(ir::MemOptVarInfo *in_var_info, const std::string &ou_var_name); + + protected: + std::string Name() const override { return "buffer_share"; } + + void RunImpl() final; + + void InitCUDA() override; + + std::vector GetLocalScopes() override { return {scope_}; } + + private: + void CallOnce(); + + Scope *scope_; + size_t scope_idx_; + std::string op_type_; + std::vector in_var_infos_; + std::vector out_var_names_; + + std::vector> in_out_vars_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc index cc3493d849eccbecf3d039dc7b2fc18575fcf9d0..a2461a36bf135295bfb69a7d5df49060cb8f05e5 100644 --- a/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/sparse_all_reduce_op_handle.cc @@ -58,8 +58,7 @@ void SparseAllReduceOpHandle::RunImplEncoded() { std::vector outs; int k = -1; for (size_t i = 0; i < local_scopes_.size(); ++i) { - auto &local_scope = - local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto *local_scope = local_exec_scopes_[i]; auto original_name = paddle::framework::GradOriginalVarName(in_var_handles[i]->name()); auto encode_var_name = original_name + g_dgc_encoded; @@ -135,9 +134,8 @@ int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) { auto var_name = original_name + g_dgc_k; PADDLE_ENFORCE(local_scopes_.size() > 0); - auto *scope = local_scopes_[0]; - auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get(); - auto var = local_scope->FindVar(var_name); + auto *scope = local_exec_scopes_[0]; + auto var = scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(var); auto tensor = var->Get().data(); return *tensor; @@ -151,8 +149,7 @@ bool SparseAllReduceOpHandle::IsEncoded() { auto step_name = g_dgc_rampup_begin_step; PADDLE_ENFORCE(local_scopes_.size() > 0); - auto *scope = local_scopes_[0]; - auto &local_scope = scope->FindVar(kLocalExecScopeName)->Get(); + auto *local_scope = local_exec_scopes_[0]; auto count_var = local_scope->FindVar(counter_name); auto step_var = local_scope->FindVar(step_name); if (count_var == nullptr || step_var == nullptr) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ac62f1dd83397a15830eae02c0ba00920a90dcfd..ed9d7d991f830428f79a56a440cb9c9a5ad86509 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -22,9 +22,11 @@ namespace framework { namespace details { ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph) : graph_(graph), local_scopes_(local_scopes), + local_exec_scopes_(local_exec_scopes), places_(places), fetch_ctxs_(places), strategy_(strategy), @@ -176,7 +178,8 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( ir::Node *fetch_node = graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation); - auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_); + auto *op = new FetchOpHandle(fetch_node, fetch_data, i, &local_scopes_, + &local_exec_scopes_); fetch_ops->emplace_back(op); for (auto &p : places_) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 6c1fb1c6c0a7b55cee89986c00bf650542520355..fe6ef95a135417c0c73cfb3c9a20af66dc5047e6 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { public: ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, + const std::vector &local_exec_scopes, const std::vector &places, ir::Graph *graph); @@ -71,6 +72,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { // be destroyed first. ir::Graph *graph_; std::vector local_scopes_; + std::vector local_exec_scopes_; + std::vector places_; platform::DeviceContextPool fetch_ctxs_; ExceptionHolder exception_holder_; diff --git a/paddle/fluid/framework/inplace_op_inference.h b/paddle/fluid/framework/inplace_op_inference.h index b5eb61f23e56fafca33e85ee4a288af53b9ceb2e..fdc0c2023cc2e3b3838ef0c66914f8c927cc18c9 100644 --- a/paddle/fluid/framework/inplace_op_inference.h +++ b/paddle/fluid/framework/inplace_op_inference.h @@ -48,35 +48,15 @@ class SingleOpInplaceInToOut : public InplaceOpInference { public: std::unordered_map operator()( const OpDesc& op_desc, bool use_cuda) const override { - PADDLE_ENFORCE(!op_desc.InputNames().empty(), - "Op inputs must not be empty"); - PADDLE_ENFORCE(!op_desc.OutputNames().empty(), - "Op outputs must not be empty"); + PADDLE_ENFORCE_EQ(op_desc.InputNames().size(), 1, + "Op inputs must be unique"); + PADDLE_ENFORCE_EQ(op_desc.OutputNames().size(), 1, + "Op outputs must be unique"); auto x_name = op_desc.InputNames().at(0); auto out_name = op_desc.OutputNames().at(0); return std::unordered_map{{x_name, out_name}}; } }; -/* - Gradient op. Inplace output use it's Input. - For example, Input@Grad->Input reuse strategy. - */ -class GradOpInplaceInToOut : public InplaceOpInference { - public: - std::unordered_map operator()( - const OpDesc& op_desc, bool use_cuda) const override { - std::unordered_map ret; - std::unordered_set output_names(op_desc.OutputNames().begin(), - op_desc.OutputNames().end()); - for (auto& input_name : op_desc.InputNames()) { - if (output_names.count(GradVarName(input_name))) { - ret.insert({input_name, GradVarName(input_name)}); - } - } - return ret; - } -}; - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt index 125cd462fa4b043464c314c97e090bbc0cb6d422..615245e7252fef69d269d56cf1b58ad58f7a83ee 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt +++ b/paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt @@ -16,3 +16,7 @@ cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_o cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper) cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper) + +cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle multi_devices_helper graph pass) + +cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass) diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..b5d17ef2978147253d2c1a8f38ba6b6181deec8b --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/buffer_shared_inplace_op_pass.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class BufferSharedInplaceOpPass : public MemoryReusePass { + protected: + std::string ReuseType() const override { return "inplace"; } + + void Run(Graph *graph) const override; +}; + +void BufferSharedInplaceOpPass::Run(Graph *graph) const { + const auto &last_live_ops = + Get>(kLastLiveOpsOfVars); + + bool use_cuda = Get(kUseCuda); + + // Step 1: Build a reverse map of last_live_ops + // i.e.: op -> vars + std::unordered_map> + candidate_ops; + for (auto &each_scope_ops : last_live_ops) { + for (auto &pair : each_scope_ops) { + // If variable has more than 1 last lived ops, this variable cannot + // be inplaced. + if (pair.second.size() != 1) { + continue; + } + + auto *op = *(pair.second.begin()); + const std::string &op_type = op->GetOp()->Type(); + const framework::OpDesc *op_desc = op->Node()->Op(); + PADDLE_ENFORCE_NOT_NULL(op_desc); + + auto &infer_inplace = OpInfoMap::Instance().Get(op_type).infer_inplace_; + if (!infer_inplace) { + continue; + } + + const std::string &var_name = pair.first; + auto in_nodes = this->FindNodesByName(var_name, op->Node()->inputs); + if (in_nodes.size() == 1) { + candidate_ops[op][var_name] = *in_nodes.begin(); + } + } + } + + // Step 2: Check which vars can be inplaced indeed + for (auto &op_vars_pair : candidate_ops) { + auto *op = op_vars_pair.first; + auto &vars = op_vars_pair.second; + + const std::string &op_type = op->GetOp()->Type(); + auto *op_desc = op->Node()->Op(); + + auto in_to_outs = + OpInfoMap::Instance().Get(op_type).infer_inplace_(*op_desc, use_cuda); + for (auto &pair : in_to_outs) { + auto &in_param = pair.first; + auto &in_args = op_desc->Input(in_param); + if (in_args.empty()) { + VLOG(4) << "Cannot inplace because Input(" << in_param + << ") is empty in " << op_type; + continue; + } + + auto &in_arg = in_args[0]; + auto iter = vars.find(in_arg); + if (iter == vars.end()) { + VLOG(4) << "Cannot inplace maybe because Input(" << in_param + << ")=" << in_arg << " is not lastly used in op " << op_type + << ", or it occurs multiple times in input or occurs in output"; + continue; + } + + ir::Node *in_node = iter->second; + + auto &out_param = pair.second; + auto &out_args = op_desc->Output(out_param); + + if (out_args.empty()) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ") is empty in " << op_type; + continue; + } + + auto &out_arg = out_args[0]; + auto out_nodes = this->FindNodesByName(out_arg, op->Node()->outputs); + if (out_nodes.size() != 1) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ")=" << out_arg << " occurs " << out_nodes.size() + << " time(s) in output of op " << op_type; + continue; + } + + auto *out_node = *out_nodes.begin(); + + auto &in_var_handle = in_node->Wrapper(); + auto &out_var_handle = out_node->Wrapper(); + + auto *in_var_handle_ptr = + dynamic_cast(&in_var_handle); + auto *out_var_handle_ptr = + dynamic_cast(&out_var_handle); + + if (in_var_handle_ptr == nullptr || out_var_handle_ptr == nullptr) { + continue; + } + + bool success = this->TryReuseVar(in_var_handle_ptr, out_var_handle_ptr); + if (success) { + VLOG(4) << "Inplace performed in op " << op_type << ": " + << in_var_handle_ptr->Name() << " -> " + << out_var_handle_ptr->Name() + << ". Debug String is: " << op->GetOp()->DebugString(); + } else { + VLOG(4) << "Inplace failed in op " << op_type << ": " + << in_var_handle_ptr->Name() << " -> " + << out_var_handle_ptr->Name(); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(buffer_shared_inplace_pass, + paddle::framework::ir::BufferSharedInplaceOpPass) + .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList) + .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars) + .RequirePassAttr(paddle::framework::ir::kUseCuda); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc index 1cdc97338ae8d0745e877071b7939c5c3d9c955c..bbef21908dcba4811fc785f7b2a9d0fe8bbbe023 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" namespace paddle { namespace framework { @@ -189,13 +190,9 @@ class EagerDeletionPass : public ir::Pass { }; void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { - auto &ref_cnts = - Get>(kRuntimeReferenceCount); - PADDLE_ENFORCE(ref_cnts.empty(), - "kRuntimeReferenceCount should be initialized here!"); + auto &var_infos = Get(kMemOptVarInfoMapList); const auto &vars = graph->Get(details::kGraphVars); - ref_cnts.resize(vars.size()); const auto &last_live_ops = Get>(kLastLiveOpsOfVars); @@ -224,10 +221,15 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { auto *eager_deletion_node = graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); + + std::unordered_set var_info; + for (auto &var_name : var_names) { + var_info.insert(var_infos[op->GetScopeIdx()].at(var_name).get()); + } + auto *eager_deletion_op = new details::EagerDeletionOpHandle( - eager_deletion_node, op->GetScope(), op->GetPlace(), var_names, - gcs.at(places[op->GetScopeIdx()]).get(), - &(ref_cnts[op->GetScopeIdx()])); + eager_deletion_node, op->GetScope(), op->GetPlace(), + std::move(var_info), gcs.at(places[op->GetScopeIdx()]).get()); auto it = std::find_if( op->Outputs().begin(), op->Outputs().end(), @@ -250,6 +252,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { graph->Get(details::kGraphDepVars) .emplace(dummy_leaf); eager_deletion_op->AddOutput(dummy_leaf); + + eager_deletion_op->SetDeviceContext( + places[op->GetScopeIdx()], + platform::DeviceContextPool::Instance().Get(places[op->GetScopeIdx()])); } VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction; @@ -273,7 +279,7 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const { } // namespace paddle REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass) - .RequirePassAttr(paddle::framework::ir::kRuntimeReferenceCount) + .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList) .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars) .RequirePassAttr(paddle::framework::ir::kAllPlaces) .RequirePassAttr(paddle::framework::ir::kGarbageCollector); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/inplace_op_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/inplace_op_pass.cc index f57e7bb2301b2b5115de51138f6c531fe94b2bd2..1935f5e31b2b93db05bb5ffe8c825e3de0244913 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/inplace_op_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/inplace_op_pass.cc @@ -106,6 +106,9 @@ class InplacePass : public ir::Pass { // Check whether var is the last version one in SSA graph bool IsLastVersionVar(ir::Node *var) const; + // Check whether var is the first version one in SSA graph + bool IsFirstVersionVar(ir::Node *var) const; + // Check whether all `ops` is the preceding ops of `op` bool CheckOpDeps(ir::Node *op, const std::vector &ops) const; @@ -155,6 +158,10 @@ bool InplacePass::IsSkipVar(const std::string &var_name) const { return skip_vars_.count(var_name) > 0; } +bool InplacePass::IsFirstVersionVar(ir::Node *var) const { + return AllVersionVars(var->Name())->front() == var; +} + bool InplacePass::IsLastVersionVar(ir::Node *var) const { return AllVersionVars(var->Name())->back() == var; } @@ -429,13 +436,19 @@ void InplacePass::ApplyImpl(ir::Graph *graph) const { } if (!FindNodesByName(out_arg, op_node->inputs).empty()) { - VLOG(4) << "Cannot inplace because Output(" << in_param + VLOG(4) << "Cannot inplace because Output(" << out_param << ")=" << out_arg << " occurs in input of op " << op_type; continue; } auto *out_node = *out_nodes.begin(); + if (!IsFirstVersionVar(out_node)) { + VLOG(4) << "Cannot inplace because Output(" << out_param + << ")=" << out_arg << " does not occur first in op " << op_type; + continue; + } + if (!NodeCanReused(out_node)) { VLOG(4) << "Cannot inplace because Output(" << out_param << ")=" << out_arg << " is not reusable in " << op_type; diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h new file mode 100644 index 0000000000000000000000000000000000000000..0ceac79139ae36ca88b63c9611f2ca3c5e986197 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h @@ -0,0 +1,105 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MemOptVarInfo { + public: + MemOptVarInfo(const std::string &name, size_t ref_cnt) : name_(name) { + SetRefCnt(ref_cnt); + } + + bool DecreaseRefCnt() { + return ref_cnt_ == 1 || (runtime_ref_cnt_.fetch_sub(1) == 1); + } + + void ResetRuntimeRefCnt() { runtime_ref_cnt_ = ref_cnt_; } + + void SetRefCnt(size_t ref_cnt) { + PADDLE_ENFORCE_GE(ref_cnt, 1, + "Reference count must be larger than or equal to 1"); + ref_cnt_ = ref_cnt; + runtime_ref_cnt_ = ref_cnt; + } + + bool IsSkipped() const { return skipped_; } + + void SetSkip(bool skipped) { skipped_ = skipped; } + + const std::string &Name() const { return name_; } + + private: + std::string name_; + size_t ref_cnt_; + std::atomic runtime_ref_cnt_; + bool skipped_{false}; +}; + +using MemOptVarInfoMapList = std::vector< + std::unordered_map>>; + +class SkipMemOptVarsGuard { + public: + SkipMemOptVarsGuard(MemOptVarInfoMapList *list, + const std::vector &vars, + bool need_reset_ref_cnt) + : list_(list), need_reset_ref_cnt_(need_reset_ref_cnt) { + if (!list_) return; + + skip_vars_.reserve(vars.size() * list->size()); + for (auto &var : vars) { + for (auto &map : *list_) { + auto iter = map.find(var); + if (iter != map.end() && !iter->second->IsSkipped()) { + iter->second->SetSkip(true); + skip_vars_.emplace_back(iter->second.get()); + } + } + } + } + + ~SkipMemOptVarsGuard() { + for (auto *var : skip_vars_) { + var->SetSkip(false); + } + + if (list_ && need_reset_ref_cnt_) { + for (auto &map : *list_) { + for (auto &pair : map) { + pair.second->ResetRuntimeRefCnt(); + } + } + } + } + + private: + MemOptVarInfoMapList *list_; + bool need_reset_ref_cnt_; + std::vector skip_vars_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a8e2530e7c74319bd4a47208d59d5ca737ff469 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.cc @@ -0,0 +1,286 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h" +#include +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +// Each ShareTensorBufferOpHandle should only have one pending +// ComputationOpHandle +static details::ComputationOpHandle *GetUniquePendingComputationOpHandle( + details::ShareTensorBufferOpHandle *share_tensor_op) { + details::ComputationOpHandle *result_op = nullptr; + for (Node *out_var : share_tensor_op->Node()->outputs) { + for (Node *pending_op : out_var->outputs) { + auto &op = pending_op->Wrapper(); + auto *compute_op = dynamic_cast(&op); + PADDLE_ENFORCE_NOT_NULL(compute_op); + + if (result_op == nullptr) { + result_op = compute_op; + } else { + PADDLE_ENFORCE_EQ(result_op, compute_op); + } + } + } + + PADDLE_ENFORCE_NOT_NULL(result_op); + return result_op; +} + +void MemoryReusePass::ApplyImpl(Graph *graph) const { + graph_ = graph; + all_vars_ = &(graph_->Get(details::kGraphVars)); + var_infos_ = &(Get(kMemOptVarInfoMapList)); + last_live_ops_of_vars_ = + &(Get>(kLastLiveOpsOfVars)); + + reused_var_names_.resize(all_vars_->size()); + var_descs_.resize(all_vars_->size()); + + // Collect the existing ShareTensorBufferOpHandles. + // This is because (1) we want to reuse the existing + // ShareTensorBufferOpHandles to avoid inserting too many ops; + // (2) what is more important, a variable cannot be reused + // by two different variables, which may cause wrong calculation + // results. We have to know which variables have been reused. + CollectShareTensorBufferOpHandles(); + CollectReusedVars(); + Run(graph); + + std::map op_num; + for (auto &pair : ops_) { + ++op_num[pair.first->GetScopeIdx()]; + } + + for (auto &pair : op_num) { + VLOG(2) << "Create " << pair.second + << " ShareTensorBufferOpHandles in Scope " << pair.first; + } +} + +bool MemoryReusePass::TryReuseVar(details::VarHandle *in_var, + details::VarHandle *out_var) const { + auto *op = + dynamic_cast(out_var->GeneratedOp()); + PADDLE_ENFORCE_NOT_NULL(op); + if (IsVarsReusable(in_var, out_var)) { + AddReuseVar(op, in_var, out_var); + return true; + } else { + return false; + } +} + +std::unordered_set MemoryReusePass::FindNodesByName( + const std::string &name, const std::vector &nodes) const { + std::unordered_set ret; + for (auto *node : nodes) { + if (node->Name() == name) { + ret.insert(node); + } + } + return ret; +} + +VarDesc *MemoryReusePass::GetVarDesc(details::VarHandle *var) const { + auto iter = var_descs_[var->scope_idx()].find(var->Name()); + if (iter == var_descs_[var->scope_idx()].end()) { + PADDLE_ENFORCE((*all_vars_)[var->scope_idx()].count(var->Name()), + "Variable %s not found", var->Name()); + auto *desc = + TryGetLatestVarDesc((*all_vars_)[var->scope_idx()].at(var->Name())); + PADDLE_ENFORCE_NOT_NULL(desc); + var_descs_[var->scope_idx()].emplace(var->Name(), desc); + return desc; + } else { + return iter->second; + } +} + +void MemoryReusePass::CollectShareTensorBufferOpHandles() const { + auto all_ops = FilterByNodeWrapper(*graph_); + for (auto *op : all_ops) { + auto *share_buffer_op = + dynamic_cast(op); + if (share_buffer_op != nullptr) { + auto *compute_op = GetUniquePendingComputationOpHandle(share_buffer_op); + PADDLE_ENFORCE(ops_.count(compute_op) == 0); + ops_.emplace(compute_op, share_buffer_op); + } + } +} + +void MemoryReusePass::CollectReusedVars() const { + for (auto &pair : ops_) { + auto reused_vars = pair.second->ReusedVarSet(); + reused_var_names_[pair.first->GetScopeIdx()].insert(reused_vars.begin(), + reused_vars.end()); + } +} + +bool MemoryReusePass::IsVarAlreadyReused(details::VarHandle *var) const { + return reused_var_names_[var->scope_idx()].count(var->Name()) > 0; +} + +details::ShareTensorBufferOpHandle * +MemoryReusePass::InsertShareTensorBufferOpHandleToGraph( + details::ComputationOpHandle *op) const { + auto *buffer_share_node = + graph_->CreateEmptyNode("buffer_share", ir::Node::Type::kOperation); + + auto *buffer_share_op = new details::ShareTensorBufferOpHandle( + buffer_share_node, op->GetScope(), op->GetScopeIdx(), op->GetOp()->Type(), + {}, {}); + + buffer_share_op->SetDeviceContext( + op->GetPlace(), + platform::DeviceContextPool::Instance().Get(op->GetPlace())); + + // Inputs of `buffer_share_op` should be all inputs of `op` + for (auto *in_var : op->Inputs()) { + buffer_share_op->AddInput(in_var); + } + + // Add a dep_var to resolve write-after-write data hazard between + // `buffer_share_op` and `op`. + auto *dep_var = new details::DummyVarHandle(graph_->CreateControlDepVar()); + graph_->Get(details::kGraphDepVars).emplace(dep_var); + op->AddInput(dep_var); + buffer_share_op->AddOutput(dep_var); + + ops_.emplace(op, buffer_share_op); + return buffer_share_op; +} + +bool MemoryReusePass::IsVarsReusable(details::VarHandle *in_var, + details::VarHandle *out_var) const { + const auto in_name = in_var->Name(); + const auto out_name = out_var->Name(); + + if (in_name == out_name) { + return false; + } + + if (in_name == kEmptyVarName || out_name == kEmptyVarName) { + return false; + } + + if (IsVarAlreadyReused(in_var)) { + return false; + } + + // out_var must be the first version!!! + auto out_var_iter = (*all_vars_)[out_var->scope_idx()].find(out_name); + PADDLE_ENFORCE(out_var_iter != (*all_vars_)[out_var->scope_idx()].end() && + !out_var_iter->second.empty(), + "Cannot find variable %s", out_name); + + if (out_var_iter->second[0] != out_var) { + return false; + } + + const VarDesc *in_var_desc = GetVarDesc(in_var); + const VarDesc *out_var_desc = GetVarDesc(out_var); + + if (in_var_desc->Persistable() || out_var_desc->Persistable()) { + return false; + } + + if (in_var_desc->GetType() != proto::VarType::LOD_TENSOR || + out_var_desc->GetType() != proto::VarType::LOD_TENSOR) { + return false; + } + + if (!FindNodesByName(in_name, out_var->GeneratedOp()->Node()->outputs) + .empty()) { + return false; + } + + if (!FindNodesByName(out_name, out_var->GeneratedOp()->Node()->inputs) + .empty()) { + return false; + } + + auto all_input_args = + out_var->GeneratedOp()->Node()->Op()->InputArgumentNames(); + if (std::count(all_input_args.begin(), all_input_args.end(), in_name) > 1) { + return false; + } + + return true; +} + +void MemoryReusePass::AddReuseVar(details::ComputationOpHandle *op, + details::VarHandle *in_var, + details::VarHandle *out_var) const { + PADDLE_ENFORCE((*var_infos_)[op->GetScopeIdx()].count(in_var->Name()) > 0, + "%s does not in mem-opt var infos", in_var->Name()); + + if (ops_.count(op) == 0) { + InsertShareTensorBufferOpHandleToGraph(op); + } + + auto *share_buffer_op = ops_[op]; + + auto &all_input_vars = share_buffer_op->Inputs(); + bool has_input = std::find(all_input_vars.begin(), all_input_vars.end(), + in_var) != all_input_vars.end(); + + if (!has_input) { + share_buffer_op->AddInput(in_var); + } + + share_buffer_op->Add( + (*var_infos_)[op->GetScopeIdx()].at(in_var->Name()).get(), + out_var->Name()); + reused_var_names_[op->GetScopeIdx()].insert(in_var->Name()); + + UpdateLastLiveOpOfVar(op, in_var, out_var); +} + +// 1. Set last living op of in_var to be any last living op of out_var +// 2. Set reference count of in_var to be 1 +void MemoryReusePass::UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, + details::VarHandle *in_var, + details::VarHandle *out_var) const { + size_t scope_idx = op->GetScopeIdx(); + auto out_var_op_iter = + (*last_live_ops_of_vars_)[scope_idx].find(out_var->Name()); + PADDLE_ENFORCE(out_var_op_iter != (*last_live_ops_of_vars_)[scope_idx].end(), + "Cannot find variable %s", out_var->Name()); + PADDLE_ENFORCE(!out_var_op_iter->second.empty()); + + auto &last_live_ops_of_in_var = + (*last_live_ops_of_vars_)[scope_idx][in_var->Name()]; + last_live_ops_of_in_var.clear(); + last_live_ops_of_in_var.insert(*(out_var_op_iter->second.begin())); + + auto in_var_info_iter = (*var_infos_)[scope_idx].find(in_var->Name()); + PADDLE_ENFORCE(in_var_info_iter != (*var_infos_)[scope_idx].end(), + "Cannot find variable %s", in_var->Name()); + + in_var_info_iter->second->SetRefCnt(1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f706b48e2e76dbb14a7f2eb44e66fd1628136f21 --- /dev/null +++ b/paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h @@ -0,0 +1,128 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/share_tensor_buffer_op_handle.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * MemoryReusePass is the base class of InplacePass and MemoryOptimizePass. + * + * Unlike the legacy Python API fluid.memory_optimize() which changes + * variable names in the program/graph, MemoryReusePass inserts + * ShareTensorBufferOpHandle into the graph. It is because if we use the + * way of changing variable names: + * + * 1. There are so many corner cases we should skip. For example, (1) variables + * that relates to send/recv ops cannot be renamed (otherwise, pserver + * and trainer cannot find the matching variables), (2) ins/outs of ops + * containing sub-blocks cannot be optimized, (3) variables inside + * op_role_vars cannot be renamed. + * + * 2. It is very difficult to avoid reusing variables that users want to fetch. + * This is because the memory-optimize passes/transpiler runs before users + * fetch, i.e., exe.run(...). We cannot know what users want to fetch in the + * future. As a result, we have to set var.persistable = True before + * applying memory-optimize passes/transpiler, which is rather ugly and not + * friendly to users. + * + * 3. Dim and LoD of the reused variable would be changed, which may result + * in potential errors in InferShape stage of the following ops. What's + * more, it makes that we cannot use the information from + * NoNeedBufferVarsInference. + * + * Considering the drawbacks of the former renaming strategy, we design a + * novel memory-optimize pass to fix these issues. Whether in-place is + * performed can be decided during run-time. ShareTensorBufferOpHandle + * would only share tensor buffers between in/out, never rename variable, + * and not change dim and LoD of variable. If users want to fetch a certain + * variable, we can skip in-place during run-time. + * + * The only concern on speed performance may be: there are too many + * ShareTensorBufferOpHandles in the graph. This can be avoided by moving + * tensor buffer sharing in each ComputationOpHandle::Run() method. We need + * a pass to clean all ShareTensorBufferOpHandles and move sharing to + * ComputationOpHandle::Run() in the future. + */ +class MemoryReusePass : public Pass { + protected: + void ApplyImpl(Graph *graph) const final; + + virtual void Run(Graph *graph) const = 0; + + virtual std::string ReuseType() const = 0; + + bool TryReuseVar(details::VarHandle *in_var, + details::VarHandle *out_var) const; + + std::unordered_set FindNodesByName( + const std::string &name, const std::vector &nodes) const; + + size_t ScopeNum() const { return all_vars_->size(); } + + private: + VarDesc *GetVarDesc(details::VarHandle *var) const; + + bool IsVarsReusable(details::VarHandle *in_var, + details::VarHandle *out_var) const; + + bool IsVarAlreadyReused(details::VarHandle *var) const; + + details::ShareTensorBufferOpHandle *InsertShareTensorBufferOpHandleToGraph( + details::ComputationOpHandle *op) const; + + void CollectShareTensorBufferOpHandles() const; + + void CollectReusedVars() const; + + void AddReuseVar(details::ComputationOpHandle *op, details::VarHandle *in_var, + details::VarHandle *out_var) const; + + void UpdateLastLiveOpOfVar(details::ComputationOpHandle *op, + details::VarHandle *in_var, + details::VarHandle *out_var) const; + + private: + mutable Graph *graph_; + mutable details::GraphVars *all_vars_; + mutable MemOptVarInfoMapList *var_infos_; + mutable std::vector *last_live_ops_of_vars_; + + mutable std::unordered_map + ops_; + + mutable std::vector> reused_var_names_; + + mutable std::vector> var_descs_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc index b927da2c3fb189dd5bb96371b033019432d5679a..e9114156d01968ed5cefd56a1998ded2670a83d8 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass.cc @@ -26,6 +26,7 @@ #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/pass.h" @@ -295,18 +296,18 @@ ExtractComputationOpFromLastLivedVar(details::VarHandle *var, size_t scope_idx, } void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { - auto &ref_cnts = Get>(kGlobalReferenceCount); + auto &var_infos = Get(kMemOptVarInfoMapList); auto &last_live_ops_of_vars = Get>(kLastLiveOpsOfVars); - PADDLE_ENFORCE(last_live_ops_of_vars.empty() && ref_cnts.empty(), + PADDLE_ENFORCE(last_live_ops_of_vars.empty() && var_infos.empty(), "Last Live Ops and Reference Counts of vars should be " "initialized at here."); const auto &vars = graph->Get(details::kGraphVars); last_live_ops_of_vars.resize(vars.size()); - ref_cnts.resize(vars.size()); + var_infos.resize(vars.size()); ShrinkDepsOpFunctor shrink_func( ir::FilterByNodeWrapper(*graph)); @@ -359,7 +360,8 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { var_name); VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; - ref_cnts[i].emplace(var_name, result.size()); + var_infos[i][var_name].reset( + new MemOptVarInfo(var_name, result.size())); last_live_ops_of_vars[i].emplace(var_name, std::move(result)); break; } @@ -375,5 +377,5 @@ void ReferenceCountPass::ApplyImpl(ir::Graph *graph) const { } // namespace paddle REGISTER_PASS(reference_count_pass, paddle::framework::ir::ReferenceCountPass) - .RequirePassAttr(paddle::framework::ir::kGlobalReferenceCount) + .RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList) .RequirePassAttr(paddle::framework::ir::kLastLiveOpsOfVars); diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h index d5e6fa17fd4e85f3f7bcb2c171d7e20a6ffc583c..3433694b052fdb04270db9d0a8bbbf3e24daf879 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h +++ b/paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h @@ -33,16 +33,10 @@ class VarDesc; namespace ir { -using ReferenceCountMap = std::unordered_map; - -using AtomicReferenceCountMap = - std::unordered_map>; - using GarbageCollectorMap = std::map>; -const char kGlobalReferenceCount[] = "global_reference_count"; -const char kRuntimeReferenceCount[] = "runtime_reference_count"; +const char kMemOptVarInfoMapList[] = "mem_opt_var_info_map_list"; const char kGarbageCollector[] = "garbage_collector"; const char kAllPlaces[] = "all_places"; diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 09a4613ba5484470f87b17b8e1977a7107570881..2e1771360d1e8c133f50292f4c11a74b8bc3cd2e 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -89,7 +89,12 @@ class Node { // Return a reference to the `wrapper`. template T& Wrapper() { - return *boost::any_cast(wrapper_); + try { + return *boost::any_cast(wrapper_); + } catch (boost::bad_any_cast&) { + PADDLE_THROW("Invalid wrapper type error, expected %s, actual %s", + typeid(T).name(), wrapper_type_.name()); + } } // Test if the Node is wrapped by type T. diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index f675742f6de85b111ab0ab98a8460aff10a7bc36..26e6fb6301a8fe3708411bac658eb7a99cd43759 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -22,11 +22,13 @@ limitations under the License. */ #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/platform/profiler.h" @@ -76,24 +78,10 @@ class ParallelExecutorPrivate { } } - ir::Graph *PrepareGCAndRefCnts(ir::Graph *graph, size_t max_memory_size); + ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph); inline bool HasGarbageCollectors() const { return !gcs_.empty(); } - void ResetRuntimeReferenceCount(const std::vector &fetch_tensors, - const std::string &fetched_var_name) { - for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) { - for (auto &pair : global_ref_cnts_[i]) { - runtime_ref_cnts_[i][pair.first] = pair.second; - } - - for (auto &fetch_name : fetch_tensors) { - runtime_ref_cnts_[i].erase(fetch_name); - } - runtime_ref_cnts_[i].erase(fetched_var_name); - } - } - #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) { VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_ @@ -201,12 +189,20 @@ class ParallelExecutorPrivate { } #endif + inline bool IsPersistable(const std::string &name) const { + auto iter = is_persistable_.find(name); + return iter != is_persistable_.end() && iter->second; + } + BuildStrategy build_strategy_; std::vector places_; std::vector local_scopes_; + std::vector local_exec_scopes_; Scope *global_scope_; // not owned std::unique_ptr executor_; + std::unordered_map is_persistable_; + #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLCommunicator *nccl_ctxs_{nullptr}; #endif @@ -215,16 +211,37 @@ class ParallelExecutorPrivate { bool use_all_reduce_; size_t nranks_; - // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and - // then keeps unchanged - // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_ - std::vector global_ref_cnts_; - std::vector runtime_ref_cnts_; + ir::MemOptVarInfoMapList mem_opt_var_infos_; ir::GarbageCollectorMap gcs_; }; -ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( - ir::Graph *graph, size_t max_memory_size) { +ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { + std::vector last_live_ops_of_vars; + + auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); + ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); + graph = ref_cnt_pass->Apply(graph); + VLOG(10) << "ReferenceCountPass Applied"; + + if (build_strategy_.enable_inplace_) { + auto inplace_pass = + ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass"); + inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); + inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); + inplace_pass->SetNotOwned(ir::kUseCuda, &use_cuda_); + VLOG(10) << "Start to apply buffer_shared_inplace_pass"; + graph = inplace_pass->Apply(graph); + VLOG(10) << "buffer_shared_inplace_pass Applied"; + } + + // TODO(zjl): refactor MemoryOptimizePass as well!!! + + if (GetEagerDeletionThreshold() < 0) { + return graph; + } + size_t max_memory_size = static_cast(GetEagerDeletionThreshold()); + for (size_t i = 0; i < places_.size(); ++i) { auto &place = places_[i]; if (gcs_.count(place) > 0) { @@ -258,19 +275,10 @@ ir::Graph *ParallelExecutorPrivate::PrepareGCAndRefCnts( } if (!gcs_.empty()) { - std::vector last_live_ops_of_vars; - - auto ref_cnt_pass = - ir::PassRegistry::Instance().Get("reference_count_pass"); - ref_cnt_pass->SetNotOwned(ir::kGlobalReferenceCount, &global_ref_cnts_); - ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); - graph = ref_cnt_pass->Apply(graph); - VLOG(10) << "ReferenceCountPass Applied"; - auto eager_deletion_pass = ir::PassRegistry::Instance().Get("eager_deletion_pass"); - eager_deletion_pass->SetNotOwned(ir::kRuntimeReferenceCount, - &runtime_ref_cnts_); + eager_deletion_pass->SetNotOwned(ir::kMemOptVarInfoMapList, + &mem_opt_var_infos_); eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_); eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); @@ -386,9 +394,8 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, // same communicators. auto *nccl_ctxs = member_->nccl_ctxs_->GetSyncBatchNormCtx(scope, member_->places_); + auto &pool = platform::DeviceContextPool::Instance(); for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); auto *dev_ctx = static_cast( pool.Get(member_->places_[dev_id])); auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]); @@ -456,13 +463,7 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } #endif - auto max_memory_size = GetEagerDeletionThreshold(); - VLOG(10) << "Eager Deletion Threshold " - << static_cast(max_memory_size) / (1 << 30); - if (max_memory_size >= 0) { - graph = member_->PrepareGCAndRefCnts(graph, - static_cast(max_memory_size)); - } + graph = member_->ApplyMemoryOptimizePass(graph); async_graphs[0] = graph; @@ -475,6 +476,9 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, var_infos.back().name_ = node->Var()->Name(); var_infos.back().type_ = node->Var()->GetType(); var_infos.back().persistable_ = node->Var()->Persistable(); + + member_->is_persistable_.emplace(node->Var()->Name(), + node->Var()->Persistable()); } } @@ -493,17 +497,34 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, } } + std::unordered_map scope_map; + for (auto *scope : member_->local_scopes_) { + auto &local_exec_scope = scope->NewScope(); + member_->local_exec_scopes_.emplace_back(&local_exec_scope); + scope_map.emplace(scope, &local_exec_scope); + } + + PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), + member_->local_exec_scopes_.size()); + + std::vector final_graphs; + if (member_->build_strategy_.async_mode_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); + exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, + member_->places_, async_graphs)); + final_graphs = async_graphs; } else if (member_->build_strategy_.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #ifdef PADDLE_WITH_CUDA // TODO(Yancey1989): Remove passing in the main_program when // allreduce_seq_pass doesn't need it as the attr. - member_->executor_.reset(new details::ParallelSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, graph)); + auto *pg_exe = new details::ParallelSSAGraphExecutor( + exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, + member_->places_, graph); + final_graphs = pg_exe->Graphs(); + member_->executor_.reset(pg_exe); #else PADDLE_THROW( "Paddle should be compiled with CUDA for ParallelGraph Execution."); @@ -512,19 +533,29 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, if (exec_strategy.type_ == ExecutionStrategy::kDefault) { VLOG(3) << "use ThreadedSSAGraphExecutor"; member_->executor_.reset(new details::ThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, graph)); + exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, + member_->places_, graph)); } else { VLOG(3) << "use FastThreadedSSAGraphExecutor"; member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, graph)); + exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, + member_->places_, graph)); } + final_graphs.emplace_back(graph); } VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; if (!member_->build_strategy_.async_mode_) { member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( - exec_strategy, member_->local_scopes_, std::move(var_infos), - member_->places_, std::move(member_->executor_))); + exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, + std::move(var_infos), member_->places_, std::move(member_->executor_))); + } + + for (auto *g : final_graphs) { + auto ops = ir::FilterByNodeWrapper(*g); + for (auto *op : ops) { + op->SetLocalExecScopes(scope_map); + } } } @@ -616,10 +647,9 @@ void ParallelExecutor::Run(const std::vector &fetch_tensors, #endif platform::RecordBlock b(0); - if (member_->HasGarbageCollectors()) { - platform::RecordEvent event("PrepareGarbageCollectors"); - member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name); - } + + ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors, + member_->HasGarbageCollectors()); VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; auto fetch_data = member_->executor_->Run(fetch_tensors); @@ -633,9 +663,13 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( for (size_t i = 0; i < tensors.size(); ++i) { auto &map = tensors[i]; - auto *scope = member_->local_scopes_[i]; for (auto &pair : map) { - auto *trg = scope->Var(pair.first)->GetMutable(); + bool is_persistable = member_->IsPersistable(pair.first); + auto *feed_scope = is_persistable ? member_->local_scopes_[i] + : member_->local_exec_scopes_[i]; + auto *feed_var = feed_scope->Var(pair.first); + + auto *trg = feed_var->GetMutable(); trg->ShareDataWith(pair.second); trg->set_lod(pair.second.lod()); } @@ -644,7 +678,7 @@ void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors) { - for (auto pair : tensors) { + for (auto &pair : tensors) { auto lod_tensors = pair.second.SplitLoDTensor(member_->places_); if (member_->places_.size() != lod_tensors.size()) { bool is_cpu_place = platform::is_cpu_place(member_->places_.front()); @@ -661,10 +695,14 @@ void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( } PADDLE_THROW(error_info); } + + bool is_persistable = member_->IsPersistable(pair.first); for (size_t j = 0; j < member_->places_.size(); ++j) { - // TODO(panxy0718): Do I need to delete this var? - auto t = - member_->local_scopes_[j]->Var(pair.first)->GetMutable(); + auto *feed_scope = is_persistable ? member_->local_scopes_[j] + : member_->local_exec_scopes_[j]; + auto *feed_var = feed_scope->Var(pair.first); + + auto t = feed_var->GetMutable(); t->ShareDataWith(lod_tensors[j]); t->set_lod(lod_tensors[j].lod()); } @@ -724,3 +762,4 @@ bool ParallelExecutor::EnableParallelGraphExecution( USE_PASS(reference_count_pass); USE_PASS(eager_deletion_pass); +USE_PASS(buffer_shared_inplace_pass); diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index 6943fe62b915e0707dfe40ecbda90f61464338cf..1ac800c9596b174d5d1187802265a766fdd32e74 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/details/build_strategy.h" #include "paddle/fluid/framework/details/execution_strategy.h" +#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/program_desc.h" diff --git a/paddle/fluid/framework/scope.cc b/paddle/fluid/framework/scope.cc index 49e22a5ad3093c2d61d0ef513974c9938e287729..afafff5218ccf95fdc4baf7282d4f2757a74ac9c 100644 --- a/paddle/fluid/framework/scope.cc +++ b/paddle/fluid/framework/scope.cc @@ -200,6 +200,17 @@ Variable* Scope::FindVarLocally(const std::string& name) const { return nullptr; } +void Scope::EraseVarsExcept(const std::unordered_set& vars) { + SCOPE_VARS_WRITER_LOCK + for (auto iter = vars_.begin(); iter != vars_.end();) { + if (vars.count(iter->second.get()) != 0) { + ++iter; + } else { + vars_.erase(iter++); + } + } +} + std::string GenScopeTreeDebugInfo(Scope* root) { std::stringstream os; diff --git a/paddle/fluid/framework/scope.h b/paddle/fluid/framework/scope.h index 5f3d106e091ace05cfbdbbde2d79d48fe01b4a38..9de2963234d9020afa44706860e947047ab69534 100644 --- a/paddle/fluid/framework/scope.h +++ b/paddle/fluid/framework/scope.h @@ -22,6 +22,7 @@ extern "C" { #include #include #include +#include #include #include @@ -66,6 +67,9 @@ class Scope { void EraseVars(const std::vector& var_names); + // Erase all variables except the given `vars` + void EraseVarsExcept(const std::unordered_set& vars); + /// Find a variable in the scope or any of its ancestors. Returns /// nullptr if cannot find. /// Caller doesn't own the returned Variable. diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 1ab75e3325740a30c9233b4cef660a869368112a..8fffecfa0e157768a00db893595bb6df4dc51a9d 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -149,7 +149,15 @@ class Tensor { void set_layout(const DataLayout layout) { layout_ = layout; } - void clear() { holder_ = nullptr; } + void clear() { + holder_ = nullptr; + offset_ = 0; + } + + void ShareBufferWith(const Tensor& tensor) { + holder_ = tensor.holder_; + offset_ = tensor.offset_; + } const std::shared_ptr& Holder() const { return holder_; } size_t offset() const { return offset_; } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index acd100a8a69779dab3b452cd7e2b1e4ff8765591..943c6f80ebdab9340b12826d366b2c8b3e76491b 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -751,6 +751,14 @@ class SquareDoubleGradMaker } }; +class ActivationGradOpInplaceInference : public framework::InplaceOpInference { + public: + std::unordered_map operator()( + const framework::OpDesc& op_desc, bool use_cuda) const override { + return {{framework::GradVarName("Out"), framework::GradVarName("X")}}; + } +}; + } // namespace operators } // namespace paddle @@ -765,11 +773,8 @@ namespace plat = paddle::platform; std::conditional>(), \ ::paddle::framework::SingleOpInplaceInToOut, \ void>::type); \ - REGISTER_OPERATOR( \ - KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ - std::conditional>(), \ - ::paddle::framework::SingleOpInplaceInToOut, \ - void>::type) + REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ + ops::ActivationGradOpInplaceInference); #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \ grad_functor) \ @@ -794,7 +799,7 @@ REGISTER_OPERATOR( ops::ActivationGradOpDescMaker::FwdDeps()>, paddle::framework::SingleOpInplaceInToOut); REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad, - paddle::framework::SingleOpInplaceInToOut, + ops::ActivationGradOpInplaceInference, ops::ReluDoubleGradMaker); REGISTER_OPERATOR( relu_grad_grad, @@ -819,7 +824,7 @@ REGISTER_OPERATOR( ops::ActivationGradOpDescMaker::FwdDeps()>, paddle::framework::SingleOpInplaceInToOut); REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad, - paddle::framework::SingleOpInplaceInToOut, + ops::ActivationGradOpInplaceInference, ops::LeakyReluDoubleGradMaker); REGISTER_OPERATOR( leaky_relu_grad_grad, @@ -843,7 +848,7 @@ REGISTER_OPERATOR( ops::ActivationGradOpDescMaker::FwdDeps()>, paddle::framework::SingleOpInplaceInToOut); REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad, - paddle::framework::SingleOpInplaceInToOut, + ops::ActivationGradOpInplaceInference, ops::SqrtDoubleGradMaker); REGISTER_OPERATOR( sqrt_grad_grad, @@ -865,7 +870,7 @@ REGISTER_OPERATOR( ops::ActivationGradOpDescMaker::FwdDeps()>, paddle::framework::SingleOpInplaceInToOut); REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad, - paddle::framework::SingleOpInplaceInToOut, + ops::ActivationGradOpInplaceInference, ops::SquareDoubleGradMaker); REGISTER_OPERATOR( square_grad_grad, diff --git a/paddle/fluid/operators/sum_op.cu b/paddle/fluid/operators/sum_op.cu index 790626a59d0cd19ba0ccf463b1b270e629617078..ba874549ce35fcdfb7026e3368b8736460069ae2 100644 --- a/paddle/fluid/operators/sum_op.cu +++ b/paddle/fluid/operators/sum_op.cu @@ -115,8 +115,15 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { auto *out = context.Output("Out"); bool in_place = in_vars[0] == context.OutputVar("Out"); + if (!in_place) { - out->mutable_data(context.GetPlace()); + auto *out_ptr = out->mutable_data(context.GetPlace()); + if (in_num >= 1 && in_vars[0]->IsType()) { + auto &in_0_tensor = in_vars[0]->Get(); + if (in_0_tensor.numel() > 0) { + in_place = (in_0_tensor.data() == out_ptr); + } + } } // Sum of two tensors diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index 0d60947971ca441b8f6785a7724e0a530e8a8e92..7a3fecace45e053bda736133e8d8a95060074fb8 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -128,10 +128,15 @@ class SumKernel : public framework::OpKernel { bool in_place = out_var == in_vars[0]; if (out_var->IsType()) { - auto *out = context.Output("Out"); - if (!in_place) { - out->mutable_data(context.GetPlace()); + auto *out = out_var->GetMutable(); + auto *out_ptr = out->mutable_data(context.GetPlace()); + if (in_num >= 1 && in_vars[0]->IsType()) { + auto &in_0_tensor = in_vars[0]->Get(); + if (in_0_tensor.numel() > 0) { + in_place = (in_0_tensor.data() == out_ptr); + } } + auto result = EigenVector::Flatten(*out); auto &place = *context.template device_context().eigen_device(); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index b346c62d811472c2256a839eb29f257ad9010e31..4dbafc08b93d7acd7b30b04006499b7e244116c0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1549,6 +1549,13 @@ All parameter, weight, gradient are variables in Paddle. "enable_inplace", [](const BuildStrategy &self) { return self.enable_inplace_; }, [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) + .def_property("_use_legacy_memory_optimize_strategy", + [](const BuildStrategy &self) { + return self.use_legacy_memory_optimize_strategy_; + }, + [](BuildStrategy &self, bool b) { + self.use_legacy_memory_optimize_strategy_ = b; + }) .def_property( "fuse_all_reduce_ops", [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 4956d9387554f9ba98721ccb60467951629dba4c..a13114577bdf910c85accad6f27929b0c0393107 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -211,34 +211,9 @@ class CompiledProgram(object): if self._program: if self._program._is_mem_optimized: self._build_strategy.memory_optimize = False - self._build_strategy.enable_inplace = False - elif not self._build_strategy.memory_optimize or not self._build_strategy.enable_inplace: - # remind the user to try our memmory optimize strategy - six.print_( - """ - You can try our memory optimize feature to save your memory usage: - # create a build_strategy variable to set memory optimize option - build_strategy = compiler.BuildStrategy() - build_strategy.enable_inplace = True - build_strategy.memory_optimize = True - - # pass the build_strategy to with_data_parallel API - compiled_prog = compiler.CompiledProgram(main).with_data_parallel( - loss_name=loss.name, build_strategy=build_strategy) - - !!! Memory optimize is our experimental feature !!! - some variables may be removed/reused internal to save memory usage, - in order to fetch the right value of the fetch_list, please set the - persistable property to true for each variable in fetch_list - - # Sample - conv1 = fluid.layers.conv2d(data, 4, 5, 1, act=None) - # if you need to fetch conv1, then: - conv1.persistable = True - - """, - file=sys.stderr) + if self._build_strategy.memory_optimize: + self._build_strategy._use_legacy_memory_optimize_strategy = True return self def with_inference_optimize(self, config): diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index fe948cdab6f33b17c979fe1a39aadccf846c447f..790f297fb966bdb923ecca0083c6745a92e34531 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -551,7 +551,7 @@ class Executor(object): if not persistable: logging.warn(""" - Detect that memory optimize or inplace is enabled, but the some variables in the fetch + Detect that build_strategy.memory_optimize = True, but the some variables in the fetch list is not persistable, you may get wrong fetched value, or an exeception may be thrown about cannot find variable of the fetch list. @@ -668,9 +668,8 @@ class Executor(object): return_numpy=return_numpy, use_program_cache=use_program_cache) else: - if fetch_list and program._is_data_parallel and program._program and ( - program._build_strategy.memory_optimize or - program._build_strategy.enable_inplace): + if fetch_list and program._is_data_parallel and program._program and \ + program._build_strategy._use_legacy_memory_optimize_strategy: self._check_fetch_vars_persistable(program._program, fetch_list) program._compile(scope, self.place) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 1c664d94c52a4abe9c2f91ee82eb7fba0f48e84f..57ec8cb84819a04d74d6901701fe5bdef9f13f91 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -256,4 +256,4 @@ endif() set_tests_properties(test_recordio_reader test_parallel_executor_test_while_train test_parallel_executor_mnist test_parallel_executor_seresnext test_parallel_executor_crf test_sync_batch_norm_op - PROPERTIES LABELS "RUN_TYPE=DIST") + test_buffer_shared_inplace_pass PROPERTIES LABELS "RUN_TYPE=DIST") diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index b1391749c0d74a6a2a3a111bbb1bdbf0307b688b..816f2b7b6b33d053c8972da2204194e94727c1c7 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -33,7 +33,7 @@ class TestParallelExecutorBase(unittest.TestCase): def check_network_convergence(cls, method, use_cuda=True, - memory_opt=True, + memory_opt=False, iter=50, batch_size=None, allow_op_delay=False, @@ -41,7 +41,7 @@ class TestParallelExecutorBase(unittest.TestCase): seed=None, use_parallel_executor=True, use_reduce=False, - use_ir_memory_optimize=True, + use_ir_memory_optimize=False, enable_inplace=True, fuse_elewise_add_act_ops=False, fuse_all_optimizer_ops=False, @@ -65,7 +65,8 @@ class TestParallelExecutorBase(unittest.TestCase): main.random_seed = seed loss = method(use_feed=feed_dict is not None) - loss.persistable = True + if memory_opt or use_ir_memory_optimize: + loss.persistable = True if optimizer: optimizer().minimize(loss) @@ -88,9 +89,8 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize build_strategy.fuse_all_optimizer_ops = fuse_all_optimizer_ops build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops - # python memory optimization is conflict with inplace pass. - # Use ir graph memory optimization after inplace pass is the correct way. - build_strategy.enable_inplace = False if memory_opt else enable_inplace + + build_strategy.enable_inplace = enable_inplace build_strategy.enable_sequential_execution = enable_sequential_execution if use_cuda and core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_buffer_shared_inplace_pass.py b/python/paddle/fluid/tests/unittests/test_buffer_shared_inplace_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..4ff5ad158eaf542bb93a25529066d05653a431dd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_buffer_shared_inplace_pass.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.framework import Parameter +import numpy as np +from simple_nets import simple_fc_net +import random +import unittest +import os + +batch_size = 32 + +feed_dict = { + 'image': np.random.random([batch_size, 784]).astype('float32'), + 'label': np.random.random_integers( + low=0, high=9, size=[batch_size, 1]).astype('int64') +} + + +class InplaceTestBase(unittest.TestCase): + def initParameter(self): + self.use_cuda = True + + def setUp(self): + self.initParameter() + if self.use_cuda and fluid.core.is_compiled_with_cuda(): + self.device_count = fluid.core.get_cuda_device_count() + else: + self.device_count = 4 + + assert batch_size % self.device_count == 0 + + def build_program_and_scope(self): + self.place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() + startup_program = fluid.Program() + main_program = fluid.Program() + startup_program.random_seed = 1 + main_program.random_seed = 1 + + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.unique_name.guard(): + loss = simple_fc_net() + adam = fluid.optimizer.Adam(learning_rate=1e-3) + adam.minimize(loss) + + with fluid.scope_guard(scope): + exe = fluid.Executor( + fluid.CUDAPlace(0) + if self.use_cuda else fluid.CPUPlace()) + exe.run(startup_program) + + return main_program, scope, exe, loss + + def is_invalid_test(self): + return self.use_cuda and not fluid.core.is_compiled_with_cuda() + + def get_all_vars(self, program): + all_vars = program.global_block().vars + all_vars_name = [] + for name, var in all_vars.items(): + if 0 not in var.shape and not var.persistable: + all_vars_name.append(name) + + return all_vars_name + + def test_single_card_fetch_var(self): + if self.is_invalid_test(): + return + + prog1, scope1, exe, loss1 = self.build_program_and_scope() + prog2, scope2, _, loss2 = self.build_program_and_scope() + prog3, scope3, _, loss3 = self.build_program_and_scope() + + build_strategy2 = fluid.BuildStrategy() + build_strategy2.memory_optimize = False + build_strategy2.enable_inplace = True + + compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel( + loss_name=loss2.name, + build_strategy=build_strategy2, + places=self.place) + + build_strategy3 = fluid.BuildStrategy() + build_strategy3.memory_optimize = False + build_strategy3.enable_inplace = False + compiled_prog3 = fluid.CompiledProgram(prog3).with_data_parallel( + loss_name=loss2.name, + build_strategy=build_strategy3, + places=self.place) + + all_vars_name = self.get_all_vars(prog1) + repeated_var_names = all_vars_name * 4 + random.shuffle(repeated_var_names) # add some random + + for fetch_var in repeated_var_names: + for _ in range(4): + with fluid.scope_guard(scope1): + fetch_val1, = exe.run(prog1, + feed=feed_dict, + fetch_list=[fetch_var]) + + with fluid.scope_guard(scope2): + fetch_val2, = exe.run(compiled_prog2, + feed=feed_dict, + fetch_list=[fetch_var]) + + with fluid.scope_guard(scope3): + fetch_val3, = exe.run(compiled_prog3, + feed=feed_dict, + fetch_list=[fetch_var]) + + self.assertTrue(np.array_equal(fetch_val1, fetch_val2)) + self.assertTrue(np.array_equal(fetch_val1, fetch_val3)) + + def test_multi_card_fetch_var(self): + if self.is_invalid_test(): + return + + prog1, scope1, exe, loss1 = self.build_program_and_scope() + prog2, scope2, _, loss2 = self.build_program_and_scope() + + build_strategy1 = fluid.BuildStrategy() + build_strategy1.memory_optimize = False + build_strategy1.enable_inplace = True + + build_strategy2 = fluid.BuildStrategy() + build_strategy2.memory_optimize = False + build_strategy2.enable_inplace = False + + if self.use_cuda: + places = fluid.cuda_places() + else: + places = fluid.cpu_places(self.device_count) + + compiled_prog1 = fluid.CompiledProgram(prog1).with_data_parallel( + loss_name=loss1.name, build_strategy=build_strategy1, places=places) + compiled_prog2 = fluid.CompiledProgram(prog2).with_data_parallel( + loss_name=loss2.name, build_strategy=build_strategy2, places=places) + + repeated_var_names = self.get_all_vars(prog1) * 4 + random.shuffle(repeated_var_names) # add some random + + for fetch_var in repeated_var_names: + for _ in range(4): + with fluid.scope_guard(scope1): + fetch_val1, = exe.run(compiled_prog1, + feed=feed_dict, + fetch_list=[fetch_var]) + + with fluid.scope_guard(scope2): + fetch_val2, = exe.run(compiled_prog2, + feed=feed_dict, + fetch_list=[fetch_var]) + + self.assertTrue(np.array_equal(fetch_val1, fetch_val2)) + + +class CPUInplaceTest(InplaceTestBase): + def initParameter(self): + self.use_cuda = False + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py b/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py index d4e514fa24c5efe6c0253ce3689f87dea4566f8d..873bd61d40bc3df6448a22cdd00211f7815eb985 100644 --- a/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py +++ b/python/paddle/fluid/tests/unittests/test_inplace_softmax_with_cross_entropy.py @@ -61,6 +61,8 @@ class TestSoftmaxWithXe(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.enable_inplace = inplace + if inplace: + build_strategy._use_legacy_memory_optimize_strategy = True prog = fluid.CompiledProgram(fluid.default_main_program( )).with_data_parallel( build_strategy=build_strategy, places=place)