未验证 提交 a02532b5 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the interface and the underlying mechanisms (#36617)

* default dist op

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2

* unify api

* unify api

* todo

* update distop forward func

* update distop forward func

* auto parallel backward

* update dist op

* autoparallel backward

* add backward for embedding

* temp1

* temp2

* temp3

* temp4

* backward done1

* backward done2

* backward done3

* dist embedding remove mp mode

* dist matmul remove mp mode

* update dist embedding
『

* dist op init1

* dist op init 2

* update unitest

* context remove parallel mode

* partitioner remove parallel mode

* update unitest

* a more general method to support varying mesh in pipeline parallel

* support varying mesh in pipeline parallel

* embedding support varying mesh in pipeline parallel

* matmul support varying mesh in pipeline parallel

* default dist op support varying mesh in pipeline parallel

* dist attribute for startup program

* default dist op support varying mesh in pipeline parallel 2

* partitoner support varying mesh in pipeline parallel

* revise logic for auto compeletion

* revise framework.py

* revise reshard unitest

* revise unitest for parallelize

* chmod

* fixed bug for dist embedding name mapping

* Improve the interface and the underlying mechanisms of auto parallel

* revise completion for backward

* revise completion for update

* revise completion for update

* update unitest

* chmod

* bugfix for grad_op output var's mesh

* Modify codes for pr 36744

* Remove unnecessary comments in framework.py

* Remove unnecessary comments in completion.py
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: NJZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com>
上级 2e40cfb5
...@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401 ...@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401
from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401 from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401
from .fleet import BoxPSDataset # noqa: F401 from .fleet import BoxPSDataset # noqa: F401
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401 from .process_mesh import ProcessMesh
from .interface import set_offload_device # noqa: F401 # from .interface import set_shard_mask # noqa: F401
from .interface import set_pipeline_stage # noqa: F401 # from .interface import set_offload_device # noqa: F401
from .interface import ProcessMesh # noqa: F401 # from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401 from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401 from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401 from .reshard import reshard # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
此差异已折叠。
...@@ -131,7 +131,7 @@ class TensorCostNode(CostNode): ...@@ -131,7 +131,7 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.int64: elif node.dtype == paddle.int64:
self.dtype_factor *= 8 self.dtype_factor *= 8
else: else:
raise NotImplementedError("{} not counted".format(v.node.dtype)) raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None self.batch_size = None
if batch_size is not None: if batch_size is not None:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid.framework import Variable
from .process_mesh import ProcessMesh
_g_tensor_dist_attr_field_keys = [
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]
_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_input_suffix = "@input"
_g_op_output_suffix = "@output"
def get_tensor_dist_attr_field_keys():
global _g_tensor_dist_attr_field_keys
return _g_tensor_dist_attr_field_keys
def get_op_dist_attr_field_keys():
global _g_op_dist_attr_field_keys
return _g_op_dist_attr_field_keys
def append_op_input_suffix(name):
global _g_op_input_suffix
return name + _g_op_input_suffix
def append_op_output_suffix(name):
global _g_op_output_suffix
return name + _g_op_output_suffix
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
self._shard_sizes = None
self._device_placement = None
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
@property
def dims_mapping(self):
return self._dims_mapping
@dims_mapping.setter
def dims_mapping(self, dims_mapping):
if dims_mapping is not None:
assert isinstance(dims_mapping, list), \
"The type of dims_mapping must be list."
assert all(isinstance(x, int) for x in dims_mapping), \
("All elements of dims_mapping must be integer")
assert all(x >= -1 for x in dims_mapping), \
("All elements of dims_mapping must be greater than or equal to -1.")
self._dims_mapping = copy.deepcopy(dims_mapping)
@property
def shard_sizes(self):
return self._shard_sizes
@shard_sizes.setter
def shard_sizes(self, shard_sizes):
if shard_sizes is not None:
self._shard_sizes = copy.deepcopy(shard_sizes)
@property
def device_placement(self):
return self._device_placement
@device_placement.setter
def device_placement(self, device_placement):
if device_placement is not None:
self._device_placement = copy.deepcopy(device_placement)
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key in dist_attr.keys():
if key in get_tensor_dist_attr_field_keys():
self.mark_annotated(key)
elif isinstance(dist_attr, TensorDistributedAttribute):
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def clear_annotated(self):
self._is_annotated.clear()
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tdims_mapping ({}): {}".format(annotated_str,
self.dims_mapping)
str += "\n\t}"
return str
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
@property
def impl_type(self):
return self._impl_type
@impl_type.setter
def impl_type(self, impl_type):
if impl_type is not None:
self._impl_type = impl_type
@property
def impl_idx(self):
return self._impl_idx
@impl_idx.setter
def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx
@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
@property
def outputs_dist_attrs(self):
return self._outputs_dist_attrs
def get_input_dist_attr(self, name):
return self._inputs_dist_attrs.get(name, None)
def set_input_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
def set_output_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
dims_mapping = input_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_input_dims_mapping(self, name, dims_mapping):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
input_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._inputs_dist_attrs[name] = dist_attr
def get_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
dims_mapping = output_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_output_dims_mapping(self, name, dims_mapping):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
output_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._outputs_dist_attrs[name] = dist_attr
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.init(value)
if dist_attr.get(append_op_input_suffix(key.name), False):
self.set_input_dist_attr(key.name, tensor_dist_attr)
if dist_attr.get(append_op_output_suffix(key.name), False):
self.set_output_dist_attr(key.name, tensor_dist_attr)
else:
if key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, OperatorDistributedAttribute):
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
self.set_input_dist_attr(
tensor_name, dist_attr.get_input_dist_attr(tensor_name))
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
self.set_output_dist_attr(
tensor_name, dist_attr.get_output_dist_attr(tensor_name))
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
for key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
for tensor_dist_attr in self.outputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
shared_process_mesh = None
for process_mesh in process_meshes:
if process_mesh is not None:
if shared_process_mesh is None:
shared_process_mesh = process_mesh
else:
assert process_mesh == shared_process_mesh, \
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
self._is_annotated[attr_name] = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
else:
self._is_annotated[attr_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
input_dist_attr = self.get_input_dist_attr(key.name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(value)
output_dist_attr = self.get_output_dist_attr(key.name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(value)
else:
if key in get_op_dist_attr_field_keys():
self.mark_annotated(key)
process_mesh_annotated = False
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.outputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
elif isinstance(dist_attr, OperatorDistributedAttribute):
process_mesh_annotated = False
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
input_dist_attr = self.get_input_dist_attr(tensor_name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(tensor_dist_attr)
if input_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
output_dist_attr = self.get_output_dist_attr(tensor_name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(tensor_dist_attr)
if output_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
def clear_annotated(self):
self._is_annotated.clear()
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
def is_annotated_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
return input_dist_attr.is_annotated("dims_mapping")
else:
return False
def is_annotated_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
return output_dist_attr.is_annotated("dims_mapping")
else:
return False
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\timpl type: {}, ".format(self._impl_type)
str += "impl idx: {}".format(self._impl_idx)
str += "\n\t}"
return str
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
def get_default_distributed_context():
global _g_default_distributed_context
if _g_default_distributed_context is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return _g_default_distributed_context
def set_default_distributed_context(dist_context):
global _g_default_distributed_context
_g_default_distributed_context = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self, program=None):
self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
@property
def serial_program(self):
return self._serial_program
@property
def serial_graph(self):
return self._serial_graph
@serial_program.setter
def serial_program(self, program):
assert self._serial_program is None, \
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def process_meshes(self):
return self._process_meshes
@property
def dist_op_context(self):
return self._dist_op_context
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
if process_mesh not in self.process_meshes:
self._process_meshes.append(process_mesh)
def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None)
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None)
def get_dist_op_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
# Copy the dist tensors and dist ops annotated by users from the default context
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program(
tensor)
if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor)
self.add_dist_tensor_for_program(dist_tensor)
for op in block.ops:
# Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op)
current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id,
None)
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
dist_op = self._dist_ops_for_program.get(op_id, None)
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr)
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear()
self._dist_ops_for_program.clear()
def clear_dist_info_for_graph(self):
self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear()
def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
node)
dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program:
assert False, \
"Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks:
for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
return True
class DistributedOperatorContext:
"""
DistributedOperatorContext is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
self._serial_outputs = {}
self._dist_attr = None
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_op(self):
return self._serial_op
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = OperatorDistributedAttribute()
# Create new dist_attr related to current serial_op
dist_attr = self._filter_dist_attr(dist_attr)
# Append suffix to mark the inputs or outputs
if isinstance(dist_attr, dict):
# Copy the keys since we may add new ones
for key in list(dist_attr.keys()):
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names:
dist_attr[append_op_input_suffix(key.name)] = True
if key.name in self._serial_op.output_arg_names:
dist_attr[append_op_output_suffix(key.name)] = True
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def get_serial_input(self, name):
return self._serial_inputs.get(name, None)
def get_serial_output(self, name):
return self._serial_outputs.get(name, None)
def _init_default_dist_attr(self):
for tensor_name in self._serial_op.input_arg_names:
if self._serial_op.type == "create_py_reader":
tensor = None
else:
tensor = self._serial_op.block._var_recursive(tensor_name)
self._serial_inputs[tensor_name] = tensor
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
return None
new_dist_attr = None
if isinstance(dist_attr, dict):
new_dist_attr = {}
for key, value in dist_attr.items():
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \
or key.name in self._serial_op.output_arg_names:
new_dist_attr[key] = value
else:
new_dist_attr[key] = value
elif isinstance(dist_attr, OperatorDistributedAttribute):
new_dist_attr = copy.deepcopy(dist_attr)
new_dist_attr._inputs_dist_attrs.clear()
new_dist_attr._outputs_dist_attrs.clear()
for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name,
tensor_dist_attr)
for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name,
tensor_dist_attr)
else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type:
return True
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
return False
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(),
self.serial_op.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_input(arg_name) is not None:
if self.get_serial_input(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_output(arg_name) is not None:
if self.get_serial_output(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
return str
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops)
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys
class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None):
self._serial_tensor = serial_tensor
self._dist_attr = None
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_tensor(self):
return self._serial_tensor
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = TensorDistributedAttribute()
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
return False
for i in range(len(self.dist_attr.dims_mapping)):
if self.dist_attr.dims_mapping[
i] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
str += ", is_parameter: {}".format(self.serial_tensor.is_parameter)
if self.dist_attr.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.dist_attr.dims_mapping)
if self.dist_attr.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str, None)
if self.dist_attr.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str, None)
return str
...@@ -18,293 +18,34 @@ import paddle ...@@ -18,293 +18,34 @@ import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from .dist_context import get_default_distributed_context
__all__ = [] from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
# a map from ProcessMesh ids to the ProcessMesh instances from .dist_attribute import TensorDistributedAttribute
_g_process_mesh_map = dict() from .dist_attribute import OperatorDistributedAttribute
# user defined map from logical process ids to physical ones
_user_defined_physical_map = None
def _append_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def _remove_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def _static_mode_check(): def _static_mode_check():
if in_dygraph_mode(): if in_dygraph_mode():
raise RuntimeError("Auto-parallel only supports static mode, " raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static().") "please use paddle.enable_static() first.")
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
parent (ProcessMesh, optional): the parent ProcessMesh. None means
the ProcessMesh is the root one without parent ProcessMesh.
Default: None.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.parent is None
assert mesh.topology == [2, 3]
assert mesh.process_group == [2, 4, 5, 0, 1, 3]
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
def __init__(self, mesh, parent=None):
_static_mode_check()
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
self._topology = _get_nested_list_shape(mesh)
self._processes = _flatten_nested_list(mesh)
# Every element of mesh must be >= 0.
assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.')
unique_ids = set(self._processes)
assert len(unique_ids) == len(self._processes), (
'All elements of mesh must be unique.')
if parent is None:
# For root ProcessMesh, the ids of logical processes must be range
# from 0 to N-1, where N is the number of logical processes.
assert max(self._processes) == len(self._processes) - 1, (
'For root ProcessMesh, ids of logical processes must be range '
'from 0 to N-1, where N is the number of logical processes.')
parent_id = core.kNoneProcessMeshIndex()
assert len(_g_process_mesh_map.keys()) == 0, (
'The first ProcessMesh must be the root, which has no parent.')
else:
assert len(_g_process_mesh_map.keys()) > 0, (
'All ProcessMesh must have a parent except the root one.')
assert isinstance(parent, ProcessMesh), (
'parent must be an instance of ProcessMesh.')
parent_id = parent._desc.id
# All elements in mesh must belong to its parent
parent_ids = set(parent.process_group)
assert unique_ids <= parent_ids, (
'All elements in mesh must belong to its parent.')
self._desc = core.ProcessMeshDesc(self._topology, self._processes,
parent_id)
self._id = self._desc.id
self._parent_id = parent_id
assert self._id not in _g_process_mesh_map, (
"The ProcessMesh with id %d already exists." % self._id)
_g_process_mesh_map[self._id] = self
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def process_group(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def parent(self):
r"""
Get the parent ProcessMesh.
"""
if self._parent_id == core.kNoneProcessMeshIndex(): return None
assert self._parent_id in _g_process_mesh_map, (
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
user defined order.
Args:
order (list): order of the physical process ids.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
assert self.parent is None, (
"This function can only be called by the root ProcessMesh.")
unique_ids = set(order)
assert isinstance(order, list)
assert len(unique_ids) == len(order), (
"All elements in order must be unique.")
assert min(order) == 0
assert max(order) == len(order) - 1, (
"All elements in order must be from 0 to N - 1, where N "
"is the number of physical processes.")
logical_order = self.process_group
global _user_defined_physical_map
assert _user_defined_physical_map is None, (
"This function can only be called once.")
_user_defined_physical_map = dict()
assert len(logical_order) == len(order)
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping): def shard_tensor(x, dist_attr=None):
assert isinstance(mesh,
ProcessMesh), 'The type of mesh must be ProcessMesh.'
assert isinstance(dim_mapping,
list), 'The type of dim_mapping must be list.'
assert len(tensor.shape) == len(dim_mapping), (
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.')
mesh_dim = len(mesh.topology)
dim_set = set()
for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (
dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), (
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.')
if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i])
def shard_tensor(x, mesh, dim_mapping):
""" """
Add distributed attributes for a tensors. Add distributed attributes for a tensors.
Args: Args:
x (Tensor): the tensor to process. x (Tensor): the tensor to be sharded.
mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes. dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
dim_mapping (list): a list to describe the mapping between `x` and `mesh`, "process_mesh": a nested list an to describe the mesh topology of logical processes.
the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means "dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
without parition along the corresponding dimension. `i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
Returns: Returns:
Tensor: the tensor `x` itself. Tensor: the tensor `x` annotated with distributed attributes.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping): ...@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping):
paddle.enable_static() paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [0, -1])
"""
_static_mode_check()
_dim_mapping_checker(x, mesh, dim_mapping)
attr_name = _append_attr_suffix('mesh_id')
x._set_attr(attr_name, mesh._id)
attr_name = _append_attr_suffix('dim_mapping')
x._set_attr(attr_name, dim_mapping)
return x
def set_shard_mask(x, mask):
"""
Set the mask for a tensor which mask out the tensor from some processes in its mesh.
Args:
x (Tensor): the tensor to process.
mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to
the tensor `x`. Every value of `mask` must be one or zero, where one means
the tenor `x` will be put on the corresponding logical process and zero means the tensor `x`
will not be put on the corresponding logical process.
For example, for a ProcessMesh represented by the 2-dimensional
array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the
2-dimensional [[1, 0, 1], [0, 1, 0]],
then the tensor `x` will only be put on logical processes 2, 5 and 1.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1]) dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
dist.set_shard_mask(x, mask) "dims_mapping": [0, -1]})
""" """
_static_mode_check() _static_mode_check()
assert isinstance(mask, list) assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
np_mask = numpy.array(mask) "The type of dist_attr must be None, dict or TensorDistributedAttribute."
min_ele = numpy.min(np_mask) dist_tensor = DistributedTensor(x, dist_attr)
max_ele = numpy.max(np_mask) dist_tensor.dist_attr.mark_annotated_as(dist_attr)
mesh_attr_name = _append_attr_suffix('mesh_id') default_dist_ctx = get_default_distributed_context()
assert x._has_attr(mesh_attr_name), \ default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
"Please set process mesh for the variable firstly."
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
return x return x
def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): def shard_op(op_fn, dist_attr=None):
""" """
Call a functioin and add distributed attributes for ops added by the function. Call a functioin and add distributed attributes for ops added by the function.
Args: Args:
op_fn (callable): a callable object of an API. op_fn (callable): a callable operator or module to be sharded.
mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes. dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping. two categories. The first category decsribes the distributed attributes shared by all inputs and
The dim_mapping is a list to describe the mapping between a tensor and `mesh`, outputs, and only `process_mesh` can be specified now. The second category describes distributed
the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`, attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
where -1 means without parition along the corresponding dimension. optional and users can specify them as need. Note that `process_mesh` for operators must be the
kwargs (dict): a dict of parameter passed to the function `op_fn`. same as these process_meshes for inputs and outputs.
Returns: Returns:
list: the outputs of the function `op_fn`. list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): ...@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
paddle.enable_static() paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
y = paddle.zeros([4, 6]) y = paddle.zeros([4, 6])
kwargs = {'x': x, 'y': y} dist_add = dist.shard_op(paddle.add,
dist.shard_op(paddle.add, mesh, None, **kwargs) dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
""" x: {"dims_mapping": [-1, 0]},
_static_mode_check() y: {"dims_mapping": [0, -1]}
main_prog = paddle.fluid.default_main_program() })
main_block = main_prog.global_block() dist_add(x, y)
op_size = len(main_block.ops)
output = op_fn(**kwargs)
new_op_size = len(main_block.ops)
if dim_mapping_dict is None:
dim_mapping_dict = dict()
else:
assert isinstance(dim_mapping_dict,
dict), 'The type of dim_mapping_dict must be dict.'
for var_name in dim_mapping_dict.keys():
dim_mapping = dim_mapping_dict[var_name]
tensor = main_block.var(var_name)
_dim_mapping_checker(tensor, mesh, dim_mapping)
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id')
op._set_attr(attr_name, mesh._id)
for var_name in dim_mapping_dict.keys():
assert var_name in op.output_arg_names + op.input_arg_names
attr_name = _append_attr_suffix(var_name)
if var_name in op.input_arg_names:
# we use the prefix "IN_" to indicates an input argument name
attr_name = "IN_" + attr_name
else:
# we use the prefix "OUT_" to indicates an input argument name
attr_name = "OUT_" + attr_name
op._set_attr(attr_name, dim_mapping_dict[var_name])
if isinstance(output, Variable):
output = [output]
return list(output)
def set_offload_device(x, device):
"""
Set the device that the tensor `x` will be put on.
Args:
x (tensor): the tensor to process.
device (str): the device that the tensor `x` will be put on, e.g., 'cpu'.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
x = paddle.ones([4, 6])
dist.set_offload_device(x, 'cpu')
"""
_static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
def set_pipeline_stage(stage):
"""
Set the pipeline stage of the following ops.
Args:
stage (int): the pipeline stage the following ops belonging to.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
dist.set_pipeline_stage(0)
""" """
from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check() _static_mode_check()
assert isinstance(stage, int), 'The type of stage must be int.' assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
_set_pipeline_stage(stage) "The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding from . import dist_embedding
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
DISTRIBUTED_OPERATORS = {} _g_distributed_operator_impl_registries = {}
class DistributedOperator: class DistributedOperatorImplContainer:
def __init__(self): def __init__(self):
self._impls = [] self._impls = []
self._name = None self._name = None
...@@ -47,67 +47,60 @@ class DistributedOperatorImpl: ...@@ -47,67 +47,60 @@ class DistributedOperatorImpl:
def get_name(self): def get_name(self):
return self._name return self._name
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr): def is_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op): def register_distributed_operator_impl_container(name, dist_op_impl_container):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
DISTRIBUTED_OPERATORS[name] = dist_op _g_distributed_operator_impl_registries[name] = dist_op_impl_container
def get_distributed_operator(name): def get_distributed_operator_impl_container(name):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
return DISTRIBUTED_OPERATORS.get(name, None) return _g_distributed_operator_impl_registries.get(name, None)
def register_distributed_operator_impl(name, dist_impl): def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name) dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op is not None: if dist_op_impl_container is not None:
dist_op.register_impl(dist_impl) dist_op_impl_container.register_impl(dist_impl)
else: else:
assert False, "Must register distributed operator first." assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx): def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx) return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr, def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
fwd=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
""" """
dist_op = get_distributed_operator(name) dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op is None: if dist_op_impl_container is None:
return None, -1 return None, -1
compatible_impls = [] compatible_impls = []
impls = dist_op.get_impls() impls = dist_op_impl_container.get_impls()
if fwd: if fwd:
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \ if impl.is_input_compatible(dist_op):
and impl.is_input_compatible(op_dist_attr):
compatible_impls.append((impl, idx)) compatible_impls.append((impl, idx))
else: else:
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \ if impl.is_output_compatible(dist_op):
and impl.is_output_compatible(op_dist_attr):
compatible_impls.append((impl, idx)) compatible_impls.append((impl, idx))
if compatible_impls: if compatible_impls:
...@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr, ...@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
return best_compatible_impl, idx return best_compatible_impl, idx
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var): # def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var):
""" # """
copy src var's dist_attr to dst var # copy src var's dist_attr to dst var
""" # """
import copy # import copy
auto_paralle_context = src_op_dist_attr.get_owner_context() # auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy( # dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) # auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var # dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( # dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context # src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) # auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
"""
copy src var's dist_attr to dst var
"""
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
# """
# copy src op's dist_attr to dst dist op
# """
# from ..attribute import OperatorDistributedAttribute
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
# op_dist_attr)
# auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
# op_dist_attr)
# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
# for input_varname in dist_op.desc.input_arg_names():
# input_var = dst_block.var(input_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# input_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
# for output_varname in dist_op.desc.output_arg_names():
# output_var = dst_block.var(output_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# output_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_output_dims_mapping(output_varname,
# tensor_dims_mapping)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr):
""" """
copy src op's dist_attr to dst dist op copy src op's dist_attr to dst dist op
""" """
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
# need check dist op attr and its inputs and outputs
auto_paralle_context = src_op_dist_attr.get_owner_context() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
for input_varname in dist_op.desc.input_arg_names(): for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname) input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var) input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
for output_varname in dist_op.desc.output_arg_names(): for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname) output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
output_var) output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping) dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -22,26 +22,27 @@ from ..utils import is_valid_list_index ...@@ -22,26 +22,27 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperator): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedDefault, self).__init__() super(DistributedDefault, self).__init__()
self._name = name self._name = name
register_distributed_operator("default", DistributedDefault("default")) register_distributed_operator_impl_container("default",
DistributedDefault("default"))
# Replicated Default # Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl): class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__() super(DistributedDefaultImpl0, self).__init__()
...@@ -49,29 +50,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -49,29 +50,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
def is_output_compatible(self, op_dist_attr): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_helper.get_varname_mapping() varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name)
...@@ -100,26 +98,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -100,26 +98,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in dist_op_desc.input_arg_names(): for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var( if startup_block.has_var(varname) and startup_block.var(
varname varname
).is_parameter and varname not in dist_op_helper.already_init_sync_vars: ).is_parameter and varname not in dist_op_context.already_init_sync_vars:
dist_op_helper.already_init_sync_vars.add(varname) dist_op_context.already_init_sync_vars.add(varname)
param = startup_block.var(varname) param = startup_block.var(varname)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program( param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
param) process_mesh = param_dist_attr.process_mesh
process_mesh = param_dist_attr.get_process_mesh() dims_mapping = param_dist_attr.dims_mapping
dims_mapping = param_dist_attr.get_dims_mapping()
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group: if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dims_mapping: if size <= 1 or axis in dims_mapping:
pass pass
else: else:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.process_group, process_mesh.topology, process_mesh.topology,
axis, rank_id) axis, rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
new_op = startup_block.append_op( new_op = startup_block.append_op(
...@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
}) })
# set distributed attribute # set distributed attribute
op_attr = OperatorDistributedAttribute(new_op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(param.name, op_attr.set_output_dims_mapping(param.name,
dims_mapping) dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping) op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(new_op, op_attr) ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -147,16 +145,16 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -147,16 +145,16 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op() backward_op = dist_op_context.get_cur_src_op()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check if need gradient allreduce # check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited, # if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output # we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce = False need_gradient_allreduce = False
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
...@@ -165,20 +163,21 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -165,20 +163,21 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname).is_parameter: varname).is_parameter:
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group: if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group( group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.process_group, process_mesh.topology, process_mesh.topology,
batch_size_axis, rank_id) batch_size_axis, rank_id)
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
break break
...@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
dims_mapping = ctx.get_tensor_distributed_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).get_dims_mapping() grad_var).dims_mapping
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping) dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping) dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp() main_block._sync_with_cpp()
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op from .common import copy_distributed_attr_for_dist_op
...@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index ...@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperator): class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbedding, self).__init__() super(DistributedEmbedding, self).__init__()
self._name = name self._name = name
register_distributed_operator("lookup_table_v2", register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding")) DistributedEmbedding("embedding"))
register_distributed_operator("c_embedding", DistributedEmbedding("embedding")) register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding"))
# RowParallel # RowParallel
...@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0] ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0] w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
...@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
...@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0] ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0] w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
...@@ -111,16 +111,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -111,16 +111,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# check validation of inputs / outputs # check validation of inputs / outputs
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W') assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out' in kwargs, "output [{}] is not given".format('Out') assert 'Out' in kwargs, "output [{}] is not given".format('Out')
...@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_var.name)[0] Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping) embedding_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group: if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# A generalized method to caculate embedding offset using cartisian product # A generalized method to caculate embedding offset using cartisian product
...@@ -162,7 +162,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -162,7 +162,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
per_part_size = Weight_var.shape[0] per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size relative_idx = relative_idx * per_part_size
# TODO caculate ring id # TODO caculate ring id
parallel_axis = embedding_row_dim_mapping parallel_axis = embedding_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id) parallel_axis, rank_id)
...@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
stop_gradient=Out_var.stop_gradient) stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr # copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
check_variable_and_dtype( check_variable_and_dtype(
Out_var, 'tensor', Out_var, 'tensor',
...@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}) })
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block,
op_dist_attr) op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr) op_dist_attr)
# param initialization sync # param initialization sync
assert Weight_var.name not in dist_op_helper.already_init_sync_vars assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_helper.already_init_sync_vars.add(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name) param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh() process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.get_dims_mapping() dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis, process_mesh.topology, axis,
rank_id) rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
...@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op() backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group: if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id) rank_id)
# check if need gradient allreduce # check if need gradient allreduce
...@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD']) kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block.var(kwargs['Ids'][0])
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, process_mesh.topology,
batch_size_axis, rank_id) batch_size_axis, rank_id)
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
...@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).get_dims_mapping() W_Grad_var).dims_mapping
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl("lookup_table_v2",
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard ...@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator): class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedReshape2, self).__init__() super(DistributedReshape2, self).__init__()
self._name = name self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2")) register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl): class DistributedReshapeImpl0(DistributedOperatorImpl):
...@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
...@@ -103,15 +103,15 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -103,15 +103,15 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name)
...@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
...@@ -231,15 +230,15 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -231,15 +230,15 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format( assert input_name in kwargs, "input [{}] is not given".format(
input_name) input_name)
...@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping ...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator): class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedSoftmax, self).__init__() super(DistributedSoftmax, self).__init__()
self._name = name self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax")) register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl): class DistributedSoftmaxImpl(DistributedOperatorImpl):
...@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
...@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping ...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator): class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedTranspose2, self).__init__() super(DistributedTranspose2, self).__init__()
self._name = name self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2")) register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl): class DistributedTranspose2Impl(DistributedOperatorImpl):
...@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """
return True return True
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
return True return True
def is_output_compatible(self, op_dist_attr): def update_dims_mapping(self, dist_op):
return True
def update_dims_mapping(self, op_dist_attr):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
import paddle import paddle
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from .context import DistributedContext from .dist_context import DistributedContext
from .context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process import get_all_process_groups from .process_group import get_all_process_groups
from .utils import make_data_unshard from .utils import make_data_unshard
from .reshard import reshard from .reshard import reshard
...@@ -70,7 +70,6 @@ class AutoParallelizer: ...@@ -70,7 +70,6 @@ class AutoParallelizer:
# Annotation completion # Annotation completion
completed_main_program = complete_annotation( completed_main_program = complete_annotation(
self._original_main_program, self._dist_context) self._original_main_program, self._dist_context)
# Logical partition # Logical partition
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
......
...@@ -19,62 +19,32 @@ from ..collective import _new_ring_id ...@@ -19,62 +19,32 @@ from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None _g_process_group_map = {}
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
def get_all_process_groups(): def get_all_process_groups():
global PROCESS_GROUP_MAP global _g_process_group_map
return PROCESS_GROUP_MAP.values() return _g_process_group_map.values()
def new_process_group(ranks): def new_process_group(ranks):
global PROCESS_GROUP_MAP global _g_process_group_map
if not PROCESS_GROUP_MAP: if not _g_process_group_map:
genv = _get_global_env() genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup( _g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size))) 0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map # A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks))) key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP: if key not in _g_process_group_map:
num_groups = len(PROCESS_GROUP_MAP) num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation # Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id() # so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1 group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks) pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg _g_process_group_map[key] = pg
return pg return pg
else: else:
pg = PROCESS_GROUP_MAP[key] pg = _g_process_group_map[key]
return pg return pg
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册