diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index fb12ae4971ae1e91b00fced835ea3283213a206a..a0b2125f166426cbd2e68971b5ff9e888273f3fd 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -28,7 +28,7 @@ from paddle.fluid import program_guard from paddle.fluid.layers.utils import flatten from paddle.fluid.executor import global_scope from paddle.fluid.backward import append_backward -from paddle.fluid.framework import Operator +from paddle.fluid.framework import Operator, Variable from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import fleet @@ -256,6 +256,7 @@ class Engine: train_data, batch_size=1, epochs=1, + fetch_list=None, steps_per_epoch=None, use_program_cache=False, return_numpy=True): @@ -266,13 +267,14 @@ class Engine: "train model is not ready, please call `engine.prepare()` first." train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch) + self._usr_fetch_list = fetch_list outputs = [] for epoch in range(epochs): for step, data in enumerate(train_dataloader): - logs, loss = self._train_step(data, use_program_cache, + logs, outs = self._train_step(data, use_program_cache, return_numpy) - outputs.append(loss) + outputs.append(outs) train_logs = { "train_" + name: val for name, val in logs.items() @@ -283,86 +285,97 @@ class Engine: def evaluate(self, eval_data, batch_size=1, + fetch_list=None, use_program_cache=False, return_numpy=True): self.mode = 'eval' assert self.mode in self._dist_main_progs, \ "eval model is not ready, please call `engine.prepare()` first." eval_dataloader = self._create_dataloader(eval_data, batch_size) + self._usr_fetch_list = fetch_list for step, data in enumerate(eval_dataloader): eval_logs = dict() - outs = self._eval_step(data, use_program_cache, return_numpy) + logs, outs = self._eval_step(data, use_program_cache, return_numpy) eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else [] for metric in self._metrics: results = metric.accumulate() for i, res in enumerate(to_list(results)): eval_logs["eval_" + metric.name()[i]] = res + for name, val in logs.items(): + eval_logs["eval_" + name] = val self._logger.info(eval_logs) return eval_logs def predict(self, test_data, batch_size=1, + fetch_list=None, use_program_cache=False, return_numpy=True): self.mode = 'predict' assert self.mode in self._dist_main_progs, \ "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size) + self._usr_fetch_list = fetch_list outputs = [] for step, data in enumerate(test_dataloader): logs, outs = self._predict_step(data, use_program_cache, return_numpy) outputs.append(outs) - predict_logs = { - "predict_" + name: val - for name, val in logs.items() - } + predict_logs = {"pred_" + name: val for name, val in logs.items()} self._logger.info(predict_logs) return outputs def _train_step(self, data, use_program_cache=False, return_numpy=True): logs = {} fetch_vars = self._fetch_vars[self.mode]["loss"] - fetch_list = self._fetch_list(fetch_vars) + fetch_list, usr_fetch_list = self._fetch_list(fetch_vars) + fetch_list += usr_fetch_list - loss = self._executor.run(self.main_program, + outs = self._executor.run(self.main_program, fetch_list=fetch_list, use_program_cache=use_program_cache, return_numpy=return_numpy) - logs["loss"] = loss - return logs, loss + for i, out in enumerate(outs): + logs[fetch_list[i]] = out + return logs, outs def _eval_step(self, data, use_program_cache=False, return_numpy=True): logs = {} metrics = self._fetch_vars[self.mode]["metrics"] losses = self._fetch_vars[self.mode]["loss"] - fetch_loss = self._fetch_list(losses) - fetch_metrics = self._fetch_list(metrics) + fetch_loss, usr_fetch_list = self._fetch_list(losses) + fetch_metrics, usr_fetch_list = self._fetch_list(metrics) fetch_list = fetch_loss + fetch_metrics - res = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - if not res[len(fetch_loss):]: - return res[:len(fetch_loss)] + outs = self._executor.run(self.main_program, + fetch_list=fetch_list + usr_fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + usr_out = outs[len(fetch_list):] + for i, out in enumerate(usr_out): + logs[usr_fetch_list[i]] = out + outs = outs[:len(fetch_list)] + if not outs[len(fetch_loss):]: + return logs, outs[:len(fetch_loss)] for metric in self._metrics: - metric.update(*res[len(fetch_loss):]) - return res[:len(fetch_loss)] + metric.update(*outs[len(fetch_loss):]) + return logs, outs[:len(fetch_loss)] def _predict_step(self, data, use_program_cache=False, return_numpy=True): logs = {} fetch_vars = self._fetch_vars[self.mode]["outputs"] - fetch_list = self._fetch_list(fetch_vars) + fetch_list, usr_fetch_list = self._fetch_list(fetch_vars) + fetch_list += usr_fetch_list outs = self._executor.run(self.main_program, fetch_list=fetch_list, use_program_cache=use_program_cache, return_numpy=return_numpy) - logs["pred"] = outs + for i, out in enumerate(outs): + logs[fetch_list[i]] = out return logs, outs def _fetch_list(self, fetch_vars): @@ -370,7 +383,18 @@ class Engine: for var in fetch_vars: if var.name in self.main_program.global_block().vars: fetch_list.append(var.name) - return fetch_list + usr_fetch_list = [] + if self._usr_fetch_list: + assert isinstance(self._usr_fetch_list, + list), "'fetch_list' type should be list." + for var in self._usr_fetch_list: + if isinstance(var, str): + if var in self.main_program.global_block().vars: + usr_fetch_list.append(var) + elif isinstance(var, Variable): + if var.name in self.main_program.global_block().vars: + usr_fetch_list.append(var.name) + return fetch_list, usr_fetch_list def _create_dataloader(self, dataset, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index e6a730f0a64d6386ffc0ce36ab88001d9be04ac8..0d96c57c2437fac37daeb799bdc32bb27c16c7eb 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -133,15 +133,16 @@ def train(): train_dataset = MyDataset(batch_num * batch_size) engine.fit(train_dataset, batch_size=batch_size, - steps_per_epoch=batch_num * batch_size) + steps_per_epoch=batch_num * batch_size, + fetch_list=['label']) # eval eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size) + engine.evaluate(eval_dataset, batch_size, fetch_list=['label']) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size) + engine.predict(test_dataset, batch_size, fetch_list=['label']) # save engine.save('./mlp_inf', training=False, mode='predict')