diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index 34c46e446af2731422d5d9b0fcc5baab86d2f377..660b1a54221a793cd55d000af7eae4cf44d0bffe 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), dist_op, fwd=True) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - # This statement will be replaced by a good way - if op_dist_impl.is_compatible(dist_op): - op_dist_attr.impl_type = op_desc.type() - op_dist_attr.impl_idx = op_dist_impl_idx - elif is_elementwise_like_op(op_desc.type()): - dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "element-wise" - op_dist_attr.impl_idx = -1 - else: - dim_changed = update_op_dims_mapping_by_default_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "default" - op_dist_attr.impl_idx = -2 + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=True) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx else: for tensor_node in op_node.outputs: if tensor_node.var() is not None: @@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): tensor_desc.name(), compatible_dims_mapping) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( - op_desc.type(), dist_op, fwd=False) - if op_dist_impl is not None: - dim_changed = op_dist_impl.update_dims_mapping(dist_op) - if dim_changed: - changed = True - # This statement will be replaced by a good way - if op_dist_impl.is_compatible(dist_op): - op_dist_attr.impl_type = op_desc.type() - op_dist_attr.impl_idx = op_dist_impl_idx - elif is_elementwise_like_op(op_desc.type()): - dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "element-wise" - op_dist_attr.impl_idx = -1 - else: - dim_changed = update_op_dims_mapping_by_default_dist_impl( - dist_context, op_node) - if dim_changed: - changed = True - op_dist_attr.impl_type = "default" - op_dist_attr.impl_idx = -2 + op_dist_impl = find_best_compatible_distributed_operator_impl( + dist_op, fwd=False) + assert op_dist_impl is not None, "Cannot find the dist op implementation." + dim_changed = op_dist_impl.update_dims_mapping(dist_op) + if dim_changed: + changed = True + if op_dist_impl.is_auto_compatible(dist_op): + if op_dist_impl.type == "elementwise": + op_dist_attr.impl_type = "default" + else: + op_dist_attr.impl_type = op_dist_impl.type + op_dist_attr.impl_idx = op_dist_impl.idx return changed diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index d3bf9e22db4387012d7d562da7ec4cc1b4a5b35c..ad3a53ff17d769af81a15b930d7607014bbb286d 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -61,6 +61,8 @@ class DistributedContext: # Other data members self._dist_op_context = DistributedOperatorContext() self._process_meshes = [] + self._serial_ordered_nodes = [] + self._tensor_id_to_tensor_node_ids = {} # Distributed programs self._dist_main_programs = {} @@ -80,6 +82,10 @@ class DistributedContext: "This distributed context has already been realted to a serial program" self._serial_program = program + @property + def serial_ordered_nodes(self): + return self._serial_ordered_nodes + @property def process_meshes(self): return self._process_meshes @@ -186,6 +192,18 @@ class DistributedContext: else: return None + # def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr): + # assert serial_tensor_node.is_var() and \ + # serial_tensor_node.var() is not None + # serial_tensor_id = serial_tensor_node.node.original_desc_id() + # dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + # assert dist_tensor is not None, \ + # "The distributed tensor of the program has not been added to this context." + # serial_tensor_node_id = serial_tensor_node.id() + # new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + # dist_attr) + # self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor + def get_op_dist_attr_for_program(self, serial_op): serial_op_id = serial_op.desc.id() dist_op = self._dist_ops_for_program.get(serial_op_id, None) @@ -218,6 +236,35 @@ class DistributedContext: else: return None + # def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr): + # assert serial_op_node.is_op() and \ + # serial_op_node.op() is not None + # serial_op_id = serial_op_node.node.original_desc_id() + # dist_op = self._dist_ops_for_program.get(serial_op_id, None) + # assert dist_op is not None, \ + # "The distributed operator of the program has not been added to this context." + # serial_op_node_id = serial_op_node.id() + # new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr) + # self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + + # def get_dist_attr_for_graph(self, serial_node): + # if serial_node.is_var() and serial_node.var() is not None: + # serial_tensor_node_id = serial_node.id() + # dist_tensor = self._dist_tensors_for_graph.get( + # serial_tensor_node_id, None) + # if dist_tensor: + # return dist_tensor.dist_attr + # else: + # return None + # if serial_node.is_op() and serial_node.op() is not None: + # serial_op_node_id = serial_node.id() + # dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + # if dist_op: + # return dist_op.dist_attr + # else: + # return None + # return None + def init_dist_attr_for_program(self): assert self._serial_program, \ "Please set the program of this context before initializing its distribute attributes." @@ -248,6 +295,44 @@ class DistributedContext: self.add_dist_op_for_program(dist_op) self._is_initialized_for_program = True + def order_nodes_by_program_order(self): + def _contains(nodes, target_node): + for node in nodes: + if node.id() == target_node.id(): + return True + return False + + ordered_tensor_nodes = [] + ordered_op_nodes = [] + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + ordered_tensor_nodes.append(node) + if node.is_op() and node.op() is not None: + ordered_op_nodes.append(node) + ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id()) + for op_node in ordered_op_nodes: + tensor_nodes = [] + for tensor_node in op_node.inputs: + if tensor_node.is_var() \ + and tensor_node.var() is not None \ + and not _contains(self._serial_ordered_nodes, tensor_node): + tensor_nodes.append(tensor_node) + tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + self._serial_ordered_nodes.extend(tensor_nodes) + self._serial_ordered_nodes.append(op_node) + tensor_nodes = [] + for tensor_node in op_node.outputs: + if tensor_node.is_var() \ + and tensor_node.var() is not None \ + and not _contains(self._serial_ordered_nodes, tensor_node): + tensor_nodes.append(tensor_node) + self._serial_ordered_nodes.extend(tensor_nodes) + num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes) + assert len(self._serial_ordered_nodes) == num_nodes_before, \ + "The number of nodes before ordering is not the same after ordering." + def init_dist_attr_for_graph(self): assert self._is_initialized_for_program, \ "The program must be initialized before initializing the distributed attributes for its graph." @@ -257,7 +342,8 @@ class DistributedContext: self._serial_graph = framework.IrGraph( core.Graph(self._serial_program.desc)) all_nodes = self._serial_graph.all_nodes() - for node in all_nodes: + self.order_nodes_by_program_order() + for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() @@ -397,7 +483,9 @@ class DistributedContext: result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": + if k == "_serial_program" or k == "_serial_graph" \ + or k == "_dist_main_programs" or k == "_dist_startup_programs" \ + or k == "_serial_ordered_nodes": setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index e2de876f01cade351523cbb75af6b1aea44dade1..a7cc2a9600c05f8e528860ffe4ed28a729ac0bbb 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -98,7 +98,7 @@ class DistributedOperator: if self._dist_attr.impl_type is None: self._dist_attr.impl_type = "default" if self._dist_attr.impl_idx is None: - self._dist_attr.impl_idx = -2 + self._dist_attr.impl_idx = 0 if self._dist_attr.is_recompute is None: self._dist_attr.is_recompute = False @@ -217,7 +217,8 @@ class DistributedOperator: str += ", pipeline stage: {}".format(None) - str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx) + str += ", dist_impl idx: {} , dist_impl type {} }}".format( + self.dist_attr._impl_idx, self.dist_attr._impl_type) return str diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index c28b7930124dd6bec09716ea3a2c84ca6c4eff30..ea743df8d643b179f8b8194bb1771c96af3c7543 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -23,5 +23,6 @@ from . import dist_reshape from . import dist_softmax from . import dist_transpose from . import dist_default +from . import dist_eltwise from . import dist_check_finite_and_unscale from . import dist_update_loss_scaling diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 505e29282b87068bd822388270065f6d1ddbd12b..4b079e7b6b575a6bcfd372782529ccc2958cf5db 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -12,109 +12,196 @@ # See the License for the specific language governing permissions and # limitations under the License +import abc from ..dist_attribute import OperatorDistributedAttribute -_g_distributed_operator_impl_registries = {} +_g_distributed_operator_impl_containers = {} + +_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} +def is_elementwise_op(op_type): + if op_type in _g_elementwise_ops: + return True + else: + return False + + class DistributedOperatorImplContainer: - def __init__(self): + def __init__(self, op_type): + self._type = op_type self._impls = [] - self._name = None + + @property + def type(self): + return self._type + + @type.setter + def type(self, op_type): + self._type = op_type + + @property + def impls(self): + return self._impls def register_impl(self, dist_impl): + assert self.type == dist_impl.type, \ + "Op type of container must be same as that of the implementation." + impl_idx = len(self.impls) + dist_impl.idx = impl_idx self._impls.append(dist_impl) def get_impl(self, impl_idx): return self._impls[impl_idx] - def get_impls(self): - return self._impls - + def get_input_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_input_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls -class DistributedOperatorImpl: - def __init__(self): - self._name = None + def get_output_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_output_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls + + def get_compatible_impls(self, dist_op): + compatible_impls = [] + for impl in self.impls: + if impl.is_auto_compatible(dist_op): + compatible_impls.append(impl) + return compatible_impls + + +class DistributedOperatorImpl(abc.ABC): + def __init__(self, name): + self._name = name + self._type = None + self._idx = None self._forward_implemented = False self._backward_implemented = False - @staticmethod - def forward(dist_ctx, *args, **kwargs): - raise NotImplementedError("Please Implement this method in Subclass.") + @property + def name(self): + return self._name - @staticmethod - def backward(dist_ctx, *grad_outputs, **kwargs): - raise NotImplementedError("Please Implement this method in Subclass.") + @name.setter + def name(self, name): + self._name = name - def get_name(self): - return self._name + @property + def type(self): + return self._type + + @type.setter + def type(self, op_type): + self._type = op_type + + @property + def idx(self): + return self._idx + @idx.setter + def idx(self, impl_idx): + self._idx = impl_idx + + @abc.abstractmethod def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") + @abc.abstractmethod def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_compatible(self, dist_op): - return self.is_input_compatible(dist_op) and \ - self.is_output_compatible(dist_op) - + @abc.abstractmethod def is_auto_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") + @staticmethod + @abc.abstractmethod + def forward(dist_ctx, *args, **kwargs): + raise NotImplementedError("Please Implement this method in Subclass.") + + @staticmethod + @abc.abstractmethod + def backward(dist_ctx, *grad_outputs, **kwargs): + raise NotImplementedError("Please Implement this method in Subclass.") + def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") -def register_distributed_operator_impl_container(name, dist_op_impl_container): - global _g_distributed_operator_impl_registries - _g_distributed_operator_impl_registries[name] = dist_op_impl_container +def register_distributed_operator_impl_container(container): + global _g_distributed_operator_impl_containers + _g_distributed_operator_impl_containers[container.type] = container -def get_distributed_operator_impl_container(name): - global _g_distributed_operator_impl_registries - return _g_distributed_operator_impl_registries.get(name, None) +def get_distributed_operator_impl_container(op_type): + global _g_distributed_operator_impl_containers + return _g_distributed_operator_impl_containers.get(op_type, None) -def register_distributed_operator_impl(name, dist_impl): - dist_op_impl_container = get_distributed_operator_impl_container(name) +def register_distributed_operator_impl(op_type, dist_impl): + dist_op_impl_container = get_distributed_operator_impl_container(op_type) if dist_op_impl_container is not None: + dist_impl.type = op_type dist_op_impl_container.register_impl(dist_impl) else: assert False, "Must register distributed operator registry first." -def get_distributed_operator_impl(name, impl_idx): - global _g_distributed_operator_impl_registries - return _g_distributed_operator_impl_registries[name].get_impl(impl_idx) - - -def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): +def find_best_compatible_distributed_operator_impl(dist_op, fwd=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. """ - dist_op_impl_container = get_distributed_operator_impl_container(name) - if dist_op_impl_container is None: - return None, -1 + op_type = dist_op.serial_op.type + dist_op_impl_container = get_distributed_operator_impl_container(op_type) + dist_op_eltwise_impl_container = get_distributed_operator_impl_container( + "elementwise") + dist_op_default_impl_container = get_distributed_operator_impl_container( + "default") compatible_impls = [] - impls = dist_op_impl_container.get_impls() if fwd: - for idx, impl in enumerate(impls): - if impl.is_input_compatible(dist_op): - compatible_impls.append((impl, idx)) + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_input_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_input_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_input_compatible_impls( + dist_op)) else: - for idx, impl in enumerate(impls): - if impl.is_output_compatible(dist_op): - compatible_impls.append((impl, idx)) - + # First, find impls in the corresponding container + if dist_op_impl_container: + compatible_impls.extend( + dist_op_impl_container.get_output_compatible_impls(dist_op)) + # Second, find impls in the elementwise container + if dist_op_eltwise_impl_container and is_elementwise_op(op_type): + compatible_impls.extend( + dist_op_eltwise_impl_container.get_output_compatible_impls( + dist_op)) + # Third, find impls in the default container + if dist_op_default_impl_container: + compatible_impls.extend( + dist_op_default_impl_container.get_output_compatible_impls( + dist_op)) if compatible_impls: - best_compatible_impl, idx = compatible_impls[0] + # For now, just return the first compatible impl + best_compatible_impl = compatible_impls[0] else: - best_compatible_impl, idx = None, -1 - - return best_compatible_impl, idx + best_compatible_impl = None + return best_compatible_impl def is_parameter_related(varname, block): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 00dc346f9a2ac52c55299dd0523dc0c565aa3e4e..52d5e85c962eb2cb28578c43abf4dd7c6c5cce82 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -30,19 +30,17 @@ global_process_mesh = get_world_process_group().ranks class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedCheckFiniteAndUnscale, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedCheckFiniteAndUnscale, self).__init__(op_type) register_distributed_operator_impl_container( - "check_finite_and_unscale", DistributedCheckFiniteAndUnscale("check_finite_and_unscale")) class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedCheckFiniteAndUnscaleImpl, self).__init__() + super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name) self._name = name self._forward_implemented = False self._backward_implemented = True @@ -57,6 +55,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): "DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !" ) + def is_auto_compatible(self, dist_op): + raise RuntimeError( + "DistributedCheckFiniteAndUnscaleImpl's is_auto_compatible should not be called !" + ) + def update_dims_mapping(self, dist_op): raise RuntimeError( "DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !" diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index a98ec89a5099a79301de6865b8b2830a091121f5..48f9b5a78dd8a371962ed4b72babe01dcc1ac5d4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -34,31 +34,162 @@ from ..utils import _get_comm_group, _get_corresponding_rank class DistributedDefault(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedDefault, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedDefault, self).__init__(op_type) -register_distributed_operator_impl_container("default", - DistributedDefault("default")) +register_distributed_operator_impl_container(DistributedDefault("default")) # Replicated Default class DistributedDefaultImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedDefaultImpl0, self).__init__() - self._name = name + super(DistributedDefaultImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True def is_input_compatible(self, dist_op): - raise NotImplementedError("Please Implement this method.") + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + return True def is_output_compatible(self, dist_op): - raise NotImplementedError("Please Implement this method.") + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False + return True + + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + batch_dim_mappings = [] + # Check input compatibility + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[0]) + + # Check output compatibility + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for mapping in dims_mapping[1:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[0]) + else: + if dims_mapping[0] != -1: + return False + if len(dims_mapping) > 2: + for mapping in dims_mapping[2:]: + if mapping != -1: + return False + batch_dim_mappings.append(dims_mapping[1]) + + # Check batch dim mapping compatibility + if not all(batch_dim_mappings[0] == dim_mapping + for dim_mapping in batch_dim_mappings): + return False + + return True def update_dims_mapping(self, dist_op): - raise NotImplementedError("Please Implement this method.") + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + # The following statement will be replaced by a more elegent way + if op_desc.type() == "shape" or op_desc.type() == "slice": + return False + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + batch_dim_mappings.append(dims_mapping[0]) + else: + batch_dim_mappings.append(dims_mapping[1]) + + compatible_dim_mapping = compute_compatible_dim_mapping( + batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + for arg_name in op_desc.input_arg_names(): + serial_tensor = dist_op.get_serial_input(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + serial_tensor = dist_op.get_serial_output(arg_name) + if serial_tensor.is_parameter: + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True + + return changed @staticmethod def forward(ctx, *args, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py new file mode 100755 index 0000000000000000000000000000000000000000..7d33692e46af9d58e9c9af10d29dd1dca17d020b --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from .common import is_elementwise_op +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping +from ..dist_attribute import OperatorDistributedAttribute +from paddle.fluid import core, unique_name +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import Program, Parameter, Variable, program_guard +from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from ..process_group import new_process_group +from ..utils import _get_comm_group, _get_corresponding_rank +from .dist_default import DistributedDefaultImpl0 + + +class DistributedElementwise(DistributedOperatorImplContainer): + def __init__(self, op_type): + super(DistributedElementwise, self).__init__(op_type) + + +register_distributed_operator_impl_container( + DistributedElementwise("elementwise")) + + +# Replicated Elementwise +class DistributedElementwiseImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedElementwiseImpl0, self).__init__(name) + self._forward_implemented = False + self._backward_implemented = False + + def is_input_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + if is_elementwise_op(op_desc.type()): + return True + else: + return False + + def is_output_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_desc = dist_op.serial_op.desc + if is_elementwise_op(op_desc.type()): + return True + else: + return False + + def is_auto_compatible(self, dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + dims_mapping_list = [] + input_arg_names = op_desc.input_arg_names() + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + dims_mapping_list.append(dims_mapping) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + for idx in range(max_dims_mapping_len): + dim_mappings = [] + for dims_mapping in dims_mapping_list: + if idx < len(dims_mapping): + dim_mappings.append(dims_mapping[-(idx + 1)]) + if not all(dim_mappings[0] == dim_mapping + for dim_mapping in dim_mappings): + return False + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + input_arg_names = op_desc.input_arg_names() + input_dims_mapping_dict = {} + input_dims_mapping_lens = {} + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + input_dims_mapping_dict[arg_name] = dims_mapping + input_dims_mapping_lens[arg_name] = len(dims_mapping) + + dims_mapping_list = [] + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = input_dims_mapping_dict[ + arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_dims_mapping_lens[arg_name]) + ] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if compatible_dims_mapping != dims_mapping: + op_dist_attr.set_output_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl( + "elementwise", DistributedElementwiseImpl0("replicate_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index f019f499aa305b8cfde423fd64c1ad232e281656..eac4776f8f3bcdbffc85725a2280b30c6bcff060 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -34,22 +34,20 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank class DistributedEmbedding(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedEmbedding, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedEmbedding, self).__init__(op_type) -register_distributed_operator_impl_container("lookup_table_v2", - DistributedEmbedding("embedding")) -register_distributed_operator_impl_container("c_embedding", - DistributedEmbedding("embedding")) +register_distributed_operator_impl_container( + DistributedEmbedding("lookup_table_v2")) +register_distributed_operator_impl_container( + DistributedEmbedding("c_embedding")) # RowParallel class DistributedEmbeddingImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedEmbeddingImpl, self).__init__() - self._name = name + super(DistributedEmbeddingImpl, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -81,6 +79,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr ids_name = op_desc.input('Ids')[0] @@ -89,18 +91,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) - if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[ - -1]): - return False - # Other dimensions must be replicate except the batch dimension - for mapping in ids_dims_mapping[1:]: - if is_dim_shard(mapping): - return False - for mapping in out_dims_mapping[1:]: - if is_dim_shard(mapping): - return False - if w_dims_mapping[-1] != out_dims_mapping[-1]: - return False + if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: return False @@ -248,6 +239,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): # matmulv2 embedding_op_dist_attr = OperatorDistributedAttribute() embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh + embedding_op_dist_attr.impl_type = op_dist_attr.impl_type embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_embedding_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -266,6 +258,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index b0b185819c58ae911c167e86ea1408631ff0d475..737fc3712b1a98e71a514fa235c324fdd58c995e 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank +from .dist_default import DistributedDefaultImpl0 def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): @@ -143,6 +144,68 @@ def _update_dims_mapping_for_matmul(dist_op): return changed +def _is_auto_compatible_for_matmul(dist_op): + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + out_name = op_desc.output('Out')[0] + # Deep copy these dims_mappings for keeping them unchanged. + x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name)) + y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name)) + out_dims_mapping = copy.deepcopy( + op_dist_attr.get_output_dims_mapping(out_name)) + x_dims_mapping_len = len(x_dims_mapping) + y_dims_mapping_len = len(y_dims_mapping) + out_dims_mapping_len = len(out_dims_mapping) + + # Add dim mapping to Make sure the length dims_mapping be at least 2 + if x_dims_mapping_len == 1: + x_dims_mapping.insert(0, -1) + if y_dims_mapping_len == 1: + y_dims_mapping.insert(1, -1) + + # Deal with dim > 2 and take care of broadcasting + if out_dims_mapping_len > 2: + broadcast_x_dims_mapping = [] + broadcast_y_dims_mapping = [] + broadcast_out_dims_mapping = [] + + for i in range(out_dims_mapping_len - x_dims_mapping_len): + broadcast_x_dims_mapping.append(out_dims_mapping[i]) + for i in range(x_dims_mapping_len - 2): + broadcast_x_dims_mapping.append(x_dims_mapping[i]) + + for i in range(out_dims_mapping_len - y_dims_mapping_len): + broadcast_y_dims_mapping.append(out_dims_mapping[i]) + for i in range(y_dims_mapping_len - 2): + broadcast_y_dims_mapping.append(y_dims_mapping[i]) + + for i in range(out_dims_mapping_len - 2): + broadcast_out_dims_mapping.append(out_dims_mapping[i]) + + is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping) and + (broadcast_x_dims_mapping == broadcast_out_dims_mapping)) + if not is_same: + return False + + # The following which uses negative index can be work + # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 + is_same = (x_dims_mapping[-1] == y_dims_mapping[-2]) + if not is_same: + return False + + is_same = (x_dims_mapping[-2] == out_dims_mapping[-2]) + if not is_same: + return False + + is_same = (y_dims_mapping[-1] == out_dims_mapping[-1]) + if not is_same: + return False + + return True + + def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): # by now the backward function only insert the gradient allreduce for dist op itself @@ -194,10 +257,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name) process_mesh_shape = dist_attr.process_mesh.topology process_mesh_group = dist_attr.process_mesh.processes - assert len( - Y_var_dim_mapping - ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( - Y_var.name, Y_var_dim_mapping) + # assert len( + # Y_var_dim_mapping + # ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format( + # Y_var.name, Y_var_dim_mapping) Y_var_partitioned = False for dim in Y_var_dim_mapping: if dim >= 0 and process_mesh_shape[dim] > 0: @@ -388,20 +451,17 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): class DistributedMatmul(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedMatmul, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedMatmul, self).__init__(op_type) -register_distributed_operator_impl_container("matmul", - DistributedMatmul("matmul")) +register_distributed_operator_impl_container(DistributedMatmul("matmul")) # ColumnParallel class DistributedMatmulImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl0, self).__init__() - self._name = name + super(DistributedMatmulImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -414,8 +474,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) if is_dim_shard(x_dims_mapping[-1]): return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): + if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[ + -1]): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -435,83 +495,11 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_replicate(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - input_dims_mapping = [] - ordered_input_shard_dims_mapping = [] - - for dim in (x_dims_mapping + y_dims_mapping): - input_dims_mapping.append(dim) - - for item in input_dims_mapping: - if item not in ordered_input_shard_dims_mapping and item != -1: - ordered_input_shard_dims_mapping.append(item) - - for mapping in out_dims_mapping: - if mapping not in input_dims_mapping: - return False - - if is_dim_shard(x_dims_mapping[0]): - order_index = 0 - for idx, item in enumerate(out_dims_mapping): - if item != -1: - if item != ordered_input_shard_dims_mapping[order_index]: - return False - else: - order_index += 1 - if order_index != len(ordered_input_shard_dims_mapping): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): - return False - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_shard(x_dims_mapping[0]): - for mapping in y_dims_mapping[1:]: - if is_dim_shard(mapping) and mapping == x_dims_mapping[0]: - return False - return True def update_dims_mapping(self, dist_op): @@ -635,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # c_identity identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input input_varname = c_identity_op.desc.input_arg_names()[0] @@ -653,6 +642,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # matmul matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input for input_varname in matmul_op.desc.input_arg_names(): @@ -692,8 +682,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # RowParallel class DistributedMatmulImpl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl1, self).__init__() - self._name = name + super(DistributedMatmulImpl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -729,93 +718,12 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'): - return False - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - # for gpt2, x dims > y dims, this is a temporary solution - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_shard(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - # Other dimensions must be replicate except the batch dimension - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_replicate(x_dims_mapping[-1]): - return False - - if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ - -1]): - return False - - # Other dimensions must be replicate except the batch dimension - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - x_shard_dim_count = 0 - x_shard_dims = [] - y_shard_dim_count = 0 - y_shard_dims = [] - for dim in x_dims_mapping: - if is_dim_shard(dim): - x_shard_dim_count += 1 - x_shard_dims.append(dim) - - for dim in y_dims_mapping: - if is_dim_shard(dim): - y_shard_dim_count += 1 - y_shard_dims.append(dim) - - if not x_shard_dims and not y_shard_dims: - return False - - if x_shard_dims[-1] != y_shard_dims[0]: - return False - - if x_shard_dim_count == y_shard_dim_count: - for dim in out_dims_mapping: - if is_dim_shard(dim): - return False - if x_shard_dims != y_shard_dims: - return False - else: - if x_shard_dim_count < y_shard_dim_count: - return False - output_shard_dims = [] - for dim in out_dims_mapping: - if is_dim_shard(dim): - output_shard_dims.append(dim) - if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]: - return False return True @@ -933,6 +841,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # matmul matmul_op_dist_attr = OperatorDistributedAttribute() matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmul_op_dist_attr.impl_type = op_dist_attr.impl_type matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -951,6 +860,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) @@ -980,8 +890,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # ReplicateParallel class DistributedMatmulImpl2(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulImpl2, self).__init__() - self._name = name + super(DistributedMatmulImpl2, self).__init__(name) def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc @@ -1020,56 +929,11 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping - ), "now just support x dims > y dims,but x:{0} and y:{1}".format( - x_dims_mapping, y_dims_mapping) - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_shard(out_dims_mapping[-1]): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if is_valid_list_index(out_dims_mapping, - -2) and is_dim_shard(out_dims_mapping[-2]): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_valid_list_index(x_dims_mapping, - -2) and is_dim_shard(x_dims_mapping[-2]): - return False - - if is_dim_shard(y_dims_mapping[-1]): - return False - - if is_valid_list_index(y_dims_mapping, - -2) and is_dim_shard(y_dims_mapping[-2]): + if not _is_auto_compatible_for_matmul(dist_op): return False return True @@ -1081,6 +945,10 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): changed = True return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) @@ -1095,20 +963,17 @@ register_distributed_operator_impl("matmul", class DistributedMatmulV2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedMatmulV2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedMatmulV2, self).__init__(op_type) -register_distributed_operator_impl_container("matmul_v2", - DistributedMatmulV2("matmul_v2")) +register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2")) # ColumnParallel class DistributedMatmulV2Impl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl0, self).__init__() - self._name = name + super(DistributedMatmulV2Impl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -1121,8 +986,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) if is_dim_shard(x_dims_mapping[-1]): return False - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): + if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[ + -1]): return False for mapping in x_dims_mapping[1:-1]: if is_dim_shard(mapping): @@ -1142,85 +1007,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - - if op_desc.attr('trans_x') or op_desc.attr('trans_y'): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_replicate(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - input_dims_mapping = [] - ordered_input_shard_dims_mapping = [] - - for dim in (x_dims_mapping + y_dims_mapping): - input_dims_mapping.append(dim) - - for item in input_dims_mapping: - if item not in ordered_input_shard_dims_mapping and item != -1: - ordered_input_shard_dims_mapping.append(item) - - for mapping in out_dims_mapping: - if mapping not in input_dims_mapping: - return False - - if is_dim_shard(x_dims_mapping[0]): - order_index = 0 - for idx, item in enumerate(out_dims_mapping): - if item != -1: - if item != ordered_input_shard_dims_mapping[order_index]: - return False - else: - order_index += 1 - if order_index != len(ordered_input_shard_dims_mapping): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ - 1]): - return False - - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_shard(x_dims_mapping[0]): - for mapping in y_dims_mapping[1:]: - if is_dim_shard(mapping) and mapping == x_dims_mapping[0]: - return False - return True def update_dims_mapping(self, dist_op): @@ -1342,6 +1135,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # c_identity identity_op_dist_attr = OperatorDistributedAttribute() identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh + identity_op_dist_attr.impl_type = op_dist_attr.impl_type identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx # input input_varname = c_identity_op.desc.input_arg_names()[0] @@ -1359,6 +1153,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # matmulv2 matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_v2_op.desc.input_arg_names(): if input_varname in src_op.desc.input_arg_names(): @@ -1395,8 +1190,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # RowParallel class DistributedMatmulV2Impl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl1, self).__init__() - self._name = name + super(DistributedMatmulV2Impl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = True @@ -1432,93 +1226,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - if op_desc.attr('trans_x') or op_desc.attr('trans_y'): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping), "now just support x dims > y dims" - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - if is_dim_shard(out_dims_mapping[-1]): + if not _is_auto_compatible_for_matmul(dist_op): return False - # Other dimensions must be replicate except the batch dimension - for mapping in out_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - if is_dim_replicate(x_dims_mapping[-1]): - return False - - if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ - -1]): - return False - - # Other dimensions must be replicate except the batch dimension - for mapping in x_dims_mapping[1:-1]: - if is_dim_shard(mapping): - return False - - x_shard_dim_count = 0 - x_shard_dims = [] - y_shard_dim_count = 0 - y_shard_dims = [] - for dim in x_dims_mapping: - if is_dim_shard(dim): - x_shard_dim_count += 1 - x_shard_dims.append(dim) - - for dim in y_dims_mapping: - if is_dim_shard(dim): - y_shard_dim_count += 1 - y_shard_dims.append(dim) - - if not x_shard_dims and not y_shard_dims: - return False - - if x_shard_dims[-1] != y_shard_dims[0]: - return False - - if x_shard_dim_count == y_shard_dim_count: - for dim in out_dims_mapping: - if is_dim_shard(dim): - return False - if x_shard_dims != y_shard_dims: - return False - else: - if x_shard_dim_count < y_shard_dim_count: - return False - output_shard_dims = [] - for dim in out_dims_mapping: - if is_dim_shard(dim): - output_shard_dims.append(dim) - if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]: - return False return True def update_dims_mapping(self, dist_op): @@ -1631,6 +1345,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # matmulv2 matmulv2_op_dist_attr = OperatorDistributedAttribute() matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh + matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in matmul_v2_op.desc.input_arg_names(): input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) @@ -1649,6 +1364,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # allreduce allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh + allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): input_var = main_block.var(input_varname) @@ -1678,8 +1394,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # ReplicateParallel class DistributedMatmulV2Impl2(DistributedOperatorImpl): def __init__(self, name): - super(DistributedMatmulV2Impl2, self).__init__() - self._name = name + super(DistributedMatmulV2Impl2, self).__init__(name) def is_input_compatible(self, dist_op): op_desc = dist_op.serial_op.desc @@ -1720,57 +1435,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - op_desc = dist_op.serial_op.desc - op_dist_attr = dist_op.dist_attr - x_name = op_desc.input('X')[0] - y_name = op_desc.input('Y')[0] - out_name = op_desc.output('Out')[0] - out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) - assert len(x_dims_mapping) >= len( - y_dims_mapping - ), "now just support x dims > y dims,but x:{0} and y:{1}".format( - x_dims_mapping, y_dims_mapping) - if len(y_dims_mapping) != 2: - return False - if len(x_dims_mapping) == len(y_dims_mapping) and len( - x_dims_mapping) == 4: - if x_dims_mapping[:2] != y_dims_mapping[:2]: - return False - if x_dims_mapping[:2] != out_dims_mapping[:2]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - elif len(x_dims_mapping) != len(y_dims_mapping) and len( - x_dims_mapping) == 3: - if x_dims_mapping[0] != out_dims_mapping[0]: - return False - x_dims_mapping = x_dims_mapping[-2:] - y_dims_mapping = y_dims_mapping[-2:] - out_dims_mapping = out_dims_mapping[-2:] - - if is_dim_shard(out_dims_mapping[-1]): - return False - - if is_valid_list_index(out_dims_mapping, - -2) and is_dim_shard(out_dims_mapping[-2]): - return False - - if is_dim_shard(x_dims_mapping[-1]): - return False - - if is_valid_list_index(x_dims_mapping, - -2) and is_dim_shard(x_dims_mapping[-2]): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): return False - if is_dim_shard(y_dims_mapping[-1]): - return False - - if is_valid_list_index(y_dims_mapping, - -2) and is_dim_shard(y_dims_mapping[-2]): + if not _is_auto_compatible_for_matmul(dist_op): return False return True @@ -1782,6 +1451,10 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): changed = True return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): _right_operand_parameter_matmul_backward(ctx, *args, **kwargs) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index e287bd75b35890233f7a6c30d50b6fd08ab58b4b..93b0d91b7836d64ae6e1dc9b17161746bc6b8444 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -27,22 +27,20 @@ from paddle.fluid import core, unique_name from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype +from .dist_default import DistributedDefaultImpl0 class DistributedReshape2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedReshape2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedReshape2, self).__init__(op_type) -register_distributed_operator_impl_container("reshape2", - DistributedReshape2("reshape2")) +register_distributed_operator_impl_container(DistributedReshape2("reshape2")) class DistributedReshapeImpl0(DistributedOperatorImpl): def __init__(self, name): - super(DistributedReshapeImpl0, self).__init__() - self._name = name + super(DistributedReshapeImpl0, self).__init__(name) self._forward_implemented = True self._backward_implemented = False @@ -76,6 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -85,17 +87,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): x_shape_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if len(x_dims_mapping) != len(out_dims_mapping) - 1: - return False - if is_dim_shard(out_dims_mapping[-1]): - return False - - for idx, item in enumerate(out_dims_mapping[:-2]): - if x_dims_mapping[idx] != item: + for idx, dim_mapping in enumerate(out_dims_mapping[:-1]): + if x_dims_mapping[idx] != dim_mapping: return False - if out_dims_mapping[-2] != x_dims_mapping[-1]: - return False if x_shape_dims_mapping[0] != -1: return False @@ -194,13 +189,12 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) class DistributedReshapeImpl1(DistributedOperatorImpl): def __init__(self, name): - super(DistributedReshapeImpl1, self).__init__() - self._name = name + super(DistributedReshapeImpl1, self).__init__(name) self._forward_implemented = True self._backward_implemented = False @@ -234,6 +228,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -244,24 +242,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( x_shape_name) - if len(x_dims_mapping) == len(out_dims_mapping) + 2: - if out_dims_mapping[0] != x_dims_mapping[0]: - return False - if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1: - return False - elif len(x_dims_mapping) != len(out_dims_mapping) + 1: - return False - if is_dim_shard(x_dims_mapping[-1]): return False - for idx, item in enumerate(x_dims_mapping[:-2]): + for idx, item in enumerate(x_dims_mapping[:-1]): if out_dims_mapping[idx] != item: return False - if x_dims_mapping[-2] != out_dims_mapping[-1]: - return False - if x_shape_dims_mapping[0] != -1: return False @@ -359,7 +346,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl("reshape2", diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index e4624b51222ed62180532b02bf4101a9713bb3ee..f78f1c58dbf074c52d3028e74a777eb93b7495ca 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -22,22 +22,20 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 class DistributedSoftmax(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedSoftmax, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedSoftmax, self).__init__(op_type) -register_distributed_operator_impl_container("softmax", - DistributedSoftmax("softmax")) +register_distributed_operator_impl_container(DistributedSoftmax("softmax")) class DistributedSoftmaxImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedSoftmaxImpl, self).__init__() - self._name = name + super(DistributedSoftmaxImpl, self).__init__(name) self._forward_implemented = False self._backward_implemented = False @@ -48,8 +46,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): axis = op_desc.attr('axis') x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) - if axis != -1 and axis != len(x_dims_mapping) - 1: - return False + # if axis != -1 and axis != len(x_dims_mapping) - 1: + # return False if is_dim_shard(x_dims_mapping[axis]): return False @@ -63,8 +61,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): axis = op_desc.attr('axis') out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if axis != -1 and axis != len(out_dims_mapping) - 1: - return False + # if axis != -1 and axis != len(out_dims_mapping) - 1: + # return False if is_dim_shard(out_dims_mapping[axis]): return False @@ -72,6 +70,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr x_name = op_desc.input('X')[0] @@ -79,11 +81,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): out_name = op_desc.output('Out')[0] x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - if axis != -1 and axis != len(x_dims_mapping) - 1: - return False - - if is_dim_shard(x_dims_mapping[axis]): - return False + # if axis != -1 and axis != len(x_dims_mapping) - 1: + # return False if x_dims_mapping != out_dims_mapping: return False @@ -107,9 +106,13 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 8b40524e47315260d17e38f12bb95b5d93df39fb..e6a96fb795ef89398ebf89c4cebc6478e53c722d 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -22,22 +22,21 @@ from ..utils import is_valid_list_index from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_and_update_dim_mapping +from .dist_default import DistributedDefaultImpl0 class DistributedTranspose2(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedTranspose2, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedTranspose2, self).__init__(op_type) register_distributed_operator_impl_container( - "transpose2", DistributedTranspose2("transpose2")) + DistributedTranspose2("transpose2")) class DistributedTranspose2Impl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedTranspose2Impl, self).__init__() - self._name = name + super(DistributedTranspose2Impl, self).__init__(name) self._forward_implemented = False self._backward_implemented = False @@ -48,6 +47,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or \ + (not self.is_output_compatible(dist_op)): + return False + op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr perm = op_desc.attr('axis') @@ -111,9 +114,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): return changed + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + @staticmethod def backward(ctx, *args, **kwargs): - pass + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) register_distributed_operator_impl( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index eccd2742db03feafafd529b1738f00bbd44a5dac..f216fce16f30d0d581248402740b27da41725904 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -20,18 +20,17 @@ from ..utils import set_dist_op_desc_original_id class DistributedUpdateLossScaling(DistributedOperatorImplContainer): - def __init__(self, name): - super(DistributedUpdateLossScaling, self).__init__() - self._name = name + def __init__(self, op_type): + super(DistributedUpdateLossScaling, self).__init__(op_type) register_distributed_operator_impl_container( - "update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling")) + DistributedUpdateLossScaling("update_loss_scaling")) class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): def __init__(self, name): - super(DistributedUpdateLossScalingImpl, self).__init__() + super(DistributedUpdateLossScalingImpl, self).__init__(name) self._name = name self._forward_implemented = False self._backward_implemented = True @@ -46,6 +45,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): "DistributedUpdateLossScalingImpl's is_output_compatible should not be called !" ) + def is_auto_compatible(self, dist_op): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's is_auto_compatible should not be called !" + ) + def update_dims_mapping(self, dist_op): raise RuntimeError( "DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !" diff --git a/python/paddle/distributed/auto_parallel/partitioner.py b/python/paddle/distributed/auto_parallel/partitioner.py index 182f6e8b6604a36149c36be9ed6e222bf9673b1c..a0a68efae3c3c1f4ca04144b3243d144651df817 100644 --- a/python/paddle/distributed/auto_parallel/partitioner.py +++ b/python/paddle/distributed/auto_parallel/partitioner.py @@ -63,7 +63,6 @@ class Partitioner(object): def partition(self, serial_main_program, serial_startup_program, params_grads): - if not isinstance(serial_main_program, (Program)): raise TypeError( "main_program be paddle.fluid.framework.program, got %s here" % @@ -87,7 +86,7 @@ class Partitioner(object): serial_main_program, serial_startup_program) dist_op_context.set_dst_startup_program(partitioned_startup_prog) - # partition main program + # partition main program partitioned_main_prog, partitioned_params_grads = self.partition_main_program( serial_main_program, params_grads) @@ -282,7 +281,7 @@ def _get_dist_shape(var, dist_attr): def _partition_parameter(dist_context, src_var, dst_block, dst_varname, dst_shape): # NOTE hack to copied Parameter - # not initialized parameter, need to initialize it + # not initialized parameter, need to initialize it copied_kwargs = {} copied_kwargs['trainable'] = src_var.trainable copied_kwargs['optimize_attr'] = src_var.optimize_attr @@ -371,19 +370,19 @@ def _get_dist_op_backward_implement(backward_op, dist_context, forward_op = forward_op_id2forward_op[forward_op_id] forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op) - dist_op = get_distributed_operator_impl_container(forward_op.type) - - # TODO backward should have its own impl_idx - if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \ - forward_op_dist_attr.impl_idx)._backward_implemented: - return dist_op.get_impl(forward_op_dist_attr.impl_idx) + dist_op_impl_container = get_distributed_operator_impl_container( + forward_op_dist_attr.impl_type) + dist_op_impl = dist_op_impl_container.get_impl( + forward_op_dist_attr.impl_idx) + return dist_op_impl - # NOTE trick for dist ops that only have backward implement + # # NOTE trick for dist ops that only have backward implement if backward_op.type in BACKWARD_ONLY_DIST_OPS: op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) - dist_op = get_distributed_operator_impl_container(backward_op.type) - if dist_op and op_dist_attr.impl_idx >= 0: - return dist_op.get_impl(op_dist_attr.impl_idx) + assert op_dist_attr.impl_idx >= 0 + dist_op_impl = get_distributed_operator_impl_container( + backward_op.type).get_impl(op_dist_attr.impl_idx) + return dist_op_impl dist_op = get_distributed_operator_impl_container("default") return dist_op.get_impl(0) @@ -391,12 +390,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_forward_implement(forward_op, dist_context): dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) - dist_op = get_distributed_operator_impl_container(forward_op.type) - - if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl( - dist_attr.impl_idx)._forward_implemented: - return dist_op.get_impl(dist_attr.impl_idx) - - else: - dist_op = get_distributed_operator_impl_container("default") - return dist_op.get_impl(0) + dist_op_impl_container = get_distributed_operator_impl_container( + dist_attr.impl_type) + dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx) + return dist_op_impl diff --git a/python/paddle/distributed/auto_parallel/planner.py b/python/paddle/distributed/auto_parallel/planner.py index 1dfefb41c80a36b941bb10d24ba34f4a9589a80a..f7d4c734feea4f8de4c961c9fc8f3a4a2dcb8f31 100755 --- a/python/paddle/distributed/auto_parallel/planner.py +++ b/python/paddle/distributed/auto_parallel/planner.py @@ -28,7 +28,7 @@ from .cost_model import estimate_cost from .dist_op import DistributedOperator from .process_group import _g_process_group_map from .process_group import ProcessGroup, get_process_group -from .completion import is_elementwise_like_op +from .operators.common import is_elementwise_op from .operators.common import get_distributed_operator_impl_container from .utils import update_op_dims_mapping_by_default_dist_impl from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl @@ -237,7 +237,7 @@ class PlanSpace: dist_op = DistributedOperator(op, op_dist_attr) if dist_op_impl_container is None: - if is_elementwise_like_op(op.type): + if is_elementwise_op(op.type): changed = True valid = True try: @@ -250,7 +250,8 @@ class PlanSpace: op, dist_op.dist_attr, vars ) and PlanFilter.check_dims_mapping_for_special_op( op, dist_op.dist_attr, vars): - dist_op.dist_attr.impl_idx = -1 + dist_op.dist_attr.impl_type = "elementwise" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) continue else: @@ -266,16 +267,18 @@ class PlanSpace: op, dist_op.dist_attr, vars ) and PlanFilter.check_dims_mapping_for_special_op( op, dist_op.dist_attr, vars): - dist_op.dist_attr.impl_idx = -2 + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) continue # if op has distributed implements, find all valid dist attr of this op - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls for idx, impl in enumerate(impls): if impl.is_auto_compatible(dist_op): if PlanFilter.check_dims_mapping_for_op( op, dist_op.dist_attr, vars): + dist_op.dist_attr.impl_type = dist_op.serial_op.type dist_op.dist_attr.impl_idx = idx op_valid_dist_attrs.append(dist_op.dist_attr) @@ -290,7 +293,8 @@ class PlanSpace: for var_name in op.output_arg_names: op_dist_attr.set_output_dims_mapping( vars[var_name], [-1 for i in vars[var_name].shape]) - dist_op.dist_attr.impl_idx = -1 + dist_op.dist_attr.impl_type = "default" + dist_op.dist_attr.impl_idx = 0 op_valid_dist_attrs.append(dist_op.dist_attr) return op_valid_dist_attrs diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py index 8593e44b3d82083dbc43212818919d785e99819e..7d94139e9a8819bc153f9b79c7e6f0e77cf352f2 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py @@ -105,7 +105,7 @@ class TestAutoParallelAPI(unittest.TestCase): self.assertEqual(dist_op.dist_attr.process_mesh, ProcessMesh(process_mesh2)) self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertEqual(dist_op.dist_attr.impl_idx, 0) self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) @@ -138,7 +138,7 @@ class TestAutoParallelAPI(unittest.TestCase): dist_op = dist_context.get_dist_op_for_program(last_op) self.assertEqual(dist_op.dist_attr.process_mesh, None) self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, -2) + self.assertEqual(dist_op.dist_attr.impl_idx, 0) self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh")) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py index c9cbcd1ea8efd59a2e9c978001a9086d7de09eb0..8c5913c66a70d9cb73622e9d6c33e65f48369c7d 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_matmul_op.py @@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -class Testcompatible(unittest.TestCase): +class TestCompatible(unittest.TestCase): def test_matmulv2_matmul_2_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -123,7 +123,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertFalse(impls[2].is_auto_compatible( + self.assertTrue(impls[2].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[2].is_auto_compatible( @@ -220,7 +220,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( + self.assertTrue(impls[1].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) self.assertFalse(impls[1].is_auto_compatible( @@ -307,7 +307,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'matmul_v2' or op.type == 'matmul': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() X = op.input_arg_names[0] Y = op.input_arg_names[1] @@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase): op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) - self.assertFalse(impls[0].is_auto_compatible( + self.assertTrue(impls[0].is_auto_compatible( DistributedOperator(op, op_dist_attr))) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) self.assertFalse(impls[0].is_auto_compatible( diff --git a/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py index 8f53a0c765d4cc6b996294b8f14ce30080473263..4cb58eac7cc41127d41666cf1d49482a3f85f2b9 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py +++ b/python/paddle/fluid/tests/unittests/test_auto_search_dist_op.py @@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -class Testcompatible(unittest.TestCase): - def test_raise_compatible(self): - valid_op_dist_attr_list = [] - program = paddle.static.Program() - startup_program = paddle.static.Program() - loss, program, start_program = mlp_forward(program, startup_program) - ops = program.global_block().ops - for idx, op in enumerate(ops): - if op.type == 'transpose2': - op_dist_attr = OperatorDistributedAttribute() - dist_op = DistributedOperator(op, op_dist_attr) - impls = DistributedOperatorImpl() - try: - impls.is_auto_compatible(dist_op) - except NotImplementedError: - e = False - self.assertTrue(e == False) - +class TestCompatible(unittest.TestCase): def test_reshape_remove_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -124,7 +107,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'reshape2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1, -1]) @@ -172,64 +155,6 @@ class Testcompatible(unittest.TestCase): self.assertFalse(impls[1].is_auto_compatible( DistributedOperator(op, op_dist_attr))) - def test_reshape_remove_two_compatible(self): - valid_op_dist_attr_list = [] - program = paddle.static.Program() - startup_program = paddle.static.Program() - loss, program, start_program = mlp_forward(program, startup_program) - ops = program.global_block().ops - for idx, op in enumerate(ops): - if op.type == 'reshape2': - dist_op_impl_container = get_distributed_operator_impl_container( - op.type) - impls = dist_op_impl_container.get_impls() - op_dist_attr = OperatorDistributedAttribute() - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, -1, -1]) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], - [-1]) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, -1, -1]) - dist_op = DistributedOperator(op, op_dist_attr) - self.assertTrue(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, 1, 0]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [0, 1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [1, -1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [-1, 1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, -1, 1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, 1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - op_dist_attr.set_output_dims_mapping(op.output_arg_names[1], - [-1, -1, 1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - - op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], - [1, -1, -1]) - self.assertFalse(impls[1].is_auto_compatible( - DistributedOperator(op, op_dist_attr))) - def test_reshape_add_compatible(self): valid_op_dist_attr_list = [] program = paddle.static.Program() @@ -240,7 +165,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'reshape2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], @@ -298,7 +223,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'transpose2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) @@ -349,7 +274,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'softmax': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1]) @@ -379,7 +304,7 @@ class Testcompatible(unittest.TestCase): if op.type == 'c_embedding' or op.type == 'lookup_table_v2': dist_op_impl_container = get_distributed_operator_impl_container( op.type) - impls = dist_op_impl_container.get_impls() + impls = dist_op_impl_container.impls op_dist_attr = OperatorDistributedAttribute() op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1, -1])