未验证 提交 6d6642c8 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] Refine standalone executor (#37278)

* init

* add feed ops in python side

* import LRScheduler

* update_feed

* refine code format
上级 ce3ee9bb
...@@ -77,6 +77,24 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -77,6 +77,24 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>()); return *(fetch_var->GetMutable<framework::FetchList>());
} }
paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
} else {
ExecuteInstructionList(vec_instruction_);
}
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>());
}
void InterpreterCore::BuildOperatorDependences() { void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, set the dependecy_count_ and Call // analysis the dependences between ops, set the dependecy_count_ and Call
// Schedule // Schedule
...@@ -505,6 +523,7 @@ void InterpreterCore::Prepare( ...@@ -505,6 +523,7 @@ void InterpreterCore::Prepare(
feed_names.size(), feed_tensors.size())); feed_names.size(), feed_tensors.size()));
auto FeedInput = [&] { auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) { for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]); auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound( PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
...@@ -529,7 +548,9 @@ void InterpreterCore::Prepare( ...@@ -529,7 +548,9 @@ void InterpreterCore::Prepare(
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should // paddle::framework::build_op_func_list, so we should
// call FeedInput again. // call FeedInput again.
if (prepare_feed) FeedInput(); if (prepare_feed) {
FeedInput();
}
} }
interpreter::CostInfo InterpreterCore::DryRun( interpreter::CostInfo InterpreterCore::DryRun(
......
...@@ -49,6 +49,8 @@ class InterpreterCore { ...@@ -49,6 +49,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run();
interpreter::CostInfo DryRun( interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
......
...@@ -52,6 +52,14 @@ paddle::framework::FetchList StandaloneExecutor::Run( ...@@ -52,6 +52,14 @@ paddle::framework::FetchList StandaloneExecutor::Run(
return core->Run(feed_names, feed_tensors); return core->Run(feed_names, feed_tensors);
} }
paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names);
return core->Run();
}
framework::interpreter::CostInfo StandaloneExecutor::DryRun( framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) { const std::vector<framework::LoDTensor>& feed_tensors) {
......
...@@ -40,11 +40,17 @@ class StandaloneExecutor : public ExecutorBase { ...@@ -40,11 +40,17 @@ class StandaloneExecutor : public ExecutorBase {
~StandaloneExecutor() {} ~StandaloneExecutor() {}
virtual paddle::framework::FetchList Run( paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors, const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names); const std::vector<std::string>& fetch_names);
// NOTE(zhiqiu): feed_names are only used for caching interpretercore.
// fetch_names are used for caching interpretercore and inserting fetch ops,
// the latter can be moved to python side.
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names);
framework::interpreter::CostInfo DryRun( framework::interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names, const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
......
...@@ -2123,6 +2123,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2123,6 +2123,16 @@ All parameter, weight, gradient are variables in Paddle.
} }
return py::cast(std::move(ret)); return py::cast(std::move(ret));
}) })
.def("run",
[](StandaloneExecutor &self, std::vector<std::string> feed_names,
std::vector<std::string> fetch_names) {
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(feed_names, fetch_names);
}
return py::cast(std::move(ret));
})
.def("dry_run", .def("dry_run",
[](StandaloneExecutor &self, [](StandaloneExecutor &self,
const std::unordered_map<std::string, py::array> &input_dict) { const std::unordered_map<std::string, py::array> &input_dict) {
......
...@@ -493,29 +493,19 @@ class _StandaloneExecutor(object): ...@@ -493,29 +493,19 @@ class _StandaloneExecutor(object):
self._scope = scope self._scope = scope
self._new_exe = self._create_new_executor() self._new_exe = self._create_new_executor()
def run(self, feed, fetch_list, return_numpy=True): def run(self, feed_names, fetch_list, return_numpy=True):
""" """
Args: Args:
feed(list|dict): This parameter represents the input Tensors of the model. feed_names(list): This parameter represents the input names of the model.
If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list of Tensors. If the
parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of
the current mini-batch must be greater than the number of places;
if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places.
The default is None.
fetch_list(list): This parameter represents the Tensors that need to be returned fetch_list(list): This parameter represents the Tensors that need to be returned
after the model runs. The default is None. after the model runs. The default is None.
return_numpy(bool): This parameter indicates whether convert the fetched Tensors return_numpy(bool): This parameter indicates whether convert the fetched Tensors
(the Tensor specified in the fetch list) to numpy.ndarray. if it is False, (the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True. the type of the return value is a list of :code:`LoDTensor`. The default is True.
""" """
feed = self._update_feed(feed)
fetch_list = self._check_fetch(fetch_list) fetch_list = self._check_fetch(fetch_list)
tensors = self._new_exe.run(feed, fetch_list)._move_to_list() tensors = self._new_exe.run(feed_names, fetch_list)._move_to_list()
if return_numpy: if return_numpy:
return as_numpy(tensors, copy=True) return as_numpy(tensors, copy=True)
else: else:
...@@ -598,9 +588,9 @@ class _ExecutorCache(object): ...@@ -598,9 +588,9 @@ class _ExecutorCache(object):
assert isinstance( assert isinstance(
program, Program), "Required type(Program), but received {}".format( program, Program), "Required type(Program), but received {}".format(
type(program).__name__) type(program).__name__)
if str(program) not in self._cached_executors: if str(program) not in self._cached_executors:
new_program = program.clone() new_program = program.clone()
_prune_feed_ops(new_program)
new_exe = _StandaloneExecutor(self._place, new_program, scope) new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[str(program)] = new_exe self._cached_executors[str(program)] = new_exe
...@@ -744,8 +734,13 @@ class Executor(object): ...@@ -744,8 +734,13 @@ class Executor(object):
def _add_scope_cache(self, scope_cache_key, scope): def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope self.scope_caches[scope_cache_key] = scope
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name, def _add_feed_fetch_ops(self,
fetch_var_name): program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
skip_fetch=False):
tmp_program = program.clone() tmp_program = program.clone()
global_block = tmp_program.global_block() global_block = tmp_program.global_block()
...@@ -780,6 +775,9 @@ class Executor(object): ...@@ -780,6 +775,9 @@ class Executor(object):
warnings.warn( warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned." "The variable %s is not found in program. It is not declared or is pruned."
% name) % name)
if skip_fetch:
return tmp_program
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
...@@ -1325,8 +1323,40 @@ class Executor(object): ...@@ -1325,8 +1323,40 @@ class Executor(object):
program, compiler.CompiledProgram) else program program, compiler.CompiledProgram) else program
assert isinstance(inner_program_, framework.Program) assert isinstance(inner_program_, framework.Program)
if not inner_program_._is_start_up_program_: if not inner_program_._is_start_up_program_:
return self._executor_cache.run(inner_program_, scope, feed, if feed is None:
fetch_list, return_numpy) feed = {}
elif isinstance(feed, (list, tuple)):
assert len(feed) == 1, "Not compiled with data parallel"
feed = feed[0]
if not isinstance(feed, dict):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s"
% (type(feed)))
feed = self._update_feed(program, feed)
program = self._add_feed_fetch_ops(
program=inner_program_,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
skip_fetch=True)
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
assert isinstance(program.lr_sheduler,
LRScheduler), "must be LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array(
[lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope,
lr_sheduler._var_name)
tensor.set(data, self.place)
return self._executor_cache.run(program, scope,
list(feed.keys()), fetch_list,
return_numpy)
# use_prune can be overrided by putting optimize_ops in fetch_list # use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list _origin_fetch_list = fetch_list
......
...@@ -251,6 +251,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase): ...@@ -251,6 +251,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
class TestException(unittest.TestCase): class TestException(unittest.TestCase):
def setUp(self): def setUp(self):
self.place = paddle.CPUPlace() self.place = paddle.CPUPlace()
self.fetch_vars = None
def build_program(self): def build_program(self):
main_program = paddle.static.Program() main_program = paddle.static.Program()
...@@ -276,6 +277,7 @@ class TestException(unittest.TestCase): ...@@ -276,6 +277,7 @@ class TestException(unittest.TestCase):
for feed in feeds: for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars) out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(main_program) print(main_program)
self.fetch_vars = fetch_vars
return out return out
def run_new_executor(self, feed): def run_new_executor(self, feed):
...@@ -317,7 +319,7 @@ class TestException(unittest.TestCase): ...@@ -317,7 +319,7 @@ class TestException(unittest.TestCase):
}] }]
self.run_new_executor(feed) self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var( self.assertIsNotNone(paddle.static.global_scope().find_var(
'embedding.tmp_2')) self.fetch_vars.name))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册