diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 7155d5ef2febc20aaa684c04a7a59f781857c9e5..a688115b11af164319458207b19e915e8eaf676a 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -14,12 +14,8 @@ limitations under the License. */ #include "paddle/fluid/framework/executor.h" -#include - -#include "gflags/gflags.h" #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/feed_fetch_method.h" -#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" @@ -40,14 +36,13 @@ namespace { int kProgramId = -1; } // namespace -struct ExecutorPrepareContext { - ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) - : prog_(prog), block_id_(block_id) {} +ExecutorPrepareContext::ExecutorPrepareContext( + const framework::ProgramDesc& prog, size_t block_id) + : prog_(prog), block_id_(block_id) {} - const framework::ProgramDesc& prog_; - size_t block_id_; - std::vector> ops_; -}; +ExecutorPrepareContext::~ExecutorPrepareContext() { + VLOG(5) << "destroy ExecutorPrepareContext"; +} Executor::Executor(const platform::Place& place) : place_(place) {} @@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { platform::RecordBlock b(block_id); - auto* ctx = Prepare(pdesc, block_id); - RunPreparedContext(ctx, scope, create_local_scope, create_vars); - delete ctx; + auto ctx = Prepare(pdesc, block_id); + RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); } // Check whether the block already has feed operators and feed_holder. @@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } -ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, - int block_id) { +std::unique_ptr Executor::Prepare( + const ProgramDesc& program, int block_id) { auto* ctx = new ExecutorPrepareContext(program, block_id); PADDLE_ENFORCE_LT(static_cast(block_id), program.Size()); auto& block = program.Block(block_id); for (auto& op_desc : block.AllOps()) { ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); } - return ctx; + return std::unique_ptr(ctx); } void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 28ce3315154cea45412984df4daf7385ce2cf572..fb29c70f1456eca7b46e779f737976f5f2da0682 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -22,7 +22,16 @@ limitations under the License. */ namespace paddle { namespace framework { -struct ExecutorPrepareContext; + +struct ExecutorPrepareContext { + ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id); + ~ExecutorPrepareContext(); + + const framework::ProgramDesc& prog_; + size_t block_id_; + std::vector> ops_; +}; + class Executor { public: // TODO(dzhwinter) : Do not rely on this function, it will be removed @@ -47,8 +56,8 @@ class Executor { const std::string& feed_holder_name = "feed", const std::string& fetch_holder_name = "fetch"); - static ExecutorPrepareContext* Prepare(const ProgramDesc& program, - int block_id); + static std::unique_ptr Prepare( + const ProgramDesc& program, int block_id); void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4490f2bf153f672464ec8bca2a44109c9fe0dd04..2612fb1ae41986ae0d5c6e942cc3accebcb00e19 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -235,6 +235,77 @@ class Executor(object): tensor.set_lod(lod) return tensor + def _get_program_cache(self, program_cache_key): + return self.program_caches.get(program_cache_key, None) + + def _add_program_cache(self, program_cache_key, program): + self.program_caches[program_cache_key] = program + + def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, + fetch_var_name): + tmp_program = program.clone() + + global_block = tmp_program.global_block() + + if feed_var_name in global_block.vars: + feed_var = global_block.var(feed_var_name) + else: + feed_var = global_block.create_var( + name=feed_var_name, + type=core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True) + + if fetch_var_name in global_block.vars: + fetch_var = global_block.var(fetch_var_name) + else: + fetch_var = global_block.create_var( + name=fetch_var_name, + type=core.VarDesc.VarType.FETCH_LIST, + persistable=True) + + # prepend feed operators + if not has_feed_operators(global_block, feed, feed_var_name): + for i, name in enumerate(feed): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + # append fetch_operators + if not has_fetch_operators(global_block, fetch_list, fetch_var_name): + for i, var in enumerate(fetch_list): + assert isinstance(var, Variable) or isinstance(var, str), ( + "Wrong type for fetch_list[%s]: %s" % (i, type(var))) + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + + return tmp_program + + def _feed_data(self, program, feed, feed_var_name, scope): + # feed var to framework + for op in program.global_block().ops: + if op.desc.type() == 'feed': + feed_target_name = op.desc.output('Out')[0] + cur_feed = feed[feed_target_name] + if not isinstance(cur_feed, core.LoDTensor): + cur_feed = self.aslodtensor(cur_feed) + idx = op.desc.attr('col') + core.set_feed_variable(scope, cur_feed, feed_var_name, idx) + else: + break + + def _fetch_data(self, fetch_list, fetch_var_name, scope): + outs = [ + core.get_fetch_variable(scope, fetch_var_name, i) + for i in xrange(len(fetch_list)) + ] + return outs + def run(self, program=None, feed=None, @@ -268,7 +339,6 @@ class Executor(object): raise TypeError("feed should be a map") if fetch_list is None: fetch_list = [] - if program is None: program = default_main_program() @@ -278,79 +348,30 @@ class Executor(object): if scope is None: scope = global_scope() - program_cache = None - program_cache_key = get_program_cache_key(feed, fetch_list) - + cache_key = get_program_cache_key(feed, fetch_list) if use_program_cache: - # find program cache by cache_key - program_cache = self.program_caches.get(program_cache_key, None) - # TODO(qiao): Should check program_cache and program are exactly the same. + cached_program = self._get_program_cache(cache_key) + if cached_program is None: + cached_program = self._add_feed_fetch_ops( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name) + self._add_program_cache(cache_key, cached_program) + program = cached_program else: - self.program_caches.pop(program_cache_key, None) - - if program_cache is None: - program_cache = program.clone() - - if use_program_cache: - self.program_caches[program_cache_key] = program_cache - - global_block = program_cache.global_block() - - if feed_var_name in global_block.vars: - feed_var = global_block.var(feed_var_name) - else: - feed_var = global_block.create_var( - name=feed_var_name, - type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) - - if fetch_var_name in global_block.vars: - fetch_var = global_block.var(fetch_var_name) - else: - fetch_var = global_block.create_var( - name=fetch_var_name, - type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) - - # prepend feed operators - if not has_feed_operators(global_block, feed, feed_var_name): - for i, name in enumerate(feed): - out = global_block.var(name) - global_block.prepend_op( - type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) - - # append fetch_operators - if not has_fetch_operators(global_block, fetch_list, - fetch_var_name): - for i, var in enumerate(fetch_list): - assert isinstance(var, Variable) or isinstance(var, str), ( - "Wrong type for fetch_list[%s]: %s" % (i, type(var))) - global_block.append_op( - type='fetch', - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) - - # feed var to framework - for op in program_cache.global_block().ops: - if op.desc.type() == 'feed': - feed_target_name = op.desc.output('Out')[0] - cur_feed = feed[feed_target_name] - if not isinstance(cur_feed, core.LoDTensor): - cur_feed = self.aslodtensor(cur_feed) - idx = op.desc.attr('col') - core.set_feed_variable(scope, cur_feed, feed_var_name, idx) - else: - break - - self.executor.run(program_cache.desc, scope, 0, True, True) - outs = [ - core.get_fetch_variable(scope, fetch_var_name, i) - for i in xrange(len(fetch_list)) - ] + self.program_caches.pop(cache_key, None) + program = self._add_feed_fetch_ops( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name) + + self._feed_data(program, feed, feed_var_name, scope) + self.executor.run(program.desc, scope, 0, True, True) + outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) return outs diff --git a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py index 4958bef3ef4d101f934a2776efc21efdd24a9a4d..e1272c1d6dd7131b55ecf33fa0de0fc78a3ac5a7 100644 --- a/python/paddle/fluid/tests/unittests/test_executor_and_mul.py +++ b/python/paddle/fluid/tests/unittests/test_executor_and_mul.py @@ -16,7 +16,6 @@ import unittest import numpy import paddle.fluid.core as core - from paddle.fluid.executor import Executor from paddle.fluid.layers import mul, data