From 387bac46b5e4d95e2888773975d1b6c3a906a588 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Fri, 7 Dec 2018 03:09:43 +0000 Subject: [PATCH] refine code test=develop --- .../framework/details/eager_deletion_pass.cc | 10 +- .../fluid/framework/details/op_graph_view.cc | 2 + .../framework/details/reference_count_pass.cc | 14 +- .../details/reference_count_pass_helper.h | 10 +- .../scope_buffered_ssa_graph_executor.cc | 8 +- .../scope_buffered_ssa_graph_executor.h | 2 +- paddle/fluid/framework/executor.cc | 14 +- paddle/fluid/framework/executor.h | 6 +- paddle/fluid/framework/parallel_executor.cc | 153 ++++++++++-------- .../fluid/operators/controlflow/while_op.cc | 10 +- 10 files changed, 122 insertions(+), 107 deletions(-) diff --git a/paddle/fluid/framework/details/eager_deletion_pass.cc b/paddle/fluid/framework/details/eager_deletion_pass.cc index 3a1b37e5339..85991c71e65 100644 --- a/paddle/fluid/framework/details/eager_deletion_pass.cc +++ b/paddle/fluid/framework/details/eager_deletion_pass.cc @@ -31,10 +31,11 @@ std::unique_ptr EagerDeletionPass::ApplyImpl( const auto &vars = graph->Get(kGraphVars); auto &ref_cnts = - Get>(kCurReferenceCount); + Get>(kRuntimeReferenceCount); const auto &last_live_ops = Get>(kLastLiveOpsOfVars); - auto &gcs = Get(kGarbageCollector); + auto &gcs = Get(kGarbageCollector); + const auto &places = Get>(kAllPlaces); ref_cnts = std::vector(vars.size()); @@ -58,7 +59,7 @@ std::unique_ptr EagerDeletionPass::ApplyImpl( graph->CreateEmptyNode("eager_deletion", ir::Node::Type::kOperation); auto *eager_deletion_op = new EagerDeletionOpHandle( eager_deletion_node, op->GetScope(), op->GetPlace(), - std::move(var_names), gcs[op->GetScopeIdx()].get(), + std::move(var_names), gcs.at(places[op->GetScopeIdx()]).get(), &(ref_cnts[op->GetScopeIdx()])); auto it = std::find_if( @@ -90,6 +91,7 @@ std::unique_ptr EagerDeletionPass::ApplyImpl( REGISTER_PASS(eager_deletion_pass, paddle::framework::details::EagerDeletionPass) - .RequirePassAttr(paddle::framework::details::kCurReferenceCount) + .RequirePassAttr(paddle::framework::details::kRuntimeReferenceCount) .RequirePassAttr(paddle::framework::details::kLastLiveOpsOfVars) + .RequirePassAttr(paddle::framework::details::kAllPlaces) .RequirePassAttr(paddle::framework::details::kGarbageCollector); diff --git a/paddle/fluid/framework/details/op_graph_view.cc b/paddle/fluid/framework/details/op_graph_view.cc index 4838c4198ff..b6b5ad42c46 100644 --- a/paddle/fluid/framework/details/op_graph_view.cc +++ b/paddle/fluid/framework/details/op_graph_view.cc @@ -23,6 +23,8 @@ namespace details { OpGraphView::OpGraphView(const std::vector &ops) { Build(ops); } void OpGraphView::Build(const std::vector &ops) { + preceding_ops_.clear(); + pending_ops_.clear(); for (auto &op : ops) { preceding_ops_[op]; pending_ops_[op]; diff --git a/paddle/fluid/framework/details/reference_count_pass.cc b/paddle/fluid/framework/details/reference_count_pass.cc index 0c096e09800..f2c9dfb5248 100644 --- a/paddle/fluid/framework/details/reference_count_pass.cc +++ b/paddle/fluid/framework/details/reference_count_pass.cc @@ -29,22 +29,22 @@ namespace paddle { namespace framework { namespace details { -class OpConnectionDetector { +class OpRelationDetector { public: enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; - explicit OpConnectionDetector(const std::vector &all_ops) + explicit OpRelationDetector(const std::vector &all_ops) : graph_(all_ops) {} template - OpSet MaxNoDepOps(const OpSet &op_set) { - if (op_set.size() <= 1) return op_set; + OpSet MaxNoDepOps(const OpSet &op_set) const { using KeyType = typename OpSet::key_type; static_assert( std::is_base_of::type>::value, - "Key type of OpSet must be or derived of OpHandleBase"); + "Key type of OpSet must be OpHandleBase, or derived of OpHandleBase"); + if (op_set.size() <= 1) return op_set; std::vector ops(op_set.begin(), op_set.end()); OpSet ret; auto rels = GetRelations(ops); @@ -59,7 +59,7 @@ class OpConnectionDetector { private: std::vector> GetRelations( - const std::vector ops) { + const std::vector ops) const { std::unordered_map op_to_idx; for (size_t i = 0; i < ops.size(); ++i) { PADDLE_ENFORCE(graph_.HasOp(ops[i]), "Op does not exist in graph"); @@ -144,7 +144,7 @@ std::unique_ptr ReferenceCountPass::ApplyImpl( last_live_ops_of_vars = std::vector(vars.size()); ref_cnts = std::vector(vars.size()); - OpConnectionDetector detector(ir::FilterByNodeWrapper(*graph)); + OpRelationDetector detector(ir::FilterByNodeWrapper(*graph)); for (size_t i = 0; i < vars.size(); ++i) { for (auto &name_var_pair : vars[i]) { diff --git a/paddle/fluid/framework/details/reference_count_pass_helper.h b/paddle/fluid/framework/details/reference_count_pass_helper.h index 77846f7bdfc..eb534f97015 100644 --- a/paddle/fluid/framework/details/reference_count_pass_helper.h +++ b/paddle/fluid/framework/details/reference_count_pass_helper.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -33,12 +34,13 @@ using ReferenceCountMap = std::unordered_map; using AtomicReferenceCountMap = std::unordered_map>; -using GarbageCollectorList = - std::vector>>; +using GarbageCollectorMap = + std::map>>; -const char kGlobalReferenceCount[] = "reference_count"; -const char kCurReferenceCount[] = "current_reference_count"; +const char kGlobalReferenceCount[] = "global_reference_count"; +const char kRuntimeReferenceCount[] = "runtime_reference_count"; const char kGarbageCollector[] = "garbage_collector"; +const char kAllPlaces[] = "all_places"; using LastLiveOpsOfVars = std::unordered_map>; diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index da5e277f276..b8775fc3291 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -32,15 +32,15 @@ ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( var_infos_(std::move(var_infos)), places_(std::move(places)) { if (Graph().Has(details::kGarbageCollector)) { - gc_ = &(Graph().Get(details::kGarbageCollector)); + gc_ = &(Graph().Get(details::kGarbageCollector)); } } void ScopeBufferedSSAGraphExecutor::WaitAllGarbageCollectors() { if (gc_) { - for (auto &gc : *gc_) { - gc->Wait(); - gc->Reset(); + for (auto &gc_pair : *gc_) { + gc_pair.second->Wait(); + gc_pair.second->Reset(); } } } diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h index 4d52183a205..6086a219e04 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h @@ -60,7 +60,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { std::vector var_infos_; std::vector places_; - GarbageCollectorList* gc_{nullptr}; + GarbageCollectorMap* gc_{nullptr}; }; } // namespace details } // namespace framework diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index f443c2d8cf6..04425a59830 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -56,13 +56,7 @@ static std::unordered_map GetNonPersistableReferenceCounts( type != proto::VarType::LOD_TENSOR_ARRAY) { continue; } - - auto it = ref_cnts.find(name); - if (it != ref_cnts.end()) { - ++it->second; - } else { - ref_cnts[name] = 1; - } + ++ref_cnts[name]; } } }; @@ -79,8 +73,8 @@ ExecutorPrepareContext::ExecutorPrepareContext( const std::vector& skip_ref_cnt_vars) : prog_(prog), block_id_(block_id) { if (GetEagerDeletionThreshold() >= 0) { - ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), - skip_ref_cnt_vars); + global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), + skip_ref_cnt_vars); } } @@ -443,7 +437,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, if (gc) { DeleteUnusedTensors(*local_scope, op.get(), gc.get(), - &(ctx->cur_ref_cnts_)); + &(ctx->runtime_ref_cnts_)); } } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 412ebd19045..5a040ac6415 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -34,14 +34,14 @@ struct ExecutorPrepareContext { ~ExecutorPrepareContext(); - void ResetReferenceCount() { cur_ref_cnts_ = ref_cnts_; } + void ResetReferenceCount() { runtime_ref_cnts_ = global_ref_cnts_; } const framework::ProgramDesc& prog_; size_t block_id_; std::vector> ops_; - std::unordered_map ref_cnts_; - std::unordered_map cur_ref_cnts_; + std::unordered_map global_ref_cnts_; + std::unordered_map runtime_ref_cnts_; }; class Executor { diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 3d466e44a19..dfd031f1195 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -51,11 +51,22 @@ class ParallelExecutorPrivate { } } - void ResetRuntimeReferenceCount() { - for (size_t i = 0; i < rt_ref_cnts_.size(); ++i) { - for (auto &pair : rt_ref_cnts_[i]) { - rt_cur_ref_cnts_[i][pair.first] = pair.second; + std::unique_ptr PrepareGCAndRefCnts( + std::unique_ptr graph, size_t max_memory_size); + + inline bool HasGarbageCollectors() const { return !gcs_.empty(); } + + void ResetRuntimeReferenceCount(const std::vector &fetch_tensors, + const std::string &fetched_var_name) { + for (size_t i = 0; i < runtime_ref_cnts_.size(); ++i) { + for (auto &pair : global_ref_cnts_[i]) { + runtime_ref_cnts_[i][pair.first] = pair.second; + } + + for (auto &fetch_name : fetch_tensors) { + runtime_ref_cnts_[i].erase(fetch_name); } + runtime_ref_cnts_[i].erase(fetched_var_name); } } @@ -71,14 +82,75 @@ class ParallelExecutorPrivate { bool use_cuda_; bool use_all_reduce_; - // rt_ref_cnts_ is only initialized when ParallelExecutor constructs, and then - // keeps unchanged - // Before each iteration, rt_cur_ref_cnts_ is reset to ref_cnts_ - std::vector rt_ref_cnts_; - std::vector rt_cur_ref_cnts_; - details::GarbageCollectorList gcs_; + // global_ref_cnts_ is only initialized when ParallelExecutor constructs, and + // then keeps unchanged + // Before each iteration, runtime_ref_cnts_ is reset to global_ref_cnts_ + std::vector global_ref_cnts_; + std::vector runtime_ref_cnts_; + details::GarbageCollectorMap gcs_; }; +std::unique_ptr ParallelExecutorPrivate::PrepareGCAndRefCnts( + std::unique_ptr graph, size_t max_memory_size) { + for (size_t i = 0; i < places_.size(); ++i) { + auto &place = places_[i]; + if (gcs_.count(place) > 0) { + continue; + } +#ifdef PADDLE_WITH_CUDA + GarbageCollector *gc = nullptr; + if (platform::is_gpu_place(place)) { + if (IsFastEagerDeletionModeEnabled()) { + gc = new UnsafeFastGPUGarbageCollector( + boost::get(place), max_memory_size); + } else { + gc = new StreamGarbageCollector( + boost::get(place), max_memory_size); + } + VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; + } else if (platform::is_cpu_place(place)) { +#endif + gc = new CPUGarbageCollector( + boost::get(place), max_memory_size); + VLOG(10) << "Created GarbageCollector at " << place; +#ifdef PADDLE_WITH_CUDA + } +#endif + + if (gc) { + gcs_[place] = std::unique_ptr>(gc); + } + } + + if (gcs_.empty()) { + std::vector last_live_ops_of_vars; + + auto ref_cnt_pass = + ir::PassRegistry::Instance().Get("reference_count_pass"); + ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, + &global_ref_cnts_); + ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + graph = ref_cnt_pass->Apply(std::move(graph)); + VLOG(10) << "ReferenceCountPass Applied"; + + auto eager_deletion_pass = + ir::PassRegistry::Instance().Get("eager_deletion_pass"); + eager_deletion_pass->SetNotOwned(details::kRuntimeReferenceCount, + &runtime_ref_cnts_); + eager_deletion_pass->SetNotOwned(details::kGarbageCollector, &gcs_); + eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, + &last_live_ops_of_vars); + eager_deletion_pass->SetNotOwned(details::kAllPlaces, &places_); + graph = eager_deletion_pass->Apply(std::move(graph)); + VLOG(10) << "EagerDeletionPass Applied"; + + graph->SetNotOwned(details::kGarbageCollector, &gcs_); + } + + return graph; +} + std::vector &ParallelExecutor::GetLocalScopes() { return member_->local_scopes_; } @@ -153,54 +225,8 @@ ParallelExecutor::ParallelExecutor( auto max_memory_size = GetEagerDeletionThreshold(); if (max_memory_size >= 0) { - size_t place_num = member_->places_.size(); - for (size_t i = 0; i < place_num; ++i) { - auto &place = member_->places_[i]; -#ifdef PADDLE_WITH_CUDA - if (platform::is_gpu_place(place)) { - if (IsFastEagerDeletionModeEnabled()) { - member_->gcs_.emplace_back(new UnsafeFastGPUGarbageCollector( - boost::get(place), max_memory_size)); - } else { - member_->gcs_.emplace_back(new StreamGarbageCollector( - boost::get(place), max_memory_size)); - } - VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; - } else if (platform::is_cpu_place(place)) { -#endif - member_->gcs_.emplace_back(new CPUGarbageCollector( - boost::get(place), max_memory_size)); - VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; -#ifdef PADDLE_WITH_CUDA - } -#endif - } - } - - if (!member_->gcs_.empty()) { - std::vector last_live_ops_of_vars; - - auto ref_cnt_pass = - ir::PassRegistry::Instance().Get("reference_count_pass"); - ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, - &(member_->rt_ref_cnts_)); - ref_cnt_pass->SetNotOwned(details::kLastLiveOpsOfVars, - &last_live_ops_of_vars); - graph = ref_cnt_pass->Apply(std::move(graph)); - VLOG(10) << "ReferenceCountPass Applied"; - - auto eager_deletion_pass = - ir::PassRegistry::Instance().Get("eager_deletion_pass"); - eager_deletion_pass->SetNotOwned(details::kCurReferenceCount, - &(member_->rt_cur_ref_cnts_)); - eager_deletion_pass->SetNotOwned(details::kGarbageCollector, - &(member_->gcs_)); - eager_deletion_pass->SetNotOwned(details::kLastLiveOpsOfVars, - &last_live_ops_of_vars); - graph = eager_deletion_pass->Apply(std::move(graph)); - VLOG(10) << "EagerDeletionPass Applied"; - - graph->SetNotOwned(details::kGarbageCollector, &(member_->gcs_)); + graph = member_->PrepareGCAndRefCnts(std::move(graph), + static_cast(max_memory_size)); } // Step 3. Create vars in each scope. Passes may also create new vars. @@ -316,15 +342,8 @@ void ParallelExecutor::BCastParamsToDevices( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { platform::RecordBlock b(0); - if (!member_->gcs_.empty()) { - member_->ResetRuntimeReferenceCount(); - size_t n = member_->rt_ref_cnts_.size(); - for (size_t i = 0; i < n; ++i) { - for (auto &fetch_name : fetch_tensors) { - member_->rt_cur_ref_cnts_[i].erase(fetch_name); - } - member_->rt_cur_ref_cnts_[i].erase(fetched_var_name); - } + if (member_->HasGarbageCollectors()) { + member_->ResetRuntimeReferenceCount(fetch_tensors, fetched_var_name); } auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index da7cad82d8d..06920a47ee0 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -74,9 +74,7 @@ class WhileOp : public framework::OperatorBase { bool is_test = Attr("is_test"); auto &skip_vars = Attr>(kSkipEagerDeletionVars); - if (framework::GetEagerDeletionThreshold() >= 0) { - VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); - } + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars); while (cond.data()[0]) { @@ -144,9 +142,7 @@ class WhileGradOp : public framework::OperatorBase { auto *program = block->Program(); auto &skip_vars = Attr>(kSkipEagerDeletionVars); - if (framework::GetEagerDeletionThreshold() >= 0) { - VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); - } + VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto *step_scopes = @@ -369,7 +365,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { // while operator could be renamed. while_grad->SetAttr("original_output_grad", output_grads_list); - /* The followi_ng codes are used in eager deletion mode */ + /* The following codes are used in eager deletion mode */ std::unordered_set bwd_skip_vars; if (framework::GetEagerDeletionThreshold() >= 0) { std::unordered_set fwd_skip_vars; -- GitLab