未验证 提交 a02532b5 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the interface and the underlying mechanisms (#36617)

* default dist op

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2

* unify api

* unify api

* todo

* update distop forward func

* update distop forward func

* auto parallel backward

* update dist op

* autoparallel backward

* add backward for embedding

* temp1

* temp2

* temp3

* temp4

* backward done1

* backward done2

* backward done3

* dist embedding remove mp mode

* dist matmul remove mp mode

* update dist embedding
『

* dist op init1

* dist op init 2

* update unitest

* context remove parallel mode

* partitioner remove parallel mode

* update unitest

* a more general method to support varying mesh in pipeline parallel

* support varying mesh in pipeline parallel

* embedding support varying mesh in pipeline parallel

* matmul support varying mesh in pipeline parallel

* default dist op support varying mesh in pipeline parallel

* dist attribute for startup program

* default dist op support varying mesh in pipeline parallel 2

* partitoner support varying mesh in pipeline parallel

* revise logic for auto compeletion

* revise framework.py

* revise reshard unitest

* revise unitest for parallelize

* chmod

* fixed bug for dist embedding name mapping

* Improve the interface and the underlying mechanisms of auto parallel

* revise completion for backward

* revise completion for update

* revise completion for update

* update unitest

* chmod

* bugfix for grad_op output var's mesh

* Modify codes for pr 36744

* Remove unnecessary comments in framework.py

* Remove unnecessary comments in completion.py
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: NJZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com>
上级 2e40cfb5
......@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401
from .fleet import BoxPSDataset # noqa: F401
......
......@@ -14,10 +14,11 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
此差异已折叠。
......@@ -131,7 +131,7 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.int64:
self.dtype_factor *= 8
else:
raise NotImplementedError("{} not counted".format(v.node.dtype))
raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None
if batch_size is not None:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid.framework import Variable
from .process_mesh import ProcessMesh
_g_tensor_dist_attr_field_keys = [
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]
_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_input_suffix = "@input"
_g_op_output_suffix = "@output"
def get_tensor_dist_attr_field_keys():
global _g_tensor_dist_attr_field_keys
return _g_tensor_dist_attr_field_keys
def get_op_dist_attr_field_keys():
global _g_op_dist_attr_field_keys
return _g_op_dist_attr_field_keys
def append_op_input_suffix(name):
global _g_op_input_suffix
return name + _g_op_input_suffix
def append_op_output_suffix(name):
global _g_op_output_suffix
return name + _g_op_output_suffix
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
self._shard_sizes = None
self._device_placement = None
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
@property
def dims_mapping(self):
return self._dims_mapping
@dims_mapping.setter
def dims_mapping(self, dims_mapping):
if dims_mapping is not None:
assert isinstance(dims_mapping, list), \
"The type of dims_mapping must be list."
assert all(isinstance(x, int) for x in dims_mapping), \
("All elements of dims_mapping must be integer")
assert all(x >= -1 for x in dims_mapping), \
("All elements of dims_mapping must be greater than or equal to -1.")
self._dims_mapping = copy.deepcopy(dims_mapping)
@property
def shard_sizes(self):
return self._shard_sizes
@shard_sizes.setter
def shard_sizes(self, shard_sizes):
if shard_sizes is not None:
self._shard_sizes = copy.deepcopy(shard_sizes)
@property
def device_placement(self):
return self._device_placement
@device_placement.setter
def device_placement(self, device_placement):
if device_placement is not None:
self._device_placement = copy.deepcopy(device_placement)
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key in dist_attr.keys():
if key in get_tensor_dist_attr_field_keys():
self.mark_annotated(key)
elif isinstance(dist_attr, TensorDistributedAttribute):
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def clear_annotated(self):
self._is_annotated.clear()
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tdims_mapping ({}): {}".format(annotated_str,
self.dims_mapping)
str += "\n\t}"
return str
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
@property
def impl_type(self):
return self._impl_type
@impl_type.setter
def impl_type(self, impl_type):
if impl_type is not None:
self._impl_type = impl_type
@property
def impl_idx(self):
return self._impl_idx
@impl_idx.setter
def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx
@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
@property
def outputs_dist_attrs(self):
return self._outputs_dist_attrs
def get_input_dist_attr(self, name):
return self._inputs_dist_attrs.get(name, None)
def set_input_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
def set_output_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
dims_mapping = input_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_input_dims_mapping(self, name, dims_mapping):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
input_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._inputs_dist_attrs[name] = dist_attr
def get_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
dims_mapping = output_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_output_dims_mapping(self, name, dims_mapping):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
output_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._outputs_dist_attrs[name] = dist_attr
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.init(value)
if dist_attr.get(append_op_input_suffix(key.name), False):
self.set_input_dist_attr(key.name, tensor_dist_attr)
if dist_attr.get(append_op_output_suffix(key.name), False):
self.set_output_dist_attr(key.name, tensor_dist_attr)
else:
if key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, OperatorDistributedAttribute):
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
self.set_input_dist_attr(
tensor_name, dist_attr.get_input_dist_attr(tensor_name))
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
self.set_output_dist_attr(
tensor_name, dist_attr.get_output_dist_attr(tensor_name))
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
for key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
for tensor_dist_attr in self.outputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
shared_process_mesh = None
for process_mesh in process_meshes:
if process_mesh is not None:
if shared_process_mesh is None:
shared_process_mesh = process_mesh
else:
assert process_mesh == shared_process_mesh, \
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
self._is_annotated[attr_name] = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
else:
self._is_annotated[attr_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
input_dist_attr = self.get_input_dist_attr(key.name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(value)
output_dist_attr = self.get_output_dist_attr(key.name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(value)
else:
if key in get_op_dist_attr_field_keys():
self.mark_annotated(key)
process_mesh_annotated = False
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.outputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
elif isinstance(dist_attr, OperatorDistributedAttribute):
process_mesh_annotated = False
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
input_dist_attr = self.get_input_dist_attr(tensor_name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(tensor_dist_attr)
if input_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
output_dist_attr = self.get_output_dist_attr(tensor_name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(tensor_dist_attr)
if output_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
def clear_annotated(self):
self._is_annotated.clear()
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
def is_annotated_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
return input_dist_attr.is_annotated("dims_mapping")
else:
return False
def is_annotated_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
return output_dist_attr.is_annotated("dims_mapping")
else:
return False
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\timpl type: {}, ".format(self._impl_type)
str += "impl idx: {}".format(self._impl_idx)
str += "\n\t}"
return str
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
def get_default_distributed_context():
global _g_default_distributed_context
if _g_default_distributed_context is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return _g_default_distributed_context
def set_default_distributed_context(dist_context):
global _g_default_distributed_context
_g_default_distributed_context = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self, program=None):
self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
@property
def serial_program(self):
return self._serial_program
@property
def serial_graph(self):
return self._serial_graph
@serial_program.setter
def serial_program(self, program):
assert self._serial_program is None, \
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def process_meshes(self):
return self._process_meshes
@property
def dist_op_context(self):
return self._dist_op_context
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
if process_mesh not in self.process_meshes:
self._process_meshes.append(process_mesh)
def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None)
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None)
def get_dist_op_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
# Copy the dist tensors and dist ops annotated by users from the default context
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program(
tensor)
if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor)
self.add_dist_tensor_for_program(dist_tensor)
for op in block.ops:
# Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op)
current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id,
None)
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
dist_op = self._dist_ops_for_program.get(op_id, None)
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr)
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear()
self._dist_ops_for_program.clear()
def clear_dist_info_for_graph(self):
self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear()
def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
node)
dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program:
assert False, \
"Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks:
for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
return True
class DistributedOperatorContext:
"""
DistributedOperatorContext is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
self._serial_outputs = {}
self._dist_attr = None
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_op(self):
return self._serial_op
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = OperatorDistributedAttribute()
# Create new dist_attr related to current serial_op
dist_attr = self._filter_dist_attr(dist_attr)
# Append suffix to mark the inputs or outputs
if isinstance(dist_attr, dict):
# Copy the keys since we may add new ones
for key in list(dist_attr.keys()):
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names:
dist_attr[append_op_input_suffix(key.name)] = True
if key.name in self._serial_op.output_arg_names:
dist_attr[append_op_output_suffix(key.name)] = True
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def get_serial_input(self, name):
return self._serial_inputs.get(name, None)
def get_serial_output(self, name):
return self._serial_outputs.get(name, None)
def _init_default_dist_attr(self):
for tensor_name in self._serial_op.input_arg_names:
if self._serial_op.type == "create_py_reader":
tensor = None
else:
tensor = self._serial_op.block._var_recursive(tensor_name)
self._serial_inputs[tensor_name] = tensor
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
return None
new_dist_attr = None
if isinstance(dist_attr, dict):
new_dist_attr = {}
for key, value in dist_attr.items():
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \
or key.name in self._serial_op.output_arg_names:
new_dist_attr[key] = value
else:
new_dist_attr[key] = value
elif isinstance(dist_attr, OperatorDistributedAttribute):
new_dist_attr = copy.deepcopy(dist_attr)
new_dist_attr._inputs_dist_attrs.clear()
new_dist_attr._outputs_dist_attrs.clear()
for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name,
tensor_dist_attr)
for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name,
tensor_dist_attr)
else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type:
return True
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
return False
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(),
self.serial_op.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_input(arg_name) is not None:
if self.get_serial_input(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_output(arg_name) is not None:
if self.get_serial_output(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
return str
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops)
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys
class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None):
self._serial_tensor = serial_tensor
self._dist_attr = None
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_tensor(self):
return self._serial_tensor
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = TensorDistributedAttribute()
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
return False
for i in range(len(self.dist_attr.dims_mapping)):
if self.dist_attr.dims_mapping[
i] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
str += ", is_parameter: {}".format(self.serial_tensor.is_parameter)
if self.dist_attr.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.dist_attr.dims_mapping)
if self.dist_attr.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str, None)
if self.dist_attr.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str, None)
return str
......@@ -18,293 +18,34 @@ import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode
__all__ = []
# a map from ProcessMesh ids to the ProcessMesh instances
_g_process_mesh_map = dict()
# user defined map from logical process ids to physical ones
_user_defined_physical_map = None
def _append_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def _remove_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
def _static_mode_check():
if in_dygraph_mode():
raise RuntimeError("Auto-parallel only supports static mode, "
"please use paddle.enable_static().")
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
parent (ProcessMesh, optional): the parent ProcessMesh. None means
the ProcessMesh is the root one without parent ProcessMesh.
Default: None.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.parent is None
assert mesh.topology == [2, 3]
assert mesh.process_group == [2, 4, 5, 0, 1, 3]
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
def __init__(self, mesh, parent=None):
_static_mode_check()
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
self._topology = _get_nested_list_shape(mesh)
self._processes = _flatten_nested_list(mesh)
# Every element of mesh must be >= 0.
assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.')
unique_ids = set(self._processes)
assert len(unique_ids) == len(self._processes), (
'All elements of mesh must be unique.')
if parent is None:
# For root ProcessMesh, the ids of logical processes must be range
# from 0 to N-1, where N is the number of logical processes.
assert max(self._processes) == len(self._processes) - 1, (
'For root ProcessMesh, ids of logical processes must be range '
'from 0 to N-1, where N is the number of logical processes.')
parent_id = core.kNoneProcessMeshIndex()
assert len(_g_process_mesh_map.keys()) == 0, (
'The first ProcessMesh must be the root, which has no parent.')
else:
assert len(_g_process_mesh_map.keys()) > 0, (
'All ProcessMesh must have a parent except the root one.')
assert isinstance(parent, ProcessMesh), (
'parent must be an instance of ProcessMesh.')
parent_id = parent._desc.id
# All elements in mesh must belong to its parent
parent_ids = set(parent.process_group)
assert unique_ids <= parent_ids, (
'All elements in mesh must belong to its parent.')
self._desc = core.ProcessMeshDesc(self._topology, self._processes,
parent_id)
self._id = self._desc.id
self._parent_id = parent_id
assert self._id not in _g_process_mesh_map, (
"The ProcessMesh with id %d already exists." % self._id)
_g_process_mesh_map[self._id] = self
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def process_group(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def parent(self):
r"""
Get the parent ProcessMesh.
"""
if self._parent_id == core.kNoneProcessMeshIndex(): return None
assert self._parent_id in _g_process_mesh_map, (
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
user defined order.
Args:
order (list): order of the physical process ids.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
assert self.parent is None, (
"This function can only be called by the root ProcessMesh.")
unique_ids = set(order)
assert isinstance(order, list)
assert len(unique_ids) == len(order), (
"All elements in order must be unique.")
assert min(order) == 0
assert max(order) == len(order) - 1, (
"All elements in order must be from 0 to N - 1, where N "
"is the number of physical processes.")
logical_order = self.process_group
global _user_defined_physical_map
assert _user_defined_physical_map is None, (
"This function can only be called once.")
_user_defined_physical_map = dict()
assert len(logical_order) == len(order)
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert isinstance(mesh,
ProcessMesh), 'The type of mesh must be ProcessMesh.'
assert isinstance(dim_mapping,
list), 'The type of dim_mapping must be list.'
assert len(tensor.shape) == len(dim_mapping), (
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.')
mesh_dim = len(mesh.topology)
dim_set = set()
for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (
dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), (
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.')
if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i])
def shard_tensor(x, mesh, dim_mapping):
def shard_tensor(x, dist_attr=None):
"""
Add distributed attributes for a tensors.
Args:
x (Tensor): the tensor to process.
mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes.
dim_mapping (list): a list to describe the mapping between `x` and `mesh`,
the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means
without parition along the corresponding dimension.
x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
"process_mesh": a nested list an to describe the mesh topology of logical processes.
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
Returns:
Tensor: the tensor `x` itself.
Tensor: the tensor `x` annotated with distributed attributes.
Examples:
.. code-block:: python
......@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping):
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [0, -1])
"""
_static_mode_check()
_dim_mapping_checker(x, mesh, dim_mapping)
attr_name = _append_attr_suffix('mesh_id')
x._set_attr(attr_name, mesh._id)
attr_name = _append_attr_suffix('dim_mapping')
x._set_attr(attr_name, dim_mapping)
return x
def set_shard_mask(x, mask):
"""
Set the mask for a tensor which mask out the tensor from some processes in its mesh.
Args:
x (Tensor): the tensor to process.
mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to
the tensor `x`. Every value of `mask` must be one or zero, where one means
the tenor `x` will be put on the corresponding logical process and zero means the tensor `x`
will not be put on the corresponding logical process.
For example, for a ProcessMesh represented by the 2-dimensional
array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the
2-dimensional [[1, 0, 1], [0, 1, 0]],
then the tensor `x` will only be put on logical processes 2, 5 and 1.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1])
dist.set_shard_mask(x, mask)
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
"dims_mapping": [0, -1]})
"""
_static_mode_check()
assert isinstance(mask, list)
np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask)
mesh_attr_name = _append_attr_suffix('mesh_id')
assert x._has_attr(mesh_attr_name), \
"Please set process mesh for the variable firstly."
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be None, dict or TensorDistributedAttribute."
dist_tensor = DistributedTensor(x, dist_attr)
dist_tensor.dist_attr.mark_annotated_as(dist_attr)
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
return x
def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
def shard_op(op_fn, dist_attr=None):
"""
Call a functioin and add distributed attributes for ops added by the function.
Args:
op_fn (callable): a callable object of an API.
mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes.
dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping.
The dim_mapping is a list to describe the mapping between a tensor and `mesh`,
the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`,
where -1 means without parition along the corresponding dimension.
kwargs (dict): a dict of parameter passed to the function `op_fn`.
op_fn (callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
two categories. The first category decsribes the distributed attributes shared by all inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
optional and users can specify them as need. Note that `process_mesh` for operators must be the
same as these process_meshes for inputs and outputs.
Returns:
list: the outputs of the function `op_fn`.
list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Examples:
.. code-block:: python
......@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
kwargs = {'x': x, 'y': y}
dist.shard_op(paddle.add, mesh, None, **kwargs)
"""
_static_mode_check()
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = op_fn(**kwargs)
new_op_size = len(main_block.ops)
if dim_mapping_dict is None:
dim_mapping_dict = dict()
else:
assert isinstance(dim_mapping_dict,
dict), 'The type of dim_mapping_dict must be dict.'
for var_name in dim_mapping_dict.keys():
dim_mapping = dim_mapping_dict[var_name]
tensor = main_block.var(var_name)
_dim_mapping_checker(tensor, mesh, dim_mapping)
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id')
op._set_attr(attr_name, mesh._id)
for var_name in dim_mapping_dict.keys():
assert var_name in op.output_arg_names + op.input_arg_names
attr_name = _append_attr_suffix(var_name)
if var_name in op.input_arg_names:
# we use the prefix "IN_" to indicates an input argument name
attr_name = "IN_" + attr_name
else:
# we use the prefix "OUT_" to indicates an input argument name
attr_name = "OUT_" + attr_name
op._set_attr(attr_name, dim_mapping_dict[var_name])
if isinstance(output, Variable):
output = [output]
return list(output)
def set_offload_device(x, device):
"""
Set the device that the tensor `x` will be put on.
Args:
x (tensor): the tensor to process.
device (str): the device that the tensor `x` will be put on, e.g., 'cpu'.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
x = paddle.ones([4, 6])
dist.set_offload_device(x, 'cpu')
"""
_static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
def set_pipeline_stage(stage):
"""
Set the pipeline stage of the following ops.
Args:
stage (int): the pipeline stage the following ops belonging to.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
dist.set_pipeline_stage(0)
dist_add = dist.shard_op(paddle.add,
dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
x: {"dims_mapping": [-1, 0]},
y: {"dims_mapping": [0, -1]}
})
dist_add(x, y)
"""
from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check()
assert isinstance(stage, int), 'The type of stage must be int.'
_set_pipeline_stage(stage)
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License
DISTRIBUTED_OPERATORS = {}
_g_distributed_operator_impl_registries = {}
class DistributedOperator:
class DistributedOperatorImplContainer:
def __init__(self):
self._impls = []
self._name = None
......@@ -47,67 +47,60 @@ class DistributedOperatorImpl:
def get_name(self):
return self._name
def is_process_mesh_compatible(self, op_dist_attr):
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def is_compatible(self, dist_op):
return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op):
global DISTRIBUTED_OPERATORS
DISTRIBUTED_OPERATORS[name] = dist_op
def register_distributed_operator_impl_container(name, dist_op_impl_container):
global _g_distributed_operator_impl_registries
_g_distributed_operator_impl_registries[name] = dist_op_impl_container
def get_distributed_operator(name):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS.get(name, None)
def get_distributed_operator_impl_container(name):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries.get(name, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name)
if dist_op is not None:
dist_op.register_impl(dist_impl)
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is not None:
dist_op_impl_container.register_impl(dist_impl)
else:
assert False, "Must register distributed operator first."
assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx)
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
fwd=True):
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op = get_distributed_operator(name)
if dist_op is None:
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is None:
return None, -1
compatible_impls = []
impls = dist_op.get_impls()
impls = dist_op_impl_container.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_input_compatible(op_dist_attr):
if impl.is_input_compatible(dist_op):
compatible_impls.append((impl, idx))
else:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_output_compatible(op_dist_attr):
if impl.is_output_compatible(dist_op):
compatible_impls.append((impl, idx))
if compatible_impls:
......@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
return best_compatible_impl, idx
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var):
"""
copy src var's dist_attr to dst var
"""
import copy
# def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var):
# """
# copy src var's dist_attr to dst var
# """
# import copy
auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# dist_attr = copy.deepcopy(
# auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
# dist_attr._owner_tensor = var
# dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
# src_var)._owner_context
# auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
"""
copy src var's dist_attr to dst var
"""
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
# """
# copy src op's dist_attr to dst dist op
# """
# from ..attribute import OperatorDistributedAttribute
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
# op_dist_attr)
# auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
# op_dist_attr)
# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
# for input_varname in dist_op.desc.input_arg_names():
# input_var = dst_block.var(input_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# input_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
# for output_varname in dist_op.desc.output_arg_names():
# output_var = dst_block.var(output_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# output_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_output_dims_mapping(output_varname,
# tensor_dims_mapping)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr):
"""
copy src op's dist_attr to dst dist op
"""
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
# need check dist op attr and its inputs and outputs
auto_paralle_context = src_op_dist_attr.get_owner_context()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping)
op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -22,26 +22,27 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperator):
class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedDefault, self).__init__()
self._name = name
register_distributed_operator("default", DistributedDefault("default"))
register_distributed_operator_impl_container("default",
DistributedDefault("default"))
# Replicated Default
# Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__()
......@@ -49,29 +50,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
varname_mapping = dist_op_helper.get_varname_mapping()
rank_id = dist_op_helper.get_rank_id()
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
......@@ -100,26 +98,26 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var(
varname
).is_parameter and varname not in dist_op_helper.already_init_sync_vars:
dist_op_helper.already_init_sync_vars.add(varname)
).is_parameter and varname not in dist_op_context.already_init_sync_vars:
dist_op_context.already_init_sync_vars.add(varname)
param = startup_block.var(varname)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(
param)
process_mesh = param_dist_attr.get_process_mesh()
dims_mapping = param_dist_attr.get_dims_mapping()
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dims_mapping = param_dist_attr.dims_mapping
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
# NOTE all not splited axis should be presented in mesh
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dims_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
axis, rank_id)
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
axis, rank_id)
sync_group = new_process_group(group_ranks)
new_op = startup_block.append_op(
......@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
})
# set distributed attribute
op_attr = OperatorDistributedAttribute(new_op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(param.name,
dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(new_op, op_attr)
ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
......@@ -147,16 +145,16 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
rank_id = dist_op_helper.get_rank_id()
rank_id = dist_op_context.get_rank_id()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
need_gradient_allreduce = False
for input_name in backward_op.desc.input_names():
......@@ -165,20 +163,21 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname).is_parameter:
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.get_process_mesh()
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
batch_size_axis, rank_id)
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
break
......@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward
})
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp()
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
......@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperator):
class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
register_distributed_operator("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator("c_embedding", DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding"))
# RowParallel
......@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
......@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension
......@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
......@@ -111,16 +111,16 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# check validation of inputs / outputs
# check validation of inputs / outputs
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
assert 'W' in kwargs, "input [{}] is not given".format('W')
assert 'Out' in kwargs, "output [{}] is not given".format('Out')
......@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# A generalized method to caculate embedding offset using cartisian product
......@@ -162,7 +162,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
per_part_size = Weight_var.shape[0]
relative_idx = relative_idx * per_part_size
# TODO caculate ring id
# TODO caculate ring id
parallel_axis = embedding_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
parallel_axis, rank_id)
......@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
check_variable_and_dtype(
Out_var, 'tensor',
......@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr)
# param initialization sync
assert Weight_var.name not in dist_op_helper.already_init_sync_vars
dist_op_helper.already_init_sync_vars.add(Weight_var.name)
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh()
dim_mapping = param_dist_attr.get_dims_mapping()
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)
......@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(),
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id)
# check if need gradient allreduce
......@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0])
process_mesh = dist_attr.get_process_mesh()
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
......@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
W_Grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
ctx.set_op_dist_attr_for_program(op, op_attr)
register_distributed_operator_impl("lookup_table_v2",
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator):
class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
......@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......@@ -103,15 +103,15 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# check validation of inputs / outputs
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
......@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......@@ -231,15 +230,15 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# check validation of inputs / outputs
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
assert input_name in kwargs, "input [{}] is not given".format(
input_name)
......@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator):
class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax"))
register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
......@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator):
class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2"))
register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
......@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
def is_input_compatible(self, dist_op):
return True
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
return True
def is_output_compatible(self, op_dist_attr):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......
......@@ -15,11 +15,11 @@
import paddle
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .context import DistributedContext
from .context import get_default_distributed_context
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process import get_all_process_groups
from .process_group import get_all_process_groups
from .utils import make_data_unshard
from .reshard import reshard
......@@ -70,7 +70,6 @@ class AutoParallelizer:
# Annotation completion
completed_main_program = complete_annotation(
self._original_main_program, self._dist_context)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
......
......@@ -19,62 +19,32 @@ from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
_g_process_group_map = {}
def get_all_process_groups():
global PROCESS_GROUP_MAP
return PROCESS_GROUP_MAP.values()
global _g_process_group_map
return _g_process_group_map.values()
def new_process_group(ranks):
global PROCESS_GROUP_MAP
if not PROCESS_GROUP_MAP:
global _g_process_group_map
if not _g_process_group_map:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP:
num_groups = len(PROCESS_GROUP_MAP)
if key not in _g_process_group_map:
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg
_g_process_group_map[key] = pg
return pg
else:
pg = PROCESS_GROUP_MAP[key]
pg = _g_process_group_map[key]
return pg
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册