未验证 提交 9acc26ca 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the dist op interface and the compatible computation (#39014)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem
上级 2a9c993e
...@@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
compatible_dims_mapping) compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( op_dist_impl = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=True) dist_op, fwd=True)
if op_dist_impl is not None: assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
# This statement will be replaced by a good way if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.is_compatible(dist_op): if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = op_desc.type() op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = op_dist_impl_idx else:
elif is_elementwise_like_op(op_desc.type()): op_dist_attr.impl_type = op_dist_impl.type
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( op_dist_attr.impl_idx = op_dist_impl.idx
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
else: else:
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
...@@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( op_dist_impl = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=False) dist_op, fwd=False)
if op_dist_impl is not None: assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
# This statement will be replaced by a good way if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.is_compatible(dist_op): if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = op_desc.type() op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = op_dist_impl_idx else:
elif is_elementwise_like_op(op_desc.type()): op_dist_attr.impl_type = op_dist_impl.type
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( op_dist_attr.impl_idx = op_dist_impl.idx
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
return changed return changed
......
...@@ -61,6 +61,8 @@ class DistributedContext: ...@@ -61,6 +61,8 @@ class DistributedContext:
# Other data members # Other data members
self._dist_op_context = DistributedOperatorContext() self._dist_op_context = DistributedOperatorContext()
self._process_meshes = [] self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}
# Distributed programs # Distributed programs
self._dist_main_programs = {} self._dist_main_programs = {}
...@@ -80,6 +82,10 @@ class DistributedContext: ...@@ -80,6 +82,10 @@ class DistributedContext:
"This distributed context has already been realted to a serial program" "This distributed context has already been realted to a serial program"
self._serial_program = program self._serial_program = program
@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes
@property @property
def process_meshes(self): def process_meshes(self):
return self._process_meshes return self._process_meshes
...@@ -186,6 +192,18 @@ class DistributedContext: ...@@ -186,6 +192,18 @@ class DistributedContext:
else: else:
return None return None
# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op): def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id() serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None) dist_op = self._dist_ops_for_program.get(serial_op_id, None)
...@@ -218,6 +236,35 @@ class DistributedContext: ...@@ -218,6 +236,35 @@ class DistributedContext:
else: else:
return None return None
# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def init_dist_attr_for_program(self): def init_dist_attr_for_program(self):
assert self._serial_program, \ assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes." "Please set the program of this context before initializing its distribute attributes."
...@@ -248,6 +295,44 @@ class DistributedContext: ...@@ -248,6 +295,44 @@ class DistributedContext:
self.add_dist_op_for_program(dist_op) self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True self._is_initialized_for_program = True
def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
return True
return False
ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."
def init_dist_attr_for_graph(self): def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \ assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph." "The program must be initialized before initializing the distributed attributes for its graph."
...@@ -257,7 +342,8 @@ class DistributedContext: ...@@ -257,7 +342,8 @@ class DistributedContext:
self._serial_graph = framework.IrGraph( self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc)) core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes() all_nodes = self._serial_graph.all_nodes()
for node in all_nodes: self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
dist_tensor = None dist_tensor = None
tensor_id = node.node.original_desc_id() tensor_id = node.node.original_desc_id()
...@@ -397,7 +483,9 @@ class DistributedContext: ...@@ -397,7 +483,9 @@ class DistributedContext:
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
......
...@@ -98,7 +98,7 @@ class DistributedOperator: ...@@ -98,7 +98,7 @@ class DistributedOperator:
if self._dist_attr.impl_type is None: if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default" self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None: if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2 self._dist_attr.impl_idx = 0
if self._dist_attr.is_recompute is None: if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False self._dist_attr.is_recompute = False
...@@ -217,7 +217,8 @@ class DistributedOperator: ...@@ -217,7 +217,8 @@ class DistributedOperator:
str += ", pipeline stage: {}".format(None) str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx) str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)
return str return str
......
...@@ -23,5 +23,6 @@ from . import dist_reshape ...@@ -23,5 +23,6 @@ from . import dist_reshape
from . import dist_softmax from . import dist_softmax
from . import dist_transpose from . import dist_transpose
from . import dist_default from . import dist_default
from . import dist_eltwise
from . import dist_check_finite_and_unscale from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling from . import dist_update_loss_scaling
...@@ -12,109 +12,196 @@ ...@@ -12,109 +12,196 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import abc
from ..dist_attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {} _g_distributed_operator_impl_containers = {}
_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
else:
return False
class DistributedOperatorImplContainer: class DistributedOperatorImplContainer:
def __init__(self): def __init__(self, op_type):
self._type = op_type
self._impls = [] self._impls = []
self._name = None
@property
def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def impls(self):
return self._impls
def register_impl(self, dist_impl): def register_impl(self, dist_impl):
assert self.type == dist_impl.type, \
"Op type of container must be same as that of the implementation."
impl_idx = len(self.impls)
dist_impl.idx = impl_idx
self._impls.append(dist_impl) self._impls.append(dist_impl)
def get_impl(self, impl_idx): def get_impl(self, impl_idx):
return self._impls[impl_idx] return self._impls[impl_idx]
def get_impls(self): def get_input_compatible_impls(self, dist_op):
return self._impls compatible_impls = []
for impl in self.impls:
if impl.is_input_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl: def get_output_compatible_impls(self, dist_op):
def __init__(self): compatible_impls = []
self._name = None for impl in self.impls:
if impl.is_output_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
def get_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_auto_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl(abc.ABC):
def __init__(self, name):
self._name = name
self._type = None
self._idx = None
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
@staticmethod @property
def forward(dist_ctx, *args, **kwargs): def name(self):
raise NotImplementedError("Please Implement this method in Subclass.") return self._name
@staticmethod @name.setter
def backward(dist_ctx, *grad_outputs, **kwargs): def name(self, name):
raise NotImplementedError("Please Implement this method in Subclass.") self._name = name
def get_name(self): @property
return self._name def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def idx(self):
return self._idx
@idx.setter
def idx(self, impl_idx):
self._idx = impl_idx
@abc.abstractmethod
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
@abc.abstractmethod
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, dist_op): @abc.abstractmethod
return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator_impl_container(name, dist_op_impl_container): def register_distributed_operator_impl_container(container):
global _g_distributed_operator_impl_registries global _g_distributed_operator_impl_containers
_g_distributed_operator_impl_registries[name] = dist_op_impl_container _g_distributed_operator_impl_containers[container.type] = container
def get_distributed_operator_impl_container(name): def get_distributed_operator_impl_container(op_type):
global _g_distributed_operator_impl_registries global _g_distributed_operator_impl_containers
return _g_distributed_operator_impl_registries.get(name, None) return _g_distributed_operator_impl_containers.get(op_type, None)
def register_distributed_operator_impl(name, dist_impl): def register_distributed_operator_impl(op_type, dist_impl):
dist_op_impl_container = get_distributed_operator_impl_container(name) dist_op_impl_container = get_distributed_operator_impl_container(op_type)
if dist_op_impl_container is not None: if dist_op_impl_container is not None:
dist_impl.type = op_type
dist_op_impl_container.register_impl(dist_impl) dist_op_impl_container.register_impl(dist_impl)
else: else:
assert False, "Must register distributed operator registry first." assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx): def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
""" """
dist_op_impl_container = get_distributed_operator_impl_container(name) op_type = dist_op.serial_op.type
if dist_op_impl_container is None: dist_op_impl_container = get_distributed_operator_impl_container(op_type)
return None, -1 dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
"elementwise")
dist_op_default_impl_container = get_distributed_operator_impl_container(
"default")
compatible_impls = [] compatible_impls = []
impls = dist_op_impl_container.get_impls()
if fwd: if fwd:
for idx, impl in enumerate(impls): # First, find impls in the corresponding container
if impl.is_input_compatible(dist_op): if dist_op_impl_container:
compatible_impls.append((impl, idx)) compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else: else:
for idx, impl in enumerate(impls): # First, find impls in the corresponding container
if impl.is_output_compatible(dist_op): if dist_op_impl_container:
compatible_impls.append((impl, idx)) compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
if compatible_impls: if compatible_impls:
best_compatible_impl, idx = compatible_impls[0] # For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0]
else: else:
best_compatible_impl, idx = None, -1 best_compatible_impl = None
return best_compatible_impl
return best_compatible_impl, idx
def is_parameter_related(varname, block): def is_parameter_related(varname, block):
......
...@@ -30,19 +30,17 @@ global_process_mesh = get_world_process_group().ranks ...@@ -30,19 +30,17 @@ global_process_mesh = get_world_process_group().ranks
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedCheckFiniteAndUnscale, self).__init__() super(DistributedCheckFiniteAndUnscale, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscale("check_finite_and_unscale")) DistributedCheckFiniteAndUnscale("check_finite_and_unscale"))
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__() super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name)
self._name = name self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
...@@ -57,6 +55,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -57,6 +55,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
"DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !" "DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !"
) )
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise RuntimeError( raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !" "DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !"
......
...@@ -34,31 +34,162 @@ from ..utils import _get_comm_group, _get_corresponding_rank ...@@ -34,31 +34,162 @@ from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperatorImplContainer): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedDefault, self).__init__() super(DistributedDefault, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container("default", register_distributed_operator_impl_container(DistributedDefault("default"))
DistributedDefault("default"))
# Replicated Default # Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl): class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__() super(DistributedDefaultImpl0, self).__init__(name)
self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_input_compatible(self, dist_op): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
return True
def is_output_compatible(self, dist_op): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
# Check input compatibility
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
# Check output compatibility
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[1])
# Check batch dim mapping compatibility
if not all(batch_dim_mappings[0] == dim_mapping
for dim_mapping in batch_dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.") changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import is_elementwise_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
class DistributedElementwise(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedElementwise, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedElementwise("elementwise"))
# Replicated Elementwise
class DistributedElementwiseImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedElementwiseImpl0, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if not all(dim_mappings[0] == dim_mapping
for dim_mapping in dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[
arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"elementwise", DistributedElementwiseImpl0("replicate_parallel"))
...@@ -34,22 +34,20 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank ...@@ -34,22 +34,20 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperatorImplContainer): class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedEmbedding, self).__init__() super(DistributedEmbedding, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container("lookup_table_v2", register_distributed_operator_impl_container(
DistributedEmbedding("embedding")) DistributedEmbedding("lookup_table_v2"))
register_distributed_operator_impl_container("c_embedding", register_distributed_operator_impl_container(
DistributedEmbedding("embedding")) DistributedEmbedding("c_embedding"))
# RowParallel # RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl): class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__() super(DistributedEmbeddingImpl, self).__init__(name)
self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
...@@ -81,6 +79,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -81,6 +79,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0] ids_name = op_desc.input('Ids')[0]
...@@ -89,18 +91,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -89,18 +91,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if w_dims_mapping[-1] != out_dims_mapping[-1]:
return False
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]: if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
return False return False
...@@ -248,6 +239,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -248,6 +239,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# matmulv2 # matmulv2
embedding_op_dist_attr = OperatorDistributedAttribute() embedding_op_dist_attr = OperatorDistributedAttribute()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_embedding_op.desc.input_arg_names(): for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname) input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
...@@ -266,6 +258,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -266,6 +258,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# allreduce # allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute() allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names(): for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname) input_var = main_block.var(input_varname)
......
...@@ -27,22 +27,20 @@ from paddle.fluid import core, unique_name ...@@ -27,22 +27,20 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0
class DistributedReshape2(DistributedOperatorImplContainer): class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedReshape2, self).__init__() super(DistributedReshape2, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container("reshape2", register_distributed_operator_impl_container(DistributedReshape2("reshape2"))
DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl): class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__() super(DistributedReshapeImpl0, self).__init__(name)
self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = False
...@@ -76,6 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -76,6 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
...@@ -85,17 +87,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -85,17 +87,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
x_shape_name) x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]): for idx, dim_mapping in enumerate(out_dims_mapping[:-1]):
return False if x_dims_mapping[idx] != dim_mapping:
for idx, item in enumerate(out_dims_mapping[:-2]):
if x_dims_mapping[idx] != item:
return False return False
if out_dims_mapping[-2] != x_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1: if x_shape_dims_mapping[0] != -1:
return False return False
...@@ -194,13 +189,12 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -194,13 +189,12 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
pass DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
class DistributedReshapeImpl1(DistributedOperatorImpl): class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__() super(DistributedReshapeImpl1, self).__init__(name)
self._name = name
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = False self._backward_implemented = False
...@@ -234,6 +228,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -234,6 +228,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
...@@ -244,24 +242,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -244,24 +242,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name) x_shape_name)
if len(x_dims_mapping) == len(out_dims_mapping) + 2:
if out_dims_mapping[0] != x_dims_mapping[0]:
return False
if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1:
return False
elif len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]): if is_dim_shard(x_dims_mapping[-1]):
return False return False
for idx, item in enumerate(x_dims_mapping[:-2]): for idx, item in enumerate(x_dims_mapping[:-1]):
if out_dims_mapping[idx] != item: if out_dims_mapping[idx] != item:
return False return False
if x_dims_mapping[-2] != out_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1: if x_shape_dims_mapping[0] != -1:
return False return False
...@@ -359,7 +346,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -359,7 +346,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
pass DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("reshape2", register_distributed_operator_impl("reshape2",
......
...@@ -22,22 +22,20 @@ from ..utils import is_valid_list_index ...@@ -22,22 +22,20 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedSoftmax(DistributedOperatorImplContainer): class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedSoftmax, self).__init__() super(DistributedSoftmax, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container("softmax", register_distributed_operator_impl_container(DistributedSoftmax("softmax"))
DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl): class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__() super(DistributedSoftmaxImpl, self).__init__(name)
self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
...@@ -48,8 +46,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -48,8 +46,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if axis != -1 and axis != len(x_dims_mapping) - 1: # if axis != -1 and axis != len(x_dims_mapping) - 1:
return False # return False
if is_dim_shard(x_dims_mapping[axis]): if is_dim_shard(x_dims_mapping[axis]):
return False return False
...@@ -63,8 +61,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -63,8 +61,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1: # if axis != -1 and axis != len(out_dims_mapping) - 1:
return False # return False
if is_dim_shard(out_dims_mapping[axis]): if is_dim_shard(out_dims_mapping[axis]):
return False return False
...@@ -72,6 +70,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -72,6 +70,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
...@@ -79,11 +81,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -79,11 +81,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(x_dims_mapping) - 1: # if axis != -1 and axis != len(x_dims_mapping) - 1:
return False # return False
if is_dim_shard(x_dims_mapping[axis]):
return False
if x_dims_mapping != out_dims_mapping: if x_dims_mapping != out_dims_mapping:
return False return False
...@@ -107,9 +106,13 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -107,9 +106,13 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return changed return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
pass DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl( register_distributed_operator_impl(
......
...@@ -22,22 +22,21 @@ from ..utils import is_valid_list_index ...@@ -22,22 +22,21 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedTranspose2(DistributedOperatorImplContainer): class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedTranspose2, self).__init__() super(DistributedTranspose2, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2")) DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl): class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__() super(DistributedTranspose2Impl, self).__init__(name)
self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = False self._backward_implemented = False
...@@ -48,6 +47,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -48,6 +47,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return True return True
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
perm = op_desc.attr('axis') perm = op_desc.attr('axis')
...@@ -111,9 +114,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -111,9 +114,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
pass DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl( register_distributed_operator_impl(
......
...@@ -20,18 +20,17 @@ from ..utils import set_dist_op_desc_original_id ...@@ -20,18 +20,17 @@ from ..utils import set_dist_op_desc_original_id
class DistributedUpdateLossScaling(DistributedOperatorImplContainer): class DistributedUpdateLossScaling(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, op_type):
super(DistributedUpdateLossScaling, self).__init__() super(DistributedUpdateLossScaling, self).__init__(op_type)
self._name = name
register_distributed_operator_impl_container( register_distributed_operator_impl_container(
"update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling")) DistributedUpdateLossScaling("update_loss_scaling"))
class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedUpdateLossScalingImpl, self).__init__() super(DistributedUpdateLossScalingImpl, self).__init__(name)
self._name = name self._name = name
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
...@@ -46,6 +45,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -46,6 +45,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
"DistributedUpdateLossScalingImpl's is_output_compatible should not be called !" "DistributedUpdateLossScalingImpl's is_output_compatible should not be called !"
) )
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedUpdateLossScalingImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op): def update_dims_mapping(self, dist_op):
raise RuntimeError( raise RuntimeError(
"DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !" "DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !"
......
...@@ -63,7 +63,6 @@ class Partitioner(object): ...@@ -63,7 +63,6 @@ class Partitioner(object):
def partition(self, serial_main_program, serial_startup_program, def partition(self, serial_main_program, serial_startup_program,
params_grads): params_grads):
if not isinstance(serial_main_program, (Program)): if not isinstance(serial_main_program, (Program)):
raise TypeError( raise TypeError(
"main_program be paddle.fluid.framework.program, got %s here" % "main_program be paddle.fluid.framework.program, got %s here" %
...@@ -87,7 +86,7 @@ class Partitioner(object): ...@@ -87,7 +86,7 @@ class Partitioner(object):
serial_main_program, serial_startup_program) serial_main_program, serial_startup_program)
dist_op_context.set_dst_startup_program(partitioned_startup_prog) dist_op_context.set_dst_startup_program(partitioned_startup_prog)
# partition main program # partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program( partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
serial_main_program, params_grads) serial_main_program, params_grads)
...@@ -282,7 +281,7 @@ def _get_dist_shape(var, dist_attr): ...@@ -282,7 +281,7 @@ def _get_dist_shape(var, dist_attr):
def _partition_parameter(dist_context, src_var, dst_block, dst_varname, def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape): dst_shape):
# NOTE hack to copied Parameter # NOTE hack to copied Parameter
# not initialized parameter, need to initialize it # not initialized parameter, need to initialize it
copied_kwargs = {} copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr copied_kwargs['optimize_attr'] = src_var.optimize_attr
...@@ -371,19 +370,19 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -371,19 +370,19 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type) dist_op_impl_container = get_distributed_operator_impl_container(
forward_op_dist_attr.impl_type)
# TODO backward should have its own impl_idx dist_op_impl = dist_op_impl_container.get_impl(
if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \ forward_op_dist_attr.impl_idx)
forward_op_dist_attr.impl_idx)._backward_implemented: return dist_op_impl
return dist_op.get_impl(forward_op_dist_attr.impl_idx)
# NOTE trick for dist ops that only have backward implement # # NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS: if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
dist_op = get_distributed_operator_impl_container(backward_op.type) assert op_dist_attr.impl_idx >= 0
if dist_op and op_dist_attr.impl_idx >= 0: dist_op_impl = get_distributed_operator_impl_container(
return dist_op.get_impl(op_dist_attr.impl_idx) backward_op.type).get_impl(op_dist_attr.impl_idx)
return dist_op_impl
dist_op = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0) return dist_op.get_impl(0)
...@@ -391,12 +390,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -391,12 +390,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def _get_dist_op_forward_implement(forward_op, dist_context): def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type) dist_op_impl_container = get_distributed_operator_impl_container(
dist_attr.impl_type)
if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl( dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
dist_attr.impl_idx)._forward_implemented: return dist_op_impl
return dist_op.get_impl(dist_attr.impl_idx)
else:
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
...@@ -28,7 +28,7 @@ from .cost_model import estimate_cost ...@@ -28,7 +28,7 @@ from .cost_model import estimate_cost
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .process_group import _g_process_group_map from .process_group import _g_process_group_map
from .process_group import ProcessGroup, get_process_group from .process_group import ProcessGroup, get_process_group
from .completion import is_elementwise_like_op from .operators.common import is_elementwise_op
from .operators.common import get_distributed_operator_impl_container from .operators.common import get_distributed_operator_impl_container
from .utils import update_op_dims_mapping_by_default_dist_impl from .utils import update_op_dims_mapping_by_default_dist_impl
from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl
...@@ -237,7 +237,7 @@ class PlanSpace: ...@@ -237,7 +237,7 @@ class PlanSpace:
dist_op = DistributedOperator(op, op_dist_attr) dist_op = DistributedOperator(op, op_dist_attr)
if dist_op_impl_container is None: if dist_op_impl_container is None:
if is_elementwise_like_op(op.type): if is_elementwise_op(op.type):
changed = True changed = True
valid = True valid = True
try: try:
...@@ -250,7 +250,8 @@ class PlanSpace: ...@@ -250,7 +250,8 @@ class PlanSpace:
op, dist_op.dist_attr, vars op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op( ) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars): op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -1 dist_op.dist_attr.impl_type = "elementwise"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
continue continue
else: else:
...@@ -266,16 +267,18 @@ class PlanSpace: ...@@ -266,16 +267,18 @@ class PlanSpace:
op, dist_op.dist_attr, vars op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op( ) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars): op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -2 dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
continue continue
# if op has distributed implements, find all valid dist attr of this op # if op has distributed implements, find all valid dist attr of this op
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
if impl.is_auto_compatible(dist_op): if impl.is_auto_compatible(dist_op):
if PlanFilter.check_dims_mapping_for_op( if PlanFilter.check_dims_mapping_for_op(
op, dist_op.dist_attr, vars): op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_type = dist_op.serial_op.type
dist_op.dist_attr.impl_idx = idx dist_op.dist_attr.impl_idx = idx
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
...@@ -290,7 +293,8 @@ class PlanSpace: ...@@ -290,7 +293,8 @@ class PlanSpace:
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
vars[var_name], [-1 for i in vars[var_name].shape]) vars[var_name], [-1 for i in vars[var_name].shape])
dist_op.dist_attr.impl_idx = -1 dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr) op_valid_dist_attrs.append(dist_op.dist_attr)
return op_valid_dist_attrs return op_valid_dist_attrs
......
...@@ -105,7 +105,7 @@ class TestAutoParallelAPI(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestAutoParallelAPI(unittest.TestCase):
self.assertEqual(dist_op.dist_attr.process_mesh, self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2)) ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default") self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2) self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
...@@ -138,7 +138,7 @@ class TestAutoParallelAPI(unittest.TestCase): ...@@ -138,7 +138,7 @@ class TestAutoParallelAPI(unittest.TestCase):
dist_op = dist_context.get_dist_op_for_program(last_op) dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None) self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default") self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2) self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh")) self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
......
...@@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program): ...@@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program return loss, train_program, start_program
class Testcompatible(unittest.TestCase): class TestCompatible(unittest.TestCase):
def test_matmulv2_matmul_2_compatible(self): def test_matmulv2_matmul_2_compatible(self):
valid_op_dist_attr_list = [] valid_op_dist_attr_list = []
program = paddle.static.Program() program = paddle.static.Program()
...@@ -123,7 +123,7 @@ class Testcompatible(unittest.TestCase): ...@@ -123,7 +123,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul': if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0] X = op.input_arg_names[0]
Y = op.input_arg_names[1] Y = op.input_arg_names[1]
...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase): ...@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible( self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible( self.assertFalse(impls[2].is_auto_compatible(
...@@ -220,7 +220,7 @@ class Testcompatible(unittest.TestCase): ...@@ -220,7 +220,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul': if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0] X = op.input_arg_names[0]
Y = op.input_arg_names[1] Y = op.input_arg_names[1]
...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase): ...@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible( self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
...@@ -307,7 +307,7 @@ class Testcompatible(unittest.TestCase): ...@@ -307,7 +307,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul': if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0] X = op.input_arg_names[0]
Y = op.input_arg_names[1] Y = op.input_arg_names[1]
...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase): ...@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1]) op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1]) op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible( self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1]) op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible( self.assertFalse(impls[0].is_auto_compatible(
......
...@@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program): ...@@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program return loss, train_program, start_program
class Testcompatible(unittest.TestCase): class TestCompatible(unittest.TestCase):
def test_raise_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'transpose2':
op_dist_attr = OperatorDistributedAttribute()
dist_op = DistributedOperator(op, op_dist_attr)
impls = DistributedOperatorImpl()
try:
impls.is_auto_compatible(dist_op)
except NotImplementedError:
e = False
self.assertTrue(e == False)
def test_reshape_remove_compatible(self): def test_reshape_remove_compatible(self):
valid_op_dist_attr_list = [] valid_op_dist_attr_list = []
program = paddle.static.Program() program = paddle.static.Program()
...@@ -124,7 +107,7 @@ class Testcompatible(unittest.TestCase): ...@@ -124,7 +107,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2': if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1]) [-1, -1, -1])
...@@ -172,64 +155,6 @@ class Testcompatible(unittest.TestCase): ...@@ -172,64 +155,6 @@ class Testcompatible(unittest.TestCase):
self.assertFalse(impls[1].is_auto_compatible( self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr))) DistributedOperator(op, op_dist_attr)))
def test_reshape_remove_two_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 0])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[0, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, 1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_add_compatible(self): def test_reshape_add_compatible(self):
valid_op_dist_attr_list = [] valid_op_dist_attr_list = []
program = paddle.static.Program() program = paddle.static.Program()
...@@ -240,7 +165,7 @@ class Testcompatible(unittest.TestCase): ...@@ -240,7 +165,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2': if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1]) op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0], op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
...@@ -298,7 +223,7 @@ class Testcompatible(unittest.TestCase): ...@@ -298,7 +223,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'transpose2': if op.type == 'transpose2':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1]) [-1, -1])
...@@ -349,7 +274,7 @@ class Testcompatible(unittest.TestCase): ...@@ -349,7 +274,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'softmax': if op.type == 'softmax':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1]) [-1, -1])
...@@ -379,7 +304,7 @@ class Testcompatible(unittest.TestCase): ...@@ -379,7 +304,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'c_embedding' or op.type == 'lookup_table_v2': if op.type == 'c_embedding' or op.type == 'lookup_table_v2':
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
op.type) op.type)
impls = dist_op_impl_container.get_impls() impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1]) [-1, -1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册