未验证 提交 9acc26ca 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the dist op interface and the compatible computation (#39014)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem
上级 2a9c993e
......@@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
......@@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
return changed
......
......@@ -61,6 +61,8 @@ class DistributedContext:
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}
# Distributed programs
self._dist_main_programs = {}
......@@ -80,6 +82,10 @@ class DistributedContext:
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes
@property
def process_meshes(self):
return self._process_meshes
......@@ -186,6 +192,18 @@ class DistributedContext:
else:
return None
# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
......@@ -218,6 +236,35 @@ class DistributedContext:
else:
return None
# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
......@@ -248,6 +295,44 @@ class DistributedContext:
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
return True
return False
ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
......@@ -257,7 +342,8 @@ class DistributedContext:
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
......@@ -397,7 +483,9 @@ class DistributedContext:
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
......
......@@ -98,7 +98,7 @@ class DistributedOperator:
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
self._dist_attr.impl_idx = 0
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False
......@@ -217,7 +217,8 @@ class DistributedOperator:
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)
return str
......
......@@ -23,5 +23,6 @@ from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_eltwise
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
......@@ -12,109 +12,196 @@
# See the License for the specific language governing permissions and
# limitations under the License
import abc
from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_registries = {}
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = ["elementwise_add", "gelu", "dropout", "cast"]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
def is_elementwise_op(op_type):
if op_type in _g_elementwise_ops:
return True
else:
return False
class DistributedOperatorImplContainer:
def __init__(self):
def __init__(self, op_type):
self._type = op_type
self._impls = []
self._name = None
@property
def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def impls(self):
return self._impls
def register_impl(self, dist_impl):
assert self.type == dist_impl.type, \
"Op type of container must be same as that of the implementation."
impl_idx = len(self.impls)
dist_impl.idx = impl_idx
self._impls.append(dist_impl)
def get_impl(self, impl_idx):
return self._impls[impl_idx]
def get_impls(self):
return self._impls
def get_input_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_input_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl:
def __init__(self):
self._name = None
def get_output_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_output_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
def get_compatible_impls(self, dist_op):
compatible_impls = []
for impl in self.impls:
if impl.is_auto_compatible(dist_op):
compatible_impls.append(impl)
return compatible_impls
class DistributedOperatorImpl(abc.ABC):
def __init__(self, name):
self._name = name
self._type = None
self._idx = None
self._forward_implemented = False
self._backward_implemented = False
@staticmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@property
def name(self):
return self._name
@staticmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@name.setter
def name(self, name):
self._name = name
def get_name(self):
return self._name
@property
def type(self):
return self._type
@type.setter
def type(self, op_type):
self._type = op_type
@property
def idx(self):
return self._idx
@idx.setter
def idx(self, impl_idx):
self._idx = impl_idx
@abc.abstractmethod
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
@abc.abstractmethod
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, dist_op):
return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
@abc.abstractmethod
def is_auto_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def forward(dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
@staticmethod
@abc.abstractmethod
def backward(dist_ctx, *grad_outputs, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator_impl_container(name, dist_op_impl_container):
global _g_distributed_operator_impl_registries
_g_distributed_operator_impl_registries[name] = dist_op_impl_container
def register_distributed_operator_impl_container(container):
global _g_distributed_operator_impl_containers
_g_distributed_operator_impl_containers[container.type] = container
def get_distributed_operator_impl_container(name):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries.get(name, None)
def get_distributed_operator_impl_container(op_type):
global _g_distributed_operator_impl_containers
return _g_distributed_operator_impl_containers.get(op_type, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op_impl_container = get_distributed_operator_impl_container(name)
def register_distributed_operator_impl(op_type, dist_impl):
dist_op_impl_container = get_distributed_operator_impl_container(op_type)
if dist_op_impl_container is not None:
dist_impl.type = op_type
dist_op_impl_container.register_impl(dist_impl)
else:
assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
def find_best_compatible_distributed_operator_impl(dist_op, fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is None:
return None, -1
op_type = dist_op.serial_op.type
dist_op_impl_container = get_distributed_operator_impl_container(op_type)
dist_op_eltwise_impl_container = get_distributed_operator_impl_container(
"elementwise")
dist_op_default_impl_container = get_distributed_operator_impl_container(
"default")
compatible_impls = []
impls = dist_op_impl_container.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_input_compatible(dist_op):
compatible_impls.append((impl, idx))
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_input_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_input_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_input_compatible_impls(
dist_op))
else:
for idx, impl in enumerate(impls):
if impl.is_output_compatible(dist_op):
compatible_impls.append((impl, idx))
# First, find impls in the corresponding container
if dist_op_impl_container:
compatible_impls.extend(
dist_op_impl_container.get_output_compatible_impls(dist_op))
# Second, find impls in the elementwise container
if dist_op_eltwise_impl_container and is_elementwise_op(op_type):
compatible_impls.extend(
dist_op_eltwise_impl_container.get_output_compatible_impls(
dist_op))
# Third, find impls in the default container
if dist_op_default_impl_container:
compatible_impls.extend(
dist_op_default_impl_container.get_output_compatible_impls(
dist_op))
if compatible_impls:
best_compatible_impl, idx = compatible_impls[0]
# For now, just return the first compatible impl
best_compatible_impl = compatible_impls[0]
else:
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
best_compatible_impl = None
return best_compatible_impl
def is_parameter_related(varname, block):
......
......@@ -30,19 +30,17 @@ global_process_mesh = get_world_process_group().ranks
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscale, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedCheckFiniteAndUnscale, self).__init__(op_type)
register_distributed_operator_impl_container(
"check_finite_and_unscale",
DistributedCheckFiniteAndUnscale("check_finite_and_unscale"))
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__()
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name)
self._name = name
self._forward_implemented = False
self._backward_implemented = True
......@@ -57,6 +55,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
"DistributedCheckFiniteAndUnscaleImpl's is_output_compatible should not be called !"
)
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedCheckFiniteAndUnscaleImpl's update_dims_mapping should not be called !"
......
......@@ -34,31 +34,162 @@ from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedDefault, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type)
register_distributed_operator_impl_container("default",
DistributedDefault("default"))
register_distributed_operator_impl_container(DistributedDefault("default"))
# Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__()
self._name = name
super(DistributedDefaultImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
return True
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
batch_dim_mappings = []
# Check input compatibility
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
# Check output compatibility
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for mapping in dims_mapping[1:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[0])
else:
if dims_mapping[0] != -1:
return False
if len(dims_mapping) > 2:
for mapping in dims_mapping[2:]:
if mapping != -1:
return False
batch_dim_mappings.append(dims_mapping[1])
# Check batch dim mapping compatibility
if not all(batch_dim_mappings[0] == dim_mapping
for dim_mapping in batch_dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
batch_dim_mappings.append(dims_mapping[0])
else:
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(
batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import is_elementwise_op
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
class DistributedElementwise(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedElementwise, self).__init__(op_type)
register_distributed_operator_impl_container(
DistributedElementwise("elementwise"))
# Replicated Elementwise
class DistributedElementwiseImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedElementwiseImpl0, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_desc = dist_op.serial_op.desc
if is_elementwise_op(op_desc.type()):
return True
else:
return False
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
dims_mapping_list = []
input_arg_names = op_desc.input_arg_names()
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
dims_mapping_list.append(dims_mapping)
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
for idx in range(max_dims_mapping_len):
dim_mappings = []
for dims_mapping in dims_mapping_list:
if idx < len(dims_mapping):
dim_mappings.append(dims_mapping[-(idx + 1)])
if not all(dim_mappings[0] == dim_mapping
for dim_mapping in dim_mappings):
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[
arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"elementwise", DistributedElementwiseImpl0("replicate_parallel"))
......@@ -34,22 +34,20 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedEmbedding, self).__init__(op_type)
register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container(
DistributedEmbedding("lookup_table_v2"))
register_distributed_operator_impl_container(
DistributedEmbedding("c_embedding"))
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
super(DistributedEmbeddingImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -81,6 +79,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
......@@ -89,18 +91,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
if w_dims_mapping[-1] != out_dims_mapping[-1]:
return False
if ids_dims_mapping != out_dims_mapping[:len(ids_dims_mapping)]:
return False
......@@ -248,6 +239,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# matmulv2
embedding_op_dist_attr = OperatorDistributedAttribute()
embedding_op_dist_attr.process_mesh = op_dist_attr.process_mesh
embedding_op_dist_attr.impl_type = op_dist_attr.impl_type
embedding_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_embedding_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
......@@ -266,6 +258,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
......
......@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
from .dist_default import DistributedDefaultImpl0
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
......@@ -143,6 +144,68 @@ def _update_dims_mapping_for_matmul(dist_op):
return changed
def _is_auto_compatible_for_matmul(dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
# Deep copy these dims_mappings for keeping them unchanged.
x_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(x_name))
y_dims_mapping = copy.deepcopy(op_dist_attr.get_input_dims_mapping(y_name))
out_dims_mapping = copy.deepcopy(
op_dist_attr.get_output_dims_mapping(out_name))
x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1)
if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1)
# Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2:
broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
is_same = ((broadcast_x_dims_mapping == broadcast_y_dims_mapping) and
(broadcast_x_dims_mapping == broadcast_out_dims_mapping))
if not is_same:
return False
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
is_same = (x_dims_mapping[-1] == y_dims_mapping[-2])
if not is_same:
return False
is_same = (x_dims_mapping[-2] == out_dims_mapping[-2])
if not is_same:
return False
is_same = (y_dims_mapping[-1] == out_dims_mapping[-1])
if not is_same:
return False
return True
def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
......@@ -194,10 +257,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
Y_var_dim_mapping = dist_attr.get_input_dims_mapping(Y_var.name)
process_mesh_shape = dist_attr.process_mesh.topology
process_mesh_group = dist_attr.process_mesh.processes
assert len(
Y_var_dim_mapping
) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
Y_var.name, Y_var_dim_mapping)
# assert len(
# Y_var_dim_mapping
# ) == 2, "dist matmual only support Y operand with 2 dims now but Y({})'s dim is [{}]".format(
# Y_var.name, Y_var_dim_mapping)
Y_var_partitioned = False
for dim in Y_var_dim_mapping:
if dim >= 0 and process_mesh_shape[dim] > 0:
......@@ -388,20 +451,17 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
class DistributedMatmul(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedMatmul, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedMatmul, self).__init__(op_type)
register_distributed_operator_impl_container("matmul",
DistributedMatmul("matmul"))
register_distributed_operator_impl_container(DistributedMatmul("matmul"))
# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__()
self._name = name
super(DistributedMatmulImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -414,8 +474,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[
-1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
......@@ -435,83 +495,11 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op):
......@@ -635,6 +623,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# c_identity
identity_op_dist_attr = OperatorDistributedAttribute()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
# input
input_varname = c_identity_op.desc.input_arg_names()[0]
......@@ -653,6 +642,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# matmul
matmul_op_dist_attr = OperatorDistributedAttribute()
matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
# input
for input_varname in matmul_op.desc.input_arg_names():
......@@ -692,8 +682,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__()
self._name = name
super(DistributedMatmulImpl1, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -729,93 +718,12 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('transpose_X') or op_desc.attr('transpose_Y'):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# for gpt2, x dims > y dims, this is a temporary solution
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
......@@ -933,6 +841,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# matmul
matmul_op_dist_attr = OperatorDistributedAttribute()
matmul_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmul_op_dist_attr.impl_type = op_dist_attr.impl_type
matmul_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in matmul_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
......@@ -951,6 +860,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
......@@ -980,8 +890,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__()
self._name = name
super(DistributedMatmulImpl2, self).__init__(name)
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......@@ -1020,56 +929,11 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
return True
......@@ -1081,6 +945,10 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
......@@ -1095,20 +963,17 @@ register_distributed_operator_impl("matmul",
class DistributedMatmulV2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedMatmulV2, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedMatmulV2, self).__init__(op_type)
register_distributed_operator_impl_container("matmul_v2",
DistributedMatmulV2("matmul_v2"))
register_distributed_operator_impl_container(DistributedMatmulV2("matmul_v2"))
# ColumnParallel
class DistributedMatmulV2Impl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl0, self).__init__()
self._name = name
super(DistributedMatmulV2Impl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -1121,8 +986,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
if is_dim_shard(y_dims_mapping[-2]) or is_dim_replicate(y_dims_mapping[
-1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
......@@ -1142,85 +1007,13 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_replicate(out_dims_mapping[-1]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
input_dims_mapping = []
ordered_input_shard_dims_mapping = []
for dim in (x_dims_mapping + y_dims_mapping):
input_dims_mapping.append(dim)
for item in input_dims_mapping:
if item not in ordered_input_shard_dims_mapping and item != -1:
ordered_input_shard_dims_mapping.append(item)
for mapping in out_dims_mapping:
if mapping not in input_dims_mapping:
return False
if is_dim_shard(x_dims_mapping[0]):
order_index = 0
for idx, item in enumerate(out_dims_mapping):
if item != -1:
if item != ordered_input_shard_dims_mapping[order_index]:
return False
else:
order_index += 1
if order_index != len(ordered_input_shard_dims_mapping):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_shard(x_dims_mapping[0]):
for mapping in y_dims_mapping[1:]:
if is_dim_shard(mapping) and mapping == x_dims_mapping[0]:
return False
return True
def update_dims_mapping(self, dist_op):
......@@ -1342,6 +1135,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# c_identity
identity_op_dist_attr = OperatorDistributedAttribute()
identity_op_dist_attr.process_mesh = op_dist_attr.process_mesh
identity_op_dist_attr.impl_type = op_dist_attr.impl_type
identity_op_dist_attr.impl_idx = op_dist_attr.impl_idx
# input
input_varname = c_identity_op.desc.input_arg_names()[0]
......@@ -1359,6 +1153,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in matmul_v2_op.desc.input_arg_names():
if input_varname in src_op.desc.input_arg_names():
......@@ -1395,8 +1190,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# RowParallel
class DistributedMatmulV2Impl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl1, self).__init__()
self._name = name
super(DistributedMatmulV2Impl1, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
......@@ -1432,93 +1226,13 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if op_desc.attr('trans_x') or op_desc.attr('trans_y'):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping), "now just support x dims > y dims"
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
x_shard_dim_count = 0
x_shard_dims = []
y_shard_dim_count = 0
y_shard_dims = []
for dim in x_dims_mapping:
if is_dim_shard(dim):
x_shard_dim_count += 1
x_shard_dims.append(dim)
for dim in y_dims_mapping:
if is_dim_shard(dim):
y_shard_dim_count += 1
y_shard_dims.append(dim)
if not x_shard_dims and not y_shard_dims:
return False
if x_shard_dims[-1] != y_shard_dims[0]:
return False
if x_shard_dim_count == y_shard_dim_count:
for dim in out_dims_mapping:
if is_dim_shard(dim):
return False
if x_shard_dims != y_shard_dims:
return False
else:
if x_shard_dim_count < y_shard_dim_count:
return False
output_shard_dims = []
for dim in out_dims_mapping:
if is_dim_shard(dim):
output_shard_dims.append(dim)
if not output_shard_dims or output_shard_dims[0] != x_shard_dims[0]:
return False
return True
def update_dims_mapping(self, dist_op):
......@@ -1631,6 +1345,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# matmulv2
matmulv2_op_dist_attr = OperatorDistributedAttribute()
matmulv2_op_dist_attr.process_mesh = op_dist_attr.process_mesh
matmulv2_op_dist_attr.impl_type = op_dist_attr.impl_type
matmulv2_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in matmul_v2_op.desc.input_arg_names():
input_dist_attr = op_dist_attr.get_input_dist_attr(input_varname)
......@@ -1649,6 +1364,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# allreduce
allreduce_op_dist_attr = OperatorDistributedAttribute()
allreduce_op_dist_attr.process_mesh = op_dist_attr.process_mesh
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname)
......@@ -1678,8 +1394,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# ReplicateParallel
class DistributedMatmulV2Impl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl2, self).__init__()
self._name = name
super(DistributedMatmulV2Impl2, self).__init__(name)
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
......@@ -1720,57 +1435,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
assert len(x_dims_mapping) >= len(
y_dims_mapping
), "now just support x dims > y dims,but x:{0} and y:{1}".format(
x_dims_mapping, y_dims_mapping)
if len(y_dims_mapping) != 2:
return False
if len(x_dims_mapping) == len(y_dims_mapping) and len(
x_dims_mapping) == 4:
if x_dims_mapping[:2] != y_dims_mapping[:2]:
return False
if x_dims_mapping[:2] != out_dims_mapping[:2]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
elif len(x_dims_mapping) != len(y_dims_mapping) and len(
x_dims_mapping) == 3:
if x_dims_mapping[0] != out_dims_mapping[0]:
return False
x_dims_mapping = x_dims_mapping[-2:]
y_dims_mapping = y_dims_mapping[-2:]
out_dims_mapping = out_dims_mapping[-2:]
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
if not _is_auto_compatible_for_matmul(dist_op):
return False
return True
......@@ -1782,6 +1451,10 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
_right_operand_parameter_matmul_backward(ctx, *args, **kwargs)
......
......@@ -27,22 +27,20 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from .dist_default import DistributedDefaultImpl0
class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedReshape2, self).__init__(op_type)
register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
register_distributed_operator_impl_container(DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
super(DistributedReshapeImpl0, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = False
......@@ -76,6 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -85,17 +87,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
x_shape_name)
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
for idx, item in enumerate(out_dims_mapping[:-2]):
if x_dims_mapping[idx] != item:
for idx, dim_mapping in enumerate(out_dims_mapping[:-1]):
if x_dims_mapping[idx] != dim_mapping:
return False
if out_dims_mapping[-2] != x_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
......@@ -194,13 +189,12 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
super(DistributedReshapeImpl1, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = False
......@@ -234,6 +228,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -244,24 +242,13 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
if len(x_dims_mapping) == len(out_dims_mapping) + 2:
if out_dims_mapping[0] != x_dims_mapping[0]:
return False
if x_dims_mapping[-1] != -1 or x_dims_mapping[-2] != -1:
return False
elif len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
for idx, item in enumerate(x_dims_mapping[:-2]):
for idx, item in enumerate(x_dims_mapping[:-1]):
if out_dims_mapping[idx] != item:
return False
if x_dims_mapping[-2] != out_dims_mapping[-1]:
return False
if x_shape_dims_mapping[0] != -1:
return False
......@@ -359,7 +346,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("reshape2",
......
......@@ -22,22 +22,20 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedSoftmax, self).__init__(op_type)
register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
register_distributed_operator_impl_container(DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
super(DistributedSoftmaxImpl, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
......@@ -48,8 +46,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
# if axis != -1 and axis != len(x_dims_mapping) - 1:
# return False
if is_dim_shard(x_dims_mapping[axis]):
return False
......@@ -63,8 +61,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1:
return False
# if axis != -1 and axis != len(out_dims_mapping) - 1:
# return False
if is_dim_shard(out_dims_mapping[axis]):
return False
......@@ -72,6 +70,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
......@@ -79,11 +81,8 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
# if axis != -1 and axis != len(x_dims_mapping) - 1:
# return False
if x_dims_mapping != out_dims_mapping:
return False
......@@ -107,9 +106,13 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
......
......@@ -22,22 +22,21 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from .dist_default import DistributedDefaultImpl0
class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedTranspose2, self).__init__(op_type)
register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
super(DistributedTranspose2Impl, self).__init__(name)
self._forward_implemented = False
self._backward_implemented = False
......@@ -48,6 +47,10 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
perm = op_desc.attr('axis')
......@@ -111,9 +114,13 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
pass
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
......
......@@ -20,18 +20,17 @@ from ..utils import set_dist_op_desc_original_id
class DistributedUpdateLossScaling(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedUpdateLossScaling, self).__init__()
self._name = name
def __init__(self, op_type):
super(DistributedUpdateLossScaling, self).__init__(op_type)
register_distributed_operator_impl_container(
"update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling"))
DistributedUpdateLossScaling("update_loss_scaling"))
class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedUpdateLossScalingImpl, self).__init__()
super(DistributedUpdateLossScalingImpl, self).__init__(name)
self._name = name
self._forward_implemented = False
self._backward_implemented = True
......@@ -46,6 +45,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
"DistributedUpdateLossScalingImpl's is_output_compatible should not be called !"
)
def is_auto_compatible(self, dist_op):
raise RuntimeError(
"DistributedUpdateLossScalingImpl's is_auto_compatible should not be called !"
)
def update_dims_mapping(self, dist_op):
raise RuntimeError(
"DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !"
......
......@@ -63,7 +63,6 @@ class Partitioner(object):
def partition(self, serial_main_program, serial_startup_program,
params_grads):
if not isinstance(serial_main_program, (Program)):
raise TypeError(
"main_program be paddle.fluid.framework.program, got %s here" %
......@@ -87,7 +86,7 @@ class Partitioner(object):
serial_main_program, serial_startup_program)
dist_op_context.set_dst_startup_program(partitioned_startup_prog)
# partition main program
# partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
serial_main_program, params_grads)
......@@ -282,7 +281,7 @@ def _get_dist_shape(var, dist_attr):
def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
# not initialized parameter, need to initialize it
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr
......@@ -371,19 +370,19 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type)
# TODO backward should have its own impl_idx
if dist_op and forward_op_dist_attr.impl_idx >= 0 and dist_op.get_impl( \
forward_op_dist_attr.impl_idx)._backward_implemented:
return dist_op.get_impl(forward_op_dist_attr.impl_idx)
dist_op_impl_container = get_distributed_operator_impl_container(
forward_op_dist_attr.impl_type)
dist_op_impl = dist_op_impl_container.get_impl(
forward_op_dist_attr.impl_idx)
return dist_op_impl
# NOTE trick for dist ops that only have backward implement
# # NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
dist_op = get_distributed_operator_impl_container(backward_op.type)
if dist_op and op_dist_attr.impl_idx >= 0:
return dist_op.get_impl(op_dist_attr.impl_idx)
assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container(
backward_op.type).get_impl(op_dist_attr.impl_idx)
return dist_op_impl
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
......@@ -391,12 +390,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_op = get_distributed_operator_impl_container(forward_op.type)
if dist_op and dist_attr.impl_idx >= 0 and dist_op.get_impl(
dist_attr.impl_idx)._forward_implemented:
return dist_op.get_impl(dist_attr.impl_idx)
else:
dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0)
dist_op_impl_container = get_distributed_operator_impl_container(
dist_attr.impl_type)
dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
return dist_op_impl
......@@ -28,7 +28,7 @@ from .cost_model import estimate_cost
from .dist_op import DistributedOperator
from .process_group import _g_process_group_map
from .process_group import ProcessGroup, get_process_group
from .completion import is_elementwise_like_op
from .operators.common import is_elementwise_op
from .operators.common import get_distributed_operator_impl_container
from .utils import update_op_dims_mapping_by_default_dist_impl
from .utils import update_op_dims_mapping_by_elementwise_like_dist_impl
......@@ -237,7 +237,7 @@ class PlanSpace:
dist_op = DistributedOperator(op, op_dist_attr)
if dist_op_impl_container is None:
if is_elementwise_like_op(op.type):
if is_elementwise_op(op.type):
changed = True
valid = True
try:
......@@ -250,7 +250,8 @@ class PlanSpace:
op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -1
dist_op.dist_attr.impl_type = "elementwise"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
continue
else:
......@@ -266,16 +267,18 @@ class PlanSpace:
op, dist_op.dist_attr, vars
) and PlanFilter.check_dims_mapping_for_special_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_idx = -2
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
continue
# if op has distributed implements, find all valid dist attr of this op
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
for idx, impl in enumerate(impls):
if impl.is_auto_compatible(dist_op):
if PlanFilter.check_dims_mapping_for_op(
op, dist_op.dist_attr, vars):
dist_op.dist_attr.impl_type = dist_op.serial_op.type
dist_op.dist_attr.impl_idx = idx
op_valid_dist_attrs.append(dist_op.dist_attr)
......@@ -290,7 +293,8 @@ class PlanSpace:
for var_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(
vars[var_name], [-1 for i in vars[var_name].shape])
dist_op.dist_attr.impl_idx = -1
dist_op.dist_attr.impl_type = "default"
dist_op.dist_attr.impl_idx = 0
op_valid_dist_attrs.append(dist_op.dist_attr)
return op_valid_dist_attrs
......
......@@ -105,7 +105,7 @@ class TestAutoParallelAPI(unittest.TestCase):
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
......@@ -138,7 +138,7 @@ class TestAutoParallelAPI(unittest.TestCase):
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
......
......@@ -96,7 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
class TestCompatible(unittest.TestCase):
def test_matmulv2_matmul_2_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -123,7 +123,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -174,7 +174,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
self.assertTrue(impls[2].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[2].is_auto_compatible(
......@@ -220,7 +220,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -261,7 +261,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, 1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, 1, -1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(Y, [0, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
......@@ -307,7 +307,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'matmul_v2' or op.type == 'matmul':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
X = op.input_arg_names[0]
Y = op.input_arg_names[1]
......@@ -362,7 +362,7 @@ class Testcompatible(unittest.TestCase):
op_dist_attr.set_input_dims_mapping(X, [-1, -1, -1, -1])
op_dist_attr.set_input_dims_mapping(Y, [-1, -1, -1, 1])
op_dist_attr.set_output_dims_mapping(out, [-1, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
self.assertTrue(impls[0].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(out, [0, -1, -1, 1])
self.assertFalse(impls[0].is_auto_compatible(
......
......@@ -96,24 +96,7 @@ def mlp_forward(train_program, start_program):
return loss, train_program, start_program
class Testcompatible(unittest.TestCase):
def test_raise_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'transpose2':
op_dist_attr = OperatorDistributedAttribute()
dist_op = DistributedOperator(op, op_dist_attr)
impls = DistributedOperatorImpl()
try:
impls.is_auto_compatible(dist_op)
except NotImplementedError:
e = False
self.assertTrue(e == False)
class TestCompatible(unittest.TestCase):
def test_reshape_remove_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -124,7 +107,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
......@@ -172,64 +155,6 @@ class Testcompatible(unittest.TestCase):
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_remove_two_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
startup_program = paddle.static.Program()
loss, program, start_program = mlp_forward(program, startup_program)
ops = program.global_block().ops
for idx, op in enumerate(ops):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1, -1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
[-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, -1])
dist_op = DistributedOperator(op, op_dist_attr)
self.assertTrue(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 0])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[0, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[1, -1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, 1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, -1, 1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, 1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_output_dims_mapping(op.output_arg_names[1],
[-1, -1, 1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[1, -1, -1])
self.assertFalse(impls[1].is_auto_compatible(
DistributedOperator(op, op_dist_attr)))
def test_reshape_add_compatible(self):
valid_op_dist_attr_list = []
program = paddle.static.Program()
......@@ -240,7 +165,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'reshape2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0], [-1])
op_dist_attr.set_output_dims_mapping(op.output_arg_names[0],
......@@ -298,7 +223,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'transpose2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......@@ -349,7 +274,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'softmax':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......@@ -379,7 +304,7 @@ class Testcompatible(unittest.TestCase):
if op.type == 'c_embedding' or op.type == 'lookup_table_v2':
dist_op_impl_container = get_distributed_operator_impl_container(
op.type)
impls = dist_op_impl_container.get_impls()
impls = dist_op_impl_container.impls
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.set_input_dims_mapping(op.input_arg_names[0],
[-1, -1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册