From 562b184cc1dcfccd9df7e6624f1c5c6cd838c3cb Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Mon, 13 Jun 2022 15:23:27 +0800 Subject: [PATCH] [AutoParallel] fix fetch list (#43412) * fix fetch list * fix unittest --- .../distributed/auto_parallel/engine.py | 212 +++++++++--------- .../unittests/auto_parallel/engine_api.py | 19 +- 2 files changed, 121 insertions(+), 110 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index a0b2125f166..dcdd098dcd9 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -17,6 +17,7 @@ import logging from collections import defaultdict import paddle +import paddle.utils as utils import paddle.distributed.auto_parallel as auto from paddle import fluid, static @@ -26,9 +27,9 @@ from paddle.static import InputSpec from paddle.fluid import core from paddle.fluid import program_guard from paddle.fluid.layers.utils import flatten -from paddle.fluid.executor import global_scope +from paddle.fluid.executor import global_scope, _to_name_str from paddle.fluid.backward import append_backward -from paddle.fluid.framework import Operator, Variable +from paddle.fluid.framework import Operator from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import fleet @@ -137,7 +138,8 @@ class Engine: metrics = [] serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() - with static.program_guard(serial_main_prog, serial_startup_prog): + with static.program_guard(serial_main_prog, serial_startup_prog), \ + utils.unique_name.guard(): inputs_spec = self.inputs_spec labels_spec = self.labels_spec if self.labels_spec else [] inputs = [s._create_feed_layer() for s in inputs_spec] @@ -256,7 +258,7 @@ class Engine: train_data, batch_size=1, epochs=1, - fetch_list=None, + fetches=None, steps_per_epoch=None, use_program_cache=False, return_numpy=True): @@ -267,134 +269,131 @@ class Engine: "train model is not ready, please call `engine.prepare()` first." train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch) - self._usr_fetch_list = fetch_list - outputs = [] + usr_fetch = self._to_map_fetch(fetches) + fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) + fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) + for epoch in range(epochs): - for step, data in enumerate(train_dataloader): - logs, outs = self._train_step(data, use_program_cache, - return_numpy) - outputs.append(outs) - train_logs = { - "train_" + name: val - for name, val in logs.items() - } + train_logs = {"epoch": epoch} + for step, _ in enumerate(train_dataloader): + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + train_logs["step"] = step + # inner fetches + if fetch_loss: + train_logs["train_loss"] = outs[0][0] + # user fetches + user_outs = outs[len(fetch_loss):] + user_fetch_list = fetch_list[len(fetch_loss):] + for i, out in enumerate(user_outs): + train_logs["train_" + + fetch_map[user_fetch_list[i]]] = out[0] self._logger.info(train_logs) - return outputs def evaluate(self, eval_data, batch_size=1, - fetch_list=None, + fetches=None, use_program_cache=False, return_numpy=True): self.mode = 'eval' assert self.mode in self._dist_main_progs, \ "eval model is not ready, please call `engine.prepare()` first." eval_dataloader = self._create_dataloader(eval_data, batch_size) - self._usr_fetch_list = fetch_list - - for step, data in enumerate(eval_dataloader): - eval_logs = dict() - logs, outs = self._eval_step(data, use_program_cache, return_numpy) - eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else [] - for metric in self._metrics: - results = metric.accumulate() - for i, res in enumerate(to_list(results)): - eval_logs["eval_" + metric.name()[i]] = res - for name, val in logs.items(): - eval_logs["eval_" + name] = val + + usr_fetch = self._to_map_fetch(fetches) + fetch_loss = self._inner_fetch(self.fetch_vars["loss"]) + fetch_metrics = self._inner_fetch(self.fetch_vars["metrics"]) + inner_fetch = dict(fetch_loss, **fetch_metrics) + fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) + + for step, _ in enumerate(eval_dataloader): + eval_logs = {"step": step} + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + # inner fetches + if fetch_loss: + eval_logs["eval_loss"] = outs[0] + # Metric + if fetch_metrics: + metric_out = outs[len(fetch_loss):len(inner_fetch)] + for metric in self._metrics: + metric.update(*metric_out) + results = metric.accumulate() + for i, res in enumerate(to_list(results)): + eval_logs["eval_" + metric.name()[i]] = res + # usr fetches + usr_out = outs[len(inner_fetch):] + usr_fetch_list = fetch_list[len(inner_fetch):] + for i, out in enumerate(usr_out): + eval_logs["eval_" + fetch_map[usr_fetch_list[i]]] = out + # logger self._logger.info(eval_logs) - return eval_logs def predict(self, test_data, batch_size=1, - fetch_list=None, + fetches=None, use_program_cache=False, return_numpy=True): self.mode = 'predict' assert self.mode in self._dist_main_progs, \ "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size) - self._usr_fetch_list = fetch_list + + usr_fetch = self._to_map_fetch(fetches) + fetch_outputs = self._inner_fetch(self.fetch_vars["outputs"]) + fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) outputs = [] - for step, data in enumerate(test_dataloader): - logs, outs = self._predict_step(data, use_program_cache, - return_numpy) - outputs.append(outs) - predict_logs = {"pred_" + name: val for name, val in logs.items()} + for step, _ in enumerate(test_dataloader): + predict_logs = {"step": step} + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + outputs.append(outs[:len(fetch_outputs)]) + for i, out in enumerate(outs): + predict_logs["pred_" + fetch_map[fetch_list[i]]] = out[0] self._logger.info(predict_logs) + return outputs - def _train_step(self, data, use_program_cache=False, return_numpy=True): - logs = {} - fetch_vars = self._fetch_vars[self.mode]["loss"] - fetch_list, usr_fetch_list = self._fetch_list(fetch_vars) - fetch_list += usr_fetch_list - - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - for i, out in enumerate(outs): - logs[fetch_list[i]] = out - return logs, outs - - def _eval_step(self, data, use_program_cache=False, return_numpy=True): - logs = {} - metrics = self._fetch_vars[self.mode]["metrics"] - losses = self._fetch_vars[self.mode]["loss"] - fetch_loss, usr_fetch_list = self._fetch_list(losses) - fetch_metrics, usr_fetch_list = self._fetch_list(metrics) - fetch_list = fetch_loss + fetch_metrics - - outs = self._executor.run(self.main_program, - fetch_list=fetch_list + usr_fetch_list, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - usr_out = outs[len(fetch_list):] - for i, out in enumerate(usr_out): - logs[usr_fetch_list[i]] = out - outs = outs[:len(fetch_list)] - if not outs[len(fetch_loss):]: - return logs, outs[:len(fetch_loss)] - for metric in self._metrics: - metric.update(*outs[len(fetch_loss):]) - return logs, outs[:len(fetch_loss)] - - def _predict_step(self, data, use_program_cache=False, return_numpy=True): - logs = {} - fetch_vars = self._fetch_vars[self.mode]["outputs"] - fetch_list, usr_fetch_list = self._fetch_list(fetch_vars) - fetch_list += usr_fetch_list - - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - for i, out in enumerate(outs): - logs[fetch_list[i]] = out - return logs, outs - - def _fetch_list(self, fetch_vars): - fetch_list = [] - for var in fetch_vars: - if var.name in self.main_program.global_block().vars: - fetch_list.append(var.name) - usr_fetch_list = [] - if self._usr_fetch_list: - assert isinstance(self._usr_fetch_list, - list), "'fetch_list' type should be list." - for var in self._usr_fetch_list: - if isinstance(var, str): - if var in self.main_program.global_block().vars: - usr_fetch_list.append(var) - elif isinstance(var, Variable): - if var.name in self.main_program.global_block().vars: - usr_fetch_list.append(var.name) - return fetch_list, usr_fetch_list + def _local_var(self, var): + var_name = _to_name_str(var) + return var_name in self.main_program.global_block().vars + + def _to_map_fetch(self, fetches): + if not fetches: + return {} + if isinstance(fetches, dict): + fetch_var_names = list(map(_to_name_str, fetches.values())) + usr_fetches = dict(zip(fetch_var_names, list(fetches.keys()))) + elif isinstance(fetches, list): + fetch_var_names = list(map(_to_name_str, fetches)) + usr_fetches = dict(zip(fetch_var_names, fetch_var_names)) + return dict(filter(lambda x: self._local_var(x[0]), + usr_fetches.items())) + + def _inner_fetch(self, fetch_vars): + fetch_list = list( + map(lambda x: x.name, list(filter(self._local_var, fetch_vars)))) + inner_fetches = dict(zip(fetch_list, fetch_list)) + return inner_fetches + + def _fetch_map(self, inner_fetch, usr_fetch): + # replace inner fetch name if usr set for it + for iname in inner_fetch: + if iname in usr_fetch: + inner_fetch[iname] = usr_fetch[iname] + usr_fetch.pop(iname) + fetches = dict(inner_fetch, **usr_fetch) + return list(fetches.keys()), fetches def _create_dataloader(self, dataset, @@ -515,7 +514,8 @@ class Engine: mode = self.mode if training: - assert 'train' in self._serial_main_progs, "training model is not ready, please call `engine.prepare(mode='train')` first." + assert 'train' in self._serial_main_progs, \ + "training model is not ready, please call `engine.prepare()` first." serial_program = self._serial_main_progs["train"] dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_context = self._dist_contexts["train"] @@ -571,3 +571,7 @@ class Engine: @property def serial_startup_program(self): return self._serial_startup_progs[self.mode] + + @property + def fetch_vars(self): + return self._fetch_vars[self.mode] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index f7a1a28aa91..ae69ee08768 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -96,10 +96,11 @@ class MLPLayer(nn.Layer): PP_MESH_1})(out)[0] out = self.dropout(out) out = self.linear2(out) + self.out = out return out -def train(): +def train(fetch): mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, dropout_ratio=0.1, @@ -118,7 +119,6 @@ def train(): dist_strategy.amp = False dist_strategy.pipeline = False dist_strategy.recompute = False - # init parallel optimizer dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) @@ -129,20 +129,26 @@ def train(): strategy=dist_strategy) engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + # fetch + if fetch: + fetches = {'out': mlp.out} + else: + fetches = None + # train train_dataset = MyDataset(batch_num * batch_size) engine.fit(train_dataset, batch_size=batch_size, steps_per_epoch=batch_num * batch_size, - fetch_list=['label']) + fetches=fetches) # eval eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size, fetch_list=['label']) + engine.evaluate(eval_dataset, batch_size, fetches=fetches) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size, fetch_list=['label']) + engine.predict(test_dataset, batch_size, fetches=fetches) # save temp_dir = tempfile.TemporaryDirectory() @@ -152,4 +158,5 @@ def train(): if __name__ == "__main__": - train() + train(fetch=True) + train(fetch=False) -- GitLab