未验证 提交 37a272e6 编写于 作者: Q Qiao Longfei 提交者: GitHub

add executor.prepare (#9022)

optimize executor.run
上级 30b70323
...@@ -14,12 +14,8 @@ limitations under the License. */ ...@@ -14,12 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -40,14 +36,13 @@ namespace { ...@@ -40,14 +36,13 @@ namespace {
int kProgramId = -1; int kProgramId = -1;
} // namespace } // namespace
struct ExecutorPrepareContext { ExecutorPrepareContext::ExecutorPrepareContext(
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {} : prog_(prog), block_id_(block_id) {}
const framework::ProgramDesc& prog_; ExecutorPrepareContext::~ExecutorPrepareContext() {
size_t block_id_; VLOG(5) << "destroy ExecutorPrepareContext";
std::vector<std::unique_ptr<OperatorBase>> ops_; }
};
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
...@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
auto* ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
delete ctx;
} }
// Check whether the block already has feed operators and feed_holder. // Check whether the block already has feed operators and feed_holder.
...@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
int block_id) { const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id); auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
return ctx; return std::unique_ptr<ExecutorPrepareContext>(ctx);
} }
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
......
...@@ -22,7 +22,16 @@ limitations under the License. */ ...@@ -22,7 +22,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ExecutorPrepareContext;
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
class Executor { class Executor {
public: public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed // TODO(dzhwinter) : Do not rely on this function, it will be removed
...@@ -47,8 +56,8 @@ class Executor { ...@@ -47,8 +56,8 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program, static std::unique_ptr<ExecutorPrepareContext> Prepare(
int block_id); const ProgramDesc& program, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
......
...@@ -235,66 +235,17 @@ class Executor(object): ...@@ -235,66 +235,17 @@ class Executor(object):
tensor.set_lod(lod) tensor.set_lod(lod)
return tensor return tensor
def run(self, def _get_program_cache(self, program_cache_key):
program=None, return self.program_caches.get(program_cache_key, None)
feed=None,
fetch_list=None,
feed_var_name='feed',
fetch_var_name='fetch',
scope=None,
return_numpy=True,
use_program_cache=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
"""
if feed is None:
feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
if not isinstance(program, Program): def _add_program_cache(self, program_cache_key, program):
raise TypeError() self.program_caches[program_cache_key] = program
if scope is None:
scope = global_scope()
program_cache = None
program_cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache: def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
# find program cache by cache_key fetch_var_name):
program_cache = self.program_caches.get(program_cache_key, None) tmp_program = program.clone()
# TODO(qiao): Should check program_cache and program are exactly the same.
else:
self.program_caches.pop(program_cache_key, None)
if program_cache is None:
program_cache = program.clone()
if use_program_cache:
self.program_caches[program_cache_key] = program_cache
global_block = program_cache.global_block() global_block = tmp_program.global_block()
if feed_var_name in global_block.vars: if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name) feed_var = global_block.var(feed_var_name)
...@@ -323,8 +274,7 @@ class Executor(object): ...@@ -323,8 +274,7 @@ class Executor(object):
attrs={'col': i}) attrs={'col': i})
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), ( assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) "Wrong type for fetch_list[%s]: %s" % (i, type(var)))
...@@ -334,8 +284,11 @@ class Executor(object): ...@@ -334,8 +284,11 @@ class Executor(object):
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
return tmp_program
def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework # feed var to framework
for op in program_cache.global_block().ops: for op in program.global_block().ops:
if op.desc.type() == 'feed': if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0] feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name] cur_feed = feed[feed_target_name]
...@@ -346,11 +299,79 @@ class Executor(object): ...@@ -346,11 +299,79 @@ class Executor(object):
else: else:
break break
self.executor.run(program_cache.desc, scope, 0, True, True) def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
] ]
return outs
def run(self,
program=None,
feed=None,
fetch_list=None,
feed_var_name='feed',
fetch_var_name='fetch',
scope=None,
return_numpy=True,
use_program_cache=False):
""" Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run. Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
:param program: the program that need to run, if not provied, then default_main_program will be used.
:param feed: feed variable map, e.g. {"image": ImageData, "label": LableData}
:param fetch_list: a list of variable or variable names that user want to get, run will return them according
to this list.
:param feed_var_name: the name for the input variable of feed Operator.
:param fetch_var_name: the name for the output variable of feed Operator.
:param scope: the scope used to run this program, you can switch it to different scope. default is global_scope
:param return_numpy: if convert the fetched tensor to numpy
:param use_program_cache: set use_program_cache to true if program not changed compare to the last step.
:return: result according to fetch_list.
"""
if feed is None:
feed = {}
if not isinstance(feed, dict):
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
if not isinstance(program, Program):
raise TypeError()
if scope is None:
scope = global_scope()
cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
cached_program = self._get_program_cache(cache_key)
if cached_program is None:
cached_program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
program = cached_program
else:
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
return outs return outs
...@@ -16,7 +16,6 @@ import unittest ...@@ -16,7 +16,6 @@ import unittest
import numpy import numpy
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data from paddle.fluid.layers import mul, data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册