未验证 提交 6d6642c8 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] Refine standalone executor (#37278)

* init

* add feed ops in python side

* import LRScheduler

* update_feed

* refine code format
上级 ce3ee9bb
......@@ -77,6 +77,24 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>());
}
paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
} else {
ExecuteInstructionList(vec_instruction_);
}
// return Fetch Tensors
auto* fetch_var = global_scope_->Var(interpreter::kFetchVarName);
return *(fetch_var->GetMutable<framework::FetchList>());
}
void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, set the dependecy_count_ and Call
// Schedule
......@@ -505,6 +523,7 @@ void InterpreterCore::Prepare(
feed_names.size(), feed_tensors.size()));
auto FeedInput = [&] {
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
......@@ -529,7 +548,9 @@ void InterpreterCore::Prepare(
// NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should
// call FeedInput again.
if (prepare_feed) FeedInput();
if (prepare_feed) {
FeedInput();
}
}
interpreter::CostInfo InterpreterCore::DryRun(
......
......@@ -49,6 +49,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
paddle::framework::FetchList Run();
interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
......
......@@ -52,6 +52,14 @@ paddle::framework::FetchList StandaloneExecutor::Run(
return core->Run(feed_names, feed_tensors);
}
paddle::framework::FetchList StandaloneExecutor::Run(
const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names) {
auto core = GetInterpreterCore(feed_names, fetch_names);
return core->Run();
}
framework::interpreter::CostInfo StandaloneExecutor::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
......
......@@ -40,11 +40,17 @@ class StandaloneExecutor : public ExecutorBase {
~StandaloneExecutor() {}
virtual paddle::framework::FetchList Run(
paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors,
const std::vector<std::string>& fetch_names);
// NOTE(zhiqiu): feed_names are only used for caching interpretercore.
// fetch_names are used for caching interpretercore and inserting fetch ops,
// the latter can be moved to python side.
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
const std::vector<std::string>& fetch_names);
framework::interpreter::CostInfo DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);
......
......@@ -2123,6 +2123,16 @@ All parameter, weight, gradient are variables in Paddle.
}
return py::cast(std::move(ret));
})
.def("run",
[](StandaloneExecutor &self, std::vector<std::string> feed_names,
std::vector<std::string> fetch_names) {
paddle::framework::FetchList ret;
{
pybind11::gil_scoped_release release;
ret = self.Run(feed_names, fetch_names);
}
return py::cast(std::move(ret));
})
.def("dry_run",
[](StandaloneExecutor &self,
const std::unordered_map<std::string, py::array> &input_dict) {
......
......@@ -493,29 +493,19 @@ class _StandaloneExecutor(object):
self._scope = scope
self._new_exe = self._create_new_executor()
def run(self, feed, fetch_list, return_numpy=True):
def run(self, feed_names, fetch_list, return_numpy=True):
"""
Args:
feed(list|dict): This parameter represents the input Tensors of the model.
If it is single card training, the feed is dict type, and if it is multi-card
training, the parameter feed can be dict or list of Tensors. If the
parameter type is dict, the data in the feed will be split and sent to
multiple devices (CPU/GPU), that is to say, the input data will be evenly
sent to different devices, so you should make sure the number of samples of
the current mini-batch must be greater than the number of places;
if the parameter type is list, those data are copied directly to each device,
so the length of this list should be equal to the number of places.
The default is None.
feed_names(list): This parameter represents the input names of the model.
fetch_list(list): This parameter represents the Tensors that need to be returned
after the model runs. The default is None.
return_numpy(bool): This parameter indicates whether convert the fetched Tensors
(the Tensor specified in the fetch list) to numpy.ndarray. if it is False,
the type of the return value is a list of :code:`LoDTensor`. The default is True.
"""
feed = self._update_feed(feed)
fetch_list = self._check_fetch(fetch_list)
tensors = self._new_exe.run(feed, fetch_list)._move_to_list()
tensors = self._new_exe.run(feed_names, fetch_list)._move_to_list()
if return_numpy:
return as_numpy(tensors, copy=True)
else:
......@@ -598,9 +588,9 @@ class _ExecutorCache(object):
assert isinstance(
program, Program), "Required type(Program), but received {}".format(
type(program).__name__)
if str(program) not in self._cached_executors:
new_program = program.clone()
_prune_feed_ops(new_program)
new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[str(program)] = new_exe
......@@ -744,8 +734,13 @@ class Executor(object):
def _add_scope_cache(self, scope_cache_key, scope):
self.scope_caches[scope_cache_key] = scope
def _add_feed_fetch_ops(self, program, feed, fetch_list, feed_var_name,
fetch_var_name):
def _add_feed_fetch_ops(self,
program,
feed,
fetch_list,
feed_var_name,
fetch_var_name,
skip_fetch=False):
tmp_program = program.clone()
global_block = tmp_program.global_block()
......@@ -780,6 +775,9 @@ class Executor(object):
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
if skip_fetch:
return tmp_program
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
......@@ -1325,8 +1323,40 @@ class Executor(object):
program, compiler.CompiledProgram) else program
assert isinstance(inner_program_, framework.Program)
if not inner_program_._is_start_up_program_:
return self._executor_cache.run(inner_program_, scope, feed,
fetch_list, return_numpy)
if feed is None:
feed = {}
elif isinstance(feed, (list, tuple)):
assert len(feed) == 1, "Not compiled with data parallel"
feed = feed[0]
if not isinstance(feed, dict):
raise TypeError(
"feed requires dict as its Parameter. But you passed in %s"
% (type(feed)))
feed = self._update_feed(program, feed)
program = self._add_feed_fetch_ops(
program=inner_program_,
feed=feed,
fetch_list=fetch_list,
feed_var_name=feed_var_name,
fetch_var_name=fetch_var_name,
skip_fetch=True)
self._feed_data(program, feed, feed_var_name, scope)
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
assert isinstance(program.lr_sheduler,
LRScheduler), "must be LRScheduler"
lr_sheduler = program.lr_sheduler
lr_value = lr_sheduler()
lr_var = program.global_block().vars[lr_sheduler._var_name]
data = np.array(
[lr_value]).astype(convert_dtype(lr_var.dtype))
tensor = core.get_variable_tensor(scope,
lr_sheduler._var_name)
tensor.set(data, self.place)
return self._executor_cache.run(program, scope,
list(feed.keys()), fetch_list,
return_numpy)
# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list
......
......@@ -251,6 +251,7 @@ class SwitchExecutorInterfaceWithFeed(unittest.TestCase):
class TestException(unittest.TestCase):
def setUp(self):
self.place = paddle.CPUPlace()
self.fetch_vars = None
def build_program(self):
main_program = paddle.static.Program()
......@@ -276,6 +277,7 @@ class TestException(unittest.TestCase):
for feed in feeds:
out = exe.run(main_program, feed=feed, fetch_list=fetch_vars)
print(main_program)
self.fetch_vars = fetch_vars
return out
def run_new_executor(self, feed):
......@@ -317,7 +319,7 @@ class TestException(unittest.TestCase):
}]
self.run_new_executor(feed)
self.assertIsNotNone(paddle.static.global_scope().find_var(
'embedding.tmp_2'))
self.fetch_vars.name))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册