未验证 提交 37a272e6 编写于 作者: Q Qiao Longfei 提交者: GitHub

add executor.prepare (#9022)

optimize executor.run
上级 30b70323
...@@ -14,12 +14,8 @@ limitations under the License. */ ...@@ -14,12 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include <set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -40,14 +36,13 @@ namespace { ...@@ -40,14 +36,13 @@ namespace {
int kProgramId = -1; int kProgramId = -1;
} // namespace } // namespace
struct ExecutorPrepareContext { ExecutorPrepareContext::ExecutorPrepareContext(
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {} : prog_(prog), block_id_(block_id) {}
const framework::ProgramDesc& prog_; ExecutorPrepareContext::~ExecutorPrepareContext() {
size_t block_id_; VLOG(5) << "destroy ExecutorPrepareContext";
std::vector<std::unique_ptr<OperatorBase>> ops_; }
};
Executor::Executor(const platform::Place& place) : place_(place) {} Executor::Executor(const platform::Place& place) : place_(place) {}
...@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
auto* ctx = Prepare(pdesc, block_id); auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
delete ctx;
} }
// Check whether the block already has feed operators and feed_holder. // Check whether the block already has feed operators and feed_holder.
...@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
int block_id) { const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id); auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc)); ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
} }
return ctx; return std::unique_ptr<ExecutorPrepareContext>(ctx);
} }
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
......
...@@ -22,7 +22,16 @@ limitations under the License. */ ...@@ -22,7 +22,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct ExecutorPrepareContext;
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
class Executor { class Executor {
public: public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed // TODO(dzhwinter) : Do not rely on this function, it will be removed
...@@ -47,8 +56,8 @@ class Executor { ...@@ -47,8 +56,8 @@ class Executor {
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program, static std::unique_ptr<ExecutorPrepareContext> Prepare(
int block_id); const ProgramDesc& program, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
......
...@@ -235,6 +235,77 @@ class Executor(object): ...@@ -235,6 +235,77 @@ class Executor(object):
tensor.set_lod(lod) tensor.set_lod(lod)
return tensor return tensor
def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None)
def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework
for op in program.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
return outs
def run(self, def run(self,
program=None, program=None,
feed=None, feed=None,
...@@ -268,7 +339,6 @@ class Executor(object): ...@@ -268,7 +339,6 @@ class Executor(object):
raise TypeError("feed should be a map") raise TypeError("feed should be a map")
if fetch_list is None: if fetch_list is None:
fetch_list = [] fetch_list = []
if program is None: if program is None:
program = default_main_program() program = default_main_program()
...@@ -278,79 +348,30 @@ class Executor(object): ...@@ -278,79 +348,30 @@ class Executor(object):
if scope is None: if scope is None:
scope = global_scope() scope = global_scope()
program_cache = None cache_key = get_program_cache_key(feed, fetch_list)
program_cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache: if use_program_cache:
# find program cache by cache_key cached_program = self._get_program_cache(cache_key)
program_cache = self.program_caches.get(program_cache_key, None) if cached_program is None:
# TODO(qiao): Should check program_cache and program are exactly the same. cached_program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
program = cached_program
else: else:
self.program_caches.pop(program_cache_key, None) self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
if program_cache is None: program=program,
program_cache = program.clone() feed=feed,
fetch_list=fetch_list,
if use_program_cache: feed_var_name=feed_var_name,
self.program_caches[program_cache_key] = program_cache fetch_var_name=fetch_var_name)
global_block = program_cache.global_block() self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
if feed_var_name in global_block.vars: outs = self._fetch_data(fetch_list, fetch_var_name, scope)
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
return outs return outs
...@@ -16,7 +16,6 @@ import unittest ...@@ -16,7 +16,6 @@ import unittest
import numpy import numpy
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data from paddle.fluid.layers import mul, data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册