未验证 提交 9acc26ca 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the dist op interface and the compatible computation (#39014)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem
上级 2a9c993e
......@@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
......@@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
......
......@@ -61,6 +61,8 @@ class DistributedContext:
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}
# Distributed programs
self._dist_main_programs = {}
......@@ -80,6 +82,10 @@ class DistributedContext:
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes
@property
def process_meshes(self):
return self._process_meshes
......@@ -186,6 +192,18 @@ class DistributedContext:
else:
return None
# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
......@@ -218,6 +236,35 @@ class DistributedContext:
else:
return None
# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
......@@ -248,6 +295,44 @@ class DistributedContext:
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
return True
return False
ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
......@@ -257,7 +342,8 @@ class DistributedContext:
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
......@@ -397,7 +483,9 @@ class DistributedContext:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......
......@@ -98,7 +98,7 @@ class DistributedOperator:
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
self._dist_attr.impl_idx = 0
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False
......@@ -217,7 +217,8 @@ class DistributedOperator:
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)
return str
......
......@@ -23,5 +23,6 @@ from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_eltwise
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
......@@ -12,109 +12,196 @@
# See the License for the specific language governing permissions and
# limitations under the License
import abc
from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {}
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
else:
return False
class DistributedOperatorImplContainer:
def __init__(self):
def __init__(self, op_type):
self._type = op_type
self._impls = []
self._name = None
@property
def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def impls(self):
return self._impls
def register_impl(self, dist_impl):
assert self.type == dist_impl.type, \
"Op type of container must be same as that of the implementation."
impl_idx = len(self.impls)
dist_impl.idx = impl_idx
self._impls.append(dist_impl)
def get_impl(self, impl_idx):
return self._impls[impl_idx]
def get_impls(self):
return self._impls
def get_input_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_input_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl:
def __init__(self):
self._name = None
def get_output_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_output_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
def get_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_auto_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl(abc.ABC):
def __init__(self, name):
self._name = name
self._type = None
self._idx = None
self._forward_implemented = False
self._backward_implemented = False
@staticmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@property
def name(self):
return self._name
@staticmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@name.setter
def name(self, name):
self._name = name
def get_name(self):
return self._name
@property
def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def idx(self):
return self._idx
@idx.setter
def idx(self, impl_idx):
self._idx = impl_idx
@abc.abstractmethod
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
@abc.abstractmethod
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, dist_op):
return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
@abc.abstractmethod
def is_auto_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator_impl_container(name, dist_op_impl_container):
global _g_distributed_operator_impl_registries
_g_distributed_operator_impl_registries[name] = dist_op_impl_container
def register_distributed_operator_impl_container(container):
global _g_distributed_operator_impl_containers
_g_distributed_operator_impl_containers[container.type] = container
def get_distributed_operator_impl_container(name):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries.get(name, None)
def get_distributed_operator_impl_container(op_type):
global _g_distributed_operator_impl_containers
return _g_distributed_operator_impl_containers.get(op_type, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op_impl_container = get_distributed_operator_impl_container(name)
def register_distributed_operator_impl(op_type, dist_impl):
dist_op_impl_container = get_distributed_operator_impl_container(op_type)
if dist_op_impl_container is not None:
dist_impl.type = op_type
dist_op_impl_container.register_impl(dist_impl)
else:
assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is None:
return None, -1
op_type = dist_op.serial_op.type
dist_op_impl_container = get_distributed_operator_impl_container(op_type)
dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
"elementwise")
dist_op_default_impl_container = get_distributed_operator_impl_container(
"default")
compatible_impls = []
impls = dist_op_impl_container.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_input_compatible(dist_op):
compatible_impls.append((impl, idx))
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else:
for idx, impl in enumerate(impls):
if impl.is_output_compatible(dist_op):
compatible_impls.append((impl, idx))
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
if compatible_impls:
best_compatible_impl, idx = compatible_impls[0]
# For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0]
else:
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
best_compatible_impl = None
return best_compatible_impl
def is_parameter_related(varname, block):
......
......@@ -30,19 +30,17 @@ global_process_mesh = get_world_process_group().ranks
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscale, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedCheckFiniteAndUnscale, self).__init__(op_type)
register_distributed_operator_impl_container(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscale("check_finite_and_unscale"))
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__()
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name)
self._name = name
self._forward_implemented = False
self._backward_implemented = True
......@@ -57,6 +55,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
"DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !"
)
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !"
......
......@@ -34,31 +34,162 @@ from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedDefault, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type)
register_distributed_operator_impl_container("default",
DistributedDefault("default"))
register_distributed_operator_impl_container(DistributedDefault("default"))
# Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__()
self._name = name
super(DistributedDefaultImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
return True
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
# Check input compatibility
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
# Check output compatibility
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[1])
# Check batch dim mapping compatibility
if not all(batch_dim_mappings[0] == dim_mapping
for dim_mapping in batch_dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import is_elementwise_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
class DistributedElementwise(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedElementwise, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedElementwise("elementwise"))
# Replicated Elementwise
class DistributedElementwiseImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedElementwiseImpl0, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if not all(dim_mappings[0] == dim_mapping
for dim_mapping in dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[
arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"elementwise", DistributedElementwiseImpl0("replicate_parallel"))
......@@ -34,22 +34,20 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedEmbedding, self).__init__(op_type)
register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table_v2"))
register_distributed_operator_impl_container(
DistributedEmbedding("c_embedding"))
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
super(DistributedEmbeddingImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -81,6 +79,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
......@@ -89,18 +91,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if w_dims_mapping[-1] != out_dims_mapping[-1]:
return False
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
return False
......@@ -248,6 +239,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# matmulv2
embedding_op_dist_attr = OperatorDistributedAttribute()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
......@@ -266,6 +258,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
......
......@@ -27,22 +27,20 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0
class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedReshape2, self).__init__(op_type)
register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
register_distributed_operator_impl_container(DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
super(DistributedReshapeImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = False
......@@ -76,6 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -85,17 +87,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
for idx, item in enumerate(out_dims_mapping[:-2]):
if x_dims_mapping[idx] != item:
for idx, dim_mapping in enumerate(out_dims_mapping[:-1]):
if x_dims_mapping[idx] != dim_mapping:
return False
if out_dims_mapping[-2] != x_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
......@@ -194,13 +189,12 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
super(DistributedReshapeImpl1, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = False
......@@ -234,6 +228,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -244,24 +242,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
if len(x_dims_mapping) == len(out_dims_mapping) + 2:
if out_dims_mapping[0] != x_dims_mapping[0]:
return False
if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1:
return False
elif len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
for idx, item in enumerate(x_dims_mapping[:-2]):
for idx, item in enumerate(x_dims_mapping[:-1]):
if out_dims_mapping[idx] != item:
return False
if x_dims_mapping[-2] != out_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
......@@ -359,7 +346,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("reshape2",
......
......@@ -22,22 +22,20 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedSoftmax, self).__init__(op_type)
register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
register_distributed_operator_impl_container(DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
super(DistributedSoftmaxImpl, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
......@@ -48,8 +46,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
# if axis != -1 and axis != len(x_dims_mapping) - 1:
# return False
if is_dim_shard(x_dims_mapping[axis]):
return False
......@@ -63,8 +61,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1:
return False
# if axis != -1 and axis != len(out_dims_mapping) - 1:
# return False
if is_dim_shard(out_dims_mapping[axis]):
return False
......@@ -72,6 +70,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -79,11 +81,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
# if axis != -1 and axis != len(x_dims_mapping) - 1:
# return False
if x_dims_mapping != out_dims_mapping:
return False
......@@ -107,9 +106,13 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
......
......@@ -22,22 +22,21 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedTranspose2, self).__init__(op_type)
register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
super(DistributedTranspose2Impl, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
......@@ -48,6 +47,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
perm = op_desc.attr('axis')
......@@ -111,9 +114,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
......
......@@ -20,18 +20,17 @@ from ..utils import set_dist_op_desc_original_id
class DistributedUpdateLossScaling(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedUpdateLossScaling, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedUpdateLossScaling, self).__init__(op_type)
register_distributed_operator_impl_container(
"update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling"))
DistributedUpdateLossScaling("update_loss_scaling"))
class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedUpdateLossScalingImpl, self).__init__()
super(DistributedUpdateLossScalingImpl, self).__init__(name)
self._name = name
self._forward_implemented = False
self._backward_implemented = True
......@@ -46,6 +45,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
"DistributedUpdateLossScalingImpl's is_output_compatible should not be called !"
)
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedUpdateLossScalingImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !"
......
......@@ -63,7 +63,6 @@ class Partitioner(object):
def partition(self, serial_main_program, serial_startup_program,
params_grads):
if not isinstance(serial_main_program, (Program)):
raise TypeError(
"main_program be paddle.fluid.framework.program, got %s here" %
......@@ -87,7 +86,7 @@ class Partitioner(object):
serial_main_program, serial_startup_program)
dist_op_context.set_dst_startup_program(partitioned_startup_prog)
# partition main program
# partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
serial_main_program, params_grads)
......@@ -282,7 +281,7 @@ def _get_dist_shape(var, dist_attr):
def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
# not initialized parameter, need to initialize it
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr
......@@ -371,19 +370,19 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type)
# TODO backward should have its own impl_idx
if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_op.get_impl(forward_op_dist_attr.impl_idx)
dist_op_impl_container = get_distributed_operator_impl_container(
forward_op_dist_attr.impl_type)
dist_op_impl = dist_op_impl_container.get_impl(
forward_op_dist_attr.impl_idx)
return dist_op_impl
# NOTE trick for dist ops that only have backward implement
# # NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
dist_op = get_distributed_operator_impl_container(backward_op.type)
if dist_op and op_dist_attr.impl_idx >= 0:
return dist_op.get_impl(op_dist_attr.impl_idx)
assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
return dist_op_impl
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
......@@ -391,12 +390,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type)
if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented:
return dist_op.get_impl(dist_attr.impl_idx)
else:
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
dist_op_impl_container = get_distributed_operator_impl_container(
dist_attr.impl_type)
dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
return dist_op_impl
......@@ -28,7 +28,7 @@ from .cost_model import estimate_cost
from .dist_op import DistributedOperator
from .process_group import _g_process_group_map
from .process_group import ProcessGroup, get_process_group
from .completion import is_elementwise_like_op
from .operators.common import is_elementwise_op
from .operators.common import get_distributed_operator_impl_container
from .utils import update_op_dims_mapping_by_default_dist_impl
from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl
......@@ -237,7 +237,7 @@ class PlanSpace:
dist_op = DistributedOperator(op, op_dist_attr)
if dist_op_impl_container is None:
if is_elementwise_like_op(op.type):
if is_elementwise_op(op.type):
changed = True
valid = True
try:
......@@ -250,7 +250,8 @@ class PlanSpace:
op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -1
dist_op.dist_attr.impl_type = "elementwise"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
continue
else:
......@@ -266,16 +267,18 @@ class PlanSpace:
op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -2
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
continue
# if op has distributed implements, find all valid dist attr of this op
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
for idx, impl in enumerate(impls):
if impl.is_auto_compatible(dist_op):
if PlanFilter.check_dims_mapping_for_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_type = dist_op.serial_op.type
dist_op.dist_attr.impl_idx = idx
op_valid_dist_attrs.append(dist_op.dist_attr)
......@@ -290,7 +293,8 @@ class PlanSpace:
for var_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(
vars[var_name], [-1 for i in vars[var_name].shape])
dist_op.dist_attr.impl_idx = -1
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
return op_valid_dist_attrs
......
......@@ -105,7 +105,7 @@ class TestAutoParallelAPI(unittest.TestCase):
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
......@@ -138,7 +138,7 @@ class TestAutoParallelAPI(unittest.TestCase):
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
......
......@@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
class TestCompatible(unittest.TestCase):
def test_matmulv2_matmul_2_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -123,7 +123,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
......@@ -220,7 +220,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
......@@ -307,7 +307,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
......
......@@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
def test_raise_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'transpose2':
op_dist_attr = OperatorDistributedAttribute()
dist_op = DistributedOperator(op, op_dist_attr)
impls = DistributedOperatorImpl()
try:
impls.is_auto_compatible(dist_op)
except NotImplementedError:
e = False
self.assertTrue(e == False)
class TestCompatible(unittest.TestCase):
def test_reshape_remove_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -124,7 +107,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
......@@ -172,64 +155,6 @@ class Testcompatible(unittest.TestCase):
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_remove_two_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 0])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[0, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, 1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_add_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -240,7 +165,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
......@@ -298,7 +223,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'transpose2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......@@ -349,7 +274,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'softmax':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......@@ -379,7 +304,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'c_embedding' or op.type == 'lookup_table_v2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册