From 732fa00eaf905e531d907493132ba948bf159639 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Fri, 8 Mar 2019 13:05:11 +0000 Subject: [PATCH] disable gc in recurrent_op currently test=develop --- paddle/fluid/framework/executor.cc | 38 ++++++++++++++++---------- paddle/fluid/framework/executor.h | 17 +++++++++--- paddle/fluid/operators/recurrent_op.cc | 8 ++++-- paddle/fluid/pybind/pybind.cc | 6 ++-- python/paddle/fluid/executor.py | 2 +- 5 files changed, 47 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7eef9ec50..f3869ceb6 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -80,11 +80,11 @@ static std::unordered_map GetNonPersistableReferenceCounts( ExecutorPrepareContext::ExecutorPrepareContext( const framework::ProgramDesc& prog, size_t block_id, - const std::vector& skip_ref_cnt_vars) - : prog_(prog), block_id_(block_id) { - if (GetEagerDeletionThreshold() >= 0) { - global_ref_cnts_ = GetNonPersistableReferenceCounts(prog.Block(block_id), - skip_ref_cnt_vars); + const std::vector& keep_vars, bool force_disable_gc) + : prog_(prog), block_id_(block_id), force_disable_gc_(force_disable_gc) { + if (GetEagerDeletionThreshold() >= 0 && !force_disable_gc_) { + global_ref_cnts_ = + GetNonPersistableReferenceCounts(prog.Block(block_id), keep_vars); } } @@ -189,13 +189,15 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, } void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, - bool create_local_scope, bool create_vars) { + bool create_local_scope, bool create_vars, + const std::vector& skip_ref_cnt_vars, + bool force_disable_gc) { platform::RecordBlock b(block_id); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); #ifdef PADDLE_WITH_NGRAPH if (FLAGS_use_ngraph) operators::NgraphEngine::EnableNgraph(pdesc); #endif - auto ctx = Prepare(pdesc, block_id); + auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); } @@ -362,9 +364,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, std::unique_ptr Executor::Prepare( const ProgramDesc& program, int block_id, - const std::vector& skip_ref_cnt_vars) { - std::unique_ptr ctx( - new ExecutorPrepareContext(program, block_id, skip_ref_cnt_vars)); + const std::vector& skip_ref_cnt_vars, bool force_disable_gc) { + std::unique_ptr ctx(new ExecutorPrepareContext( + program, block_id, skip_ref_cnt_vars, force_disable_gc)); PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); auto& block = program.Block(block_id); for (auto& op_desc : block.AllOps()) { @@ -375,7 +377,8 @@ std::unique_ptr Executor::Prepare( std::vector> Executor::Prepare( const ProgramDesc& program, const std::vector& block_ids, - const std::vector>& skip_ref_cnt_vars) { + const std::vector>& skip_ref_cnt_vars, + bool force_disable_gc) { PADDLE_ENFORCE( skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(), "skip_ref_cnt_vars should be either empty or equals to block number %d", @@ -385,9 +388,11 @@ std::vector> Executor::Prepare( for (auto& bid : block_ids) { ExecutorPrepareContext* ctx; if (skip_ref_cnt_vars.empty()) { - ctx = new ExecutorPrepareContext(program, bid); + ctx = new ExecutorPrepareContext(program, bid, std::vector(), + force_disable_gc); } else { - ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx]); + ctx = new ExecutorPrepareContext(program, bid, skip_ref_cnt_vars[idx], + force_disable_gc); } PADDLE_ENFORCE_LT(static_cast(bid), program.Size()); auto& block = program.Block(bid); @@ -414,7 +419,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, int64_t max_memory_size = GetEagerDeletionThreshold(); std::unique_ptr gc; - if (max_memory_size >= 0) { + // FIXME(zjl): recurrent_op is rather complex, we would + // disable gc forcely in recurrent_op + if (!ctx->force_disable_gc_ && max_memory_size >= 0) { ctx->ResetReferenceCount(); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(place_)) { @@ -432,7 +439,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, #ifdef PADDLE_WITH_CUDA } #endif - if (gc && keep_kids) { + // If gc is enabled and block size > 1 + if (gc && ctx->prog_.Size() > 1) { operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_, ctx->ops_); } diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 5a040ac64..65cb9e51a 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -15,7 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include +#include #include #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/op_info.h" @@ -30,7 +32,8 @@ namespace framework { struct ExecutorPrepareContext { ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id, const std::vector& skip_ref_cnt_vars = - std::vector()); + std::vector(), + bool force_disable_gc = false); ~ExecutorPrepareContext(); @@ -38,6 +41,7 @@ struct ExecutorPrepareContext { const framework::ProgramDesc& prog_; size_t block_id_; + bool force_disable_gc_; std::vector> ops_; std::unordered_map global_ref_cnts_; @@ -66,7 +70,10 @@ class Executor { * Scope */ void Run(const ProgramDesc& prog, Scope* scope, int block_id, - bool create_local_scope = true, bool create_vars = true); + bool create_local_scope = true, bool create_vars = true, + const std::vector& skip_ref_cnt_vars = + std::vector(), + bool force_disable_gc = false); // This API is very slow. void Run(const ProgramDesc& program, Scope* scope, @@ -79,12 +86,14 @@ class Executor { static std::unique_ptr Prepare( const ProgramDesc& program, int block_id, const std::vector& skip_ref_cnt_vars = - std::vector()); + std::vector(), + bool force_disable_gc = false); static std::vector> Prepare( const ProgramDesc& program, const std::vector& block_ids, const std::vector>& skip_ref_cnt_vars = - std::vector>()); + std::vector>(), + bool force_disable_gc = false); void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id); diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index a1e02a3fd..eb39b3119 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -270,7 +270,9 @@ class RecurrentOp : public RecurrentBase { // Every inputs are linked now, execute! executor.Run(*program, &cur_scope, block->ID(), - false /*create_local_scope*/); + false /*create_local_scope*/, true /*create_vars*/, + std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); // get device context from pool platform::DeviceContextPool &pool = @@ -385,7 +387,9 @@ class RecurrentGradOp : public RecurrentBase { VLOG(5) << "Recurrent memory linking finished "; // Run step block with cur_scope executor.Run(*program, &cur_scope, block->ID(), - false /*create_local_scope*/); + false /*create_local_scope*/, true /*create_vars*/, + std::vector() /*skip_ref_cnt_vars*/, + true /*force_disable_gc*/); VLOG(5) << "executor.Run finished "; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index cf59ff6d3..439d9aa83 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -876,9 +876,11 @@ All parameter, weight, gradient are variables in Paddle. .def(py::init()) .def("close", &Executor::Close) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, - int block_id, bool create_local_scope, bool create_vars) { + int block_id, bool create_local_scope, bool create_vars, + const std::vector &fetch_vars) { pybind11::gil_scoped_release release; - self.Run(prog, scope, block_id, create_local_scope, create_vars); + self.Run(prog, scope, block_id, create_local_scope, create_vars, + fetch_vars); }); m.def("init_gflags", framework::InitGflags); diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index dfa50e721..cc3c0dd68 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -590,7 +590,7 @@ class Executor(object): fetch_var_name=fetch_var_name) self._feed_data(program, feed, feed_var_name, scope) - exe.run(program.desc, scope, 0, True, True) + exe.run(program.desc, scope, 0, True, True, fetch_var_name) outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) -- GitLab