未验证 提交 562b184c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] fix fetch list (#43412)

* fix fetch list

* fix unittest
上级 c92b3805
......@@ -17,6 +17,7 @@ import logging
from collections import defaultdict
import paddle
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle import fluid, static
......@@ -26,9 +27,9 @@ from paddle.static import InputSpec
from paddle.fluid import core
from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator, Variable
from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
......@@ -137,7 +138,8 @@ class Engine:
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog):
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
......@@ -256,7 +258,7 @@ class Engine:
train_data,
batch_size=1,
epochs=1,
fetch_list=None,
fetches=None,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True):
......@@ -267,134 +269,131 @@ class Engine:
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list
outputs = []
usr_fetch = self._to_map_fetch(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
for epoch in range(epochs):
for step, data in enumerate(train_dataloader):
logs, outs = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(outs)
train_logs = {
"train_" + name: val
for name, val in logs.items()
}
train_logs = {"epoch": epoch}
for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
train_logs["step"] = step
# inner fetches
if fetch_loss:
train_logs["train_loss"] = outs[0][0]
# user fetches
user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):]
for i, out in enumerate(user_outs):
train_logs["train_" +
fetch_map[user_fetch_list[i]]] = out[0]
self._logger.info(train_logs)
return outputs
def evaluate(self,
eval_data,
batch_size=1,
fetch_list=None,
fetches=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list
for step, data in enumerate(eval_dataloader):
eval_logs = dict()
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics:
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
for name, val in logs.items():
eval_logs["eval_" + name] = val
usr_fetch = self._to_map_fetch(fetches)
fetch_loss = self._inner_fetch(self.fetch_vars["loss"])
fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for step, _ in enumerate(eval_dataloader):
eval_logs = {"step": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
# inner fetches
if fetch_loss:
eval_logs["eval_loss"] = outs[0]
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
# usr fetches
usr_out = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_out):
eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out
# logger
self._logger.info(eval_logs)
return eval_logs
def predict(self,
test_data,
batch_size=1,
fetch_list=None,
fetches=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list
usr_fetch = self._to_map_fetch(fetches)
fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"])
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = []
for step, data in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache,
return_numpy)
outputs.append(outs)
predict_logs = {"pred_" + name: val for name, val in logs.items()}
for step, _ in enumerate(test_dataloader):
predict_logs = {"step": step}
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0]
self._logger.info(predict_logs)
return outputs
def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs
def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"]
fetch_loss, usr_fetch_list = self._fetch_list(losses)
fetch_metrics, usr_fetch_list = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics
outs = self._executor.run(self.main_program,
fetch_list=fetch_list + usr_fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
usr_out = outs[len(fetch_list):]
for i, out in enumerate(usr_out):
logs[usr_fetch_list[i]] = out
outs = outs[:len(fetch_list)]
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*outs[len(fetch_loss):])
return logs, outs[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs
def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list
def _local_var(self, var):
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _to_map_fetch(self, fetches):
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
usr_fetches = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
usr_fetches = dict(zip(fetch_var_names, fetch_var_names))
return dict(filter(lambda x: self._local_var(x[0]),
usr_fetches.items()))
def _inner_fetch(self, fetch_vars):
fetch_list = list(
map(lambda x: x.name, list(filter(self._local_var, fetch_vars))))
inner_fetches = dict(zip(fetch_list, fetch_list))
return inner_fetches
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _create_dataloader(self,
dataset,
......@@ -515,7 +514,8 @@ class Engine:
mode = self.mode
if training:
assert 'train' in self._serial_main_progs, "training model is not ready, please call `engine.prepare(mode='train')` first."
assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first."
serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"]
......@@ -571,3 +571,7 @@ class Engine:
@property
def serial_startup_program(self):
return self._serial_startup_progs[self.mode]
@property
def fetch_vars(self):
return self._fetch_vars[self.mode]
......@@ -96,10 +96,11 @@ class MLPLayer(nn.Layer):
PP_MESH_1})(out)[0]
out = self.dropout(out)
out = self.linear2(out)
self.out = out
return out
def train():
def train(fetch):
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
......@@ -118,7 +119,6 @@ def train():
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
......@@ -129,20 +129,26 @@ def train():
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetch_list=['label'])
fetches=fetches)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetch_list=['label'])
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetch_list=['label'])
engine.predict(test_dataset, batch_size, fetches=fetches)
# save
temp_dir = tempfile.TemporaryDirectory()
......@@ -152,4 +158,5 @@ def train():
if __name__ == "__main__":
train()
train(fetch=True)
train(fetch=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册