未验证 提交 562b184c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix fetch list (#43412)

* fix fetch list

* fix unittest
上级 c92b3805
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
from collections import defaultdict from collections import defaultdict
import paddle import paddle
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle import fluid, static from paddle import fluid, static
...@@ -26,9 +27,9 @@ from paddle.static import InputSpec ...@@ -26,9 +27,9 @@ from paddle.static import InputSpec
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator, Variable from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
...@@ -137,7 +138,8 @@ class Engine: ...@@ -137,7 +138,8 @@ class Engine:
metrics = [] metrics = []
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog): with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else [] labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec] inputs = [s._create_feed_layer() for s in inputs_spec]
...@@ -256,7 +258,7 @@ class Engine: ...@@ -256,7 +258,7 @@ class Engine:
train_data, train_data,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
fetch_list=None, fetches=None,
steps_per_epoch=None, steps_per_epoch=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
...@@ -267,134 +269,131 @@ class Engine: ...@@ -267,134 +269,131 @@ class Engine:
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch) epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list
outputs = [] usr_fetch = self._to_map_fetch(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
for epoch in range(epochs): for epoch in range(epochs):
for step, data in enumerate(train_dataloader): train_logs = {"epoch": epoch}
logs, outs = self._train_step(data, use_program_cache, for step, _ in enumerate(train_dataloader):
return_numpy) outs = self._executor.run(self.main_program,
outputs.append(outs) fetch_list=fetch_list,
train_logs = { use_program_cache=use_program_cache,
"train_" + name: val return_numpy=return_numpy)
for name, val in logs.items() train_logs["step"] = step
} # inner fetches
if fetch_loss:
train_logs["train_loss"] = outs[0][0]
# user fetches
user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs):
train_logs["train_" +
fetch_map[user_fetch_list[i]]] = out[0]
self._logger.info(train_logs) self._logger.info(train_logs)
return outputs
def evaluate(self, def evaluate(self,
eval_data, eval_data,
batch_size=1, batch_size=1,
fetch_list=None, fetches=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'eval' self.mode = 'eval'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list
usr_fetch = self._to_map_fetch(fetches)
for step, data in enumerate(eval_dataloader): fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
eval_logs = dict() fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"])
logs, outs = self._eval_step(data, use_program_cache, return_numpy) inner_fetch = dict(fetch_loss, **fetch_metrics)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else [] fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for metric in self._metrics:
results = metric.accumulate() for step, _ in enumerate(eval_dataloader):
for i, res in enumerate(to_list(results)): eval_logs = {"step": step}
eval_logs["eval_" + metric.name()[i]] = res outs = self._executor.run(self.main_program,
for name, val in logs.items(): fetch_list=fetch_list,
eval_logs["eval_" + name] = val use_program_cache=use_program_cache,
return_numpy=return_numpy)
# inner fetches
if fetch_loss:
eval_logs["eval_loss"] = outs[0]
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
# usr fetches
usr_out = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_out):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
# logger
self._logger.info(eval_logs) self._logger.info(eval_logs)
return eval_logs
def predict(self, def predict(self,
test_data, test_data,
batch_size=1, batch_size=1,
fetch_list=None, fetches=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'predict' self.mode = 'predict'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list
usr_fetch = self._to_map_fetch(fetches)
fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"])
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = [] outputs = []
for step, data in enumerate(test_dataloader): for step, _ in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache, predict_logs = {"step": step}
return_numpy) outs = self._executor.run(self.main_program,
outputs.append(outs) fetch_list=fetch_list,
predict_logs = {"pred_" + name: val for name, val in logs.items()} use_program_cache=use_program_cache,
return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0]
self._logger.info(predict_logs) self._logger.info(predict_logs)
return outputs return outputs
def _train_step(self, data, use_program_cache=False, return_numpy=True): def _local_var(self, var):
logs = {} var_name = _to_name_str(var)
fetch_vars = self._fetch_vars[self.mode]["loss"] return var_name in self.main_program.global_block().vars
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list def _to_map_fetch(self, fetches):
if not fetches:
outs = self._executor.run(self.main_program, return {}
fetch_list=fetch_list, if isinstance(fetches, dict):
use_program_cache=use_program_cache, fetch_var_names = list(map(_to_name_str, fetches.values()))
return_numpy=return_numpy) usr_fetches = dict(zip(fetch_var_names, list(fetches.keys())))
for i, out in enumerate(outs): elif isinstance(fetches, list):
logs[fetch_list[i]] = out fetch_var_names = list(map(_to_name_str, fetches))
return logs, outs usr_fetches = dict(zip(fetch_var_names, fetch_var_names))
return dict(filter(lambda x: self._local_var(x[0]),
def _eval_step(self, data, use_program_cache=False, return_numpy=True): usr_fetches.items()))
logs = {}
metrics = self._fetch_vars[self.mode]["metrics"] def _inner_fetch(self, fetch_vars):
losses = self._fetch_vars[self.mode]["loss"] fetch_list = list(
fetch_loss, usr_fetch_list = self._fetch_list(losses) map(lambda x: x.name, list(filter(self._local_var, fetch_vars))))
fetch_metrics, usr_fetch_list = self._fetch_list(metrics) inner_fetches = dict(zip(fetch_list, fetch_list))
fetch_list = fetch_loss + fetch_metrics return inner_fetches
outs = self._executor.run(self.main_program, def _fetch_map(self, inner_fetch, usr_fetch):
fetch_list=fetch_list + usr_fetch_list, # replace inner fetch name if usr set for it
use_program_cache=use_program_cache, for iname in inner_fetch:
return_numpy=return_numpy) if iname in usr_fetch:
usr_out = outs[len(fetch_list):] inner_fetch[iname] = usr_fetch[iname]
for i, out in enumerate(usr_out): usr_fetch.pop(iname)
logs[usr_fetch_list[i]] = out fetches = dict(inner_fetch, **usr_fetch)
outs = outs[:len(fetch_list)] return list(fetches.keys()), fetches
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*outs[len(fetch_loss):])
return logs, outs[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs
def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list
def _create_dataloader(self, def _create_dataloader(self,
dataset, dataset,
...@@ -515,7 +514,8 @@ class Engine: ...@@ -515,7 +514,8 @@ class Engine:
mode = self.mode mode = self.mode
if training: if training:
assert 'train' in self._serial_main_progs, "training model is not ready, please call `engine.prepare(mode='train')` first." assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first."
serial_program = self._serial_main_progs["train"] serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"] dist_context = self._dist_contexts["train"]
...@@ -571,3 +571,7 @@ class Engine: ...@@ -571,3 +571,7 @@ class Engine:
@property @property
def serial_startup_program(self): def serial_startup_program(self):
return self._serial_startup_progs[self.mode] return self._serial_startup_progs[self.mode]
@property
def fetch_vars(self):
return self._fetch_vars[self.mode]
...@@ -96,10 +96,11 @@ class MLPLayer(nn.Layer): ...@@ -96,10 +96,11 @@ class MLPLayer(nn.Layer):
PP_MESH_1})(out)[0] PP_MESH_1})(out)[0]
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
self.out = out
return out return out
def train(): def train(fetch):
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size, intermediate_size=4 * hidden_size,
dropout_ratio=0.1, dropout_ratio=0.1,
...@@ -118,7 +119,6 @@ def train(): ...@@ -118,7 +119,6 @@ def train():
dist_strategy.amp = False dist_strategy.amp = False
dist_strategy.pipeline = False dist_strategy.pipeline = False
dist_strategy.recompute = False dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
...@@ -129,20 +129,26 @@ def train(): ...@@ -129,20 +129,26 @@ def train():
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, engine.fit(train_dataset,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size, steps_per_epoch=batch_num * batch_size,
fetch_list=['label']) fetches=fetches)
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetch_list=['label']) engine.evaluate(eval_dataset, batch_size, fetches=fetches)
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetch_list=['label']) engine.predict(test_dataset, batch_size, fetches=fetches)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
...@@ -152,4 +158,5 @@ def train(): ...@@ -152,4 +158,5 @@ def train():
if __name__ == "__main__": if __name__ == "__main__":
train() train(fetch=True)
train(fetch=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册