未验证 提交 7f8bc49d 编写于 作者: G guru4elephant 提交者: GitHub

polish_executor_and_add_ctx_cache (#17536)

* polish_executor_and_add_ctx_cache
上级 7ae461eb
...@@ -244,6 +244,25 @@ static bool has_fetch_operators( ...@@ -244,6 +244,25 @@ static bool has_fetch_operators(
return fetch_count > 0; return fetch_count > 0;
} }
std::unique_ptr<ExecutorPrepareContext> Executor::PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
std::unique_ptr<ExecutorPrepareContext> ctx;
ctx.reset(new ExecutorPrepareContext(program, block_id));
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
#ifdef PADDLE_WITH_NGRAPH
if (FLAGS_use_ngraph) {
paddle::operators::NgraphEngine::FuseNgraphOps(
ctx->prog_.Block(ctx->block_id_), &ctx->ops_);
}
#endif
ctx->PrepareUnusedVars(skip_ref_cnt_vars, force_disable_gc);
return ctx;
}
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets, std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets, std::map<std::string, LoDTensor*>* fetch_targets,
...@@ -368,6 +387,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -368,6 +387,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
bool keep_kids) { bool keep_kids) {
platform::RecordBlock b(kProgramId);
PADDLE_ENFORCE_NOT_NULL(scope); PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
...@@ -407,7 +427,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -407,7 +427,6 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
if (gc) { if (gc) {
DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get()); DeleteUnusedTensors(*local_scope, op.get(), ctx->unused_vars_, gc.get());
} }
......
...@@ -83,6 +83,21 @@ class Executor { ...@@ -83,6 +83,21 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
std::unique_ptr<ExecutorPrepareContext> PrepareCtxCache(
const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id, const ProgramDesc& program, int block_id,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
...@@ -101,15 +116,6 @@ class Executor { ...@@ -101,15 +116,6 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true, bool keep_kids = false); bool create_vars = true, bool keep_kids = false);
// This API is very slow.
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromDataset(const ProgramDesc& main_program, Scope* scope, void RunFromDataset(const ProgramDesc& main_program, Scope* scope,
......
...@@ -24,7 +24,7 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) { ...@@ -24,7 +24,7 @@ void HogwildWorker::Initialize(const TrainerDesc& desc) {
fetch_config_ = desc.fetch_config(); fetch_config_ = desc.fetch_config();
param_ = desc.hogwild_param(); param_ = desc.hogwild_param();
skip_ops_.resize(param_.skip_ops_size()); skip_ops_.resize(param_.skip_ops_size());
for (size_t i = 0; i < param_.skip_ops_size(); ++i) { for (int i = 0; i < param_.skip_ops_size(); ++i) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
use_cvm_ = desc.use_cvm(); use_cvm_ = desc.use_cvm();
......
...@@ -1032,10 +1032,28 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1032,10 +1032,28 @@ All parameter, weight, gradient are variables in Paddle.
[](const OperatorBase &op) { return op.OutputVars(false); }) [](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU); .def("support_gpu", &OperatorBase::SupportGPU);
py::class_<framework::ExecutorPrepareContext>(m, "ExecutorPrepareContext")
.def(py::init<const ProgramDesc &, size_t>());
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>()) .def(py::init<const platform::Place &>())
.def("close", &Executor::Close) .def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset) .def("run_from_dataset", &Executor::RunFromDataset,
py::call_guard<py::gil_scoped_release>())
.def("run_prepared_ctx",
[](Executor &self, ExecutorPrepareContext *ctx, Scope *scope,
std::map<std::string, const LoDTensor *> *feed_targets,
std::map<std::string, LoDTensor *> *fetch_targets,
bool create_local_scope = true, bool create_vars = true,
const std::string &feed_holder_name = "feed",
const std::string &fetch_holder_name = "fetch") {
pybind11::gil_scoped_release release;
self.RunPreparedContext(ctx, scope, feed_targets, fetch_targets,
create_local_scope, create_vars,
feed_holder_name, fetch_holder_name);
})
.def("prepare_ctx_cache", &Executor::PrepareCtxCache,
py::call_guard<py::gil_scoped_release>())
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars, int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) { const std::vector<std::string> &fetch_vars) {
......
...@@ -247,6 +247,10 @@ def _to_name_str(var): ...@@ -247,6 +247,10 @@ def _to_name_str(var):
raise TypeError(str(var) + " should be Variable or str") raise TypeError(str(var) + " should be Variable or str")
def _get_strong_program_cache_key(program, feed, fetch_list):
return str(id(program)) + _get_program_cache_key(feed, fetch_list)
def _get_program_cache_key(feed, fetch_list): def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys()) feed_var_names = list(feed.keys())
fetch_var_names = list(map(_to_name_str, fetch_list)) fetch_var_names = list(map(_to_name_str, fetch_list))
...@@ -356,17 +360,24 @@ class Executor(object): ...@@ -356,17 +360,24 @@ class Executor(object):
def __init__(self, place): def __init__(self, place):
self.place = place self.place = place
self.program_caches = dict() self.program_caches = dict()
self.ctx_caches = dict()
p = core.Place() p = core.Place()
p.set_place(self.place) p.set_place(self.place)
self._default_executor = core.Executor(p) self._default_executor = core.Executor(p)
self._closed = False self._closed = False
def _get_ctx_cache(self, program_cache_key):
return self.ctx_caches.get(program_cache_key, None)
def _get_program_cache(self, program_cache_key): def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None) return self.program_caches.get(program_cache_key, None)
def _add_program_cache(self, program_cache_key, program): def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program self.program_caches[program_cache_key] = program
def _add_ctx_cache(self, ctx_cache_key, ctx):
self.ctx_caches[ctx_cache_key] = ctx
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name): fetch_var_name):
tmp_program = program.clone() tmp_program = program.clone()
...@@ -645,6 +656,7 @@ class Executor(object): ...@@ -645,6 +656,7 @@ class Executor(object):
# performance. # performance.
# TODO(panyx0718): executor should be able to run graph. # TODO(panyx0718): executor should be able to run graph.
assert program._program, "CompiledProgram is compiled from graph, can only run with_data_parallel." assert program._program, "CompiledProgram is compiled from graph, can only run with_data_parallel."
# use_program_cache is not valid with CompiledProgram
return self._run( return self._run(
program._program, program._program,
self._default_executor, self._default_executor,
...@@ -654,7 +666,7 @@ class Executor(object): ...@@ -654,7 +666,7 @@ class Executor(object):
fetch_var_name=fetch_var_name, fetch_var_name=fetch_var_name,
scope=scope, scope=scope,
return_numpy=return_numpy, return_numpy=return_numpy,
use_program_cache=use_program_cache) use_program_cache=False)
def _run(self, program, exe, feed, fetch_list, feed_var_name, def _run(self, program, exe, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache): fetch_var_name, scope, return_numpy, use_program_cache):
...@@ -677,9 +689,10 @@ class Executor(object): ...@@ -677,9 +689,10 @@ class Executor(object):
"Executor requires Program as its Parameter. But you passed in %s" "Executor requires Program as its Parameter. But you passed in %s"
% (type(program))) % (type(program)))
cache_key = _get_program_cache_key(feed, fetch_list) cache_key = _get_strong_program_cache_key(program, feed, fetch_list)
if use_program_cache: if use_program_cache:
cached_program = self._get_program_cache(cache_key) cached_program = self._get_program_cache(cache_key)
cached_ctx = self._get_ctx_cache(cache_key)
if cached_program is None: if cached_program is None:
cached_program = self._add_feed_fetch_ops( cached_program = self._add_feed_fetch_ops(
program=program, program=program,
...@@ -688,7 +701,11 @@ class Executor(object): ...@@ -688,7 +701,11 @@ class Executor(object):
feed_var_name=feed_var_name, feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program) self._add_program_cache(cache_key, cached_program)
cached_ctx = self._default_executor.prepare_ctx_cache(
cached_program.desc, 0, fetch_list, False)
self._add_ctx_cache(cache_key, cached_ctx)
program = cached_program program = cached_program
ctx = cached_ctx
else: else:
self.program_caches.pop(cache_key, None) self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops( program = self._add_feed_fetch_ops(
...@@ -699,7 +716,10 @@ class Executor(object): ...@@ -699,7 +716,10 @@ class Executor(object):
fetch_var_name=fetch_var_name) fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope) self._feed_data(program, feed, feed_var_name, scope)
if not use_program_cache:
exe.run(program.desc, scope, 0, True, True, fetch_var_name) exe.run(program.desc, scope, 0, True, True, fetch_var_name)
else:
exe.run_prepared_ctx(ctx, scope, True, True, False)
outs = self._fetch_data(fetch_list, fetch_var_name, scope) outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册