From 8c7cb3d6a9cd95b7f552e48202b0d778c15cb4f7 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Thu, 2 Jun 2022 10:07:27 +0800 Subject: [PATCH] [AutoParallel] engine.prepare only once (#43093) * prepare only once --- .../distributed/auto_parallel/dist_context.py | 8 +- .../distributed/auto_parallel/engine.py | 281 +++++++++++------- .../unittests/auto_parallel/engine_api.py | 15 +- 3 files changed, 184 insertions(+), 120 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 6a38b53cf2..df4c92641f 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -343,8 +343,12 @@ class DistributedContext: self._serial_startup_program = self._original_serial_startup_program if not self._serial_loss: if isinstance(self._original_serial_loss, list): - assert len(self._original_serial_loss) == 1 - self._serial_loss = self._original_serial_loss[0] + if len(self._original_serial_loss) == 1: + self._serial_loss = self._original_serial_loss[0] + elif len(self._original_serial_loss) == 0: + self._serial_loss = self._original_serial_loss + else: + raise ValueError("multi loss vars are not supported.") else: self._serial_loss = self._original_serial_loss if not self._serial_optimizer: diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index ab9391cf66..ea003dfff4 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -19,7 +19,7 @@ from collections import defaultdict import paddle import paddle.distributed.auto_parallel as auto -from paddle import fluid +from paddle import fluid, static from paddle.io import Dataset from paddle.metric import Metric from paddle.static import InputSpec @@ -71,8 +71,8 @@ class Engine: self._logger = get_logger(logging.INFO) self._default_strategy = None - self._orig_main_prog = fluid.default_main_program() - self._orig_startup_prog = fluid.default_startup_program() + self._orig_main_prog = static.default_main_program() + self._orig_startup_prog = static.default_startup_program() self._orig_dist_context = get_default_distributed_context() self._dist_contexts = {} self._serial_main_progs = {} @@ -87,28 +87,131 @@ class Engine: loss=None, gradient_scale=True, metrics=None, - mode='train', all_ranks=False): + if optimizer and not isinstance(optimizer, ( + paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ + " or `paddle.fluid.optimizer.Optimizer`." + ) self._optimizer = optimizer - # TODO: check loss type + + if loss and not isinstance(loss, + paddle.nn.Layer) and not callable(loss): + raise TypeError( + "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." + ) self._loss = loss + + metrics = metrics or [] + for metric in to_list(metrics): + assert isinstance(metric, Metric), \ + "{} is not sub class of Metric".format( + metric.__class__.__name__) self._metrics = to_list(metrics) - self._mode = mode self._gradient_scale = gradient_scale + + self._planned_mode = None + self._modes = ['train', 'eval', 'predict'] # Build forward program - self._build(mode) - # Do the planning process - planner = Planner(mode, self._dist_contexts[mode]) - planner.plan() + self._build() + + # Do auto parallel process + for mode in self._modes: + # Do the planning process + self._plan(mode) + # Do the parallel process + self._parallel(mode, all_ranks) + # Init comm and startup program + self._initialize(mode) + + def _build(self): + for mode in self._modes: + serial_main_prog = self._serial_main_progs.get(mode, None) + if serial_main_prog is not None: + return + + losses = [] + metrics = [] + serial_main_prog = self._orig_main_prog.clone() + serial_startup_prog = self._orig_startup_prog.clone() + with static.program_guard(serial_main_prog, serial_startup_prog): + inputs_spec = self.inputs_spec + labels_spec = self.labels_spec if self.labels_spec else [] + inputs = [s._create_feed_layer() for s in inputs_spec] + labels = [s._create_feed_layer() for s in labels_spec] + outputs = to_list(self.model(*inputs)) + if mode != "predict" and self._loss: + losses = to_list(self._loss(*(outputs + labels))) + + if mode != "predict": + for metric in self._metrics: + metrics.extend( + to_list(metric.compute(*(outputs + labels)))) + + default_ctx = get_default_distributed_context() + if not default_ctx.has_annotation or self._default_strategy: + inputs = [self._set_data_parallel(var) for var in inputs] + labels = [self._set_data_parallel(var) for var in labels] + + # self._feed_vars[mode] = {"inputs": inputs, "labels": labels} + feed_vars = {"inputs": inputs, "labels": labels} + + # self._fetch_vars[mode] = { + # "outputs": flatten(outputs), + # "loss": losses, + # "metrics": metrics + # } + fetch_vars = { + "outputs": flatten(outputs), + "loss": losses, + "metrics": metrics + } + + self._dist_contexts[mode] = DistributedContext( + serial_main_prog, serial_startup_prog, self._optimizer, losses, + feed_vars, fetch_vars, self.cluster, self.strategy) + self._dist_contexts[mode].gradient_scale = self._gradient_scale + + def _plan(self, mode): + if self._planned_mode is None: + self._planned_mode = mode + else: + self._init_dist_context(mode) + + self.planner = Planner(mode, self._dist_contexts[mode]) + self.planner.plan() + + def _parallel(self, mode, all_ranks): # Parallelize program based on the planner's results # For now, the completer has to be passed to the planner, # because we may use it to complete the annotation of the backwarkward and update. - parallelizer = Parallelizer(mode, planner.completer, + parallelizer = Parallelizer(mode, self.planner.completer, self._dist_contexts[mode]) if not all_ranks: parallelizer.parallel(self._cur_rank) else: parallelizer.parallel_all() + + def _init_dist_context(self, mode): + # Init dist_context['mode'] with the first planned dist_context + # to guarantee that train/eval/predict mode have same parallel strategy + dist_context = self._dist_contexts[mode] + origin_main_prog = dist_context._original_serial_main_program + ref_mode = self._planned_mode + ref_dist_context = self._dist_contexts[ref_mode] + ref_origin_main_prog = ref_dist_context._original_serial_main_program + ref_blocks = ref_origin_main_prog.blocks + for ib, block in enumerate(origin_main_prog.blocks): + for iop, op in enumerate(block.ops): + ref_op = ref_blocks[ib].ops[iop] + assert op.type == ref_op.type, \ + "'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type) + ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program( + ref_op) + dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) + + def _initialize(self, mode): # Get the current content from the distributed context self._serial_main_progs[mode] = self._dist_contexts[ mode].serial_main_program @@ -120,52 +223,7 @@ class Engine: mode].dist_startup_programs self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars - # Init comm and startup program - self._initialize(mode) - - def _build(self, mode): - serial_main_prog = self._serial_main_progs.get(mode, None) - if serial_main_prog is not None: - return - - losses = [] - metrics = [] - serial_main_prog = self._orig_main_prog.clone() - serial_startup_prog = self._orig_startup_prog.clone() - with fluid.program_guard(serial_main_prog, serial_startup_prog): - inputs_spec = self.inputs_spec - labels_spec = self.labels_spec if self.labels_spec else [] - inputs = [s._create_feed_layer() for s in inputs_spec] - labels = [s._create_feed_layer() for s in labels_spec] - outputs = to_list(self.model(*inputs)) - if mode != "predict" and self._loss: - losses = to_list(self._loss(*(outputs + labels))) - - default_ctx = get_default_distributed_context() - if not default_ctx.has_annotation or self._default_strategy: - inputs = [self._set_data_parallel(var) for var in inputs] - labels = [self._set_data_parallel(var) for var in labels] - - # self._feed_vars[mode] = {"inputs": inputs, "labels": labels} - feed_vars = {"inputs": inputs, "labels": labels} - - # self._fetch_vars[mode] = { - # "outputs": flatten(outputs), - # "loss": losses, - # "metrics": metrics - # } - fetch_vars = { - "outputs": flatten(outputs), - "loss": losses, - "metrics": metrics - } - - self._dist_contexts[mode] = DistributedContext( - serial_main_prog, serial_startup_prog, self._optimizer, losses, - feed_vars, fetch_vars, self.cluster, self.strategy) - self._dist_contexts[mode].gradient_scale = self._gradient_scale - def _initialize(self, mode): if self._nranks > 1: # Traverse different rank programs and traverse each op of them, # instantiate communication by process_mapping. @@ -203,7 +261,7 @@ class Engine: # TODO: evaluate after training self.mode = 'train' assert self.mode in self._dist_main_progs, \ - "train model is not ready, please call `engine.prepare(mode='train')` first." + "train model is not ready, please call `engine.prepare()` first." train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch) @@ -227,16 +285,19 @@ class Engine: return_numpy=True): self.mode = 'eval' assert self.mode in self._dist_main_progs, \ - "eval model is not ready, please call `engine.prepare(mode='eval')` first." + "eval model is not ready, please call `engine.prepare()` first." eval_dataloader = self._create_dataloader(eval_data, batch_size) - outputs = [] for step, data in enumerate(eval_dataloader): - logs, outs = self._eval_step(data, use_program_cache, return_numpy) - outputs.append(outs) - predict_logs = {"eval_" + name: val for name, val in logs.items()} - self._logger.info(predict_logs) - return outputs + eval_logs = dict() + outs = self._eval_step(data, use_program_cache, return_numpy) + eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else [] + for metric in self._metrics: + results = metric.accumulate() + for i, res in enumerate(to_list(results)): + eval_logs["eval_" + metric.name()[i]] = res + self._logger.info(eval_logs) + return eval_logs def predict(self, test_data, @@ -245,7 +306,7 @@ class Engine: return_numpy=True): self.mode = 'predict' assert self.mode in self._dist_main_progs, \ - "predict model is not ready, please call `engine.prepare(mode='predict')` first." + "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size) outputs = [] @@ -262,57 +323,53 @@ class Engine: def _train_step(self, data, use_program_cache=False, return_numpy=True): logs = {} - dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] - fetch_var = self._fetch_vars[self.mode]["loss"][0] - if fetch_var.name not in dist_main_prog.global_block().vars: - loss = self._executor.run(dist_main_prog, - use_program_cache=use_program_cache) - logs["loss"] = None - else: - loss = self._executor.run(dist_main_prog, - fetch_list=to_list(fetch_var), - use_program_cache=use_program_cache, - return_numpy=return_numpy) - logs["loss"] = loss + fetch_vars = self._fetch_vars[self.mode]["loss"] + fetch_list = self._fetch_list(fetch_vars) + + loss = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + logs["loss"] = loss return logs, loss def _eval_step(self, data, use_program_cache=False, return_numpy=True): logs = {} - dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] - fetch_var = self._fetch_vars[self.mode]["loss"][0] - - if fetch_var.name not in dist_main_prog.global_block().vars: - outs = self._executor.run(dist_main_prog, - use_program_cache=use_program_cache) - logs["loss"] = outs - else: - outs = self._executor.run(dist_main_prog, - fetch_list=fetch_var, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - logs["loss"] = outs - return logs, outs + metrics = self._fetch_vars[self.mode]["metrics"] + losses = self._fetch_vars[self.mode]["loss"] + fetch_loss = self._fetch_list(losses) + fetch_metrics = self._fetch_list(metrics) + fetch_list = fetch_loss + fetch_metrics + + res = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + if not res[len(fetch_loss):]: + return res[:len(fetch_loss)] + for metric in self._metrics: + metric.update(*res[len(fetch_loss):]) + return res[:len(fetch_loss)] def _predict_step(self, data, use_program_cache=False, return_numpy=True): logs = {} - dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] - fetch_var = [] - for var in self._fetch_vars[self.mode]["outputs"]: - if var.name in dist_main_prog.global_block().vars: - fetch_var.append(var) - - if fetch_var is []: - outs = self._executor.run(dist_main_prog, - use_program_cache=use_program_cache) - logs["pred"] = outs - else: - outs = self._executor.run(dist_main_prog, - fetch_list=fetch_var, - use_program_cache=use_program_cache, - return_numpy=return_numpy) - logs["pred"] = outs + fetch_vars = self._fetch_vars[self.mode]["outputs"] + fetch_list = self._fetch_list(fetch_vars) + + outs = self._executor.run(self.main_program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + return_numpy=return_numpy) + logs["pred"] = outs return logs, outs + def _fetch_list(self, fetch_vars): + fetch_list = [] + for var in fetch_vars: + if var.name in self.main_program.global_block().vars: + fetch_list.append(var.name) + return fetch_list + def _create_dataloader(self, dataset, batch_size, @@ -323,7 +380,9 @@ class Engine: dist_context = self._dist_contexts[self.mode] dist_main_block = dist_main_prog.global_block() - # get feed_list from dist_program + # NOTE: Get feed_list from dist_program, then insert dataloader op + # with sharded var shape. Because predict_program does not contain + # labels var, so we will filter dataset's value with length of feed_list. inputs_var = self._feed_vars[self.mode]["inputs"] labels_var = self._feed_vars[self.mode]["labels"] feed_list = [] @@ -342,7 +401,7 @@ class Engine: # insert read op at the end of program places = paddle.static.cuda_places() - with fluid.program_guard(dist_main_prog, dist_startup_prog): + with static.program_guard(dist_main_prog, dist_startup_prog): dataloader = NonIterableGeneratorLoader( dataset, feed_list, @@ -468,10 +527,6 @@ class Engine: def mode(self, mode): self._mode = mode - @property - def metrics(self): - return self._metrics - @property def main_program(self): return self._dist_main_progs[self.mode][self._cur_rank] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 66addd1be0..23bab5ffa2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -107,7 +107,6 @@ def train(): epsilon=1e-08, grad_clip=None) - dataset = MyDataset(batch_num * batch_size) inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') labels_spec = InputSpec([batch_size], 'int64', 'label') @@ -119,23 +118,29 @@ def train(): dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) + # init engine engine = Engine( mlp, inputs_spec=inputs_spec, labels_spec=labels_spec, strategy=dist_strategy) - engine.prepare(optimizer, loss) - engine.fit(dataset, + engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + + # train + train_dataset = MyDataset(batch_num * batch_size) + engine.fit(train_dataset, batch_size=batch_size, steps_per_epoch=batch_num * batch_size) + # eval eval_dataset = MyDataset(batch_size) - engine.prepare(optimizer, loss, mode='eval') engine.evaluate(eval_dataset, batch_size) + # predict test_dataset = MyDataset(batch_size) - engine.prepare(mode='predict') engine.predict(test_dataset, batch_size) + + # save engine.save('./mlp_inf', training=False, mode='predict') -- GitLab