From 010aba33ee5655555ce1e9bf92e9596828d446ae Mon Sep 17 00:00:00 2001 From: Yulong Ao Date: Wed, 1 Jun 2022 10:18:26 +0800 Subject: [PATCH] [Auto Parallel] Add miscellaneous improvements (#43108) * [Auto Parallel] Add the parallel tuner * [Auto Parallel] Improve the parallel tuner and fix some bugs * upodate cost model * update import Resharder by dist op * update cost model * fix comp cost bug * update cost model * [Auto Parallel] Amend the dist attr for #processses=1 * update cost model and tuner * update cost model and tuner * update cost model and tuner * update cluster * update reshard * [Auto Parallel] Add the estimation from the cost model * [Auto Parallel] Reimplement the backup and restore functions * [Auto Parallel] Fix the bugs of the parallel tuner * [Auto Parallel] Update the engine api and dist context * [Auto Parallel] Work around the high order grad problem * [Auto Parallel] Add some miscellaneous improvements * [Auto Parallel] Add a unittest for DistributedContext Co-authored-by: caozhou --- .../distributed/auto_parallel/completion.py | 77 +++-- .../auto_parallel/dist_attribute.py | 49 +-- .../distributed/auto_parallel/dist_context.py | 303 +++++++++++++----- .../distributed/auto_parallel/dist_tensor.py | 7 +- .../distributed/auto_parallel/engine.py | 34 +- .../auto_parallel/operators/__init__.py | 2 +- .../auto_parallel/operators/common.py | 4 +- .../auto_parallel/operators/dist_default.py | 4 +- .../auto_parallel/operators/dist_pnorm.py | 3 +- .../auto_parallel/parallelizer_v2.py | 6 +- .../distributed/auto_parallel/planner_v2.py | 13 +- .../paddle/distributed/auto_parallel/utils.py | 11 +- .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/test_dist_context.py | 204 ++++++++++++ .../auto_parallel/test_dist_slice.py | 3 +- .../auto_parallel/test_while_op_completion.py | 2 +- .../auto_parallel/test_while_op_partition.py | 2 +- 17 files changed, 574 insertions(+), 151 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 03996ec350..465c450c0b 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -20,7 +20,7 @@ from paddle.fluid import core from paddle.fluid import framework from .utils import print_program_with_dist_attr -from .operators import find_best_compatible_distributed_operator_impl +from .operators import find_compatible_distributed_operator_impls from .dist_context import get_default_distributed_context, _node_id from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator @@ -238,13 +238,17 @@ class Completer: tensor_desc.name()) compatible_dims_mapping = compute_compatible_dims_mapping( [op_dims_mapping, tensor_dims_mapping]) + if not _validate_dims_mapping( + compatible_dims_mapping, + op_dist_attr.process_mesh): + continue if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != op_dims_mapping): op_dist_attr.set_input_dims_mapping( tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impls = find_best_compatible_distributed_operator_impl( + op_dist_impls = find_compatible_distributed_operator_impls( dist_op, fwd=True) if op_dist_impls is not None: not_compatible = True @@ -254,7 +258,8 @@ class Completer: dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True - if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.is_auto_compatible(dist_op) \ + and dist_op.validate_dist_attr(): if op_dist_impl.type == "elementwise": op_dist_attr.impl_type = "default" else: @@ -289,13 +294,17 @@ class Completer: tensor_desc.name()) compatible_dims_mapping = compute_compatible_dims_mapping( [op_dims_mapping, tensor_dims_mapping]) + if not _validate_dims_mapping( + compatible_dims_mapping, + op_dist_attr.process_mesh): + continue if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != op_dims_mapping): op_dist_attr.set_output_dims_mapping( tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impls = find_best_compatible_distributed_operator_impl( + op_dist_impls = find_compatible_distributed_operator_impls( dist_op, fwd=False) if op_dist_impls is not None: not_compatible = True @@ -305,8 +314,8 @@ class Completer: dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True - if op_dist_impl.is_auto_compatible(dist_op): - not_compatible = False + if op_dist_impl.is_auto_compatible(dist_op) \ + and dist_op.validate_dist_attr(): if op_dist_impl.type == "elementwise": op_dist_attr.impl_type = "default" else: @@ -352,6 +361,23 @@ class Completer: changed = True return changed + def _update_dims_mapping_for_special(self): + # Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it + op_nodes = self._dist_context._serial_ordered_op_nodes + for op_node in op_nodes: + op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) + for tensor_node in op_node.outputs: + if tensor_node.is_var() and tensor_node.var() is not None: + if tensor_node.var().type() == core.VarDesc.VarType.READER: + continue + tensor_desc = tensor_node.var() + tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node) + if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + tensor_dist_attr.dims_mapping = op_dims_mapping + def _update_dims_mapping(self): # Complete dims_mapping for each node reach_fix_point = False @@ -378,6 +404,7 @@ class Completer: reach_fix_point = False else: reach_fix_point = True + self._update_dims_mapping_for_special() def _update_process_mesh_by_nearest(self, op_node, nearest_op_node): op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) @@ -685,7 +712,7 @@ class Completer: # Step 3: adjust the process meshes for special ops self._update_process_mesh_for_specials() - # Step 4: adjust the process meshes between graphs + # Step 4: adjust the process meshes between graphs self._update_process_mesh_between_graphs() def _prepare(self): @@ -727,14 +754,14 @@ class Completer: """ Complete annotation for the partial annotated serial_main_program. Arguments: serial_main_program: partial annotated serial_main_program. - Returns: + Returns:e serial_main_program: completed annotated serial_main_program. """ if serial_main_program is None: serial_main_program = self._dist_context.serial_main_program else: - self._dist_context.serial_main_program = serial_main_program + self._dist_context._serial_main_program = serial_main_program self._dist_context.initialize() @@ -757,13 +784,18 @@ class Completer: return serial_main_program - def _complete_high_order_grad_annotation(self, serial_main_program): + def _complete_high_order_grad_annotation(self, serial_main_program=None): """ NOTE: [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. This function is temporary to support high order gradient, and will be removed in the future. """ + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context._serial_main_program = serial_main_program + def _is_grad_var_name(name): if "@GRAD" in name: return True @@ -917,12 +949,13 @@ class Completer: self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) - def complete_backward_annotation(self, serial_main_program): + def complete_backward_annotation(self, serial_main_program=None): """Complete the annotation of vars and ops in the backward phase for parallel program.""" + if serial_main_program is None: serial_main_program = self._dist_context.serial_main_program else: - self._dist_context.serial_main_program = serial_main_program + self._dist_context._serial_main_program = serial_main_program def _is_grad_var_name(name): if "@GRAD" in name: @@ -1032,6 +1065,9 @@ class Completer: grad_op_dist_attr.process_mesh = ref_mesh self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) + grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type + grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx + continue fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( @@ -1078,6 +1114,8 @@ class Completer: grad_op_dist_attr.set_output_dims_mapping(output_name, ref_dims_mapping) + grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type + grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) @@ -1111,6 +1149,8 @@ class Completer: var_name, ref_fwd_dims_mapping) grad_op_dist_attr.set_output_dims_mapping( output_name, ref_fwd_dims_mapping) + grad_op_dist_attr.impl_type = "default" + grad_op_dist_attr.impl_idx = 0 elif grad_op.type == 'fill_zeros_like': ref_var_name = grad_op.input_arg_names[0] @@ -1142,12 +1182,13 @@ class Completer: self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) - def complete_update_annotation(self, serial_main_program=None): + def complete_update_annotation(self, serial_main_program): """Complete the annotation of vars and ops in the update phase for parallel program.""" - if serial_main_program is None: - serial_main_program = self._dist_context.serial_main_program - else: - self._dist_context.serial_main_program = serial_main_program + + # Notice: serial_main_program is actually a dist_main_program of current rank, + # and must be passed into this function. + # TODO: We should fix this behavior. + ops = list(serial_main_program.global_block().ops) vars = serial_main_program.global_block().vars learning_rate_completed = False @@ -1304,7 +1345,7 @@ class Completer: dist_op.dist_attr.process_mesh = world_ranks # Find the most compatible implemenetations from the distributed operator - op_dist_impls = find_best_compatible_distributed_operator_impl( + op_dist_impls = find_compatible_distributed_operator_impls( dist_op, fwd=True) if op_dist_impls is not None: backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr) diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 6fa5b756c7..3dbdb79f48 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -132,15 +132,17 @@ class TensorDistributedAttribute: key, dist_attr) self._is_annotated = copy.deepcopy(dist_attr._is_annotated) - # def reset(self, skip_dist_attr_field_names): - # if skip_dist_attr_field_names is not None \ - # and "process_mesh" not in skip_dist_attr_field_names: - # self._process_mesh = None - # if skip_dist_attr_field_names is not None \ - # and "dims_mapping" not in skip_dist_attr_field_names: - # for i in enumerate(self._dims_mapping): - # self._dims_mapping[i] = -1 - # self._is_annotated = {} + def reset(self, skip_dist_attr_field_names=None): + if skip_dist_attr_field_names is None or \ + (skip_dist_attr_field_names is not None \ + and "process_mesh" not in skip_dist_attr_field_names): + self._process_mesh = None + if skip_dist_attr_field_names is None or \ + (skip_dist_attr_field_names is not None \ + and "dims_mapping" not in skip_dist_attr_field_names): + for i, _ in enumerate(self._dims_mapping): + self._dims_mapping[i] = -1 + self._is_annotated = {} def is_annotated(self, dist_attr_field_name): return self._is_annotated.get(dist_attr_field_name, False) @@ -272,6 +274,9 @@ class OperatorDistributedAttribute: dist_attr_object.init(dist_attr) self._inputs_dist_attrs[name] = dist_attr_object + # def del_input_dist_attr(self, name): + # del self._inputs_dist_attrs[name] + def get_output_dist_attr(self, name): return self._outputs_dist_attrs.get(name, None) @@ -280,6 +285,9 @@ class OperatorDistributedAttribute: dist_attr_object.init(dist_attr) self._outputs_dist_attrs[name] = dist_attr_object + # def del_output_dist_attr(self, name): + # del self._inputs_dist_attrs[name] + def get_input_dims_mapping(self, name): input_dist_attr = self.get_input_dist_attr(name) if input_dist_attr: @@ -374,17 +382,18 @@ class OperatorDistributedAttribute: "ProcessMeshes in DistributedOperator must be the same." self.process_mesh = shared_process_mesh - # def reset(self, skip_dist_attr_field_names): - # for tensor_dist_attr in self.inputs_dist_attrs.values(): - # tensor_dist_attr.reset(skip_dist_attr_field_names) - # for tensor_dist_attr in self.outputs_dist_attrs.values(): - # tensor_dist_attr.reset(skip_dist_attr_field_names) - # if skip_dist_attr_field_names is not None \ - # and "process_mesh" not in skip_dist_attr_field_names: - # self.process_mesh = None - # self.impl_type = "default" - # self.impl_idx = 0 - # self._is_annotated = {} + def reset(self, skip_dist_attr_field_names=None): + for tensor_dist_attr in self.inputs_dist_attrs.values(): + tensor_dist_attr.reset(skip_dist_attr_field_names) + for tensor_dist_attr in self.outputs_dist_attrs.values(): + tensor_dist_attr.reset(skip_dist_attr_field_names) + if skip_dist_attr_field_names is None or \ + (skip_dist_attr_field_names is not None \ + and "process_mesh" not in skip_dist_attr_field_names): + self._process_mesh = None + self.impl_type = "default" + self.impl_idx = 0 + self._is_annotated = {} def is_annotated(self, attr_name): return self._is_annotated.get(attr_name, False) diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index a47ef66ee8..6a38b53cf2 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -57,33 +57,30 @@ class DistributedContext: serial_startup_prog=None, serial_optimizer=None, serial_loss=None, - feed_vars=None, - fetch_vars=None, + feed_vars={}, + fetch_vars={}, + cluster=None, strategy=None): # Data members related to original programs (unchanged) self._original_serial_main_program = serial_main_prog self._original_serial_startup_program = serial_startup_prog + self._original_serial_optimizer = serial_optimizer self._original_serial_loss = serial_loss + self._original_serial_feed_vars = feed_vars + self._original_serial_fetch_vars = fetch_vars self._original_serial_optimizer = serial_optimizer - if self._original_serial_main_program is None: - self._original_serial_main_program = paddle.fluid.default_main_program( - ) - if self._original_serial_startup_program is None: - self._original_serial_startup_program = paddle.fluid.default_startup_program( - ) # Data members related to programs (changed) self._serial_main_program = None self._serial_startup_program = None - self._serial_loss = serial_loss - self._serial_optimizer = serial_optimizer - self._serial_feed_vars = feed_vars - self._serial_fetch_vars = fetch_vars + self._serial_loss = None + self._serial_optimizer = None + self._serial_feed_vars = {} + self._serial_fetch_vars = {} # Data members related to the program self._dist_tensors_for_program = {} self._dist_ops_for_program = {} - self._block_state = BlockState() # Data members related to the graph self._serial_graph = None @@ -96,24 +93,30 @@ class DistributedContext: # Distributed programs self._dist_main_programs = {} self._dist_startup_programs = {} + self._dist_op_context = DistributedOperatorContext() + self._process_meshes = [] - # Distributed Strategy + self._cluster = cluster self._strategy = strategy # Pass Context self._pass_context = PassContext() - - # Distributed Operator Context - self._dist_op_context = DistributedOperatorContext() + self._block_state = BlockState() # Other data members - self._process_meshes = [] self._serial_ordered_tensor_nodes = [] self._serial_ordered_op_nodes = [] self._serial_ordered_nodes = [] # self._tensor_id_to_tensor_node_ids = {} self._is_initialized = False + self._need_copy_dist_attr_to_graph = False + self._backup_pass_context_stack = [] + self._backup_block_state_stack = [] + self._backup_dist_tensors_for_program_stack = [] + self._backup_dist_ops_for_program_stack = [] + self._backup_serial_main_program_stack = [] + self._backup_serial_startup_program_stack = [] # flag whether scale gradient with dp size self._gradient_scale = True @@ -122,13 +125,6 @@ class DistributedContext: def serial_main_program(self): return self._serial_main_program - @serial_main_program.setter - def serial_main_program(self, program): - # if self._serial_main_program: - # print("WARNING: The program attached to this distributed context will be replaced by the new one.") - self._original_serial_main_program = program - self._serial_main_program = program - @property def serial_startup_program(self): return self._serial_startup_program @@ -149,6 +145,18 @@ class DistributedContext: def serial_fetch_vars(self): return self._serial_fetch_vars + @property + def dist_main_programs(self): + return self._dist_main_programs + + @property + def dist_startup_programs(self): + return self._dist_startup_programs + + @property + def cluster(self): + return self._cluster + @property def strategy(self): return self._strategy @@ -177,14 +185,6 @@ class DistributedContext: def block_state(self): return self._block_state - @property - def dist_main_programs(self): - return self._dist_main_programs - - @property - def dist_startup_programs(self): - return self._dist_startup_programs - @property def has_annotation(self): return len(self._dist_tensors_for_program) or len( @@ -198,21 +198,168 @@ class DistributedContext: def gradient_scale(self, gs): self._gradient_scale = gs - def initialize(self): - if not self._is_initialized: + def _backup_serial_info(self, mode): + self._backup_serial_main_program_stack.append( + self._serial_main_program.clone()) + self._backup_serial_startup_program_stack.append( + self._serial_startup_program.clone()) + self._backup_pass_context_stack.append( + copy.deepcopy(self._pass_context)) + self._backup_block_state_stack.append(copy.deepcopy(self._block_state)) + + def _backup_dist_info(self, mode): + self._backup_dist_tensors_for_program_stack.append( + copy.deepcopy(self._dist_tensors_for_program)) + self._backup_dist_ops_for_program_stack.append( + copy.deepcopy(self._dist_ops_for_program)) + + def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None): + # Use this function carefully + if serial: + self._backup_serial_info(serial_mode) + if dist: + self._backup_dist_info(dist_mode) + + def _restore_serial_info(self, mode="to_backup"): + if mode == "to_backup": + self._serial_main_program = self._backup_serial_main_program_stack.pop( + ) + self._serial_startup_program = self._backup_serial_startup_program_stack.pop( + ) + elif mode == "to_original": + assert self._original_serial_main_program is not None + assert self._original_serial_startup_program is not None self._serial_main_program = self._original_serial_main_program.clone( ) self._serial_startup_program = self._original_serial_startup_program.clone( ) - # self._serial_main_program = self._original_serial_main_program - # self._serial_startup_program = self._original_serial_startup_program - if self._original_serial_loss: - self._serial_loss = self._serial_main_program.global_block( - ).vars[self._original_serial_loss[0].name] + + self._serial_optimizer = self._original_serial_optimizer + + if self._original_serial_loss: + if isinstance(self._original_serial_loss, list): + assert len(self._original_serial_loss) == 1 + loss = self._original_serial_loss[0] + block_idx = loss.block.idx + var_name = loss.name + var = self._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + self._serial_loss = var else: - self._serial_loss = self._original_serial_loss - self._serial_optimizer = self._original_serial_optimizer + block_idx = self._original_serial_loss.block.idx + var_name = self._original_serial_loss.name + var = self._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + self._serial_loss = var + + for key, var_list in self._original_serial_feed_vars.items(): + new_var_list = [] + for var in var_list: + block_idx = var.block.idx + var_name = var.name + var = self._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + new_var_list.append(var) + self._serial_feed_vars[key] = new_var_list + + for key, var_list in self._original_serial_fetch_vars.items(): + new_var_list = [] + for var in var_list: + block_idx = var.block.idx + var_name = var.name + var = self._serial_main_program.blocks[ + block_idx]._var_recursive(var_name) + new_var_list.append(var) + self._serial_fetch_vars[key] = new_var_list + + self._pass_context = self._backup_pass_context_stack.pop() + self._block_state = self._backup_block_state_stack.pop() + + def _restore_dist_info(self, mode="to_backup"): + if mode == "to_backup": + self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop( + ) + self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop( + ) + elif mode == "to_original": + assert self._original_dist_tensors_for_program + assert self._original_dist_ops_for_program + self._dist_tensors_for_program = copy.deepcopy( + self._original_dist_tensors_for_program) + self._dist_ops_for_program = copy.deepcopy( + self._original_dist_ops_for_program) + elif mode == "to_default": + new_tensors_ids = [] + for tensor_id, dist_tensor in self._dist_tensors_for_program.items( + ): + if tensor_id in self._tensors_ids: + dist_tensor.dist_attr.reset() + else: + new_tensors_ids.append(tensor_id) + for tensor_id in new_tensors_ids: + self._dist_tensors_for_program.pop(tensor_id) + new_ops_ids = [] + for op_id, dist_op in self._dist_ops_for_program.items(): + if op_id in self._ops_ids: + dist_op.dist_attr.reset() + else: + new_ops_ids.append(op_id) + for op_id in new_ops_ids: + self._dist_ops_for_program.pop(op_id) + else: + new_tensors_ids = [] + for tensor_id, dist_tensor in self._dist_tensors_for_program.items( + ): + new_tensors_ids.append(tensor_id) + for tensor_id in new_tensors_ids: + self._dist_tensors_for_program.pop(tensor_id) + new_ops_ids = [] + for op_id, dist_op in self._dist_ops_for_program.items(): + new_ops_ids.append(op_id) + for op_id in new_ops_ids: + self._dist_ops_for_program.pop(op_id) + self._dist_main_programs = {} + self._dist_startup_programs = {} + self._dist_op_context = DistributedOperatorContext() + self._need_copy_dist_attr_to_graph = True + self._process_meshes = [] + + def _restore(self, + serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_backup"): + # Use this function carefully + if serial: + self._restore_serial_info(serial_mode) + if dist: + self._restore_dist_info(dist_mode) + + def initialize(self): + if not self._is_initialized: + if not self._serial_main_program: + self._serial_main_program = self._original_serial_main_program + if not self._serial_startup_program: + self._serial_startup_program = self._original_serial_startup_program + if not self._serial_loss: + if isinstance(self._original_serial_loss, list): + assert len(self._original_serial_loss) == 1 + self._serial_loss = self._original_serial_loss[0] + else: + self._serial_loss = self._original_serial_loss + if not self._serial_optimizer: + self._serial_optimizer = self._original_serial_optimizer + if not self._serial_feed_vars: + self._serial_feed_vars = self._original_serial_feed_vars + if not self._serial_fetch_vars: + self._serial_fetch_vars = self._original_serial_fetch_vars + self._init_dist_attr_for_program() + # Backup the original distributed information for later restore + self._original_dist_tensors_for_program = copy.deepcopy( + self._dist_tensors_for_program) + self._original_dist_ops_for_program = copy.deepcopy( + self._dist_ops_for_program) self._tensors_ids = list(self._dist_tensors_for_program.keys()) self._ops_ids = list(self._dist_ops_for_program.keys()) set_flags({"FLAGS_convert_all_blocks": True}) @@ -220,41 +367,9 @@ class DistributedContext: core.Graph(self._serial_main_program.desc)) self._init_dist_attr_for_graph() self._is_initialized = True - - # def reset(self, - # skip_dist_tensors=None, - # skip_dist_ops=None, - # skip_tensor_dist_attr_fields=None, - # skip_op_dist_attr_fields=None): - # self._serial_main_program = self._original_serial_main_program.clone() - # self._serial_startup_program = self._original_serial_startup_program.clone() - # new_tensors_ids = [] - # for tensor_id, dist_tensor in self._dist_tensors_for_program.items(): - # if tensor_id in self._tensors_ids: - # dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields) - # else: - # new_tensors_ids.append(tensor_id) - # for tensor_id in new_tensors_ids: - # self._dist_tensors_for_program.pop(tensor_id) - # new_ops_ids = [] - # for op_id, dist_op in self._dist_ops_for_program.items(): - # if op_id in self._ops_ids: - # dist_op.dist_attr.reset(skip_op_dist_attr_fields) - # else: - # new_ops_ids.append(op_id) - # for op_id in new_ops_ids: - # self._dist_ops_for_program.pop(op_id) - - # self.copy_dist_attr_from_program_to_graph() - - # self._dist_main_programs = {} - # self._dist_startup_programs = {} - - # self._pass_context = PassContext() - - # self._dist_op_context = DistributedOperatorContext() - - # self._process_meshes = [] + self._need_copy_dist_attr_to_graph = False + if self._need_copy_dist_attr_to_graph: + self.copy_dist_attr_from_program_to_graph() def add_process_mesh(self, process_mesh): assert isinstance(process_mesh, ProcessMesh), \ @@ -423,6 +538,10 @@ class DistributedContext: if current_dist_op is None: dist_op = DistributedOperator(op) self.add_dist_op_for_program(dist_op) + self._original_dist_tensors_for_program = copy.deepcopy( + self._dist_tensors_for_program) + self._original_dist_ops_for_program = copy.deepcopy( + self._dist_ops_for_program) def _order_nodes_by_program_order(self): def _contains(nodes, target_node): @@ -592,7 +711,7 @@ class DistributedContext: op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program.dist_attr = op_dist_attr_for_graph - # TODO: the completion algorithm will skip orphan tensors, + # TODO: the completion algorithm will skipped orphan tensors, # here we just set there process_mesh to the first one. for orphan_node in self._serial_orphan_tensor_nodes: serial_tensor_id = orphan_node.var().id() @@ -618,16 +737,21 @@ class DistributedContext: tensor_shape = serial_tensor.shape dims_mapping = dist_attr.dims_mapping process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_processes = dist_attr.process_mesh.processes # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 + if dims_mapping[i] != -1 and len(process_mesh_processes) == 1: + dims_mapping[i] = -1 for dist_op in self._dist_ops_for_program.values(): serial_op = dist_op.serial_op dist_attr = dist_op.dist_attr + process_mesh_shape = dist_attr.process_mesh.topology + process_mesh_processes = dist_attr.process_mesh.processes for arg_name in serial_op.input_arg_names: if dist_op.get_serial_input(arg_name) is None: tensor_shape = [] @@ -639,13 +763,15 @@ class DistributedContext: else: tensor_shape = dist_op.get_serial_input(arg_name).shape dims_mapping = dist_attr.get_input_dims_mapping(arg_name) - process_mesh_shape = dist_attr.process_mesh.topology # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 + if dims_mapping[i] != -1 and len( + process_mesh_processes) == 1: + dims_mapping[i] = -1 for arg_name in serial_op.output_arg_names: if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \ or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ @@ -654,13 +780,18 @@ class DistributedContext: else: tensor_shape = dist_op.get_serial_output(arg_name).shape dims_mapping = dist_attr.get_output_dims_mapping(arg_name) - process_mesh_shape = dist_attr.process_mesh.topology # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: dims_mapping[i] = -1 + if dims_mapping[i] != -1 and len( + process_mesh_processes) == 1: + dims_mapping[i] = -1 + if len(process_mesh_processes) == 1: + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 def validate_dist_attr_for_program(self): if not self._is_initialized: @@ -674,16 +805,20 @@ class DistributedContext: dist_tensor.serial_tensor.name) if (dist_tensor is not None) and ( not dist_tensor.validate_dist_attr()): - assert False, "Tensor {} has a wrong distributed attributes {}.".format( - dist_tensor.serial_tensor.name, dist_tensor.dist_attr) + assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( + dist_tensor.serial_tensor.name, + dist_tensor.desc.id(), + dist_tensor.desc.original_id(), dist_tensor.dist_attr) for op in block.ops: dist_op = self.get_dist_op_for_program(op) assert dist_op is not None, \ "Operator {} does not have a distributed attribute.".format( dist_op.serial_op.type) if (dist_op is not None) and (not dist_op.validate_dist_attr()): - assert False, "Operator {} has a wrong distributed attributes {}.".format( - dist_op.serial_op.type, dist_op.dist_attr) + assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format( + dist_op.serial_op.type, + dist_op.serial_op.desc.id(), + dist_op.serial_op.desc.original_id(), dist_op.dist_attr) return True def __deepcopy__(self, memo): diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index a42ce86349..e3f06da275 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -41,7 +41,7 @@ class DistributedTensor: rank=None, shard_sizes=None): if not (isinstance(sizes, (list, tuple)) and - all(map(lambda x: isinstance(x, int) and x > 0, sizes))): + all(map(lambda x: isinstance(x, int) and x >= 0, sizes))): raise ValueError( "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". format(sizes)) @@ -79,8 +79,11 @@ class DistributedTensor: local_sizes = [] # for even sharding, the local sizes of every rank are equal + for idx, item in enumerate(global_sizes): - if dims_mapping[idx] == -1: + # This is a trick to avoid dims_mapping is [] + val = dims_mapping[idx] if idx < len(dims_mapping) else -1 + if val == -1: local_sizes.append(item) else: local_sizes.append(item // topology[dims_mapping[idx]]) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index c38953ca9e..ab9391cf66 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -31,10 +31,11 @@ from paddle.fluid.backward import append_backward from paddle.fluid.framework import Operator from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed import fleet from paddle.distributed.utils import get_logger from paddle.distributed.passes import new_pass, PassContext -from .cluster import Cluster +# from .cluster import Cluster, get_default_cluster from .planner_v2 import Planner from .parallelizer_v2 import Parallelizer from .dist_op import DistributedOperator @@ -57,7 +58,11 @@ class Engine: self.inputs_spec = self._validate_spec(inputs_spec) self.labels_spec = self._validate_spec(labels_spec) self.cluster = cluster + # if self.cluster is None: + # self.cluster = get_default_cluster() self.strategy = strategy + if self.strategy is None: + self.strategy = fleet.DistributedStrategy() self._executor = None self._cur_rank = paddle.distributed.get_rank() @@ -69,11 +74,11 @@ class Engine: self._orig_main_prog = fluid.default_main_program() self._orig_startup_prog = fluid.default_startup_program() self._orig_dist_context = get_default_distributed_context() + self._dist_contexts = {} self._serial_main_progs = {} self._serial_startup_progs = {} self._dist_main_progs = defaultdict(dict) # dist main programs self._dist_startup_progs = defaultdict(dict) # dist startup programs - self._dist_contexts = {} self._feed_vars = {} self._fetch_vars = {} @@ -104,11 +109,17 @@ class Engine: parallelizer.parallel(self._cur_rank) else: parallelizer.parallel_all() - # Get the distributed main programs and startup programs + # Get the current content from the distributed context + self._serial_main_progs[mode] = self._dist_contexts[ + mode].serial_main_program + self._serial_startup_progs[mode] = self._dist_contexts[ + mode].serial_startup_program self._dist_main_progs[mode] = self._dist_contexts[ mode].dist_main_programs self._dist_startup_progs[mode] = self._dist_contexts[ mode].dist_startup_programs + self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars + self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars # Init comm and startup program self._initialize(mode) @@ -135,20 +146,23 @@ class Engine: inputs = [self._set_data_parallel(var) for var in inputs] labels = [self._set_data_parallel(var) for var in labels] - self._feed_vars[mode] = {"inputs": inputs, "labels": labels} + # self._feed_vars[mode] = {"inputs": inputs, "labels": labels} + feed_vars = {"inputs": inputs, "labels": labels} - self._fetch_vars[mode] = { + # self._fetch_vars[mode] = { + # "outputs": flatten(outputs), + # "loss": losses, + # "metrics": metrics + # } + fetch_vars = { "outputs": flatten(outputs), "loss": losses, "metrics": metrics } - self._serial_main_progs[mode] = serial_main_prog - self._serial_startup_progs[mode] = serial_startup_prog self._dist_contexts[mode] = DistributedContext( - self._serial_main_progs[mode], self._serial_startup_progs[mode], - self._optimizer, losses, self._feed_vars[mode], - self._fetch_vars[mode], self.strategy) + serial_main_prog, serial_startup_prog, self._optimizer, losses, + feed_vars, fetch_vars, self.cluster, self.strategy) self._dist_contexts[mode].gradient_scale = self._gradient_scale def _initialize(self, mode): diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 3ff4746972..295e3557df 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -16,7 +16,7 @@ from .common import DistributedOperatorImplContainer from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl -from .common import find_best_compatible_distributed_operator_impl +from .common import find_compatible_distributed_operator_impls from . import dist_embedding from . import dist_matmul from . import dist_reshape diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 441eb88a9f..6b3c655f29 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -157,9 +157,7 @@ def register_distributed_operator_impl(op_type, dist_impl): assert False, "Must register distributed operator registry first." -def find_best_compatible_distributed_operator_impl(dist_op, - fwd=True, - partial=True): +def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 6d9b48ea1e..78f30422e7 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -187,7 +187,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) - if serial_tensor.is_parameter: + if serial_tensor is not None and serial_tensor.is_parameter: for mapping in dims_mapping: if mapping != -1: return False @@ -217,7 +217,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) - if serial_tensor.is_parameter: + if serial_tensor is not None and serial_tensor.is_parameter: for mapping in dims_mapping: if mapping != -1: return False diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 89cd2c9d9e..4d52e5a94b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -22,7 +22,6 @@ from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from .common import set_comm_op_dist_attr_for_program from .dist_default import DistributedDefaultImpl0 -from ..reshard import Resharder from ..process_group import new_process_group from ..utils import is_dim_shard, is_dim_replicate, _get_corresponding_rank from ..utils import compute_compatible_dim_mapping, set_dist_op_desc_original_id, _get_comm_group @@ -324,6 +323,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_group = op_dist_attr.process_mesh.processes dims_mapping = [0] + [-1 for _ in range(len(new_X_grad.shape) - 1)] + from ..reshard import Resharder + partition_idx = Resharder.compute_partition_index( rank_id, new_X_grad.shape, dims_mapping, process_mesh_shape, process_mesh_group) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 4d73632761..218513323d 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -35,7 +35,7 @@ class Parallelizer: self._mode = mode self._completer = completer self._dist_context = dist_context - self._dist_context.initialize() + assert self._dist_context._is_initialized self._pass_context = self._dist_context.pass_context self._strategy = self._dist_context.strategy @@ -43,7 +43,9 @@ class Parallelizer: world_process_group = get_world_process_group() all_ranks = world_process_group.ranks for rank in all_ranks: + # self._dist_context._backup(serial=True, dist=True) self.parallel(rank) + # self._dist_context._restore(serial=True, dist=True) def parallel(self, rank): serial_main_program = self._dist_context.serial_main_program @@ -58,6 +60,7 @@ class Parallelizer: self._apply_pre_optimization(serial_main_program, serial_startup_program, serial_loss, serial_optimizer, params_grads) + # Do logical partition partitioner = Partitioner(self._dist_context, rank) dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( @@ -85,7 +88,6 @@ class Parallelizer: resharder = Resharder(dist_main_prog, dist_startup_prog, rank, self._dist_context, [], 1) resharder.reshard() - # Clone program for test if self._mode != 'train': dist_main_prog = dist_main_prog.clone(for_test=True) diff --git a/python/paddle/distributed/auto_parallel/planner_v2.py b/python/paddle/distributed/auto_parallel/planner_v2.py index 7db17e98d0..3625a25d74 100755 --- a/python/paddle/distributed/auto_parallel/planner_v2.py +++ b/python/paddle/distributed/auto_parallel/planner_v2.py @@ -16,6 +16,8 @@ from .completion import Completer from .dist_context import get_default_distributed_context from .utils import print_program_with_dist_attr +# from .tuner.parallel_tuner import ParallelTuner + class Planner: def __init__(self, mode, dist_context): @@ -24,19 +26,28 @@ class Planner: # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need # dependency of backward-forward ops in forward completion. + # TODO: The id mapping will be lost if we clone the original program. default_ctx = get_default_distributed_context() self._dist_context._dist_op_context = default_ctx.dist_op_context self._dist_context.initialize() self._completer = Completer(self._dist_context) + self._strategy = dist_context.strategy + # if self._strategy.auto_search: + # self._parallel_tuner = ParallelTuner( + # self._dist_context, mode=self._mode) + @property def completer(self): return self._completer def plan(self): self._completer.complete_forward_annotation() + # if self._strategy.auto_search: + # self._parallel_tuner.tune() + # else: + # self._completer.complete_forward_annotation() # parse forward sub block self._dist_context.block_state.parse_forward_blocks( self._dist_context.serial_main_program) - # TODO: add the auto searcher diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index fbe3a43a79..42d90b0d4d 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -324,10 +324,13 @@ def _get_corresponding_rank(dist_context, target_mesh, rank): mesh.processes.index(rank)) break - assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( - rank) - return target_mesh.processes[_coordinate2linear_idx(mesh.topology, - coordinate)] + # assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( + # rank) + if coordinate is not None: + return target_mesh.processes[_coordinate2linear_idx(mesh.topology, + coordinate)] + else: + return target_mesh.processes[0] def _get_unshard_dist_shape(var, dist_attr): diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 346939fb5c..381461130e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -31,4 +31,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS}) + py_test_modules(test_dist_context MODULES test_dist_context ENVS ${dist_ENVS}) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py new file mode 100644 index 0000000000..f7718e584f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py @@ -0,0 +1,204 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import json + +import paddle +import numpy as np +import paddle.nn as nn +import paddle.utils as utils +import paddle.static as static +import paddle.nn.functional as F + +from paddle.distributed import fleet +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +paddle.enable_static() + +batch_size = 4 +hidden_size = 1024 +sequence_len = 512 +_g_process_mesh = [[0, 1], [2, 3]] + + +def get_random_inputs_and_labels(input_shape, label_shape): + input = np.random.random(size=input_shape).astype('float32') + label = np.random.random(size=label_shape).astype('float32') + return input, label + + +def batch_generator_creator(): + def __reader__(): + for _ in range(batch_size): + batch_input, batch_label = get_random_inputs_and_labels( + [batch_size, sequence_len, hidden_size], + [batch_size, sequence_len, 1]) + yield batch_input, batch_label + + return __reader__ + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + param_initializer = nn.initializer.Normal( + mean=0.0, std=initializer_range) + + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.linear0 = nn.Linear( + d_model, + dim_feedforward, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + self.linear1 = nn.Linear( + dim_feedforward, + d_model, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + + def forward(self, input): + out = self.norm(input) + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": _g_process_mesh[0], + "dims_mapping": [-1, 0] + }) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _g_process_mesh[1], + "dims_mapping": [0, -1] + }) + out = self.linear1(out) + + return out + + +def get_program(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + # fleet.init(is_collective=True, strategy=dist_strategy) + + train_program = static.Program() + start_program = static.Program() + with static.program_guard(train_program, start_program): + # input + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data( + name="label", shape=[batch_size, sequence_len, 1], dtype='float32') + data_holder = [input, label] + # dataloader + dataloader = paddle.io.DataLoader.from_generator( + feed_list=data_holder, capacity=4 * batch_size, iterable=False) + dataloader.set_batch_generator( + batch_generator_creator(), places=paddle.static.cuda_places()) + # data dist_attr + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _g_process_mesh[0], + "dims_mapping": [0, -1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": _g_process_mesh[0], + "dims_mapping": [0, -1, -1] + }) + + mlp_start = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + pred = mlp_start(input) + + mlp_mid = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + pred = mlp_mid(pred) + + mlp_end = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + pred = mlp_end(pred) + + error_cost = paddle.nn.functional.square_error_cost(pred, label) + loss = paddle.mean(error_cost) + + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + + feed_vars = {"inputs": [input], "labels": [label]} + fetch_vars = {"loss": [loss]} + + return train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars + + +class TestDistributedContext(unittest.TestCase): + def test_backup_restore(self): + train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program( + ) + dist_context = DistributedContext(train_program, start_program, + optimizer, loss, feed_vars, + fetch_vars) + dist_context.initialize() + + dist_context._backup(serial=True, dist=True) + dist_context._restore( + serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_backup") + + dist_context._backup(serial=True, dist=True) + dist_context._restore( + serial=True, + serial_mode="to_original", + dist=True, + dist_mode="to_original") + + dist_context._backup(serial=True, dist=True) + dist_context._restore(serial=True, dist=True, dist_mode="to_default") + + dist_context._backup(serial=True, dist=True) + dist_context._restore(serial=True, dist=True, dist_mode="to_nothing") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index aa0bf719fa..8af055a09a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -94,7 +94,8 @@ class TestDistSlice(unittest.TestCase): ops = dist_main_prog.global_block().ops for op in ops: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - assert op_dist_attr.impl_type == "slice" + # We amend this impl_type after completion + assert op_dist_attr.impl_type == "default" for out in op.output_arg_names: var_dims_mapping = op_dist_attr.get_output_dims_mapping(out) ref_dims_mapping = [-1 for i in range(len(var_dims_mapping))] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py index 1179fd9a9f..9989f5bbdc 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py @@ -27,7 +27,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context -from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl +from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py index 894bed7108..d296d94333 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py @@ -28,7 +28,7 @@ from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context -from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl +from paddle.distributed.auto_parallel.operators import find_compatible_distributed_operator_impls from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() -- GitLab