未验证 提交 37a272e6 编写于 作者: Q Qiao Longfei 提交者: GitHub

add executor.prepare (#9022)

optimize executor.run
上级 30b70323
......@@ -14,12 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include <set>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -40,14 +36,13 @@ namespace {
int kProgramId = -1;
} // namespace
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
ExecutorPrepareContext::ExecutorPrepareContext(
const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
ExecutorPrepareContext::~ExecutorPrepareContext() {
VLOG(5) << "destroy ExecutorPrepareContext";
}
Executor::Executor(const platform::Place& place) : place_(place) {}
......@@ -101,9 +96,8 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
auto* ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
auto ctx = Prepare(pdesc, block_id);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
}
// Check whether the block already has feed operators and feed_holder.
......@@ -274,15 +268,15 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
const ProgramDesc& program, int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
return ctx;
return std::unique_ptr<ExecutorPrepareContext>(ctx);
}
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
......
......@@ -22,7 +22,16 @@ limitations under the License. */
namespace paddle {
namespace framework {
struct ExecutorPrepareContext;
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id);
~ExecutorPrepareContext();
const framework::ProgramDesc& prog_;
size_t block_id_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
......@@ -47,8 +56,8 @@ class Executor {
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
......
......@@ -235,6 +235,77 @@ class Executor(object):
tensor.set_lod(lod)
return tensor
def _get_program_cache(self, program_cache_key):
return self.program_caches.get(program_cache_key, None)
def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
tmp_program = program.clone()
global_block = tmp_program.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
return tmp_program
def _feed_data(self, program, feed, feed_var_name, scope):
# feed var to framework
for op in program.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
return outs
def run(self,
program=None,
feed=None,
......@@ -268,7 +339,6 @@ class Executor(object):
raise TypeError("feed should be a map")
if fetch_list is None:
fetch_list = []
if program is None:
program = default_main_program()
......@@ -278,79 +348,30 @@ class Executor(object):
if scope is None:
scope = global_scope()
program_cache = None
program_cache_key = get_program_cache_key(feed, fetch_list)
cache_key = get_program_cache_key(feed, fetch_list)
if use_program_cache:
# find program cache by cache_key
program_cache = self.program_caches.get(program_cache_key, None)
# TODO(qiao): Should check program_cache and program are exactly the same.
cached_program = self._get_program_cache(cache_key)
if cached_program is None:
cached_program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._add_program_cache(cache_key, cached_program)
program = cached_program
else:
self.program_caches.pop(program_cache_key, None)
if program_cache is None:
program_cache = program.clone()
if use_program_cache:
self.program_caches[program_cache_key] = program_cache
global_block = program_cache.global_block()
if feed_var_name in global_block.vars:
feed_var = global_block.var(feed_var_name)
else:
feed_var = global_block.create_var(
name=feed_var_name,
type=core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True)
if fetch_var_name in global_block.vars:
fetch_var = global_block.var(fetch_var_name)
else:
fetch_var = global_block.create_var(
name=fetch_var_name,
type=core.VarDesc.VarType.FETCH_LIST,
persistable=True)
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list,
fetch_var_name):
for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
# feed var to framework
for op in program_cache.global_block().ops:
if op.desc.type() == 'feed':
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
break
self.executor.run(program_cache.desc, scope, 0, True, True)
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
]
self.program_caches.pop(cache_key, None)
program = self._add_feed_fetch_ops(
program=program,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name)
self._feed_data(program, feed, feed_var_name, scope)
self.executor.run(program.desc, scope, 0, True, True)
outs = self._fetch_data(fetch_list, fetch_var_name, scope)
if return_numpy:
outs = as_numpy(outs)
return outs
......@@ -16,7 +16,6 @@ import unittest
import numpy
import paddle.fluid.core as core
from paddle.fluid.executor import Executor
from paddle.fluid.layers import mul, data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册