diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 408a1fdaafeefefb5065de53093da0da7a92587c..8c286c02015bf03591256efe0cf3046640a5a0ed 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -123,6 +123,19 @@ def merge_process_mesh_two(pm1, pm2): return merged_process_mesh +def _validate_dims_mapping(dims_mapping, process_mesh): + if dims_mapping is None: + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len( + process_mesh.topology): + return False + for i in range(len(process_mesh.topology)): + if dims_mapping.count(i) > 1: + return False + return True + + class Completer: def __init__(self, dist_context): assert dist_context is not None @@ -161,6 +174,9 @@ class Completer: dims_mapping_list.append(tensor_dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping( dims_mapping_list) + if not _validate_dims_mapping(compatible_dims_mapping, + tensor_dist_attr.process_mesh): + return False if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != tensor_dims_mapping): tensor_dist_attr.dims_mapping = compatible_dims_mapping @@ -182,6 +198,9 @@ class Completer: dims_mapping_list.append(tensor_dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping( dims_mapping_list) + if not _validate_dims_mapping(compatible_dims_mapping, + tensor_dist_attr.process_mesh): + return False if (compatible_dims_mapping is not None) and \ (compatible_dims_mapping != tensor_dims_mapping): tensor_dist_attr.dims_mapping = compatible_dims_mapping @@ -196,10 +215,12 @@ class Completer: op_desc = op_node.op() if op_desc.type() == "create_py_reader" \ or op_desc.type() == "create_double_buffer_reader" \ + or op_desc.type() == "while" \ or op_desc.type() == "read": return False dist_op = self._dist_context.get_dist_op_for_graph(op_node) op_dist_attr = dist_op.dist_attr + original_op_dist_attr = copy.deepcopy(op_dist_attr) if fwd: for tensor_node in op_node.inputs: if tensor_node.is_var() and tensor_node.var() is not None: @@ -223,18 +244,34 @@ class Completer: tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( + op_dist_impls = find_best_compatible_distributed_operator_impl( dist_op, fwd=True) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" + if op_dist_impls is not None: + not_compatible = True + backup_op_dist_attr = copy.deepcopy(op_dist_attr) + backup_changed = changed + for op_dist_impl in op_dist_impls: + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + # op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + not_compatible = False + break else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx + dist_op.dist_attr = backup_op_dist_attr + changed = backup_changed + if not_compatible: + dist_op.dist_attr = original_op_dist_attr + changed = False + else: + dist_op.dist_attr = original_op_dist_attr + changed = False else: for tensor_node in op_node.outputs: if tensor_node.is_var() and tensor_node.var() is not None: @@ -258,18 +295,35 @@ class Completer: tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl = find_best_compatible_distributed_operator_impl( + op_dist_impls = find_best_compatible_distributed_operator_impl( dist_op, fwd=False) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - if op_dist_impl.is_auto_compatible(dist_op): - if op_dist_impl.type == "elementwise": - op_dist_attr.impl_type = "default" + if op_dist_impls is not None: + not_compatible = True + backup_op_dist_attr = copy.deepcopy(op_dist_attr) + backup_changed = changed + for op_dist_impl in op_dist_impls: + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + not_compatible = False + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + # op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx + not_compatible = False + break else: - op_dist_attr.impl_type = op_dist_impl.type - op_dist_attr.impl_idx = op_dist_impl.idx + dist_op.dist_attr = backup_op_dist_attr + changed = backup_changed + if not_compatible: + dist_op.dist_attr = original_op_dist_attr + changed = False + else: + dist_op.dist_attr = original_op_dist_attr + changed = False return changed def _update_dims_mapping_between_graphs(self): @@ -279,17 +333,22 @@ class Completer: parent_node) child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( child_node) + if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh: + continue parent_node_dims_mapping = parent_node_dist_attr.dims_mapping child_node_dims_mapping = child_node_dist_attr.dims_mapping compatible_dims_mapping = compute_compatible_dims_mapping( [parent_node_dims_mapping, child_node_dims_mapping]) + if not _validate_dims_mapping(compatible_dims_mapping, + parent_node_dist_attr.process_mesh): + return False if (compatible_dims_mapping is not None) \ and (compatible_dims_mapping != parent_node_dims_mapping): parent_node_dist_attr.dims_mapping = compatible_dims_mapping changed = True if (compatible_dims_mapping is not None) \ and (compatible_dims_mapping != child_node_dims_mapping): - parent_node_dist_attr.dims_mapping = compatible_dims_mapping + child_node_dist_attr.dims_mapping = compatible_dims_mapping changed = True return changed @@ -351,7 +410,7 @@ class Completer: if compatible_process_mesh is not None \ and tensor_dist_attr.process_mesh != compatible_process_mesh: tensor_dist_attr.process_mesh = compatible_process_mesh - # Set the process mesh of the op node's outputs + # Set the process mesh of the op node's outputs for tensor_node in op_node.outputs: if tensor_node.is_var() and tensor_node.var() is not None: tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( @@ -389,7 +448,8 @@ class Completer: if _node_id(cur) in visited: continue # TODO: need more restrictions - for node in cur.inputs: + neighbors = cur.inputs + cur.outputs + for node in neighbors: if node.is_var() and node.var() is not None: if node.var().type() != core.VarDesc.VarType.READER \ and len(node.var().shape()) == 1: @@ -421,10 +481,29 @@ class Completer: visited.add(_node_id(cur)) return related_nodes + def _make_dims_mapping_replicate(dist_attr): + if isinstance(dist_attr, TensorDistributedAttribute): + for i, _ in enumerate(dist_attr.dims_mapping): + dist_attr.dims_mapping[i] = -1 + if isinstance(dist_attr, OperatorDistributedAttribute): + for arg_name in dist_attr.inputs_dist_attrs.keys(): + new_dims_mapping = [] + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + for _ in dims_mapping: + new_dims_mapping.append(-1) + dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + for arg_name in dist_attr.outputs_dist_attrs.keys(): + new_dims_mapping = [] + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + for _ in dims_mapping: + new_dims_mapping.append(-1) + dist_attr.set_output_dims_mapping(arg_name, + new_dims_mapping) + # Amend the process meshes related to while_op for while_op_node, while_op_node_idx in self._while_op_nodes.values(): sub_graph_id = while_op_node.op()._block_attr_id("sub_block") - sub_graph = self._dist_context._serial_graph.get_sub_graph( + sub_graph = self._dist_context.serial_graph.get_sub_graph( sub_graph_id) sub_graph_nodes = list(sub_graph.all_nodes()) while_dist_op = self._dist_context.get_dist_op_for_graph( @@ -440,6 +519,7 @@ class Completer: merged_process_mesh = merge_process_mesh_two( merged_process_mesh, dist_attr.process_mesh) while_op_dist_attr.process_mesh = merged_process_mesh + _make_dims_mapping_replicate(while_op_dist_attr) # Step 2: set the related nodes of while_op to the process mesh of while_op # Step 2.1: Find related nodes of cond var the graph of while_op @@ -480,6 +560,7 @@ class Completer: tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( node) tensor_dist_attr.process_mesh = merged_process_mesh + _make_dims_mapping_replicate(tensor_dist_attr) # Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs @@ -519,6 +600,25 @@ class Completer: dist_attr = self._dist_context.get_dist_attr_for_graph( array_node) dist_attr.process_mesh = merged_process_mesh + _make_dims_mapping_replicate(dist_attr) + + def _update_process_mesh_between_graphs(self): + for parent_node, child_node in self._node_pairs_between_graphs: + parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph( + parent_node) + child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( + child_node) + parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh + compatible_process_mesh = compute_compatible_process_mesh([ + parent_node_dist_attr.process_mesh, + child_node_dist_attr.process_mesh + ]) + if compatible_process_mesh is not None \ + and parent_node_dist_attr.process_mesh != compatible_process_mesh: + parent_node_dist_attr.process_mesh = compatible_process_mesh + if compatible_process_mesh is not None \ + and child_node_dist_attr.process_mesh != compatible_process_mesh: + child_node_dist_attr.process_mesh = compatible_process_mesh def _update_process_mesh(self): ordered_op_nodes = self._dist_context._serial_ordered_op_nodes @@ -569,7 +669,7 @@ class Completer: return None for idx, op_node in enumerate(ordered_op_nodes[ idx_of_first_op_node_has_process_mesh + 1:]): - original_idx = idx_of_first_op_node_has_process_mesh + +idx + 1 + original_idx = idx_of_first_op_node_has_process_mesh + idx + 1 nearest_op_node = ordered_op_nodes[original_idx - 1] nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph( nearest_op_node) @@ -585,6 +685,9 @@ class Completer: # Step 3: adjust the process meshes for special ops self._update_process_mesh_for_specials() + # Step 4: adjust the process meshes between graphs + self._update_process_mesh_between_graphs() + def _prepare(self): self._while_op_nodes = {} self._array_nodes = {} @@ -620,7 +723,7 @@ class Completer: self._node_pairs_between_graphs.append( (after_node, node)) - def complete_forward_annotation(self, serial_main_program): + def complete_forward_annotation(self, serial_main_program=None): """ Complete annotation for the partial annotated serial_main_program. Arguments: serial_main_program: partial annotated serial_main_program. @@ -628,15 +731,12 @@ class Completer: serial_main_program: completed annotated serial_main_program. """ - # Use the default distribted context for completeion if there is no one - self._dist_context.serial_program = serial_main_program - - # Initialize distributed attributes for all var and op node in serial_main_program - self._dist_context.init_dist_attr_for_program() - # print_program_with_dist_attr(serial_main_program, self._dist_context) + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context.serial_main_program = serial_main_program - # Initialize distributed attributes for all var and op node in graph - self._dist_context.init_dist_attr_for_graph() + self._dist_context.initialize() self._prepare() @@ -646,10 +746,9 @@ class Completer: # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() - self._dist_context.clear_dist_info_for_graph() # NOTE:[HighOrderGrad] update vars and ops distributed attribute in high order gradient - self.complete_high_order_grad_annotation(serial_main_program) + self._complete_high_order_grad_annotation(serial_main_program) # Do the validation check and amend some completion self._dist_context.amend_dist_attr_for_program() @@ -658,7 +757,7 @@ class Completer: return serial_main_program - def complete_high_order_grad_annotation(self, serial_main_program): + def _complete_high_order_grad_annotation(self, serial_main_program): """ NOTE: [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. @@ -818,6 +917,10 @@ class Completer: def complete_backward_annotation(self, serial_main_program): """Complete the annotation of vars and ops in the backward phase for parallel program.""" + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context.serial_main_program = serial_main_program def _is_grad_var_name(name): if "@GRAD" in name: @@ -1036,8 +1139,12 @@ class Completer: self._dist_context.set_op_dist_attr_for_program( grad_op, grad_op_dist_attr) - def complete_update_annotation(self, serial_main_program): + def complete_update_annotation(self, serial_main_program=None): """Complete the annotation of vars and ops in the update phase for parallel program.""" + if serial_main_program is None: + serial_main_program = self._dist_context.serial_main_program + else: + self._dist_context.serial_main_program = serial_main_program ops = list(serial_main_program.global_block().ops) vars = serial_main_program.global_block().vars learning_rate_completed = False diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index 8ec702ffcb0b65af96833b4d4d2be1c8ff08d788..857f141f30b1f277780882119f7225b1ab37f8ad 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -52,7 +52,7 @@ def append_op_output_suffix(name): class TensorDistributedAttribute: def __init__(self): - # The process mesh of distributed operator attribute must is the same as + # The process mesh of distributed operator attribute must is the same as # the process meshes of all input and output distributed attributed self._process_mesh = None self._dims_mapping = None @@ -132,12 +132,29 @@ class TensorDistributedAttribute: key, dist_attr) self._is_annotated = copy.deepcopy(dist_attr._is_annotated) + # def reset(self, skip_dist_attr_field_names): + # if skip_dist_attr_field_names is not None \ + # and "process_mesh" not in skip_dist_attr_field_names: + # self._process_mesh = None + # if skip_dist_attr_field_names is not None \ + # and "dims_mapping" not in skip_dist_attr_field_names: + # for i in enumerate(self._dims_mapping): + # self._dims_mapping[i] = -1 + # self._is_annotated = {} + def is_annotated(self, dist_attr_field_name): return self._is_annotated.get(dist_attr_field_name, False) + # def mark_annotated_all(self): + # for key in get_tensor_dist_attr_field_keys(): + # self.mark_annotated(key) + def mark_annotated(self, dist_attr_field_name): self._is_annotated[dist_attr_field_name] = True + # def unmark_annotated(self, dist_attr_field_name): + # self._is_annotated[dist_attr_field_name] = False + def mark_annotated_as(self, dist_attr): if dist_attr is None: return @@ -195,7 +212,7 @@ class OperatorDistributedAttribute: if isinstance(process_mesh, list): process_mesh = ProcessMesh(process_mesh) self._process_mesh = copy.deepcopy(process_mesh) - # In while op, the proess mesh is not shared by all inputs and outputs + # In while op, the proess mesh is not shared by all inputs and outputs if self._op_type == "while": return None for dist_attr in self._inputs_dist_attrs.values(): @@ -357,9 +374,25 @@ class OperatorDistributedAttribute: "ProcessMeshes in DistributedOperator must be the same." self.process_mesh = shared_process_mesh + # def reset(self, skip_dist_attr_field_names): + # for tensor_dist_attr in self.inputs_dist_attrs.values(): + # tensor_dist_attr.reset(skip_dist_attr_field_names) + # for tensor_dist_attr in self.outputs_dist_attrs.values(): + # tensor_dist_attr.reset(skip_dist_attr_field_names) + # if skip_dist_attr_field_names is not None \ + # and "process_mesh" not in skip_dist_attr_field_names: + # self.process_mesh = None + # self.impl_type = "default" + # self.impl_idx = 0 + # self._is_annotated = {} + def is_annotated(self, attr_name): return self._is_annotated.get(attr_name, False) + # def mark_annotated_all(self): + # for key in get_op_dist_attr_field_keys(): + # self.mark_annotated(key) + def mark_annotated(self, attr_name): if attr_name == "process_mesh": # Make sure proscess_mesh be annotated consistently diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 7e245358d4bccaad4b6ffeb0648350459d6212e9..5082ac987f456a06e1da0c91ed627ef966dcf5e3 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -14,9 +14,11 @@ import copy from collections import defaultdict +import paddle.fluid from paddle.fluid import framework from paddle.fluid.framework import get_flags, set_flags from paddle.fluid import core +from paddle.distributed.passes import PassContext from .dist_attribute import TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute from .dist_tensor import DistributedTensor @@ -54,26 +56,41 @@ class DistributedContext: serial_main_prog=None, serial_startup_prog=None, dist_main_progs=None, - dist_startup_progs=None): - # Program related data members - self._serial_program = serial_main_prog - self._is_initialized_for_program = False + dist_startup_progs=None, + serial_loss=None, + serial_optimizer=None, + strategy=None): + # Data members related to original programs (unchanged) + self._original_serial_main_program = serial_main_prog + self._original_serial_startup_program = serial_startup_prog + self._original_serial_loss = serial_loss + self._original_serial_optimizer = serial_optimizer + if self._original_serial_main_program is None: + self._original_serial_main_program = paddle.fluid.default_main_program( + ) + if self._original_serial_startup_program is None: + self._original_serial_startup_program = paddle.fluid.default_startup_program( + ) + + # Data members related to programs (changed) + self._serial_main_program = None + self._serial_startup_program = None + self._serial_loss = None + self._serial_optimizer = None + + # Data members related to the program self._dist_tensors_for_program = {} self._dist_ops_for_program = {} self._block_state = BlockState() - # Graph related data members - self._is_initialized_for_graph = False + + # Data members related to the graph self._serial_graph = None self._dist_tensors_for_graph = {} self._dist_ops_for_graph = {} self._node_id_to_tensor_id = {} self._node_id_to_op_id = {} - # Other data members - self._dist_op_context = DistributedOperatorContext() - self._process_meshes = [] - self._serial_ordered_nodes = [] - self._tensor_id_to_tensor_node_ids = {} + # Data members related to the distributed programs # Distributed programs self._dist_main_programs = dist_main_progs if not self._dist_main_programs: @@ -82,20 +99,71 @@ class DistributedContext: if not self._dist_startup_programs: self._dist_startup_programs = {} + # Distributed Strategy + self._strategy = strategy + + # Pass Context + self._pass_context = PassContext() + + # Distributed Operator Context + self._dist_op_context = DistributedOperatorContext() + + # Other data members + self._process_meshes = [] + self._serial_ordered_tensor_nodes = [] + self._serial_ordered_op_nodes = [] + self._serial_ordered_nodes = [] + # self._tensor_id_to_tensor_node_ids = {} + + self._is_initialized = False + @property - def serial_program(self): - return self._serial_program + def serial_main_program(self): + return self._serial_main_program + + @serial_main_program.setter + def serial_main_program(self, program): + # if self._serial_main_program: + # print("WARNING: The program attached to this distributed context will be replaced by the new one.") + self._original_serial_main_program = program + self._serial_main_program = program + + @property + def serial_startup_program(self): + return self._serial_startup_program + + # @serial_startup_program.setter + # def serial_startup_program(self, serial_startup_program): + # self._serial_startup_program = serial_startup_program + + @property + def serial_loss(self): + return self._serial_loss + + # @serial_loss.setter + # def serial_loss(self, serial_loss): + # self._serial_loss = serial_loss + + @property + def serial_optimizer(self): + return self._serial_optimizer + + # @serial_optimizer.setter + # def serial_optimizer(self, serial_optimizer): + # self._serial_optimizer = serial_optimizer + + @property + def strategy(self): + return self._strategy + + # @strategy.setter + # def strategy(self, strategy): + # self._strategy = strategy @property def serial_graph(self): return self._serial_graph - @serial_program.setter - def serial_program(self, program): - # assert self._serial_program is None, \ - # "This distributed context has already been realted to a serial program" - self._serial_program = program - @property def serial_ordered_nodes(self): return self._serial_ordered_nodes @@ -104,6 +172,10 @@ class DistributedContext: def process_meshes(self): return self._process_meshes + @property + def pass_context(self): + return self._pass_context + @property def dist_op_context(self): return self._dist_op_context @@ -121,10 +193,64 @@ class DistributedContext: return self._dist_startup_programs @property - def is_annotation(self): + def has_annotation(self): return len(self._dist_tensors_for_program) or len( self._dist_ops_for_program) + def initialize(self): + if not self._is_initialized: + self._serial_main_program = self._original_serial_main_program.clone( + ) + self._serial_startup_program = self._original_serial_startup_program.clone( + ) + self._serial_main_program = self._original_serial_main_program + self._serial_startup_program = self._original_serial_startup_program + self._serial_loss = self._original_serial_loss + self._serial_optimizer = self._original_serial_optimizer + self._init_dist_attr_for_program() + self._tensors_ids = list(self._dist_tensors_for_program.keys()) + self._ops_ids = list(self._dist_ops_for_program.keys()) + set_flags({"FLAGS_convert_all_blocks": True}) + self._serial_graph = framework.IrGraph( + core.Graph(self._serial_main_program.desc)) + self._init_dist_attr_for_graph() + self._is_initialized = True + + # def reset(self, + # skip_dist_tensors=None, + # skip_dist_ops=None, + # skip_tensor_dist_attr_fields=None, + # skip_op_dist_attr_fields=None): + # self._serial_main_program = self._original_serial_main_program.clone() + # self._serial_startup_program = self._original_serial_startup_program.clone() + # new_tensors_ids = [] + # for tensor_id, dist_tensor in self._dist_tensors_for_program.items(): + # if tensor_id in self._tensors_ids: + # dist_tensor.dist_attr.reset(skip_tensor_dist_attr_fields) + # else: + # new_tensors_ids.append(tensor_id) + # for tensor_id in new_tensors_ids: + # self._dist_tensors_for_program.pop(tensor_id) + # new_ops_ids = [] + # for op_id, dist_op in self._dist_ops_for_program.items(): + # if op_id in self._ops_ids: + # dist_op.dist_attr.reset(skip_op_dist_attr_fields) + # else: + # new_ops_ids.append(op_id) + # for op_id in new_ops_ids: + # self._dist_ops_for_program.pop(op_id) + + # self.copy_dist_attr_from_program_to_graph() + + # self._dist_main_programs = {} + # self._dist_startup_programs = {} + + # self._pass_context = PassContext() + + # self._dist_op_context = DistributedOperatorContext() + + # self._process_meshes = [] + def add_process_mesh(self, process_mesh): assert isinstance(process_mesh, ProcessMesh), \ 'The type of dim_mapping must be ProcessMesh.' @@ -133,12 +259,12 @@ class DistributedContext: def add_dist_tensor_for_program(self, dist_tensor): inner_serial_tensor = dist_tensor.serial_tensor - inner_serial_tensor_id = inner_serial_tensor.desc.id() + inner_serial_tensor_id = inner_serial_tensor.desc.original_id() self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor def add_dist_op_for_program(self, dist_op): inner_serial_op = dist_op.serial_op - inner_serial_op_id = inner_serial_op.desc.id() + inner_serial_op_id = inner_serial_op.desc.original_id() self._dist_ops_for_program[inner_serial_op_id] = dist_op def get_dist_tensor_for_program(self, serial_tensor): @@ -215,18 +341,6 @@ class DistributedContext: else: return None - # def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): - # assert serial_tensor_node.is_var() and \ - # serial_tensor_node.var() is not None - # serial_tensor_id = serial_tensor_node.node.original_desc_id() - # dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) - # assert dist_tensor is not None, \ - # "The distributed tensor of the program has not been added to this context." - # serial_tensor_node_id = serial_tensor_node.id() - # new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, - # dist_attr) - # self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor - def get_op_dist_attr_for_program(self, serial_op): serial_op_id = serial_op.desc.id() dist_op = self._dist_ops_for_program.get(serial_op_id, None) @@ -259,17 +373,6 @@ class DistributedContext: else: return None - # def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): - # assert serial_op_node.is_op() and \ - # serial_op_node.op() is not None - # serial_op_id = serial_op_node.node.original_desc_id() - # dist_op = self._dist_ops_for_program.get(serial_op_id, None) - # assert dist_op is not None, \ - # "The distributed operator of the program has not been added to this context." - # serial_op_node_id = serial_op_node.id() - # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) - # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - def get_dist_attr_for_graph(self, serial_node): if serial_node.is_var() and serial_node.var() is not None: serial_tensor_node_id = _node_id(serial_node) @@ -288,15 +391,14 @@ class DistributedContext: return None return None - def init_dist_attr_for_program(self): - assert self._serial_program, \ - "Please set the program of this context before initializing its distribute attributes." - if self._is_initialized_for_program: - return + def _init_dist_attr_for_program(self, no_default=False): # Copy the dist tensors and dist ops annotated by users from the default context - default_ctx = get_default_distributed_context() - self._process_meshes = copy.deepcopy(default_ctx.process_meshes) - for block in self._serial_program.blocks: + if not no_default: + default_ctx = get_default_distributed_context() + self._process_meshes = copy.deepcopy(default_ctx.process_meshes) + else: + default_ctx = self + for block in self._serial_main_program.blocks: for tensor in block.vars.values(): # Copy the distributed tensors in the default context default_dist_tensor = default_ctx.get_dist_tensor_for_program( @@ -316,9 +418,8 @@ class DistributedContext: if current_dist_op is None: dist_op = DistributedOperator(op) self.add_dist_op_for_program(dist_op) - self._is_initialized_for_program = True - def order_nodes_by_program_order(self): + def _order_nodes_by_program_order(self): def _contains(nodes, target_node): for node in nodes: if _node_id(node) == _node_id(target_node): @@ -328,7 +429,6 @@ class DistributedContext: serial_ordered_tensor_nodes = [] serial_ordered_op_nodes = [] all_nodes = [] - # for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for node in graph.all_nodes(): all_nodes.append(node) @@ -346,33 +446,35 @@ class DistributedContext: new_serial_ordered_tensor_nodes = [] new_serial_ordered_op_nodes = [] + new_serial_ordered_nodes = [] for op_node in serial_ordered_op_nodes: tensor_nodes = [] for tensor_node in op_node.inputs: if tensor_node.is_var() \ and tensor_node.var() is not None \ - and not _contains(self._serial_ordered_nodes, tensor_node): + and not _contains(new_serial_ordered_nodes, tensor_node): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) - self._serial_ordered_nodes.extend(tensor_nodes) - self._serial_ordered_nodes.append(op_node) + new_serial_ordered_nodes.extend(tensor_nodes) + new_serial_ordered_nodes.append(op_node) new_serial_ordered_op_nodes.append(op_node) tensor_nodes = [] for tensor_node in op_node.outputs: if tensor_node.is_var() \ and tensor_node.var() is not None \ - and not _contains(self._serial_ordered_nodes, tensor_node): + and not _contains(new_serial_ordered_nodes, tensor_node): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) - self._serial_ordered_nodes.extend(tensor_nodes) + new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_tensor_nodes.sort( key=lambda node: node.node.original_desc_id()) new_serial_ordered_op_nodes.sort( key=lambda node: node.node.original_desc_id()) self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes self._serial_ordered_op_nodes = new_serial_ordered_op_nodes + self._serial_ordered_nodes = new_serial_ordered_nodes assert len(self._serial_ordered_nodes) == len( self._serial_ordered_tensor_nodes) + len( self._serial_ordered_op_nodes) @@ -385,16 +487,9 @@ class DistributedContext: "WARNING: there are some orphan tensors or ops which are not used in the execution." ) - def init_dist_attr_for_graph(self): - assert self._is_initialized_for_program, \ - "The program must be initialized before initializing the distributed attributes for its graph." - if self._is_initialized_for_graph: - return - # Convert program to graph - set_flags({"FLAGS_convert_all_blocks": True}) - self._serial_graph = framework.IrGraph( - core.Graph(self._serial_program.desc)) - self.order_nodes_by_program_order() + def _init_dist_attr_for_graph(self): + # Convert program to graph and initialize the distributed attributes + self._order_nodes_by_program_order() for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None @@ -428,7 +523,6 @@ class DistributedContext: new_dist_op = DistributedOperator(dist_op.serial_op, dist_op.dist_attr) self._dist_ops_for_graph[serial_op_node_id] = new_dist_op - self._is_initialized_for_graph = True def clear_dist_info_for_program(self): self._dist_tensors_for_program.clear() @@ -438,8 +532,40 @@ class DistributedContext: self._dist_tensors_for_graph.clear() self._dist_ops_for_graph.clear() + def copy_dist_attr_from_program_to_graph(self): + for node in self.serial_ordered_nodes: + if node.is_var() and node.var() is not None: + dist_tensor = None + tensor_id = node.node.original_desc_id() + for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items( + ): + if tensor_id == cur_tensor_id \ + or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): + dist_tensor = cur_dist_tensor + assert dist_tensor is not None, \ + "Tensor must have a distributed tensor after the initialization for program." + serial_tensor_node_id = _node_id(node) + new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + dist_tensor.dist_attr) + self._dist_tensors_for_graph[ + serial_tensor_node_id] = new_dist_tensor + if node.is_op() and node.op() is not None: + dist_op = None + op_id = node.node.original_desc_id() + for cur_op_id, cur_dist_op in self._dist_ops_for_program.items( + ): + if op_id == cur_op_id \ + or op_id == cur_dist_op.serial_op.desc.original_id(): + dist_op = cur_dist_op + assert dist_op is not None, \ + "Operator must have a distributed operator after the initialization for program." + serial_op_node_id = _node_id(node) + new_dist_op = DistributedOperator(dist_op.serial_op, + dist_op.dist_attr) + self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + def copy_dist_attr_from_graph_to_program(self): - assert self._is_initialized_for_program and self._is_initialized_for_graph, \ + assert self._is_initialized, \ "Both program and graph must be initialized." updated_tensors = {} # all_nodes = self._serial_graph.all_nodes() @@ -461,7 +587,7 @@ class DistributedContext: op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) dist_op_for_program = self._dist_ops_for_program[op_id] dist_op_for_program.dist_attr = op_dist_attr_for_graph - # TODO: the completion algorithm will skip orphan tensors, + # TODO: the completion algorithm will skip orphan tensors, # here we just set there process_mesh to the first one. for orphan_node in self._serial_orphan_tensor_nodes: serial_tensor_id = orphan_node.var().id() @@ -532,18 +658,24 @@ class DistributedContext: dims_mapping[i] = -1 def validate_dist_attr_for_program(self): - if not self._is_initialized_for_program: + if not self._is_initialized: assert False, \ "Program must be initialized before validating its distributed attributes" - for block in self.serial_program.blocks: + for block in self.serial_main_program.blocks: for tensor in block.vars.values(): dist_tensor = self.get_dist_tensor_for_program(tensor) + assert dist_tensor is not None, \ + "Tensor {} does not have a distributed attribute.".format( + dist_tensor.serial_tensor.name) if (dist_tensor is not None) and ( not dist_tensor.validate_dist_attr()): assert False, "Tensor {} has a wrong distributed attributes {}.".format( dist_tensor.serial_tensor.name, dist_tensor.dist_attr) for op in block.ops: dist_op = self.get_dist_op_for_program(op) + assert dist_op is not None, \ + "Operator {} does not have a distributed attribute.".format( + dist_op.serial_op.type) if (dist_op is not None) and (not dist_op.validate_dist_attr()): assert False, "Operator {} has a wrong distributed attributes {}.".format( dist_op.serial_op.type, dist_tensor.dist_attr) @@ -554,10 +686,12 @@ class DistributedContext: result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_program" or k == "_serial_graph" \ - or k == "_dist_main_programs" or k == "_dist_startup_programs" \ - or k == "_serial_ordered_nodes" or k == "_serial_ordered_tensor_nodes" \ - or k == "_serial_ordered_op_nodes": + if k in [ + "_original_serial_main_program", "_original_serial_startup_program", \ + "_serial_main_program", "_serial_startup_program", "_serial_graph", \ + "_dist_main_programs", "_dist_startup_programs", \ + "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ + "_serial_ordered_op_nodes"]: setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 2cd841ef80979bb89b90460fbb106f464d74145f..ea6aeb513ffb97d4089fb9aa3208b52945888da0 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -118,11 +118,10 @@ class Engine: losses = to_list(self._loss(*(outputs + labels))) default_ctx = get_default_distributed_context() - if not default_ctx.is_annotation or self._default_strategy: + if not default_ctx.has_annotation or self._default_strategy: inputs = [self._set_data_parallel(var) for var in inputs] labels = [self._set_data_parallel(var) for var in labels] - # print(serial_main_prog) self._feed_vars[mode] = {"inputs": inputs, "labels": labels} self._fetch_vars[mode] = { diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 47f76353e465529f1d29a05852a952d151c76c93..5d43c5682727421ed596d89a5345b3e3481200a5 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -18,16 +18,16 @@ from ..dist_attribute import OperatorDistributedAttribute _g_distributed_operator_impl_containers = {} _g_elementwise_ops = [ - "elementwise_add", "gelu", "dropout", "cast", "gather", "concat" + "elementwise", "gelu", "dropout", "cast", "gather", "concat" ] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} def is_elementwise_op(op_type): - if op_type in _g_elementwise_ops: - return True - else: - return False + for eltwise_op in _g_elementwise_ops: + if eltwise_op in op_type: + return True + return False class DistributedOperatorImplContainer: @@ -156,7 +156,9 @@ def register_distributed_operator_impl(op_type, dist_impl): assert False, "Must register distributed operator registry first." -def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): +def find_best_compatible_distributed_operator_impl(dist_op, + fwd=True, + partial=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. @@ -168,39 +170,55 @@ def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): dist_op_default_impl_container = get_distributed_operator_impl_container( "default") compatible_impls = [] - if fwd: - # First, find impls in the corresponding container - if dist_op_impl_container: - compatible_impls.extend( - dist_op_impl_container.get_input_compatible_impls(dist_op)) - # Second, find impls in the elementwise container - if dist_op_eltwise_impl_container and is_elementwise_op(op_type): - compatible_impls.extend( - dist_op_eltwise_impl_container.get_input_compatible_impls( - dist_op)) - # Third, find impls in the default container - if dist_op_default_impl_container: - compatible_impls.extend( - dist_op_default_impl_container.get_input_compatible_impls( - dist_op)) + if partial: + if fwd: + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_input_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_input_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_input_compatible_impls( + dist_op)) + else: + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_output_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_output_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_output_compatible_impls( + dist_op)) else: # First, find impls in the corresponding container if dist_op_impl_container: compatible_impls.extend( - dist_op_impl_container.get_output_compatible_impls(dist_op)) + dist_op_impl_container.get_compatible_impls(dist_op)) # Second, find impls in the elementwise container if dist_op_eltwise_impl_container and is_elementwise_op(op_type): compatible_impls.extend( - dist_op_eltwise_impl_container.get_output_compatible_impls( - dist_op)) + dist_op_eltwise_impl_container.get_compatible_impls(dist_op)) # Third, find impls in the default container if dist_op_default_impl_container: compatible_impls.extend( - dist_op_default_impl_container.get_output_compatible_impls( - dist_op)) + dist_op_default_impl_container.get_compatible_impls(dist_op)) + if compatible_impls: # For now, just return the first compatible impl - best_compatible_impl = compatible_impls[0] + # best_compatible_impl = compatible_impls[0] + best_compatible_impl = compatible_impls else: best_compatible_impl = None return best_compatible_impl diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 4795050d15dcc0b60328e0a5be97bac46cfdea88..0696b728d161b39c63da39c68a641efde946922b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -53,6 +53,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr + batch_dim_mappings = [] input_names = op_desc.input_names() xshape_arg_names = [] if "XShape" in input_names: @@ -64,14 +65,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for mapping in dims_mapping: if mapping != -1: return False - # continue - # if len(dims_mapping) < 1: - # continue + continue if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for mapping in dims_mapping[1:]: if mapping != -1: return False + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) else: if dims_mapping[0] != -1: return False @@ -79,12 +80,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for mapping in dims_mapping[2:]: if mapping != -1: return False + if len(dims_mapping) >= 2: + batch_dim_mappings.append(dims_mapping[1]) + + if compute_compatible_dim_mapping(batch_dim_mappings) is None: + return False + return True def is_output_compatible(self, dist_op): op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr output_names = op_desc.output_names() + batch_dim_mappings = [] xshape_arg_names = [] if "XShape" in output_names: xshape_arg_names = op_desc.output("XShape") @@ -95,14 +103,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for mapping in dims_mapping: if mapping != -1: return False - # continue - # if len(dims_mapping) < 1: - # continue + continue if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for mapping in dims_mapping[1:]: if mapping != -1: return False + if len(dims_mapping) >= 1: + batch_dim_mappings.append(dims_mapping[0]) else: if dims_mapping[0] != -1: return False @@ -110,6 +118,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for mapping in dims_mapping[2:]: if mapping != -1: return False + if len(dims_mapping) >= 2: + batch_dim_mappings.append(dims_mapping[1]) + + if compute_compatible_dim_mapping(batch_dim_mappings) is None: + return False + return True def is_auto_compatible(self, dist_op): @@ -123,9 +137,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): xshape_arg_names = op_desc.input("XShape") for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if serial_tensor.is_parameter: + for mapping in dims_mapping: + if mapping != -1: + return False continue - dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for mapping in dims_mapping[1:]: @@ -150,9 +167,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): xshape_arg_names = op_desc.output("XShape") for arg_name in op_desc.output_arg_names(): serial_tensor = dist_op.get_serial_output(arg_name) + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if serial_tensor.is_parameter: + for mapping in dims_mapping: + if mapping != -1: + return False continue - dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in xshape_arg_names: if len(dims_mapping) > 1: for mapping in dims_mapping[1:]: @@ -229,7 +249,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): compatible_dim_mapping = compute_compatible_dim_mapping( batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + if compatible_dim_mapping is None: + return False + for arg_name in op_desc.input_arg_names(): serial_tensor = dist_op.get_serial_input(arg_name) if serial_tensor.is_parameter: diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py index 37d7d93a2e93427f41970c3be2fe85f0e137c569..aac7f16b6909bdd5a3c2f4e0f4d7ff1ec0f9a6a2 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -52,21 +52,46 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc - if is_elementwise_op(op_desc.type()): - return True - else: + if not is_elementwise_op(op_desc.type()): return False + op_dist_attr = dist_op.dist_attr + dims_mapping_list = [] + input_arg_names = op_desc.input_arg_names() + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + dims_mapping_list.append(dims_mapping) + + for idx in range(max_dims_mapping_len): + dim_mappings = [] + for dims_mapping in dims_mapping_list: + if idx < len(dims_mapping): + dim_mappings.append(dims_mapping[-(idx + 1)]) + if compute_compatible_dim_mapping(dim_mappings) is None: + return False + return True def is_output_compatible(self, dist_op): op_desc = dist_op.serial_op.desc - op_desc = dist_op.serial_op.desc - if is_elementwise_op(op_desc.type()): - return True - else: + if not is_elementwise_op(op_desc.type()): + return False + op_dist_attr = dist_op.dist_attr + dims_mapping_list = [] + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + dims_mapping_list.append(dims_mapping) + + if compute_compatible_dims_mapping(dims_mapping_list) is None: return False + return True def is_auto_compatible(self, dist_op): op_desc = dist_op.serial_op.desc + if not is_elementwise_op(op_desc.type()): + return False op_dist_attr = dist_op.dist_attr dims_mapping_list = [] input_arg_names = op_desc.input_arg_names() @@ -127,7 +152,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): compatible_dims_mapping = compute_compatible_dims_mapping( dims_mapping_list) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + if compatible_dims_mapping is None: + return False for arg_name in input_arg_names: if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index 68167c1c4f7e8250ea499b7304c580970136f4e9..69e1c866de691cfd88557f0f039dc96b65d2826b 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -95,7 +95,8 @@ def _update_dims_mapping_for_matmul(dist_op): broadcast_x_dims_mapping, broadcast_y_dims_mapping, broadcast_out_dims_mapping ]) - assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + if compatible_dims_mapping is None: + return False for i in range(x_dims_mapping_len - 2): new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index ce68e2060218d5294f9434788cbfbc2e4421c3d6..89cd2c9d9e41a66d11d60c83a4e6f85d014dc051 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -117,7 +117,8 @@ class DistributedPNormImpl(DistributedOperatorImpl): compatible_dim_mapping = compute_compatible_dim_mapping( batch_dim_mappings) - assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + if compatible_dim_mapping is None: + return False for arg_name in op_desc.input_arg_names(): dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index 4bc0a471dcf1c29ddc4d2912a8a4f34cc687343c..e3da47fd172ea278401ca93000ac48039c0e6be4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,6 +17,7 @@ from .common import DistributedOperatorImpl from .common import register_distributed_operator_impl_container from .common import register_distributed_operator_impl from ..utils import is_dim_shard +from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping from .dist_default import DistributedDefaultImpl0 @@ -47,6 +48,29 @@ class DistributedSliceImpl(DistributedOperatorImpl): return True def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + in_name = op_desc.input('Input')[0] + out_name = op_desc.output('Out')[0] + axes = op_desc.attr('axes') + decrease_axis = op_desc.attr('decrease_axis') + in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + ref_indices = [] + for i in range(len(in_dims_mapping)): + if i not in decrease_axis: + ref_indices.append(i) + if ref_indices == []: + assert len(out_dims_mapping) == 1 + if is_dim_shard(out_dims_mapping[0]): + return False + else: + for i in range(len(out_dims_mapping)): + ref_index = ref_indices[i] + if ref_index in axes and is_dim_shard(out_dims_mapping[i]): + return False + return True def is_compatible(self, dist_op): @@ -95,17 +119,30 @@ class DistributedSliceImpl(DistributedOperatorImpl): out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) ref_dims_mapping = [] + ref_indices = [] for i in range(len(in_dims_mapping)): if i not in decrease_axis: ref_dims_mapping.append(in_dims_mapping[i]) + ref_indices.append(i) + if ref_dims_mapping == []: ref_dims_mapping = [-1] - - assert len(ref_dims_mapping) == len(out_dims_mapping) - for i in range(len(out_dims_mapping)): - if out_dims_mapping[i] != ref_dims_mapping[i]: - out_dims_mapping[i] = ref_dims_mapping[i] - changed = True + assert len(ref_dims_mapping) == len(out_dims_mapping) + assert ref_dims_mapping[0] == out_dims_mapping[0] + changed = False + else: + assert len(ref_dims_mapping) == len(out_dims_mapping) + for i in range(len(out_dims_mapping)): + compatible_dim_mapping = compute_compatible_dim_mapping( + [out_dims_mapping[i], ref_dims_mapping[i]]) + if compatible_dim_mapping is None: + continue + if ref_dims_mapping[i] != compatible_dim_mapping: + in_dims_mapping[ref_indices[i]] = compatible_dim_mapping + changed = True + if out_dims_mapping[i] != compatible_dim_mapping: + out_dims_mapping[i] = compatible_dim_mapping + changed = True return changed diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index fc5f1686d0f8c91ac16644e67380084a9cc74933..2ea1223c6f2f326ad6a98e4e8b834b4f44837066 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -230,7 +230,7 @@ class AutoParallelizer: g_process_group_map = copy.deepcopy(_g_process_group_map) _g_process_group_map.clear() _g_process_group_map[0] = ProcessGroup(0, []) - for process_mesh in dist_context._process_meshes: + for process_mesh in self._dist_context._process_meshes: _g_process_group_map[0].add_ranks(process_mesh.processes) return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map diff --git a/python/paddle/distributed/auto_parallel/tuner/recorder.py b/python/paddle/distributed/auto_parallel/tuner/recorder.py index d0f181a6354136a37fba5af4c6ef2d172958a202..ba61843831a25f47e330d36fe32acd37284908ae 100644 --- a/python/paddle/distributed/auto_parallel/tuner/recorder.py +++ b/python/paddle/distributed/auto_parallel/tuner/recorder.py @@ -138,7 +138,6 @@ class MetricRecords(object): def from_state(cls, state): records = cls(state["direction"]) records.records = [MetricRecord.from_state(r) for r in state["records"]] - print("here 1", records.records) return records diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 9c40034498dbc504cd106fef35a886fd1054990a..ac07b49f45c3b312214e6ec2688a126abadd9d74 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -159,11 +159,11 @@ def print_program_with_dist_attr(program, dist_context=None): from .dist_context import set_default_distributed_context if dist_context is None: dist_context = get_default_distributed_context() - print(program) + print(program, flush=True) else: original_default_context = get_default_distributed_context() set_default_distributed_context(dist_context) - print(program) + print(program, flush=True) set_default_distributed_context(original_default_context) lock.release() diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index 185fb453412eab69579197e10c43d5c9f8d7fc1a..258f46304d18902cf2dc4908fa21fe36217b3728 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -350,11 +350,12 @@ class RecomputePass(PassBase): for _, op_desc in reversed(list(enumerate(segment_descs))): rc_desc = main_block.desc._insert_op(idx) rc_desc.copy_from(op_desc) + rc_desc.set_original_id(rc_desc.id()) rc_op = Operator(main_block, rc_desc) main_block.ops.insert(idx, rc_op) # set recomputed ops' dist attr fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( - rc_desc.original_id()) + op_desc.original_id()) assert fwd_op_dist_attr is not None self.set_op_dist_attr(rc_op, fwd_op_dist_attr, var_name_dict) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 4d052f7e90cd3d11e8607bd4a60546f9eca27ae1..7c747338593a393f135b64faa509ca736074da8d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -3,18 +3,23 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_auto_parallel_relaunch MODULES test_auto_parallel_relaunch ENVS ${dist_ENVS}) set_tests_properties(test_auto_parallel_relaunch PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_relaunch_with_planner MODULES test_relaunch_with_planner ENVS ${dist_ENVS}) set_tests_properties(test_relaunch_with_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + py_test_modules(test_relaunch_with_gpt_planner MODULES test_relaunch_with_gpt_planner ENVS ${dist_ENVS}) set_tests_properties(test_relaunch_with_gpt_planner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 240) + py_test_modules(test_engine_api MODULES test_engine_api ENVS ${dist_ENVS}) set_tests_properties(test_engine_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) - py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) + py_test_modules(test_converter MODULES test_converter ENVS ${dist_ENVS}) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) py_test_modules(test_high_order_grad MODULES test_high_order_grad ENVS ${dist_ENVS}) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) + py_test_modules(test_while_op_partition MODULES test_while_op_partition ENVS ${dist_ENVS}) py_test_modules(test_tunable_variable MODULES test_tunable_variable ENVS ${dist_ENVS}) py_test_modules(test_tunable_space MODULES test_tunable_space ENVS ${dist_ENVS}) py_test_modules(test_recorder MODULES test_recorder ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py index f170dbc9095f2d3790941988547083e7b7b3efb7..8777bf3ff1f2eb6ce2dc3f02ab9ff779c4240945 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py @@ -66,7 +66,6 @@ class TestDistReshape(unittest.TestCase): for rank in range(2): dist_main_prog, dist_context = parallelizer(make_program_dp2, rank) ops = dist_main_prog.global_block().ops - print_program_with_dist_attr(dist_main_prog, dist_context) for idx, op in enumerate(ops): op_dist_attr = dist_context.get_op_dist_attr_for_program(op) assert op_dist_attr.impl_type == "reshape2" diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index 6cf4621dbb0ce840515475b0e82e10605a7e3f06..0914126feb852a02002cdf319a72452ea51a7cf9 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -15,6 +15,7 @@ import unittest import paddle import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() @@ -85,14 +86,9 @@ class TestDistSlice(unittest.TestCase): for op in ops: axes = op.desc.attr('axes') op_dist_attr = dist_context.get_op_dist_attr_for_program(op) - if axes[0] == 0: - assert op_dist_attr.impl_type == "default" - else: - assert op_dist_attr.impl_type == "slice" - for out in op.output_arg_names: - var_dims_mapping = op_dist_attr.get_output_dims_mapping( - out) - assert var_dims_mapping[0] == 0 + assert op_dist_attr.impl_type == "slice" + for out in op.output_arg_names: + var_dims_mapping = op_dist_attr.get_output_dims_mapping(out) def test_dist_slice_serial(self): dist_main_prog, dist_context = parallelizer(make_program_serial, 0) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py similarity index 60% rename from python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py rename to python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py index 07e6a2c4346da42fd9ff0aafef12bf453cc6f463..894bed7108a1d99f569868e693c91d266ce5593b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_auto_parallel_while_op.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py @@ -23,12 +23,13 @@ import paddle.nn.functional as F import paddle.distributed.auto_parallel as auto from paddle.distributed import fleet - +from paddle.distributed.auto_parallel.completion import Completer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.utils import make_data_unshard from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() @@ -283,139 +284,143 @@ def get_program(): def completion(train_program, start_program, dist_context): - blocks = train_program.blocks - # completion tensors - for block in blocks: - for op in block.ops: - if op.type == "layer_norm": - for out_name in op.output_arg_names: - out_var = block.vars[out_name] - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - out_var) - if tensor_dist_attr: - continue - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.process_mesh = _g_process_mesh - tensor_dist_attr.dims_mapping = [-1] - dist_context.set_tensor_dist_attr_for_program( - out_var, tensor_dist_attr) - - elif op.type == "elementwise_sub": - for out_name in op.output_arg_names: - out_var = block.vars[out_name] - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.process_mesh = _g_process_mesh - tensor_dist_attr.dims_mapping = [-1, -1, -1] - dist_context.set_tensor_dist_attr_for_program( - out_var, tensor_dist_attr) - - elif op.type == "matmul_v2": - col = False - for in_name in op.input_arg_names: - if ".w_" not in in_name: - continue - if in_name not in block.vars: - in_var = blocks[0].vars[in_name] - else: - in_var = block.vars[in_name] - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - in_var) - assert tensor_dist_attr is not None - if tensor_dist_attr.dims_mapping == [-1, 0]: - col = True - for out_name in op.output_arg_names: - out_var = block.vars[out_name] - tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( - out_var) - if tensor_dist_attr: - continue - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.process_mesh = _g_process_mesh - if col: - tensor_dist_attr.dims_mapping = [-1, -1, 0] - else: - tensor_dist_attr.dims_mapping = [-1, -1, -1] - dist_context.set_tensor_dist_attr_for_program( - out_var, tensor_dist_attr) - elif op.type == "while": - out_name = op.desc.output("StepScopes")[0] - out_var = block.vars[out_name] - tensor_dist_attr = TensorDistributedAttribute() - tensor_dist_attr.process_mesh = _g_process_mesh - tensor_dist_attr.dims_mapping = [-1] - dist_context.set_tensor_dist_attr_for_program(out_var, - tensor_dist_attr) - - # completion ops - for block in blocks: - for op in block.ops: - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.process_mesh = _g_process_mesh - if op.type == "create_by_read" or op.type == "create_double_buffer_reader": - for in_name in op.input_arg_names: - op_dist_attr.set_input_dims_mapping(in_name, []) - for out_name in op.output_arg_names: - op_dist_attr.set_output_dims_mapping(out_name, []) - elif op.type == "read": - for in_name in op.input_arg_names: - op_dist_attr.set_output_dims_mapping(in_name, []) - for out_name in op.output_arg_names: - out_var = block.vars[out_name] - out_dist_attr = dist_context.get_tensor_dist_attr_for_program( - out_var) - op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) - elif op.type == "while": - for in_name in op.input_arg_names: - in_var = block.vars[in_name] - in_dist_attr = dist_context.get_tensor_dist_attr_for_program( - in_var) - op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) - for out_name in op.output_arg_names: - if out_name == op.desc.output("StepScopes")[0]: - op_dist_attr.set_output_dims_mapping(out_name, []) - else: - out_var = block.vars[out_name] - out_dist_attr = dist_context.get_tensor_dist_attr_for_program( - out_var) - op_dist_attr.set_output_dist_attr(out_name, - out_dist_attr) - else: - for in_name in op.input_arg_names: - if in_name == "lod_tensor_blocking_queue_0": - continue - if in_name not in block.vars: - in_var = blocks[0].vars[in_name] - else: - in_var = block.vars[in_name] - in_dist_attr = dist_context.get_tensor_dist_attr_for_program( - in_var) - op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) - for out_name in op.output_arg_names: - if out_name not in block.vars: - out_var = blocks[0].vars[out_name] - else: - out_var = block.vars[out_name] - out_dist_attr = dist_context.get_tensor_dist_attr_for_program( - out_var) - op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) - - if op.type == "matmul_v2": - op_dist_attr.impl_type = "matmul_v2" - for in_name in op_dist_attr.inputs_dist_attrs.keys(): - in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] - if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: - op_dist_attr.impl_idx = 0 - else: - op_dist_attr.impl_idx = 1 - elif op.type == "fill_constant_batch_size_like": - op_dist_attr.impl_type = "fill_constant_batch_size_like" - op_dist_attr.impl_idx = 0 - else: - op_dist_attr.impl_type = "default" - op_dist_attr.impl_idx = 0 - - dist_context.set_op_dist_attr_for_program(op, op_dist_attr) - make_data_unshard(train_program, start_program, dist_context) + # blocks = train_program.blocks + # # completion tensors + # for block in blocks: + # for op in block.ops: + # if op.type == "layer_norm": + # for out_name in op.output_arg_names: + # out_var = block.vars[out_name] + # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # out_var) + # if tensor_dist_attr: + # continue + # tensor_dist_attr = TensorDistributedAttribute() + # tensor_dist_attr.process_mesh = _g_process_mesh + # tensor_dist_attr.dims_mapping = [-1] + # dist_context.set_tensor_dist_attr_for_program( + # out_var, tensor_dist_attr) + + # elif op.type == "elementwise_sub": + # for out_name in op.output_arg_names: + # out_var = block.vars[out_name] + # tensor_dist_attr = TensorDistributedAttribute() + # tensor_dist_attr.process_mesh = _g_process_mesh + # tensor_dist_attr.dims_mapping = [-1, -1, -1] + # dist_context.set_tensor_dist_attr_for_program( + # out_var, tensor_dist_attr) + + # elif op.type == "matmul_v2": + # col = False + # for in_name in op.input_arg_names: + # if ".w_" not in in_name: + # continue + # if in_name not in block.vars: + # in_var = blocks[0].vars[in_name] + # else: + # in_var = block.vars[in_name] + # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # in_var) + # assert tensor_dist_attr is not None + # if tensor_dist_attr.dims_mapping == [-1, 0]: + # col = True + # for out_name in op.output_arg_names: + # out_var = block.vars[out_name] + # tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # out_var) + # if tensor_dist_attr: + # continue + # tensor_dist_attr = TensorDistributedAttribute() + # tensor_dist_attr.process_mesh = _g_process_mesh + # if col: + # tensor_dist_attr.dims_mapping = [-1, -1, 0] + # else: + # tensor_dist_attr.dims_mapping = [-1, -1, -1] + # dist_context.set_tensor_dist_attr_for_program( + # out_var, tensor_dist_attr) + # elif op.type == "while": + # out_name = op.desc.output("StepScopes")[0] + # out_var = block.vars[out_name] + # tensor_dist_attr = TensorDistributedAttribute() + # tensor_dist_attr.process_mesh = _g_process_mesh + # tensor_dist_attr.dims_mapping = [-1] + # dist_context.set_tensor_dist_attr_for_program(out_var, + # tensor_dist_attr) + + # # completion ops + # for block in blocks: + # for op in block.ops: + # op_dist_attr = OperatorDistributedAttribute() + # op_dist_attr.process_mesh = _g_process_mesh + # if op.type == "create_by_read" or op.type == "create_double_buffer_reader": + # for in_name in op.input_arg_names: + # op_dist_attr.set_input_dims_mapping(in_name, []) + # for out_name in op.output_arg_names: + # op_dist_attr.set_output_dims_mapping(out_name, []) + # elif op.type == "read": + # for in_name in op.input_arg_names: + # op_dist_attr.set_output_dims_mapping(in_name, []) + # for out_name in op.output_arg_names: + # out_var = block.vars[out_name] + # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # out_var) + # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) + # elif op.type == "while": + # for in_name in op.input_arg_names: + # in_var = block.vars[in_name] + # in_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # in_var) + # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) + # for out_name in op.output_arg_names: + # if out_name == op.desc.output("StepScopes")[0]: + # op_dist_attr.set_output_dims_mapping(out_name, []) + # else: + # out_var = block.vars[out_name] + # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # out_var) + # op_dist_attr.set_output_dist_attr(out_name, + # out_dist_attr) + # else: + # for in_name in op.input_arg_names: + # if in_name == "lod_tensor_blocking_queue_0": + # continue + # if in_name not in block.vars: + # in_var = blocks[0].vars[in_name] + # else: + # in_var = block.vars[in_name] + # in_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # in_var) + # op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) + # for out_name in op.output_arg_names: + # if out_name not in block.vars: + # out_var = blocks[0].vars[out_name] + # else: + # out_var = block.vars[out_name] + # out_dist_attr = dist_context.get_tensor_dist_attr_for_program( + # out_var) + # op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) + + # if op.type == "matmul_v2": + # op_dist_attr.impl_type = "matmul_v2" + # for in_name in op_dist_attr.inputs_dist_attrs.keys(): + # in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name] + # if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0: + # op_dist_attr.impl_idx = 0 + # else: + # op_dist_attr.impl_idx = 1 + # elif op.type == "fill_constant_batch_size_like": + # op_dist_attr.impl_type = "fill_constant_batch_size_like" + # op_dist_attr.impl_idx = 0 + # else: + # op_dist_attr.impl_type = "default" + # op_dist_attr.impl_idx = 0 + + # dist_context.set_op_dist_attr_for_program(op, op_dist_attr) + # make_data_unshard(train_program, start_program, dist_context) + + completer = Completer(dist_context) + train_program = completer.complete_forward_annotation(train_program) + make_data_unshard(train_program, start_program, dist_context) return train_program, start_program diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py index 036b46470a762503d16c25763fa42f75e8740aa3..3ddd41158a69eec726315c80244e3990a0a676e9 100755 --- a/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_parallelizer.py @@ -134,7 +134,6 @@ class TestMLPAutoParallelizer(unittest.TestCase): for op in block.ops: for attr_name in op.attr_names: self.assertTrue(suffix not in attr_name) - # print_program_with_dist_attr(distributed_main_program) self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_main_program) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index a33874a330a21ad7e28bea266cffac005d46f18d..9888d2c68f195d9ac7f8557fea61dd440c9f10e7 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -332,7 +332,6 @@ class TestMLPReshard(unittest.TestCase): resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() - print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))