From 83a4b26a133ac14966c06b173af02ef0375694fd Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Tue, 10 May 2022 17:59:26 +0800 Subject: [PATCH] [Auto Parallel] Refactor the engine api and parallelizer (#42576) * [Auto Parallel] Refactor the engine api and parallelizer * [Auto Parallel] Fix the default dist op for the slice op * [Auto Parallel] Fix the format of planer.py * [Auto Parallel] Fix a bug --- .../auto_parallel/dist_attribute.py | 4 +- .../distributed/auto_parallel/dist_context.py | 44 ++-- .../distributed/auto_parallel/engine.py | 189 +++--------------- .../auto_parallel/operators/dist_default.py | 8 +- .../auto_parallel/operators/dist_eltwise.py | 93 +++++++-- .../auto_parallel/parallelizer_v2.py | 172 ++++++++++++++++ .../distributed/auto_parallel/planner.py | 1 - .../distributed/auto_parallel/planner_v2.py | 42 ++++ .../auto_parallel/high_order_grad.py | 5 +- .../auto_parallel/test_dist_slice.py | 1 - 10 files changed, 338 insertions(+), 221 deletions(-) create mode 100644 python/paddle/distributed/auto_parallel/parallelizer_v2.py create mode 100755 python/paddle/distributed/auto_parallel/planner_v2.py diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 857f141f30b..6fa5b756c75 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -485,10 +485,10 @@ class OperatorDistributedAttribute: self.process_mesh) for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items(): - str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr) + str += "\n\t\t{}'s (input): {},".format(arg_name, tensor_dist_attr) for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items(): - str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr) + str += "\n\t\t{}'s (output): {},".format(arg_name, tensor_dist_attr) str += "\n\t\timpl type: {}, ".format(self._impl_type) str += "impl idx: {}".format(self._impl_idx) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 5082ac987f4..f9d77a0077c 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -55,10 +55,10 @@ class DistributedContext: def __init__(self, serial_main_prog=None, serial_startup_prog=None, - dist_main_progs=None, - dist_startup_progs=None, - serial_loss=None, serial_optimizer=None, + serial_loss=None, + feed_vars=None, + fetch_vars=None, strategy=None): # Data members related to original programs (unchanged) self._original_serial_main_program = serial_main_prog @@ -75,8 +75,10 @@ class DistributedContext: # Data members related to programs (changed) self._serial_main_program = None self._serial_startup_program = None - self._serial_loss = None - self._serial_optimizer = None + self._serial_loss = serial_loss + self._serial_optimizer = serial_optimizer + self._serial_feed_vars = feed_vars + self._serial_fetch_vars = fetch_vars # Data members related to the program self._dist_tensors_for_program = {} @@ -92,12 +94,8 @@ class DistributedContext: # Data members related to the distributed programs # Distributed programs - self._dist_main_programs = dist_main_progs - if not self._dist_main_programs: - self._dist_main_programs = {} - self._dist_startup_programs = dist_startup_progs - if not self._dist_startup_programs: - self._dist_startup_programs = {} + self._dist_main_programs = {} + self._dist_startup_programs = {} # Distributed Strategy self._strategy = strategy @@ -132,34 +130,26 @@ class DistributedContext: def serial_startup_program(self): return self._serial_startup_program - # @serial_startup_program.setter - # def serial_startup_program(self, serial_startup_program): - # self._serial_startup_program = serial_startup_program - @property def serial_loss(self): return self._serial_loss - # @serial_loss.setter - # def serial_loss(self, serial_loss): - # self._serial_loss = serial_loss - @property def serial_optimizer(self): return self._serial_optimizer - # @serial_optimizer.setter - # def serial_optimizer(self, serial_optimizer): - # self._serial_optimizer = serial_optimizer + @property + def serial_feed_vars(self): + return self._serial_feed_vars + + @property + def serial_fetch_vars(self): + return self._serial_fetch_vars @property def strategy(self): return self._strategy - # @strategy.setter - # def strategy(self, strategy): - # self._strategy = strategy - @property def serial_graph(self): return self._serial_graph @@ -678,7 +668,7 @@ class DistributedContext: dist_op.serial_op.type) if (dist_op is not None) and (not dist_op.validate_dist_attr()): assert False, "Operator {} has a wrong distributed attributes {}.".format( - dist_op.serial_op.type, dist_tensor.dist_attr) + dist_op.serial_op.type, dist_op.dist_attr) return True def __deepcopy__(self, memo): diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index ea6aeb513ff..b9ee6d93fd2 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -34,12 +34,9 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed.utils import get_logger from paddle.distributed.passes import new_pass, PassContext -from .mapper import mapping from .cluster import Cluster -from .reshard import Resharder -from .planner import Planner -from .completion import Completer -from .partitioner import Partitioner +from .planner_v2 import Planner +from .parallelizer_v2 import Parallelizer from .dist_op import DistributedOperator from .dist_saver import DistributedSaver from .dist_loader import NonIterableGeneratorLoader @@ -79,7 +76,6 @@ class Engine: self._dist_main_progs = defaultdict(dict) # dist main programs self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_contexts = {} - self._pass_contexts = {} self._feed_vars = {} self._fetch_vars = {} @@ -94,10 +90,27 @@ class Engine: self._loss = loss self._metrics = to_list(metrics) self._mode = mode - self._build(mode) # build forward program - self._plan(mode) # completion & planner - self._parallel(mode, all_ranks) # parallel - self._initialize(mode) # init comm and startup program + # Build forward program + self._build(mode) + # Do the planning process + planner = Planner(mode, self._dist_contexts[mode]) + planner.plan() + # Parallelize program based on the planner's results + # For now, the completer has to be passed to the planner, + # because we may use it to complete the annotation of the backwarkward and update. + parallelizer = Parallelizer(mode, planner.completer, + self._dist_contexts[mode]) + if not all_ranks: + parallelizer.parallel(self._cur_rank) + else: + parallelizer.parallel_all() + # Get the distributed main programs and startup programs + self._dist_main_progs[mode] = self._dist_contexts[ + mode].dist_main_programs + self._dist_startup_progs[mode] = self._dist_contexts[ + mode].dist_startup_programs + # Init comm and startup program + self._initialize(mode) def _build(self, mode): serial_main_prog = self._serial_main_progs.get(mode, None) @@ -133,34 +146,9 @@ class Engine: self._serial_main_progs[mode] = serial_main_prog self._serial_startup_progs[mode] = serial_startup_prog self._dist_contexts[mode] = DistributedContext( - serial_main_prog, serial_startup_prog, self._dist_main_progs[mode], - self._dist_startup_progs[mode]) - self._pass_contexts[mode] = PassContext() - - def _plan(self, mode): - - # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need - # dependency of backward-forward ops in forward completition. - defualt_ctx = get_default_distributed_context() - self._dist_contexts[mode]._dist_op_context = defualt_ctx.dist_op_context - - # Complete the distributed annotation - serial_main_prog = self._serial_main_progs[mode] - self._completer = Completer(self._dist_contexts[mode]) - self._completer.complete_forward_annotation(serial_main_prog) - # TODO: add auto planner process - # parse forward sub block - self._dist_contexts[mode].block_state.parse_forward_blocks( - serial_main_prog) - - def _parallel(self, mode, all_ranks=False): - if not all_ranks: - self._parallel_program(mode, self._cur_rank) - else: - world_process_group = get_world_process_group() - all_ranks = world_process_group.ranks - for rank in all_ranks: - self._parallel_program(mode, rank) + self._serial_main_progs[mode], self._serial_startup_progs[mode], + self._optimizer, losses, self._feed_vars[mode], + self._fetch_vars[mode], self.strategy) def _initialize(self, mode): if self._nranks > 1: @@ -189,131 +177,6 @@ class Engine: prune_startup_prog = dist_startup_prog._prune(uninitialized) self._executor.run(prune_startup_prog) - def _parallel_program(self, mode, rank): - serial_main_program = self._serial_main_progs[mode] - serial_startup_program = self._serial_startup_progs[mode] - dist_context = self._dist_contexts[mode] - if mode == "train" and self._optimizer: - # Generate backward - serial_loss = self._fetch_vars[mode]["loss"][0] - params_grads = self._generate_backward( - serial_main_program, serial_startup_program, serial_loss) - # Apply pre optimization passes - self._apply_pre_optimization(serial_main_program, - serial_startup_program, serial_loss, - params_grads) - # Do logical partition - partitioner = Partitioner(dist_context, rank) - dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( - serial_main_program, serial_startup_program, params_grads) - # Generate optimizer - self._generate_optimizer(dist_main_prog, dist_startup_prog, - dist_params_grads) - # Do reshard process - set_grad_var_shape(dist_main_prog, dist_context) - make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank, - dist_context, dist_params_grads) - resharder.reshard() - # Apply post optimization passes - self._apply_post_optimization(dist_main_prog, dist_startup_prog, - rank, dist_params_grads) - else: - # Apply pre optimization passes - self._apply_pre_optimization(serial_main_program, - serial_startup_program, None, None) - # Do logical partition - partitioner = Partitioner(dist_context, rank) - dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( - serial_main_program, serial_startup_program, []) - # Do reshard process - make_data_unshard(dist_main_prog, dist_startup_prog, dist_context) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank, - dist_context, [], 1) - resharder.reshard() - - # clone program for test - if mode != 'train': - dist_main_prog = dist_main_prog.clone(for_test=True) - dist_startup_prog = dist_startup_prog.clone(for_test=True) - - self._dist_main_progs[mode][rank] = dist_main_prog - self._dist_startup_progs[mode][rank] = dist_startup_prog - - def _generate_backward(self, main_program, startup_program, loss): - with program_guard(main_program, startup_program): - params_grads = append_backward( - loss, - distop_context=self._dist_contexts[self.mode].dist_op_context) - self._completer.complete_backward_annotation(main_program) - self._dist_contexts[self.mode].block_state.parse_backward_blocks( - main_program) - return params_grads - - def _generate_optimizer(self, main_program, startup_program, params_grads): - with program_guard(main_program, startup_program): - optimizer_ops = copy.deepcopy(self._optimizer).apply_gradients( - params_grads) - self._completer.complete_update_annotation(main_program) - return optimizer_ops - - def _apply_pre_optimization(self, main_program, startup_program, loss, - params_grads): - - # apply amp pass - if self.strategy.amp: - config = copy.deepcopy(self.strategy.amp_configs) - config["dist_context"] = self._dist_contexts[self.mode] - config["params_grads"] = params_grads - config["loss"] = loss - config["input_data"] = self._feed_vars[self.mode][ - "inputs"] + self._feed_vars[self.mode]["labels"] - if config["use_pure_fp16"]: - config["base_opt"] = self._optimizer - auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) - auto_parallel_fp16_pass.apply([main_program], - [startup_program], - self._pass_contexts[self.mode]) - else: - auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_contexts[self.mode]) - - # apply recompute pass - if self.strategy.recompute: - config = copy.deepcopy(self.strategy.recompute_configs) - config["dist_context"] = self._dist_contexts[self.mode] - config["no_grad_set"] = None - config["loss"] = loss - auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", - config) - auto_parallel_recompute_pass.apply([main_program], - [startup_program], - self._pass_contexts[self.mode]) - - def _apply_post_optimization(self, main_program, startup_program, rank, - params_grads): - if self.strategy.sharding: - config = copy.deepcopy(self.strategy.sharding_configs) - config["dist_context"] = self._dist_contexts[self.mode] - config["params_grads"] = params_grads - config["global_rank"] = rank - auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", - config) - auto_parallel_sharding_pass.apply([main_program], - [startup_program], - self._pass_contexts[self.mode]) - - if self.strategy.gradient_merge: - config = copy.deepcopy(self.strategy.gradient_merge_configs) - config["dist_context"] = self._dist_contexts[self.mode] - config["params_grads"] = params_grads - auto_parallel_gradient_merge_pass = new_pass( - "auto_parallel_gradient_merge_pass", config) - auto_parallel_gradient_merge_pass.apply( - [main_program], [startup_program], - self._pass_contexts[self.mode]) - def fit(self, train_data, batch_size=1, diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 0696b728d16..563d247af3b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -201,10 +201,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): changed = False op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr - # The following statement will be replaced by a more elegent way - if op_desc.type() == "shape" \ - or op_desc.type() == "slice" \ - or op_desc.type() == "while": + + if op_desc.type() == "while": return False input_names = op_desc.input_names() @@ -273,6 +271,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): )[0]) if input_tensor.is_parameter: continue + if op_desc.type() in ["shape", "slice"]: + continue serial_tensor = dist_op.get_serial_output(arg_name) if serial_tensor.is_parameter: continue diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py index aac7f16b690..78589afc498 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -80,12 +80,20 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr dims_mapping_list = [] output_arg_names = op_desc.output_arg_names() + max_dims_mapping_len = -1 for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) dims_mapping_list.append(dims_mapping) - if compute_compatible_dims_mapping(dims_mapping_list) is None: - return False + for idx in range(max_dims_mapping_len): + dim_mappings = [] + for dims_mapping in dims_mapping_list: + if idx < len(dims_mapping): + dim_mappings.append(dims_mapping[-(idx + 1)]) + if compute_compatible_dim_mapping(dim_mappings) is None: + return False return True def is_auto_compatible(self, dist_op): @@ -94,19 +102,26 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): return False op_dist_attr = dist_op.dist_attr dims_mapping_list = [] + input_arg_names = op_desc.input_arg_names() - max_dims_mapping_len = -1 + input_max_dims_mapping_len = -1 for arg_name in input_arg_names: dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if max_dims_mapping_len < len(dims_mapping): - max_dims_mapping_len = len(dims_mapping) + if input_max_dims_mapping_len < len(dims_mapping): + input_max_dims_mapping_len = len(dims_mapping) dims_mapping_list.append(dims_mapping) + output_arg_names = op_desc.output_arg_names() + output_max_dims_mapping_len = -1 for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - assert len(dims_mapping) == max_dims_mapping_len + if output_max_dims_mapping_len < len(dims_mapping): + output_max_dims_mapping_len = len(dims_mapping) dims_mapping_list.append(dims_mapping) + assert input_max_dims_mapping_len == output_max_dims_mapping_len + max_dims_mapping_len = input_max_dims_mapping_len + for idx in range(max_dims_mapping_len): dim_mappings = [] for dims_mapping in dims_mapping_list: @@ -121,35 +136,58 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): changed = False op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr + dims_mapping_list = [] + input_arg_names = op_desc.input_arg_names() input_dims_mapping_dict = {} input_dims_mapping_lens = {} - max_dims_mapping_len = -1 + input_max_dims_mapping_len = -1 for arg_name in input_arg_names: dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if max_dims_mapping_len < len(dims_mapping): - max_dims_mapping_len = len(dims_mapping) + if input_max_dims_mapping_len < len(dims_mapping): + input_max_dims_mapping_len = len(dims_mapping) input_dims_mapping_dict[arg_name] = dims_mapping input_dims_mapping_lens[arg_name] = len(dims_mapping) - - dims_mapping_list = [] for arg_name in input_arg_names: - if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: - new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + if input_dims_mapping_lens[arg_name] < input_max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_max_dims_mapping_len) + ] for i in range(input_dims_mapping_lens[arg_name]): - new_idx = (max_dims_mapping_len - + new_idx = (input_max_dims_mapping_len - input_dims_mapping_lens[arg_name]) + i new_dims_mapping[new_idx] = input_dims_mapping_dict[ arg_name][i] dims_mapping_list.append(new_dims_mapping) else: dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + output_dims_mapping_dict = {} + output_dims_mapping_lens = {} + output_max_dims_mapping_len = -1 for arg_name in output_arg_names: dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - assert len(dims_mapping) == max_dims_mapping_len - dims_mapping_list.append(dims_mapping) + if output_max_dims_mapping_len < len(dims_mapping): + output_max_dims_mapping_len = len(dims_mapping) + output_dims_mapping_dict[arg_name] = dims_mapping + output_dims_mapping_lens[arg_name] = len(dims_mapping) + for arg_name in output_arg_names: + if output_dims_mapping_lens[arg_name] < output_max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(output_max_dims_mapping_len) + ] + for i in range(output_dims_mapping_lens[arg_name]): + new_idx = (output_max_dims_mapping_len - + output_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = output_dims_mapping_dict[ + arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(output_dims_mapping_dict[arg_name]) + assert input_max_dims_mapping_len == output_max_dims_mapping_len + max_dims_mapping_len = input_max_dims_mapping_len compatible_dims_mapping = compute_compatible_dims_mapping( dims_mapping_list) if compatible_dims_mapping is None: @@ -175,11 +213,24 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): changed = True for arg_name in output_arg_names: - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if compatible_dims_mapping != dims_mapping: - op_dist_attr.set_output_dims_mapping(arg_name, - compatible_dims_mapping) - changed = True + if output_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(output_dims_mapping_lens[arg_name]) + ] + for i in range(output_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + output_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != output_dims_mapping_dict[arg_name]: + op_dist_attr.set_output_dims_mapping(arg_name, + new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != output_dims_mapping_dict[ + arg_name]: + op_dist_attr.set_output_dims_mapping( + arg_name, compatible_dims_mapping) + changed = True return changed diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py new file mode 100644 index 00000000000..401b423638c --- /dev/null +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -0,0 +1,172 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from collections import defaultdict + +from paddle.fluid import program_guard +from paddle.fluid.backward import append_backward +from paddle.distributed.passes import new_pass + +from .reshard import Resharder +from .partitioner import Partitioner +from .dist_op import DistributedOperator +from .dist_saver import DistributedSaver +from .dist_loader import NonIterableGeneratorLoader +from .utils import make_data_unshard, set_grad_var_shape +from .utils import print_program_with_dist_attr, to_list +from .process_group import get_all_process_groups, get_world_process_group +from .dist_context import DistributedContext, get_default_distributed_context + + +class Parallelizer: + def __init__(self, mode, completer, dist_context): + self._mode = mode + self._completer = completer + self._dist_context = dist_context + self._dist_context.initialize() + self._pass_context = self._dist_context.pass_context + self._strategy = self._dist_context.strategy + + def parallel_all(self): + world_process_group = get_world_process_group() + all_ranks = world_process_group.ranks + for rank in all_ranks: + self.parallel(rank) + + def parallel(self, rank): + serial_main_program = self._dist_context.serial_main_program + serial_startup_program = self._dist_context.serial_startup_program + serial_optimizer = self._dist_context.serial_optimizer + if self._mode == "train" and serial_optimizer: + # Generate backward + serial_loss = self._dist_context.serial_fetch_vars["loss"][0] + params_grads = self._generate_backward( + serial_main_program, serial_startup_program, serial_loss) + # Apply pre optimization passes + self._apply_pre_optimization(serial_main_program, + serial_startup_program, serial_loss, + serial_optimizer, params_grads) + # Do logical partition + partitioner = Partitioner(self._dist_context, rank) + dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( + serial_main_program, serial_startup_program, params_grads) + # Generate optimizer + self._generate_optimizer(dist_main_prog, dist_startup_prog, + serial_optimizer, dist_params_grads) + # Do reshard process + set_grad_var_shape(dist_main_prog, self._dist_context) + make_data_unshard(dist_main_prog, dist_startup_prog, + self._dist_context) + resharder = Resharder(dist_main_prog, dist_startup_prog, rank, + self._dist_context, dist_params_grads) + resharder.reshard() + # Apply post optimization passes + self._apply_post_optimization(dist_main_prog, dist_startup_prog, + rank, dist_params_grads) + else: + # Apply pre optimization passes + self._apply_pre_optimization( + serial_main_program, serial_startup_program, None, None, None) + # Do logical partition + partitioner = Partitioner(self._dist_context, rank) + dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( + serial_main_program, serial_startup_program, []) + # Do reshard process + make_data_unshard(dist_main_prog, dist_startup_prog, + self._dist_context) + resharder = Resharder(dist_main_prog, dist_startup_prog, rank, + self._dist_context, [], 1) + resharder.reshard() + + # Clone program for test + if self._mode != 'train': + dist_main_prog = dist_main_prog.clone(for_test=True) + dist_startup_prog = dist_startup_prog.clone(for_test=True) + + # Store the distributed programs for further usages + self._dist_context.dist_main_programs[rank] = dist_main_prog + self._dist_context.dist_startup_programs[rank] = dist_startup_prog + + def _generate_backward(self, main_program, startup_program, loss): + with program_guard(main_program, startup_program): + params_grads = append_backward( + loss, distop_context=self._dist_context.dist_op_context) + self._completer.complete_backward_annotation(main_program) + self._dist_context.block_state.parse_backward_blocks(main_program) + return params_grads + + def _generate_optimizer(self, main_program, startup_program, optimizer, + params_grads): + with program_guard(main_program, startup_program): + optimizer_ops = copy.deepcopy(optimizer).apply_gradients( + params_grads) + self._completer.complete_update_annotation(main_program) + return optimizer_ops + + def _apply_pre_optimization(self, main_program, startup_program, loss, + optimizer, params_grads): + if self._strategy is None: + return + # apply amp pass + if self._strategy.amp: + config = copy.deepcopy(self._strategy.amp_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["loss"] = loss + config["input_data"] = self._dist_context.serial_feed_vars["inputs"] \ + + self._dist_context.serial_feed_vars["labels"] + if config["use_pure_fp16"]: + config["base_opt"] = optimizer + auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) + auto_parallel_fp16_pass.apply( + [main_program], [startup_program], self._pass_context) + else: + auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) + auto_parallel_amp_pass.apply([main_program], [startup_program], + self._pass_context) + + # apply recompute pass + if self._strategy.recompute: + config = copy.deepcopy(self._strategy.recompute_configs) + config["dist_context"] = self._dist_context + config["no_grad_set"] = None + config["loss"] = loss + auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", + config) + auto_parallel_recompute_pass.apply( + [main_program], [startup_program], self._dist_context) + + def _apply_post_optimization(self, main_program, startup_program, rank, + params_grads): + if self._strategy is None: + return + if self._strategy.sharding: + config = copy.deepcopy(self._strategy.sharding_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + config["global_rank"] = rank + auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", + config) + auto_parallel_sharding_pass.apply( + [main_program], [startup_program], self._dist_context) + + if self._strategy.gradient_merge: + config = copy.deepcopy(self._strategy.gradient_merge_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + auto_parallel_gradient_merge_pass = new_pass( + "auto_parallel_gradient_merge_pass", config) + auto_parallel_gradient_merge_pass.apply( + [main_program], [startup_program], self._dist_context) diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 73df0da1033..b97c09bd59d 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -35,7 +35,6 @@ from .utils import get_all_distributed_main_program from .dist_context import DistributedContext, DistributedOperatorContext from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute -paddle.enable_static() paddle.seed(123) random.seed(123) np.random.seed(123) diff --git a/python/paddle/distributed/auto_parallel/planner_v2.py b/python/paddle/distributed/auto_parallel/planner_v2.py new file mode 100755 index 00000000000..7db17e98d07 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/planner_v2.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .completion import Completer +from .dist_context import get_default_distributed_context +from .utils import print_program_with_dist_attr + + +class Planner: + def __init__(self, mode, dist_context): + self._mode = mode + self._dist_context = dist_context + + # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need + # dependency of backward-forward ops in forward completion. + default_ctx = get_default_distributed_context() + self._dist_context._dist_op_context = default_ctx.dist_op_context + self._dist_context.initialize() + + self._completer = Completer(self._dist_context) + + @property + def completer(self): + return self._completer + + def plan(self): + self._completer.complete_forward_annotation() + # parse forward sub block + self._dist_context.block_state.parse_forward_blocks( + self._dist_context.serial_main_program) + # TODO: add the auto searcher diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py index 9a9efe7ab2d..3f828386676 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py @@ -23,6 +23,9 @@ from paddle.distributed import fleet from paddle.incubate.autograd import Hessian from paddle.distributed.auto_parallel.engine import Engine +np.random.seed(1234) +paddle.seed(1234) + class FCNet: def __init__(self, num_ins, num_outs, num_layers, hidden_size): @@ -136,10 +139,8 @@ def main(): inputs_spec=inputs_spec, labels_spec=labels_spec, strategy=dist_strategy) - paddle.seed(1234 + engine._cur_rank) engine.prepare(optimizer=optimizer, loss=loss_func) res = engine.fit(train_dataset, sample_generator=False) - assert np.allclose(res[-1], 2.840593) dist_context = engine.dist_context block = engine.main_program.global_block() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index 0914126feb8..aa0bf719fab 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -79,7 +79,6 @@ def parallelizer(program_func, rank): class TestDistSlice(unittest.TestCase): def test_dist_slice_dp2(self): - for rank in range(2): dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) ops = dist_main_prog.global_block().ops -- GitLab