未验证 提交 971e4791 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add fetch_list in engine api (#43312)

* add fetch_list

* fix evaluate log

* tiny fix
上级 07ede118
......@@ -28,7 +28,7 @@ from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator
from paddle.fluid.framework import Operator, Variable
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
......@@ -256,6 +256,7 @@ class Engine:
train_data,
batch_size=1,
epochs=1,
fetch_list=None,
steps_per_epoch=None,
use_program_cache=False,
return_numpy=True):
......@@ -266,13 +267,14 @@ class Engine:
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list
outputs = []
for epoch in range(epochs):
for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data, use_program_cache,
logs, outs = self._train_step(data, use_program_cache,
return_numpy)
outputs.append(loss)
outputs.append(outs)
train_logs = {
"train_" + name: val
for name, val in logs.items()
......@@ -283,86 +285,97 @@ class Engine:
def evaluate(self,
eval_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list
for step, data in enumerate(eval_dataloader):
eval_logs = dict()
outs = self._eval_step(data, use_program_cache, return_numpy)
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics:
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
for name, val in logs.items():
eval_logs["eval_" + name] = val
self._logger.info(eval_logs)
return eval_logs
def predict(self,
test_data,
batch_size=1,
fetch_list=None,
use_program_cache=False,
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list
outputs = []
for step, data in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache,
return_numpy)
outputs.append(outs)
predict_logs = {
"predict_" + name: val
for name, val in logs.items()
}
predict_logs = {"pred_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs
def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
loss = self._executor.run(self.main_program,
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs
def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses)
fetch_metrics = self._fetch_list(metrics)
fetch_loss, usr_fetch_list = self._fetch_list(losses)
fetch_metrics, usr_fetch_list = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics
res = self._executor.run(self.main_program,
fetch_list=fetch_list,
outs = self._executor.run(self.main_program,
fetch_list=fetch_list + usr_fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
if not res[len(fetch_loss):]:
return res[:len(fetch_loss)]
usr_out = outs[len(fetch_list):]
for i, out in enumerate(usr_out):
logs[usr_fetch_list[i]] = out
outs = outs[:len(fetch_list)]
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*res[len(fetch_loss):])
return res[:len(fetch_loss)]
metric.update(*outs[len(fetch_loss):])
return logs, outs[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list = self._fetch_list(fetch_vars)
fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs
def _fetch_list(self, fetch_vars):
......@@ -370,7 +383,18 @@ class Engine:
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
return fetch_list
usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list
def _create_dataloader(self,
dataset,
......
......@@ -133,15 +133,16 @@ def train():
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size)
steps_per_epoch=batch_num * batch_size,
fetch_list=['label'])
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size)
engine.evaluate(eval_dataset, batch_size, fetch_list=['label'])
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size)
engine.predict(test_dataset, batch_size, fetch_list=['label'])
# save
engine.save('./mlp_inf', training=False, mode='predict')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册