未验证 提交 971e4791 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add fetch_list in engine api (#43312)

* add fetch_list

* fix evaluate log

* tiny fix
上级 07ede118
...@@ -28,7 +28,7 @@ from paddle.fluid import program_guard ...@@ -28,7 +28,7 @@ from paddle.fluid import program_guard
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator from paddle.fluid.framework import Operator, Variable
from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet from paddle.distributed import fleet
...@@ -256,6 +256,7 @@ class Engine: ...@@ -256,6 +256,7 @@ class Engine:
train_data, train_data,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
fetch_list=None,
steps_per_epoch=None, steps_per_epoch=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
...@@ -266,13 +267,14 @@ class Engine: ...@@ -266,13 +267,14 @@ class Engine:
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch) epochs, steps_per_epoch)
self._usr_fetch_list = fetch_list
outputs = [] outputs = []
for epoch in range(epochs): for epoch in range(epochs):
for step, data in enumerate(train_dataloader): for step, data in enumerate(train_dataloader):
logs, loss = self._train_step(data, use_program_cache, logs, outs = self._train_step(data, use_program_cache,
return_numpy) return_numpy)
outputs.append(loss) outputs.append(outs)
train_logs = { train_logs = {
"train_" + name: val "train_" + name: val
for name, val in logs.items() for name, val in logs.items()
...@@ -283,86 +285,97 @@ class Engine: ...@@ -283,86 +285,97 @@ class Engine:
def evaluate(self, def evaluate(self,
eval_data, eval_data,
batch_size=1, batch_size=1,
fetch_list=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'eval' self.mode = 'eval'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data, batch_size)
self._usr_fetch_list = fetch_list
for step, data in enumerate(eval_dataloader): for step, data in enumerate(eval_dataloader):
eval_logs = dict() eval_logs = dict()
outs = self._eval_step(data, use_program_cache, return_numpy) logs, outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else [] eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics: for metric in self._metrics:
results = metric.accumulate() results = metric.accumulate()
for i, res in enumerate(to_list(results)): for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res eval_logs["eval_" + metric.name()[i]] = res
for name, val in logs.items():
eval_logs["eval_" + name] = val
self._logger.info(eval_logs) self._logger.info(eval_logs)
return eval_logs return eval_logs
def predict(self, def predict(self,
test_data, test_data,
batch_size=1, batch_size=1,
fetch_list=None,
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'predict' self.mode = 'predict'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data, batch_size)
self._usr_fetch_list = fetch_list
outputs = [] outputs = []
for step, data in enumerate(test_dataloader): for step, data in enumerate(test_dataloader):
logs, outs = self._predict_step(data, use_program_cache, logs, outs = self._predict_step(data, use_program_cache,
return_numpy) return_numpy)
outputs.append(outs) outputs.append(outs)
predict_logs = { predict_logs = {"pred_" + name: val for name, val in logs.items()}
"predict_" + name: val
for name, val in logs.items()
}
self._logger.info(predict_logs) self._logger.info(predict_logs)
return outputs return outputs
def _train_step(self, data, use_program_cache=False, return_numpy=True): def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
fetch_vars = self._fetch_vars[self.mode]["loss"] fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list = self._fetch_list(fetch_vars) fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
loss = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_program_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
logs["loss"] = loss for i, out in enumerate(outs):
return logs, loss logs[fetch_list[i]] = out
return logs, outs
def _eval_step(self, data, use_program_cache=False, return_numpy=True): def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
metrics = self._fetch_vars[self.mode]["metrics"] metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"] losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses) fetch_loss, usr_fetch_list = self._fetch_list(losses)
fetch_metrics = self._fetch_list(metrics) fetch_metrics, usr_fetch_list = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics fetch_list = fetch_loss + fetch_metrics
res = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list + usr_fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_program_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
if not res[len(fetch_loss):]: usr_out = outs[len(fetch_list):]
return res[:len(fetch_loss)] for i, out in enumerate(usr_out):
logs[usr_fetch_list[i]] = out
outs = outs[:len(fetch_list)]
if not outs[len(fetch_loss):]:
return logs, outs[:len(fetch_loss)]
for metric in self._metrics: for metric in self._metrics:
metric.update(*res[len(fetch_loss):]) metric.update(*outs[len(fetch_loss):])
return res[:len(fetch_loss)] return logs, outs[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True): def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
fetch_vars = self._fetch_vars[self.mode]["outputs"] fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list = self._fetch_list(fetch_vars) fetch_list, usr_fetch_list = self._fetch_list(fetch_vars)
fetch_list += usr_fetch_list
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_program_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
logs["pred"] = outs for i, out in enumerate(outs):
logs[fetch_list[i]] = out
return logs, outs return logs, outs
def _fetch_list(self, fetch_vars): def _fetch_list(self, fetch_vars):
...@@ -370,7 +383,18 @@ class Engine: ...@@ -370,7 +383,18 @@ class Engine:
for var in fetch_vars: for var in fetch_vars:
if var.name in self.main_program.global_block().vars: if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name) fetch_list.append(var.name)
return fetch_list usr_fetch_list = []
if self._usr_fetch_list:
assert isinstance(self._usr_fetch_list,
list), "'fetch_list' type should be list."
for var in self._usr_fetch_list:
if isinstance(var, str):
if var in self.main_program.global_block().vars:
usr_fetch_list.append(var)
elif isinstance(var, Variable):
if var.name in self.main_program.global_block().vars:
usr_fetch_list.append(var.name)
return fetch_list, usr_fetch_list
def _create_dataloader(self, def _create_dataloader(self,
dataset, dataset,
......
...@@ -133,15 +133,16 @@ def train(): ...@@ -133,15 +133,16 @@ def train():
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, engine.fit(train_dataset,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size) steps_per_epoch=batch_num * batch_size,
fetch_list=['label'])
# eval # eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size) engine.evaluate(eval_dataset, batch_size, fetch_list=['label'])
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size) engine.predict(test_dataset, batch_size, fetch_list=['label'])
# save # save
engine.save('./mlp_inf', training=False, mode='predict') engine.save('./mlp_inf', training=False, mode='predict')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册