未验证 提交 a02532b5 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the interface and the underlying mechanisms (#36617)

* default dist op

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2

* unify api

* unify api

* todo

* update distop forward func

* update distop forward func

* auto parallel backward

* update dist op

* autoparallel backward

* add backward for embedding

* temp1

* temp2

* temp3

* temp4

* backward done1

* backward done2

* backward done3

* dist embedding remove mp mode

* dist matmul remove mp mode

* update dist embedding
『

* dist op init1

* dist op init 2

* update unitest

* context remove parallel mode

* partitioner remove parallel mode

* update unitest

* a more general method to support varying mesh in pipeline parallel

* support varying mesh in pipeline parallel

* embedding support varying mesh in pipeline parallel

* matmul support varying mesh in pipeline parallel

* default dist op support varying mesh in pipeline parallel

* dist attribute for startup program

* default dist op support varying mesh in pipeline parallel 2

* partitoner support varying mesh in pipeline parallel

* revise logic for auto compeletion

* revise framework.py

* revise reshard unitest

* revise unitest for parallelize

* chmod

* fixed bug for dist embedding name mapping

* Improve the interface and the underlying mechanisms of auto parallel

* revise completion for backward

* revise completion for update

* revise completion for update

* update unitest

* chmod

* bugfix for grad_op output var's mesh

* Modify codes for pr 36744

* Remove unnecessary comments in framework.py

* Remove unnecessary comments in completion.py
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: NJZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com>
上级 2e40cfb5
......@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401
from .fleet import BoxPSDataset # noqa: F401
......
......@@ -14,10 +14,11 @@
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401
from .process_mesh import ProcessMesh
# from .interface import set_shard_mask # noqa: F401
# from .interface import set_offload_device # noqa: F401
# from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
......@@ -20,10 +20,13 @@ from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_distributed_attr
from .context import get_default_distributed_context
from .utils import print_program_with_dist_attr
from .operators import find_best_compatible_distributed_operator_impl
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
......@@ -43,36 +46,35 @@ def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
process meshes are compatible for now.
"""
changed = False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
return changed
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
tensor_process_mesh = tensor_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
op_process_mesh = op_dist_attr.process_mesh
inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
op_process_mesh = op_dist_attr.process_mesh
outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
......@@ -84,43 +86,47 @@ def update_op_node_process_mesh(dist_context, op_node, fwd=True):
process meshes are compatible for now.
"""
changed = False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"):
return changed
op_process_mesh = op_dist_attr.get_process_mesh()
op_process_mesh = op_dist_attr.process_mesh
if fwd:
inputs_process_meshes = []
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
tensor_process_mesh = tensor_dist_attr.process_mesh
inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
tensor_process_mesh = tensor_dist_attr.process_mesh
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
op_dist_attr.process_mesh = compatible_process_mesh
changed = True
return changed
def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
......@@ -130,7 +136,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
......@@ -140,7 +147,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
......@@ -164,14 +172,16 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
......@@ -186,10 +196,13 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
return changed
def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr):
def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting)."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
......@@ -258,12 +271,11 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
# Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER:
return False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"):
return False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_dims_mapping = tensor_dist_attr.dims_mapping
if fwd:
dims_mapping_list = []
for pred_op_node in tensor_node.inputs:
......@@ -272,7 +284,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
......@@ -282,7 +294,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
else:
dims_mapping_list = []
......@@ -292,7 +304,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read":
continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
......@@ -302,7 +314,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True
return changed
......@@ -317,7 +329,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read":
return False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
......@@ -327,9 +340,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
......@@ -341,26 +354,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=True)
op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
......@@ -370,9 +386,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
......@@ -384,26 +400,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=False)
op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
return changed
......@@ -421,18 +440,20 @@ def complete_annotation(program, dist_context=None):
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
dist_context.serial_program = program
else:
dist_context.serial_program = program
# Initialize distributed attributes for all var and op node in program
dist_context.initialize_distributed_attr_for_program(program)
# print_program_with_dist_attr(program, dist_context)
# Convert program to graph
graph = framework.IrGraph(core.Graph(program.desc))
# Initialize distributed attributes for all var and op node in program
dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph)
dist_context.init_dist_attr_for_graph()
# Complete process mesh for each node
all_nodes = list(graph.all_nodes())
all_nodes = list(dist_context.serial_graph.all_nodes())
def sort_key_fun(node):
first = -1
......@@ -498,27 +519,27 @@ def complete_annotation(program, dist_context=None):
is_wrong = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
node)
if tensor_dist_attr.get_process_mesh() is None:
if tensor_dist_attr.process_mesh is None:
msg_str = ""
for op_node in node.inputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
for op_node in node.outputs:
if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node)
msg_str += "{} [{}], ".format(
op_node.op().type(),
op_dist_attr.get_process_mesh())
op_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(op_node.name(),
None)
......@@ -527,27 +548,26 @@ def complete_annotation(program, dist_context=None):
is_wrong = True
print(msg_str)
if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
node)
if op_dist_attr.get_process_mesh() is None:
op_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
if op_dist_attr.process_mesh is None:
msg_str = ""
for tensor_node in node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
tensor_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
for tensor_node in node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
msg_str += "{} [{}], ".format(
tensor_node.var().name(),
tensor_dist_attr.get_process_mesh())
tensor_dist_attr.process_mesh)
else:
msg_str += "{} [{}], ".format(
tensor_node.name(), None)
......@@ -592,11 +612,14 @@ def complete_annotation(program, dist_context=None):
reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program
dist_context.copy_distribute_attr_from_graph_to_program(graph, program)
dist_context.clear_distributed_attr_for_graph()
dist_context.copy_dist_attr_from_graph_to_program()
dist_context.clear_dist_info_for_graph()
# Do the validation check and amend some completion
dist_context.amend_distributed_attr_for_program()
dist_context.amend_dist_attr_for_program()
# print_program_with_dist_attr(program, dist_context)
dist_context.validate_dist_attr_for_program()
return program
......@@ -636,7 +659,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars
dist_op_helper = dist_context.get_dist_op_helper()
dist_op_context = dist_context.dist_op_context
for idx in range(first_backward_op_idx, len(ops)):
......@@ -658,45 +681,42 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
forward_var = vars[forward_var_name]
# TODO complete other attribte for grad var
tensor_attr = TensorDistributedAttribute(grad_var, dist_context)
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
tensor_attr.set_dims_mapping(dims_mapping)
tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(grad_var,
tensor_attr)
op_attr = OperatorDistributedAttribute(ops[idx], dist_context)
op_attr.set_process_mesh(process_mesh)
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr)
tensor_dist_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(grad_var,
tensor_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr)
continue
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in gradopidx2opidx
grad_op = ops[idx]
if grad_op.desc.id() in dist_op_helper.gradopidx2opidx:
if grad_op.desc.id() in dist_op_context.gradopidx2opidx:
# TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(
ops[:first_backward_op_idx],
dist_op_helper.gradopidx2opidx[grad_op.desc.id()])
dist_op_context.gradopidx2opidx[grad_op.desc.id()])
assert forward_op is not None
# op dist attr
forward_op_attr = dist_context.get_op_distributed_attr_for_program(
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
forward_op_process_mesh = forward_op_attr.get_process_mesh()
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(forward_op_process_mesh)
forward_op_process_mesh = forward_op_dist_attr.process_mesh
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = forward_op_process_mesh
# var
for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1]
# if grad_op.type == "cast":
# input_name = "X"
# else:
if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname(
output_name)
......@@ -711,39 +731,38 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
if len(grad_op.desc.output(output_name)) == 1:
assert len(forward_op.desc.input(input_name)) == 1
input_var = vars[forward_op.desc.input(input_name)[0]]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
ref_dims_mapping = input_var_dist_attr.dims_mapping
# tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]]
output_var_attr = TensorDistributedAttribute(output_var,
dist_context)
output_var_attr.set_dims_mapping(ref_dims_mapping)
output_var_attr.set_process_mesh(forward_op_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
output_var, output_var_attr)
output_var_dist_attr = TensorDistributedAttribute()
output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_dist_attr.process_mesh = forward_op_process_mesh
dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr)
# op dist attr
grad_op_attr.set_output_dims_mapping(output_var.name,
grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping)
for input_name in grad_op.input_arg_names:
input_var = vars[input_name]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping()
ref_dims_mapping = input_var_dist_attr.dims_mapping
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name)
grad_op_attr.set_input_dims_mapping(input_name,
grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
else:
......@@ -755,32 +774,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_dims_mapping()
ref_forward_var_process_mesh = dist_context.get_tensor_distributed_attr_for_program(
forward_var).get_process_mesh()
ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).dims_mapping
ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).process_mesh
# output
tensor_attr = TensorDistributedAttribute(
vars[grad_op.output_arg_names[0]], dist_context)
tensor_attr.set_dims_mapping(ref_forward_var_dims_mapping)
tensor_attr.set_process_mesh(ref_forward_var_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_attr)
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
dist_context.set_tensor_dist_attr_for_program(
vars[grad_op.output_arg_names[0]], tensor_dist_attr)
# op
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context)
grad_op_attr.set_process_mesh(ref_forward_var_process_mesh)
grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name
grad_op_attr.set_input_dims_mapping(
grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping)
grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0],
ref_forward_var_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op,
grad_op_attr)
grad_op_dist_attr.set_output_dims_mapping(
grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_dist_attr)
def complete_update_annotation(auto_parallel_main_prog, dist_context):
......@@ -808,39 +826,40 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]]
param_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
param_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param)
grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
grad_var)
assert param_dist_attr is not None
assert grad_dist_attr is not None
assert param_dist_attr.get_dims_mapping(
) == grad_dist_attr.get_dims_mapping()
assert param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping
ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program(
param).get_process_mesh()
ref_process_mesh = dist_context.get_tensor_dist_attr_for_program(
param).process_mesh
assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program(
param).get_dims_mapping()
ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
param).dims_mapping
assert ref_dims_mapping is not None
op_attr = OperatorDistributedAttribute(op, dist_context)
op_attr.set_process_mesh(ref_process_mesh)
op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping)
op_attr.set_input_dims_mapping(param.name, ref_dims_mapping)
op_attr.set_output_dims_mapping(param.name, ref_dims_mapping)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dims_mapping(grad_var.name,
ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(param.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(param.name,
ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]]
op_attr.set_input_dims_mapping(learning_var.name, [-1])
op_attr.set_output_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name, [-1])
if not learning_rate_completed:
learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute(learning_var,
dist_context)
var_dist_attr.set_process_mesh(ref_process_mesh)
var_dist_attr.set_dims_mapping([-1])
dist_context.set_tensor_distributed_attr_for_program(
learning_var, var_dist_attr)
var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(learning_var,
var_dist_attr)
for input_name in op.desc.input_names():
......@@ -853,24 +872,25 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute(input_var,
dist_context)
input_var_attr = TensorDistributedAttribute()
if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.set_dims_mapping([-1])
op_attr.set_input_dims_mapping(input_var.name, [-1])
op_attr.set_output_dims_mapping(input_var.name, [-1])
input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(input_var.name,
[-1])
op_dist_attr.set_output_dims_mapping(input_var.name,
[-1])
else:
assert "Moment" in input_name
input_var_attr.set_dims_mapping(ref_dims_mapping)
op_attr.set_input_dims_mapping(input_var.name,
input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping)
op_attr.set_output_dims_mapping(input_var.name,
op_dist_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping)
input_var_attr.set_process_mesh(ref_process_mesh)
dist_context.set_tensor_distributed_attr_for_program(
input_var_attr.process_mesh = ref_process_mesh
dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr)
dist_context.set_op_distributed_attr_for_program(op, op_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
continue
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
def get_default_distributed_context():
global DEFAULT_DISTRIBUTED_CONTEXT
if DEFAULT_DISTRIBUTED_CONTEXT is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return DEFAULT_DISTRIBUTED_CONTEXT
def set_default_distributed_context(dist_context):
global DEFAULT_DISTRIBUTED_CONTEXT
DEFAULT_DISTRIBUTED_CONTEXT = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self):
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._tensor_distributed_attr_map_for_program = {}
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
self._get_dist_op_helper = DistOpHelper()
self._process_mesh = _g_process_mesh_map.get(0, None)
def is_initialized_for_program(self):
return self._is_initialized_for_program
def is_initialized_for_graph(self):
return self._is_initialized_for_graph
def get_tensor_distributed_attr_for_program(self, tensor):
tensor_id = tensor.desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get(
tensor_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr):
tensor_id = tensor.desc.id()
self._tensor_distributed_attr_map_for_program[
tensor_id] = tensor_dist_attr
def get_op_distributed_attr_for_program(self, op):
op_id = op.desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_program(self, op, op_dist_attr):
op_id = op.desc.id()
self._op_distributed_attr_map_for_program[op_id] = op_dist_attr
def get_tensor_distributed_attr_for_graph(self, tensor_node):
tensor_node_id = tensor_node.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get(
tensor_node_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_graph(self, tensor_node,
tensor_dist_attr):
tensor_node_id = tensor_node.id()
self._tensor_distributed_attr_map_for_graph[
tensor_node_id] = tensor_dist_attr
def get_op_distributed_attr_for_graph(self, op_node):
op_node_id = op_node.id()
op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
def get_dist_op_helper(self):
return self._get_dist_op_helper
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
for block in program.blocks:
for tensor in block.vars.values():
# Since only tensors have distributed attributes, it's better to make sure var is a tensor
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is None:
tensor_dist_attr = TensorDistributedAttribute(tensor, self)
self._copy_distributed_attr_from_tensor_desc(
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
if tensor.type == core.VarDesc.VarType.READER:
tensor_dist_attr.set_shape([])
else:
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor_dist_attr.get_shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
tensor_dist_attr.mark_as_annotated("dims_mapping")
if isinstance(tensor, framework.Parameter):
tensor_dist_attr.mark_as_parameter()
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
op_dist_attr = OperatorDistributedAttribute(op, self)
self._copy_distributed_attr_from_op_desc(op.desc,
op_dist_attr)
self.set_op_distributed_attr_for_program(op, op_dist_attr)
# Default distributed implementation for all operators
# This will be updated during the completion prcess
op_dist_attr.set_impl_idx(-2)
if op_dist_attr.get_process_mesh() is not None:
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
if op.type == "create_py_reader" \
or tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_input_shape(tensor_name, [])
else:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1
for _ in range(
len(op_dist_attr.get_input_shape(tensor_name)))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_input_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_output_shape(tensor_name, [])
else:
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1
for _ in range(
len(
op_dist_attr.get_output_shape(tensor_name)))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_output_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
self._is_initialized_for_program = True
def finalize_distributed_attr_for_program(self, program):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before finalization."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is not None:
self._store_distributed_attr_to_tensor_desc(
tensor.desc, tensor_dist_attr)
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
self._store_distributed_attr_to_op_desc(op.desc,
op_dist_attr)
def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
attr_name = append_distributed_attr_suffix("dim_mapping")
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_dims_mapping(copied_dims_mapping)
attr_name = append_distributed_attr_suffix("mask")
if desc.has_attr(attr_name):
shard_mask = desc.attr(attr_name)
copied_shard_mask = copy.deepcopy(shard_mask)
dist_attr.set_shard_mask(copied_shard_mask)
attr_name = append_distributed_attr_suffix("offload_device")
if desc.has_attr(attr_name):
offload_device = desc.attr(attr_name)
copied_offload_device = copy.deepcopy(offload_device)
dist_attr.set_offload_device(copied_offload_device)
def _copy_distributed_attr_from_op_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
attr_name = append_distributed_attr_suffix("pipeline_stage")
if desc.has_attr(attr_name):
pipeline_stage = desc.attr(attr_name)
copied_pipeline_stage = copy.deepcopy(pipeline_stage)
dist_attr.set_pipeline_stage(copied_pipeline_stage)
def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
dims_mapping = dist_attr.get_dims_mapping()
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("dim_mapping")
desc._set_attr(attr_name, dims_mapping)
shard_mask = dist_attr.get_shard_mask()
if shard_mask is not None:
attr_name = append_distributed_attr_suffix("mask")
desc._set_attr(attr_name, shard_mask)
offload_device = dist_attr.get_offload_device()
if offload_device is not None:
attr_name = append_distributed_attr_suffix("offload_device")
desc._set_attr(attr_name, offload_device)
def _store_distributed_attr_to_op_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
for tensor_name in desc.input_arg_names():
dims_mapping = dist_attr.get_input_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
for tensor_name in desc.output_arg_names():
dims_mapping = dist_attr.get_output_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
pipeline_stage = dist_attr.get_pipeline_stage()
if pipeline_stage is not None:
attr_name = append_distributed_attr_suffix("pipeline_stage")
desc._set_attr(attr_name, pipeline_stage)
def initialize_distributed_attr_for_graph(self, graph):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before its graph."
if self._is_initialized_for_graph:
return
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program[
tensor_id]
assert tensor_dist_attr is not None, \
"Tensor must have a distributed attribute after the initialization for program."
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self.set_tensor_distributed_attr_for_graph(node,
new_tensor_dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program[op_id]
assert op_dist_attr is not None, \
"Operator must have a distributed attribute after the initialization for program."
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self.set_op_distributed_attr_for_graph(node, new_op_dist_attr)
self._is_initialized_for_graph = True
def clear_distributed_attr_for_program(self):
self._tensor_distributed_attr_map_for_program.clear()
self._op_distributed_attr_map_for_program.clear()
def clear_distributed_attr_for_graph(self):
self._tensor_distributed_attr_map_for_graph.clear()
self._op_distributed_attr_map_for_graph.clear()
def copy_distribute_attr_from_graph_to_program(self, graph, program):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"The distribute attributes must be initialized both in its program and graph"
updated_tensors = {}
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr = self.get_tensor_distributed_attr_for_graph(
node)
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self._tensor_distributed_attr_map_for_program[
tensor_id] = new_tensor_dist_attr
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self.get_op_distributed_attr_for_graph(node)
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self._op_distributed_attr_map_for_program[
op_id] = new_op_dist_attr
def amend_distributed_attr_for_program(self):
for attr in self._tensor_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Tensor's distributed attribute {} is not valid".format(attr)
tensor_shape = attr.get_shape()
dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Operator's distributed attribute {} is not valid".format(attr)
for arg_name in attr.get_owner_op().desc.input_arg_names():
tensor_shape = attr.get_input_shape(arg_name)
dims_mapping = attr.get_input_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names():
tensor_shape = attr.get_output_shape(arg_name)
dims_mapping = attr.get_output_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
class DistOpHelper:
"""
DistOpHelper is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
......@@ -131,7 +131,7 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.int64:
self.dtype_factor *= 8
else:
raise NotImplementedError("{} not counted".format(v.node.dtype))
raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None
if batch_size is not None:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid.framework import Variable
from .process_mesh import ProcessMesh
_g_tensor_dist_attr_field_keys = [
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]
_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_input_suffix = "@input"
_g_op_output_suffix = "@output"
def get_tensor_dist_attr_field_keys():
global _g_tensor_dist_attr_field_keys
return _g_tensor_dist_attr_field_keys
def get_op_dist_attr_field_keys():
global _g_op_dist_attr_field_keys
return _g_op_dist_attr_field_keys
def append_op_input_suffix(name):
global _g_op_input_suffix
return name + _g_op_input_suffix
def append_op_output_suffix(name):
global _g_op_output_suffix
return name + _g_op_output_suffix
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
self._shard_sizes = None
self._device_placement = None
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
@property
def dims_mapping(self):
return self._dims_mapping
@dims_mapping.setter
def dims_mapping(self, dims_mapping):
if dims_mapping is not None:
assert isinstance(dims_mapping, list), \
"The type of dims_mapping must be list."
assert all(isinstance(x, int) for x in dims_mapping), \
("All elements of dims_mapping must be integer")
assert all(x >= -1 for x in dims_mapping), \
("All elements of dims_mapping must be greater than or equal to -1.")
self._dims_mapping = copy.deepcopy(dims_mapping)
@property
def shard_sizes(self):
return self._shard_sizes
@shard_sizes.setter
def shard_sizes(self, shard_sizes):
if shard_sizes is not None:
self._shard_sizes = copy.deepcopy(shard_sizes)
@property
def device_placement(self):
return self._device_placement
@device_placement.setter
def device_placement(self, device_placement):
if device_placement is not None:
self._device_placement = copy.deepcopy(device_placement)
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key in dist_attr.keys():
if key in get_tensor_dist_attr_field_keys():
self.mark_annotated(key)
elif isinstance(dist_attr, TensorDistributedAttribute):
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def clear_annotated(self):
self._is_annotated.clear()
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tdims_mapping ({}): {}".format(annotated_str,
self.dims_mapping)
str += "\n\t}"
return str
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
@property
def impl_type(self):
return self._impl_type
@impl_type.setter
def impl_type(self, impl_type):
if impl_type is not None:
self._impl_type = impl_type
@property
def impl_idx(self):
return self._impl_idx
@impl_idx.setter
def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx
@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
@property
def outputs_dist_attrs(self):
return self._outputs_dist_attrs
def get_input_dist_attr(self, name):
return self._inputs_dist_attrs.get(name, None)
def set_input_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
def set_output_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
dims_mapping = input_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_input_dims_mapping(self, name, dims_mapping):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
input_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._inputs_dist_attrs[name] = dist_attr
def get_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
dims_mapping = output_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_output_dims_mapping(self, name, dims_mapping):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
output_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._outputs_dist_attrs[name] = dist_attr
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.init(value)
if dist_attr.get(append_op_input_suffix(key.name), False):
self.set_input_dist_attr(key.name, tensor_dist_attr)
if dist_attr.get(append_op_output_suffix(key.name), False):
self.set_output_dist_attr(key.name, tensor_dist_attr)
else:
if key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, OperatorDistributedAttribute):
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
self.set_input_dist_attr(
tensor_name, dist_attr.get_input_dist_attr(tensor_name))
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
self.set_output_dist_attr(
tensor_name, dist_attr.get_output_dist_attr(tensor_name))
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
for key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
for tensor_dist_attr in self.outputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
shared_process_mesh = None
for process_mesh in process_meshes:
if process_mesh is not None:
if shared_process_mesh is None:
shared_process_mesh = process_mesh
else:
assert process_mesh == shared_process_mesh, \
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
self._is_annotated[attr_name] = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
else:
self._is_annotated[attr_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
input_dist_attr = self.get_input_dist_attr(key.name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(value)
output_dist_attr = self.get_output_dist_attr(key.name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(value)
else:
if key in get_op_dist_attr_field_keys():
self.mark_annotated(key)
process_mesh_annotated = False
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.outputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
elif isinstance(dist_attr, OperatorDistributedAttribute):
process_mesh_annotated = False
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
input_dist_attr = self.get_input_dist_attr(tensor_name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(tensor_dist_attr)
if input_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
output_dist_attr = self.get_output_dist_attr(tensor_name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(tensor_dist_attr)
if output_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
def clear_annotated(self):
self._is_annotated.clear()
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
def is_annotated_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
return input_dist_attr.is_annotated("dims_mapping")
else:
return False
def is_annotated_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
return output_dist_attr.is_annotated("dims_mapping")
else:
return False
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\timpl type: {}, ".format(self._impl_type)
str += "impl idx: {}".format(self._impl_idx)
str += "\n\t}"
return str
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
def get_default_distributed_context():
global _g_default_distributed_context
if _g_default_distributed_context is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return _g_default_distributed_context
def set_default_distributed_context(dist_context):
global _g_default_distributed_context
_g_default_distributed_context = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self, program=None):
self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
@property
def serial_program(self):
return self._serial_program
@property
def serial_graph(self):
return self._serial_graph
@serial_program.setter
def serial_program(self, program):
assert self._serial_program is None, \
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def process_meshes(self):
return self._process_meshes
@property
def dist_op_context(self):
return self._dist_op_context
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
if process_mesh not in self.process_meshes:
self._process_meshes.append(process_mesh)
def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None)
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None)
def get_dist_op_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
# Copy the dist tensors and dist ops annotated by users from the default context
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program(
tensor)
if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor)
self.add_dist_tensor_for_program(dist_tensor)
for op in block.ops:
# Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op)
current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id,
None)
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
dist_op = self._dist_ops_for_program.get(op_id, None)
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr)
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear()
self._dist_ops_for_program.clear()
def clear_dist_info_for_graph(self):
self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear()
def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
node)
dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program:
assert False, \
"Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks:
for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
return True
class DistributedOperatorContext:
"""
DistributedOperatorContext is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
self._serial_outputs = {}
self._dist_attr = None
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_op(self):
return self._serial_op
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = OperatorDistributedAttribute()
# Create new dist_attr related to current serial_op
dist_attr = self._filter_dist_attr(dist_attr)
# Append suffix to mark the inputs or outputs
if isinstance(dist_attr, dict):
# Copy the keys since we may add new ones
for key in list(dist_attr.keys()):
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names:
dist_attr[append_op_input_suffix(key.name)] = True
if key.name in self._serial_op.output_arg_names:
dist_attr[append_op_output_suffix(key.name)] = True
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def get_serial_input(self, name):
return self._serial_inputs.get(name, None)
def get_serial_output(self, name):
return self._serial_outputs.get(name, None)
def _init_default_dist_attr(self):
for tensor_name in self._serial_op.input_arg_names:
if self._serial_op.type == "create_py_reader":
tensor = None
else:
tensor = self._serial_op.block._var_recursive(tensor_name)
self._serial_inputs[tensor_name] = tensor
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
return None
new_dist_attr = None
if isinstance(dist_attr, dict):
new_dist_attr = {}
for key, value in dist_attr.items():
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \
or key.name in self._serial_op.output_arg_names:
new_dist_attr[key] = value
else:
new_dist_attr[key] = value
elif isinstance(dist_attr, OperatorDistributedAttribute):
new_dist_attr = copy.deepcopy(dist_attr)
new_dist_attr._inputs_dist_attrs.clear()
new_dist_attr._outputs_dist_attrs.clear()
for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name,
tensor_dist_attr)
for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name,
tensor_dist_attr)
else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type:
return True
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
return False
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(),
self.serial_op.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_input(arg_name) is not None:
if self.get_serial_input(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_output(arg_name) is not None:
if self.get_serial_output(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
return str
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops)
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys
class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None):
self._serial_tensor = serial_tensor
self._dist_attr = None
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_tensor(self):
return self._serial_tensor
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = TensorDistributedAttribute()
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
return False
for i in range(len(self.dist_attr.dims_mapping)):
if self.dist_attr.dims_mapping[
i] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
str += ", is_parameter: {}".format(self.serial_tensor.is_parameter)
if self.dist_attr.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.dist_attr.dims_mapping)
if self.dist_attr.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str, None)
if self.dist_attr.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str, None)
return str
......@@ -18,293 +18,34 @@ import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode
__all__ = []
# a map from ProcessMesh ids to the ProcessMesh instances
_g_process_mesh_map = dict()
# user defined map from logical process ids to physical ones
_user_defined_physical_map = None
def _append_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def _remove_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
def _static_mode_check():
if in_dygraph_mode():
raise RuntimeError("Auto-parallel only supports static mode, "
"please use paddle.enable_static().")
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
parent (ProcessMesh, optional): the parent ProcessMesh. None means
the ProcessMesh is the root one without parent ProcessMesh.
Default: None.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.parent is None
assert mesh.topology == [2, 3]
assert mesh.process_group == [2, 4, 5, 0, 1, 3]
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
def __init__(self, mesh, parent=None):
_static_mode_check()
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
self._topology = _get_nested_list_shape(mesh)
self._processes = _flatten_nested_list(mesh)
# Every element of mesh must be >= 0.
assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.')
unique_ids = set(self._processes)
assert len(unique_ids) == len(self._processes), (
'All elements of mesh must be unique.')
if parent is None:
# For root ProcessMesh, the ids of logical processes must be range
# from 0 to N-1, where N is the number of logical processes.
assert max(self._processes) == len(self._processes) - 1, (
'For root ProcessMesh, ids of logical processes must be range '
'from 0 to N-1, where N is the number of logical processes.')
parent_id = core.kNoneProcessMeshIndex()
assert len(_g_process_mesh_map.keys()) == 0, (
'The first ProcessMesh must be the root, which has no parent.')
else:
assert len(_g_process_mesh_map.keys()) > 0, (
'All ProcessMesh must have a parent except the root one.')
assert isinstance(parent, ProcessMesh), (
'parent must be an instance of ProcessMesh.')
parent_id = parent._desc.id
# All elements in mesh must belong to its parent
parent_ids = set(parent.process_group)
assert unique_ids <= parent_ids, (
'All elements in mesh must belong to its parent.')
self._desc = core.ProcessMeshDesc(self._topology, self._processes,
parent_id)
self._id = self._desc.id
self._parent_id = parent_id
assert self._id not in _g_process_mesh_map, (
"The ProcessMesh with id %d already exists." % self._id)
_g_process_mesh_map[self._id] = self
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def process_group(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def parent(self):
r"""
Get the parent ProcessMesh.
"""
if self._parent_id == core.kNoneProcessMeshIndex(): return None
assert self._parent_id in _g_process_mesh_map, (
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
user defined order.
Args:
order (list): order of the physical process ids.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
assert self.parent is None, (
"This function can only be called by the root ProcessMesh.")
unique_ids = set(order)
assert isinstance(order, list)
assert len(unique_ids) == len(order), (
"All elements in order must be unique.")
assert min(order) == 0
assert max(order) == len(order) - 1, (
"All elements in order must be from 0 to N - 1, where N "
"is the number of physical processes.")
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
logical_order = self.process_group
global _user_defined_physical_map
assert _user_defined_physical_map is None, (
"This function can only be called once.")
_user_defined_physical_map = dict()
assert len(logical_order) == len(order)
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert isinstance(mesh,
ProcessMesh), 'The type of mesh must be ProcessMesh.'
assert isinstance(dim_mapping,
list), 'The type of dim_mapping must be list.'
assert len(tensor.shape) == len(dim_mapping), (
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.')
mesh_dim = len(mesh.topology)
dim_set = set()
for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (
dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), (
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.')
if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i])
def shard_tensor(x, mesh, dim_mapping):
def shard_tensor(x, dist_attr=None):
"""
Add distributed attributes for a tensors.
Args:
x (Tensor): the tensor to process.
mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes.
dim_mapping (list): a list to describe the mapping between `x` and `mesh`,
the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means
without parition along the corresponding dimension.
x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
"process_mesh": a nested list an to describe the mesh topology of logical processes.
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
Returns:
Tensor: the tensor `x` itself.
Tensor: the tensor `x` annotated with distributed attributes.
Examples:
.. code-block:: python
......@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping):
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [0, -1])
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
"dims_mapping": [0, -1]})
"""
_static_mode_check()
_dim_mapping_checker(x, mesh, dim_mapping)
attr_name = _append_attr_suffix('mesh_id')
x._set_attr(attr_name, mesh._id)
attr_name = _append_attr_suffix('dim_mapping')
x._set_attr(attr_name, dim_mapping)
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be None, dict or TensorDistributedAttribute."
dist_tensor = DistributedTensor(x, dist_attr)
dist_tensor.dist_attr.mark_annotated_as(dist_attr)
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
return x
def set_shard_mask(x, mask):
"""
Set the mask for a tensor which mask out the tensor from some processes in its mesh.
Args:
x (Tensor): the tensor to process.
mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to
the tensor `x`. Every value of `mask` must be one or zero, where one means
the tenor `x` will be put on the corresponding logical process and zero means the tensor `x`
will not be put on the corresponding logical process.
For example, for a ProcessMesh represented by the 2-dimensional
array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the
2-dimensional [[1, 0, 1], [0, 1, 0]],
then the tensor `x` will only be put on logical processes 2, 5 and 1.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1])
dist.set_shard_mask(x, mask)
"""
_static_mode_check()
assert isinstance(mask, list)
np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask)
mesh_attr_name = _append_attr_suffix('mesh_id')
assert x._has_attr(mesh_attr_name), \
"Please set process mesh for the variable firstly."
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
return x
def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
def shard_op(op_fn, dist_attr=None):
"""
Call a functioin and add distributed attributes for ops added by the function.
Args:
op_fn (callable): a callable object of an API.
mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes.
dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping.
The dim_mapping is a list to describe the mapping between a tensor and `mesh`,
the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`,
where -1 means without parition along the corresponding dimension.
kwargs (dict): a dict of parameter passed to the function `op_fn`.
op_fn (callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
two categories. The first category decsribes the distributed attributes shared by all inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
optional and users can specify them as need. Note that `process_mesh` for operators must be the
same as these process_meshes for inputs and outputs.
Returns:
list: the outputs of the function `op_fn`.
list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Examples:
.. code-block:: python
......@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
kwargs = {'x': x, 'y': y}
dist.shard_op(paddle.add, mesh, None, **kwargs)
"""
_static_mode_check()
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = op_fn(**kwargs)
new_op_size = len(main_block.ops)
if dim_mapping_dict is None:
dim_mapping_dict = dict()
else:
assert isinstance(dim_mapping_dict,
dict), 'The type of dim_mapping_dict must be dict.'
for var_name in dim_mapping_dict.keys():
dim_mapping = dim_mapping_dict[var_name]
tensor = main_block.var(var_name)
_dim_mapping_checker(tensor, mesh, dim_mapping)
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id')
op._set_attr(attr_name, mesh._id)
for var_name in dim_mapping_dict.keys():
assert var_name in op.output_arg_names + op.input_arg_names
attr_name = _append_attr_suffix(var_name)
if var_name in op.input_arg_names:
# we use the prefix "IN_" to indicates an input argument name
attr_name = "IN_" + attr_name
else:
# we use the prefix "OUT_" to indicates an input argument name
attr_name = "OUT_" + attr_name
op._set_attr(attr_name, dim_mapping_dict[var_name])
if isinstance(output, Variable):
output = [output]
return list(output)
def set_offload_device(x, device):
"""
Set the device that the tensor `x` will be put on.
Args:
x (tensor): the tensor to process.
device (str): the device that the tensor `x` will be put on, e.g., 'cpu'.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
x = paddle.ones([4, 6])
dist.set_offload_device(x, 'cpu')
"""
_static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
def set_pipeline_stage(stage):
"""
Set the pipeline stage of the following ops.
Args:
stage (int): the pipeline stage the following ops belonging to.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
dist.set_pipeline_stage(0)
dist_add = dist.shard_op(paddle.add,
dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
x: {"dims_mapping": [-1, 0]},
y: {"dims_mapping": [0, -1]}
})
dist_add(x, y)
"""
from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check()
assert isinstance(stage, int), 'The type of stage must be int.'
_set_pipeline_stage(stage)
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License
DISTRIBUTED_OPERATORS = {}
_g_distributed_operator_impl_registries = {}
class DistributedOperator:
class DistributedOperatorImplContainer:
def __init__(self):
self._impls = []
self._name = None
......@@ -47,67 +47,60 @@ class DistributedOperatorImpl:
def get_name(self):
return self._name
def is_process_mesh_compatible(self, op_dist_attr):
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def is_compatible(self, dist_op):
return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op):
global DISTRIBUTED_OPERATORS
DISTRIBUTED_OPERATORS[name] = dist_op
def register_distributed_operator_impl_container(name, dist_op_impl_container):
global _g_distributed_operator_impl_registries
_g_distributed_operator_impl_registries[name] = dist_op_impl_container
def get_distributed_operator(name):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS.get(name, None)
def get_distributed_operator_impl_container(name):
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries.get(name, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name)
if dist_op is not None:
dist_op.register_impl(dist_impl)
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is not None:
dist_op_impl_container.register_impl(dist_impl)
else:
assert False, "Must register distributed operator first."
assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx)
global _g_distributed_operator_impl_registries
return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
fwd=True):
def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op = get_distributed_operator(name)
if dist_op is None:
dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op_impl_container is None:
return None, -1
compatible_impls = []
impls = dist_op.get_impls()
impls = dist_op_impl_container.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_input_compatible(op_dist_attr):
if impl.is_input_compatible(dist_op):
compatible_impls.append((impl, idx))
else:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_output_compatible(op_dist_attr):
if impl.is_output_compatible(dist_op):
compatible_impls.append((impl, idx))
if compatible_impls:
......@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
return best_compatible_impl, idx
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var):
"""
copy src var's dist_attr to dst var
"""
import copy
# def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var):
# """
# copy src var's dist_attr to dst var
# """
# import copy
auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# dist_attr = copy.deepcopy(
# auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
# dist_attr._owner_tensor = var
# dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
# src_var)._owner_context
# auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
"""
copy src var's dist_attr to dst var
"""
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
# """
# copy src op's dist_attr to dst dist op
# """
# from ..attribute import OperatorDistributedAttribute
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
# op_dist_attr)
# auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
# op_dist_attr)
# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
# for input_varname in dist_op.desc.input_arg_names():
# input_var = dst_block.var(input_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# input_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
# for output_varname in dist_op.desc.output_arg_names():
# output_var = dst_block.var(output_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# output_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_output_dims_mapping(output_varname,
# tensor_dims_mapping)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr):
"""
copy src op's dist_attr to dst dist op
"""
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
# need check dist op attr and its inputs and outputs
auto_paralle_context = src_op_dist_attr.get_owner_context()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping)
op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -22,23 +22,24 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperator):
class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedDefault, self).__init__()
self._name = name
register_distributed_operator("default", DistributedDefault("default"))
register_distributed_operator_impl_container("default",
DistributedDefault("default"))
# Replicated Default
......@@ -49,27 +50,24 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.")
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
varname_mapping = dist_op_helper.get_varname_mapping()
rank_id = dist_op_helper.get_rank_id()
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs
for input_name in src_op.desc.input_names():
......@@ -100,25 +98,25 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var(
varname
).is_parameter and varname not in dist_op_helper.already_init_sync_vars:
dist_op_helper.already_init_sync_vars.add(varname)
).is_parameter and varname not in dist_op_context.already_init_sync_vars:
dist_op_context.already_init_sync_vars.add(varname)
param = startup_block.var(varname)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(
param)
process_mesh = param_dist_attr.get_process_mesh()
dims_mapping = param_dist_attr.get_dims_mapping()
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dims_mapping = param_dist_attr.dims_mapping
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dims_mapping:
pass
else:
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
axis, rank_id)
sync_group = new_process_group(group_ranks)
......@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
})
# set distributed attribute
op_attr = OperatorDistributedAttribute(new_op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(param.name,
dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(new_op, op_attr)
ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
......@@ -147,13 +145,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
rank_id = dist_op_helper.get_rank_id()
rank_id = dist_op_context.get_rank_id()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
......@@ -165,19 +163,20 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname).is_parameter:
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.get_process_mesh()
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group:
rank_id = _get_corresponding_rank(process_mesh, rank_id)
if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(
process_mesh.process_group, process_mesh.topology,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks)
......@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward
})
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp()
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
......@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperator):
class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
register_distributed_operator("lookup_table_v2",
register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding"))
register_distributed_operator("c_embedding", DistributedEmbedding("embedding"))
# RowParallel
......@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
......@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension
......@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
......@@ -111,12 +111,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
......@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# A generalized method to caculate embedding offset using cartisian product
......@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
check_variable_and_dtype(
Out_var, 'tensor',
......@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr)
# param initialization sync
assert Weight_var.name not in dist_op_helper.already_init_sync_vars
dist_op_helper.already_init_sync_vars.add(Weight_var.name)
assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh()
dim_mapping = param_dist_attr.get_dims_mapping()
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis,
rank_id)
sync_group = new_process_group(group_ranks)
......@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(),
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id)
# check if need gradient allreduce
......@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0])
process_mesh = dist_attr.get_process_mesh()
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology,
batch_size_axis, rank_id)
dp_degree = len(group_ranks)
......@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
W_Grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
ctx.set_op_dist_attr_for_program(op, op_attr)
register_distributed_operator_impl("lookup_table_v2",
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op
......@@ -24,19 +24,20 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute
from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group
from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank
def _update_dims_mapping_for_matmul(op_dist_attr):
def _update_dims_mapping_for_matmul(dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
......@@ -112,7 +113,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
if dim_changed:
changed = True
# Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor
# Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if x_dims_mapping_len == 1:
x_dims_mapping.pop(0)
if y_dims_mapping_len == 1:
......@@ -129,17 +130,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id)
if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
# check if need gradient allreduce
need_gradient_allreduce = False
......@@ -175,13 +176,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name)
process_mesh = dist_attr.get_process_mesh()
process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, batch_size_axis,
rank_id)
dp_degree = len(group_ranks)
......@@ -207,32 +208,32 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program(
Y_Grad_var).get_dims_mapping()
process_mesh = dist_attr.get_process_mesh()
dims_mapping = ctx.get_tensor_dist_attr_for_program(
Y_Grad_var).dims_mapping
process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx)
op_attr.set_process_mesh(process_mesh)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr)
ctx.set_op_dist_attr_for_program(op, op_attr)
def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id):
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
assert Weight_var.name not in dist_op_helper.already_init_sync_vars
assert Weight_var.name not in dist_op_context.already_init_sync_vars
assert startup_block.has_var(Weight_var.name)
dist_op_helper.already_init_sync_vars.add(Weight_var.name)
dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh()
dim_mapping = param_dist_attr.get_dims_mapping()
param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.dims_mapping
for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping:
pass
else:
group_ranks = _get_comm_group(process_mesh.process_group,
group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis, rank_id)
sync_group = new_process_group(group_ranks)
......@@ -249,13 +250,14 @@ def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id):
startup_block._sync_with_cpp()
class DistributedMatmul(DistributedOperator):
class DistributedMatmul(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedMatmul, self).__init__()
self._name = name
register_distributed_operator("matmul", DistributedMatmul("matmul"))
register_distributed_operator_impl_container("matmul",
DistributedMatmul("matmul"))
# ColumnParallel
......@@ -266,12 +268,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -286,8 +285,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
......@@ -297,9 +297,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......@@ -310,18 +310,18 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
......@@ -348,8 +348,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
Weight_var.name)[1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
......@@ -365,7 +365,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var)
check_variable_and_dtype(
X_var, 'tensor',
......@@ -395,13 +395,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr)
# init param sync
if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx,
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
......@@ -417,12 +418,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -438,8 +436,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
......@@ -450,9 +449,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......@@ -463,18 +462,18 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
......@@ -501,8 +500,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
Weight_var.name)[0]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
......@@ -528,7 +527,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
matmul_op = main_block.append_op(
type='matmul',
......@@ -547,13 +546,14 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block,
copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr)
# init param sync
if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx,
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
......@@ -567,12 +567,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
super(DistributedMatmulImpl2, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -592,8 +589,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......@@ -605,9 +603,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......@@ -625,13 +623,14 @@ register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel"))
class DistributedMatmulV2(DistributedOperator):
class DistributedMatmulV2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedMatmulV2, self).__init__()
self._name = name
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2"))
register_distributed_operator_impl_container("matmul_v2",
DistributedMatmulV2("matmul_v2"))
# ColumnParallel
......@@ -642,12 +641,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -662,8 +658,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
......@@ -673,9 +670,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......@@ -686,18 +683,18 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
......@@ -724,8 +721,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
Weight_var.name)[1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
......@@ -741,7 +738,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
persistable=False,
stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var)
check_variable_and_dtype(
X_var, 'tensor',
......@@ -770,14 +767,14 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_v2_op, main_block,
copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block,
op_dist_attr)
# init param sync
if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx,
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
......@@ -793,12 +790,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -814,8 +808,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
......@@ -826,9 +821,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......@@ -839,18 +834,18 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(),
if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id)
# check validation of inputs / outputs
......@@ -877,8 +872,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
Weight_var.name)[0]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group
process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
......@@ -900,7 +895,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
is_data=False,
need_check_feed=Out_var.desc.need_check_feed())
# copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var)
copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
matmul_v2_op = main_block.append_op(
type='matmul_v2',
......@@ -919,14 +914,14 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
})
# copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_v2_op, main_block,
copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block,
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr)
# init param sync
if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx,
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id)
@staticmethod
......@@ -940,12 +935,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
super(DistributedMatmulV2Impl2, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -965,8 +957,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......@@ -978,9 +973,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed:
changed = True
return changed
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator):
class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
......@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......@@ -103,11 +103,11 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
......@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
......@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......@@ -231,11 +230,11 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping
"""
dist_op_helper = ctx.get_dist_op_helper()
main_block = dist_op_helper.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op)
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op))
......@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology
process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape
for idx, axis in enumerate(dim_mapping):
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator):
class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax"))
register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
......@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_input_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
......@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......
......@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
......@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator):
class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2"))
register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
......@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
self._forward_implemented = False
self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
def is_input_compatible(self, dist_op):
return True
def is_input_compatible(self, op_dist_attr):
def is_output_compatible(self, dist_op):
return True
def is_output_compatible(self, op_dist_attr):
return True
def update_dims_mapping(self, op_dist_attr):
def update_dims_mapping(self, dist_op):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
......
......@@ -15,11 +15,11 @@
import paddle
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from .context import DistributedContext
from .context import get_default_distributed_context
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner
from .process import get_all_process_groups
from .process_group import get_all_process_groups
from .utils import make_data_unshard
from .reshard import reshard
......@@ -70,7 +70,6 @@ class AutoParallelizer:
# Annotation completion
completed_main_program = complete_annotation(
self._original_main_program, self._dist_context)
# Logical partition
rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
......
......@@ -22,15 +22,15 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
from paddle.distributed.auto_parallel.context import DistributedContext, DistOpHelper
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from .process import new_process_group
from .interface import _g_process_mesh_map
from .attribute import OperatorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
......@@ -68,14 +68,14 @@ class Partitioner(object):
# auto completion
auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
annotated_main_program = auto.complete_annotation(serial_main_program)
auto_paralle_context = get_default_distributed_context()
dist_context = get_default_distributed_context()
# distributed strategy & rank info
rank_id = paddle.distributed.get_rank()
dist_strategy = fleet.DistributedStrategy()
# create partitioner
Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id)
Partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# create dist program with forward only
# for distributed inference, using partitioned_main_prog from here
......@@ -93,11 +93,11 @@ class Partitioner(object):
opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
"""
def __init__(self, dist_strategy, auto_parallel_context, rank_id=0):
def __init__(self, dist_strategy, dist_context, rank_id=0):
"""
Args:
dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
rank_id (int): global rank id to which the partitioned distributed program belong.
"""
......@@ -106,13 +106,13 @@ class Partitioner(object):
"dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
% type(dist_strategy))
if not isinstance(auto_parallel_context, DistributedContext):
if not isinstance(dist_context, DistributedContext):
raise TypeError(
"auto_parallel_context be paddle.fluid.DistributedContext, got %s here"
% type(auto_parallel_context))
"dist_context be paddle.fluid.DistributedContext, got %s here" %
type(dist_context))
self._dist_strategy = dist_strategy
self._auto_parallel_context = auto_parallel_context
self._dist_context = dist_context
self._rank_id = rank_id
self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = ""
......@@ -218,8 +218,8 @@ class Partitioner(object):
if not isinstance(startup_program, (Program)):
raise TypeError(
"auto_parallel_context be paddle.fluid.framework.program, got %s here"
% type(startup_program))
"dist_context be paddle.fluid.framework.program, got %s here" %
type(startup_program))
# check if shard annotated serial program valid
if not self._is_valid_annotated_program(main_program):
......@@ -310,13 +310,12 @@ class Partitioner(object):
if isinstance(var, Parameter):
# TODO if var not belong to this rank, should be filtered
serial_main_var = serial_main_block.var(var.name)
dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_main_var)
target_shape = _get_dist_shape(serial_main_var, dist_attr)
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
_partition_parameter(self._auto_parallel_context,
serial_main_var,
_partition_parameter(self._dist_context, serial_main_var,
partitioned_startup_global_block,
new_name, target_shape)
param2shape[new_name] = target_shape
......@@ -346,24 +345,22 @@ class Partitioner(object):
assert new_op.desc == new_op_desc
output_var = partitioned_startup_global_block.var(output_vars[
0])
output_var_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
output_var_attr = self._dist_context.get_tensor_dist_attr_for_program(
output_var)
op_attr = OperatorDistributedAttribute(
new_op, self._auto_parallel_context)
op_attr.set_process_mesh(output_var_attr.get_process_mesh())
op_attr.set_output_dims_mapping(
output_var.name, output_var_attr.get_dims_mapping())
op_attr.set_input_dims_mapping(
output_var.name, output_var_attr.get_dims_mapping())
self._auto_parallel_context.set_op_distributed_attr_for_program(
new_op, op_attr)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = output_var_attr.process_mesh
op_attr.set_output_dims_mapping(output_var.name,
output_var_attr.dims_mapping)
op_attr.set_input_dims_mapping(output_var.name,
output_var_attr.dims_mapping)
self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
# TODO move helper init to a comm place
dist_op_helper = self._auto_parallel_context.get_dist_op_helper()
dist_op_helper.set_dst_main_program(partitioned_main_prog)
dist_op_helper.set_dst_startup_program(partitioned_startup_prog)
dist_op_helper.set_varname_mapping(self._serial2dist_varname_mapping)
dist_op_helper.set_rank_id(self._rank_id)
dist_op_context = self._dist_context.dist_op_context
dist_op_context.set_dst_main_program(partitioned_main_prog)
dist_op_context.set_dst_startup_program(partitioned_startup_prog)
dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping)
dist_op_context.set_rank_id(self._rank_id)
# transpile main program
for op in serial_ops:
......@@ -373,8 +370,7 @@ class Partitioner(object):
if serial_input_varname not in self._serial2dist_varname_mapping:
new_varname = serial_input_varname + self._dist_varname_suffix
if serial_main_block.has_var(serial_input_varname):
_partition_var(self._auto_parallel_context,
serial_main_block,
_partition_var(self._dist_context, serial_main_block,
partitioned_global_block,
serial_input_varname, new_varname)
else:
......@@ -387,28 +383,25 @@ class Partitioner(object):
for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping:
new_varname = serial_output_varname + self._dist_varname_suffix
_partition_var(self._auto_parallel_context,
serial_main_block, partitioned_global_block,
_partition_var(self._dist_context, serial_main_block,
partitioned_global_block,
serial_output_varname, new_varname)
self._serial2dist_varname_mapping[
serial_output_varname] = new_varname
# partition op
kinputs, koutputs = dist_op_helper.prepare_forward_context(op)
dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program(
op)
if _is_dist_op_forward_implement(self._auto_parallel_context, op):
dist_ops = get_distributed_operator(op.type)
dist_op_impl = dist_ops.get_impl(dist_attr.get_impl_idx())
dist_op_impl.forward(self._auto_parallel_context, **kinputs,
**koutputs)
kinputs, koutputs = dist_op_context.prepare_forward_context(op)
dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if _is_dist_op_forward_implement(self._dist_context, op):
dist_ops = get_distributed_operator_impl_container(op.type)
dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx)
dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
else:
# replicate op
dist_ops = get_distributed_operator("default")
dist_ops = get_distributed_operator_impl_container("default")
dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.forward(self._auto_parallel_context, **kinputs,
**koutputs)
dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
return partitioned_main_prog, partitioned_startup_prog
......@@ -453,18 +446,18 @@ class Partitioner(object):
for param in no_grad_set
]
dist_op_helper = self._auto_parallel_context.get_dist_op_helper()
dist_op_context = self._dist_context.dist_op_context
params_and_grads = _auto_backward(
dist_loss,
dist_startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
callbacks=callbacks,
distop_context=dist_op_helper)
distop_context=dist_op_context)
# backward completion
complete_backward_annotation(
dist_main_program, dist_context=self._auto_parallel_context)
dist_main_program, dist_context=self._dist_context)
# transpiler backward for dist op
# get backward ops
......@@ -485,31 +478,33 @@ class Partitioner(object):
backward_ops = ops[first_backward_op_idx:]
for backward_op in backward_ops:
# if the backward op has a corresponding forward op
if backward_op.desc.id() in dist_op_helper.gradopidx2opidx:
forward_op_id = dist_op_helper.gradopidx2opidx[
if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
forward_op_id = dist_op_context.gradopidx2opidx[
backward_op.desc.id()]
forward_op = forward_op_id2forward_op[forward_op_id]
# TODO backward attr should has _impl_idx
forward_op_dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program(
forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op)
# TODO use the backward op itself to find the dist op
dist_ops = get_distributed_operator(forward_op.type)
kinputs, koutputs = dist_op_helper.prepare_backward_context(
dist_ops = get_distributed_operator_impl_container(
forward_op.type)
kinputs, koutputs = dist_op_context.prepare_backward_context(
backward_op)
# TODO use backward op itself to determine impl idx
if _is_dist_op_backward_implement(
self._auto_parallel_context, forward_op):
if _is_dist_op_backward_implement(self._dist_context,
forward_op):
dist_op_impl = dist_ops.get_impl(
forward_op_dist_attr.get_impl_idx())
dist_op_impl.backward(self._auto_parallel_context,
**kinputs, **koutputs)
forward_op_dist_attr.impl_idx)
dist_op_impl.backward(self._dist_context, **kinputs,
**koutputs)
else:
# replicate op
dist_ops = get_distributed_operator("default")
dist_ops = get_distributed_operator_impl_container(
"default")
dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.backward(self._auto_parallel_context,
**kinputs, **koutputs)
dist_op_impl.backward(self._dist_context, **kinputs,
**koutputs)
return params_and_grads
# replace dist grad ops
......@@ -524,7 +519,7 @@ class Partitioner(object):
# update completion
complete_update_annotation(
main_program, dist_context=self._auto_parallel_context)
main_program, dist_context=self._dist_context)
return optimize_ops
......@@ -534,12 +529,11 @@ class Partitioner(object):
ops = program.global_block().ops
vars_ = program.list_vars()
op_dist_attrs = [
self._auto_parallel_context.get_op_distributed_attr_for_program(op)
for op in ops
self._dist_context.get_op_dist_attr_for_program(op) for op in ops
]
var_dist_attrs = [
self._auto_parallel_context.get_tensor_distributed_attr_for_program(
var) for var in vars_
self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_
]
all_ops_annotated = all(dist_attr is not None
......@@ -563,8 +557,7 @@ class Partitioner(object):
def _is_var_distributed(self, var):
dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program(
var)
dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var)
assert dist_attr is not None, "dist_attr of var [{}] is None".format(
var.name)
return _is_distributed(dist_attr)
......@@ -637,20 +630,20 @@ def _get_no_grad_set(loss, no_grad_set=None):
return no_grad_set
def _is_dist_op_forward_implement(auto_paralle_context, op):
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op)
dist_ops = get_distributed_operator(op.type)
def _is_dist_op_forward_implement(dist_context, op):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \
dist_attr.get_impl_idx())._forward_implemented
return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.impl_idx)._forward_implemented
def _is_dist_op_backward_implement(auto_paralle_context, op):
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op)
dist_ops = get_distributed_operator(op.type)
def _is_dist_op_backward_implement(dist_context, op):
dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \
dist_attr.get_impl_idx())._backward_implemented
return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.impl_idx)._backward_implemented
def _auto_backward(loss,
......@@ -690,8 +683,8 @@ def _auto_backward(loss,
def _is_distributed(dist_attr):
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
for idx in range(len(mapping)):
if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
return True
......@@ -702,8 +695,8 @@ def _is_distributed(dist_attr):
def _get_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
......@@ -721,7 +714,7 @@ def _get_dist_shape(var, dist_attr):
return new_shape
def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname,
def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
......@@ -749,17 +742,13 @@ def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname,
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_attr._owner_tensor = param
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(param,
dist_attr)
dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
def _partition_intermediate_var(auto_paralle_context, src_var, dst_block,
dst_varname, dst_shape):
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
dst_shape):
var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
......@@ -776,15 +765,12 @@ def _partition_intermediate_var(auto_paralle_context, src_var, dst_block,
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None
dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
def _partition_var(auto_paralle_context, src_block, dst_block, src_varname,
def _partition_var(dist_context, src_block, dst_block, src_varname,
dst_varname):
"""
partition include: split + replicate
......@@ -798,16 +784,15 @@ def _partition_var(auto_paralle_context, src_block, dst_block, src_varname,
persistable=True,
stop_gradient=True)
else:
dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter):
_partition_parameter(auto_paralle_context, src_var, dst_block,
dst_varname, target_shape)
_partition_parameter(dist_context, src_var, dst_block, dst_varname,
target_shape)
else:
_partition_intermediate_var(auto_paralle_context, src_var,
dst_block, dst_varname, target_shape)
_partition_intermediate_var(dist_context, src_var, dst_block,
dst_varname, target_shape)
def _insert_src_op(src_op, dst_block, varname_mapping):
......@@ -822,8 +807,7 @@ def _insert_src_op(src_op, dst_block, varname_mapping):
dst_block._sync_with_cpp()
def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context,
rank_id):
def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id):
# build input varname mapping
input_mapping = {}
......@@ -842,10 +826,9 @@ def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context,
output_mapping[output_name] = varnames
# append dist op
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op)
dist_ops = get_distributed_operator(src_op.type)
append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward(
src_op)
dist_attr = dist_context.get_op_dist_attr_for_program(src_op)
dist_ops = get_distributed_operator_impl_container(src_op.type)
append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op)
append_op_handle(
dst_block,
src_op,
......
......@@ -19,62 +19,32 @@ from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
_g_process_group_map = {}
def get_all_process_groups():
global PROCESS_GROUP_MAP
return PROCESS_GROUP_MAP.values()
global _g_process_group_map
return _g_process_group_map.values()
def new_process_group(ranks):
global PROCESS_GROUP_MAP
if not PROCESS_GROUP_MAP:
global _g_process_group_map
if not _g_process_group_map:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP:
num_groups = len(PROCESS_GROUP_MAP)
if key not in _g_process_group_map:
num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg
_g_process_group_map[key] = pg
return pg
else:
pg = PROCESS_GROUP_MAP[key]
pg = _g_process_group_map[key]
return pg
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import copy
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.topology == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3]
"""
def __init__(self, mesh):
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
processes = _flatten_nested_list(mesh)
assert all(isinstance(p, int) for p in processes), \
("All elements of mesh must be integer")
assert min(processes) >= 0, ('All elements of mesh must be >= 0.')
unique_processes = set(processes)
assert len(unique_processes) == len(processes), (
'All elements of mesh must be unique.')
self._topology = _get_nested_list_shape(mesh)
self._processes = processes
from .dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self)
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def processes(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def __eq__(self, other):
if not isinstance(other, ProcessMesh):
return False
if self.topology != other.topology or self.processes != other.processes:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.processes)
return str
......@@ -22,9 +22,9 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Program, OpProtoHolder
import paddle.fluid.layers.utils as utils
from ..collective import _get_global_env
from .context import DistributedContext
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process import new_process_group, ProcessGroup, PROCESS_GROUP_MAP
from .dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map
class AllGatherOpDesc:
......@@ -276,20 +276,22 @@ def _is_overlapped(shape_x, shape_y):
return overlapped
def _need_reshard(tensor_dist_attr, op_dist_attr):
def _need_reshard(dist_tensor, dist_op):
"""Judge the tensor whether needs to be resharded."""
is_reshard = False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_dist_attr.get_owner_tensor().name)
op_process_mesh = op_dist_attr.get_process_mesh()
tensor_dist_attr = dist_tensor.dist_attr
tensor_name = dist_tensor.serial_tensor.name
tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_process_mesh = tensor_dist_attr.process_mesh
op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all(
map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping,
op_process_mesh
])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh._id != op_process_mesh._id:
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh:
is_reshard = True
return is_reshard
......@@ -305,28 +307,30 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return complete_shape
def find_op_desc_seq(source_tensor, tensor_dist_attr, op_dist_attr):
def find_op_desc_seq(dist_tensor, dist_op):
"""
Find the op description sequence to reshard the source tensor for matching the op requirement.
Args:
source_tensor (Variable): A tensor with distributed attribute.
tensor_dist_attr (TensorDistributedAttribute): The distributed attribute of tensor.
op_dist_attr (OperatorDistributedAttribute): The distributed attribute of operator.
dist_tensor (DistributedTensor): A distributed tensor.
dist_op (DistributedOperator): A distributed operator.
Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
process and value is a list containing op description.
"""
source_dims_mapping = tensor_dist_attr.get_dims_mapping()
source_process_mesh = tensor_dist_attr.get_process_mesh()
source_process_group = source_process_mesh.process_group
tensor_dist_attr = dist_tensor.dist_attr
source_tensor = dist_tensor.serial_tensor
tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology
target_process_mesh = op_dist_attr.get_process_mesh()
target_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_dist_attr.get_owner_tensor().name)
target_process_group = target_process_mesh.process_group
op_dist_attr = dist_op.dist_attr
target_process_mesh = op_dist_attr.process_mesh
target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology
complete_shape = _compute_complete_shape(
......@@ -662,11 +666,11 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
def _init_comm_for_send_recv():
if not PROCESS_GROUP_MAP:
if not _g_process_group_map:
genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup(
_g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size)))
PROCESS_GROUP_MAP["global_group"].instantiate()
_g_process_group_map["global_group"].instantiate()
HAS_SENT = {}
......@@ -773,31 +777,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
axes=op_desc.axes,
new_var_name=new_name)
tensor_attr = TensorDistributedAttribute(target_tensor,
dist_context)
process_mesh = dist_context.get_op_distributed_attr_for_program(
matched_op).get_process_mesh()
dims_mapping = dist_context.get_op_distributed_attr_for_program(
tensor_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_op_dist_attr_for_program(
matched_op).process_mesh
dims_mapping = dist_context.get_op_dist_attr_for_program(
matched_op).get_input_dims_mapping(var_name)
tensor_attr.set_dims_mapping(dims_mapping)
tensor_attr.set_process_mesh(process_mesh)
dist_context.set_tensor_distributed_attr_for_program(target_tensor,
tensor_attr.dims_mapping = dims_mapping
tensor_attr.process_mesh = process_mesh
dist_context.set_tensor_dist_attr_for_program(target_tensor,
tensor_attr)
# rename op input name according to new name
for op in block.ops:
for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
op)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if name == var_name and op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh()
op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name)
if op_process_mesh._id == process_mesh._id and op_input_dims_mapping == dims_mapping:
if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping)
op_dist_attr._dims_mapping.pop(name, None)
op_dist_attr.set_input_dist_attr(name, None)
def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
......@@ -825,9 +827,9 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
if op.type == "c_sync_comm_stream":
need_save = []
for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_distributed_attr_for_program(
vars[var_name]).get_process_mesh()
if rank_id in process_mesh.process_group:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
vars[var_name]).process_mesh
if rank_id in process_mesh.processes:
need_save.append(var_name)
if not need_save:
remove_op_idx.append(idx)
......@@ -839,10 +841,10 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
continue
# judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh()
if rank_id not in op_process_mesh.process_group and op.type not in not_remove_op_ref:
op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
......@@ -974,20 +976,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None:
idx_offset = 0
for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
continue
var = block.vars[var_name]
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
var)
if tensor_dist_attr is not None and _need_reshard(
tensor_dist_attr, op_dist_attr):
reshard_op_desc = find_op_desc_seq(var, tensor_dist_attr,
op_dist_attr)
dist_tensor = dist_context.get_dist_tensor_for_program(var)
if dist_tensor is not None and _need_reshard(dist_tensor,
dist_op):
reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op)
parse_op_desc(auto_parallel_main_prog, rank_id,
reshard_op_desc, var_name, op, dist_context)
cur_op_count = len(block.ops)
......
......@@ -15,7 +15,6 @@
import threading
import paddle.fluid.core as core
import numpy as np
from .interface import _g_process_mesh_map
def is_valid_list_index(list, index):
......@@ -119,34 +118,35 @@ def remove_distributed_attr_suffix(name):
def check_distributed_attr_for_program(program, dist_context=None):
from .context import get_default_distributed_context
from .dist_context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
dist_tensor = dist_context.get_dist_tensor_for_graph(tensor)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor)
if (tensor_dist_attr is not None) and (
not tensor_dist_attr.is_valid()):
if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()):
return False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if (op_dist_attr is not None) and (not op_dist_attr.is_valid()):
dist_op = dist_context.get_dist_op_for_graph(tensor)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if (op_dist_attr is not None) and (not dist_op.is_valid()):
return False
return True
def print_program_with_distributed_attr(program, dist_context=None):
def print_program_with_dist_attr(program, dist_context=None):
"""
This function reuses the original program output ability with a distributed context.
Using lock can avoid multiple threads change the default distributed context simultaneously.
"""
lock = threading.Lock()
lock.acquire()
from .context import get_default_distributed_context
from .context import set_default_distributed_context
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
......@@ -301,31 +301,29 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
return coordinate
def _get_corresponding_rank(target_mesh, rank):
def _get_corresponding_rank(dist_context, target_mesh, rank):
# TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case.
# we assume that all mesh are evenly divide from a parent mesh and should have same size.
# to revise this in future.
coordinate = None
for key, mesh in _g_process_mesh_map.items():
if key == 0:
continue
if rank in mesh.process_group and mesh.topology == target_mesh.topology:
for mesh in dist_context.process_meshes:
if rank in mesh.processes and mesh.topology == target_mesh.topology:
coordinate = _linear_idx2coordinate(mesh.topology,
mesh.process_group.index(rank))
mesh.processes.index(rank))
break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank)
return target_mesh.process_group[_coordinate2linear_idx(mesh.topology,
return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)]
def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape
mapping = dist_attr.get_dims_mapping()
mesh = dist_attr.get_process_mesh().topology
mapping = dist_attr.dims_mapping
mesh = dist_attr.process_mesh.topology
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
......@@ -341,19 +339,16 @@ def _get_unshard_dist_shape(var, dist_attr):
def make_data_unshard(dist_main_prog, dist_startup_prog):
from .context import get_default_distributed_context
from .dist_context import get_default_distributed_context
dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars():
if var.is_data:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var)
inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr)
var.desc.set_shape(inverse_shape)
dim_mapping = tensor_dist_attr.get_dims_mapping()
dim_mapping = tensor_dist_attr.dims_mapping
dim_mapping = [-1] * len(dim_mapping)
tensor_dist_attr.set_dims_mapping(dim_mapping)
dist_context.set_tensor_distributed_attr_for_program(
var, tensor_dist_attr)
var._set_attr('dim_mapping' + core.kAutoParallelSuffix(),
dim_mapping)
tensor_dist_attr.dims_mapping = dim_mapping
dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
......@@ -1308,13 +1308,12 @@ class Variable(object):
if self.persistable:
var_str = "persist " + var_str
from paddle.distributed.auto_parallel.context import get_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
dist_context = get_default_distributed_context()
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
self)
if var_dist_attr is not None:
dist_tensor = dist_context.get_dist_tensor_for_program(self)
if dist_tensor is not None:
var_str += ", {name} = {value}".format(
name="dist_attr", value=var_dist_attr)
name="dist_attr", value=dist_tensor)
return var_str
......@@ -2529,12 +2528,12 @@ class Operator(object):
if i != len(attr_names) - 1:
attrs_str += ", "
from paddle.distributed.auto_parallel.context import get_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
dist_context = get_default_distributed_context()
op_dist_attr = dist_context.get_op_distributed_attr_for_program(self)
if op_dist_attr is not None:
dist_op = dist_context.get_dist_op_for_program(self)
if dist_op is not None:
attrs_str += ", {name} = {value}".format(
name="dist_attr", value=op_dist_attr)
name="dist_attr", value=dist_op)
if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
......
......@@ -36,8 +36,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
MESH_0 = auto.ProcessMesh([0, 1])
input = paddle.static.data(name='input', shape=[2, 8])
label = paddle.static.data(name='label', shape=[2, 8])
......@@ -47,10 +46,30 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(
linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
......@@ -105,8 +124,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
MESH_0 = auto.ProcessMesh([0, 1])
input = paddle.static.data(name='input', shape=[8, 8])
label = paddle.static.data(name='label', shape=[8, 8])
......@@ -116,11 +134,31 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, 0])
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
......
......@@ -24,13 +24,12 @@ import paddle.utils as utils
from paddle.fluid import layers
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
import paddle.fluid.core as core
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
class MLPLayer(nn.Layer):
......@@ -78,8 +77,12 @@ def mlp_pretrain_forward(train_program, start_program):
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1])
auto.set_pipeline_stage(1)
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -99,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
......@@ -131,7 +134,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops:
for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program)
# print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program)
......
......@@ -15,128 +15,153 @@
from __future__ import print_function
import unittest
import functools
import operator
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn as nn
import paddle.distributed as dist
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
paddle.enable_static()
def _flatten_nested_list(nested_list):
result = functools.reduce(operator.iconcat, nested_list, [])
return result
def _append_attr_suffix(name):
return name + core.kAutoParallelSuffix()
LAST_PP_STAGE = 3
MASK = [[0, 1, 1], [0, 1, 1]]
MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]])
process_mesh1 = [0, 1, 2, 3]
process_mesh2 = [[0, 1, 2], [3, 4, 5]]
class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__()
self.mesh = MESH
self.mesh.set_placement([5, 4, 3, 2, 1, 0])
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2)
def forward(self, x, y):
x = dist.shard_tensor(x, self.mesh, dim_mapping=[0, -1])
x = dist.set_shard_mask(x, MASK)
# Test shard_tensor interface with dist_attr arg
x = dist.shard_tensor(
x,
dist_attr={"process_mesh": process_mesh1,
"dims_mapping": [0, -1]})
emb_out = self.word_embeddings(x)
dist.set_pipeline_stage(LAST_PP_STAGE)
y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1])
dist.set_offload_device(y, "cpu")
# Test shard_tensor interface with no dist_attr arg
y = dist.shard_tensor(y)
linear1 = self.dense1(y)
out = self.dense2(linear1)
return x, y, self.mesh
return x, y
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
dist_context = get_default_distributed_context()
net = SimpleNet()
data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
data2 = fluid.layers.fill_constant(
shape=[2, 4], value=2, dtype="float32")
data3 = fluid.layers.fill_constant(
shape=[2, 4], value=4, dtype="float32")
x, y, mesh = net.forward(data1, data2)
mesh_attr = _append_attr_suffix('mesh_id')
x_mesh_id = x._get_attr(mesh_attr)
self.assertEqual(x_mesh_id, mesh._id)
x_mesh = x.process_mesh
allatts = x.attr_names
self.assertEqual(x_mesh, mesh)
shard_mask_attr = _append_attr_suffix('mask')
self.assertEqual(
x._get_attr(shard_mask_attr), _flatten_nested_list(MASK))
self.assertEqual(x.shard_mask, _flatten_nested_list(MASK))
offload_attr = _append_attr_suffix('offload_device')
self.assertEqual(y._get_attr(offload_attr), "cpu")
self.assertEqual(y.desc.has_attr(offload_attr), True)
self.assertEqual(y.offload_device, "cpu")
y._remove_attr(offload_attr)
self.assertEqual(y._has_attr(offload_attr), False)
ops = paddle.static.default_main_program().block(0).ops
first_op = ops[0]
last_op = ops[-1]
self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE)
DIMS_MAPPING1 = [0, 1]
DIMS_MAPPING2 = [-1, 0]
kwargs = {'x': data2, 'y': data3}
dist.shard_op(
x, y = net.forward(data1, data2)
dist_x = dist_context.get_dist_tensor_for_program(x)
self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1)
self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1])
self.assertEqual(dist_x.dist_attr.shard_sizes, None)
self.assertEqual(dist_x.dist_attr.device_placement, None)
self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_x.dist_attr.is_annotated("device_placement"))
dist_y = dist_context.get_dist_tensor_for_program(y)
self.assertEqual(dist_y.dist_attr.process_mesh, None)
self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1])
self.assertEqual(dist_y.dist_attr.shard_sizes, None)
self.assertEqual(dist_y.dist_attr.device_placement, None)
self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh"))
self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_y.dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dims_mapping1 = [0, 1]
dims_mapping2 = [-1, 0]
dist_add = dist.shard_op(
paddle.add,
mesh=mesh,
dim_mapping_dict={
data2.name: DIMS_MAPPING1,
data3.name: DIMS_MAPPING2
dist_attr={
data2: {
"process_mesh": process_mesh2,
"dims_mapping": dims_mapping1
},
**kwargs)
data3: {
"dims_mapping": dims_mapping2
}
})
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
self.assertEqual(last_op.process_mesh, mesh)
attr_name = "IN_" + data2.name
attr_name = _append_attr_suffix(attr_name)
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING1)
attr_name = "IN_" + data3.name
attr_name = _append_attr_suffix(attr_name)
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING2)
def test_process_mesh(self):
mesh1 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=MESH)
mesh2 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=mesh1)
mesh3 = dist.ProcessMesh([[0, 1], [2, 3]], parent=mesh1)
mesh4 = dist.ProcessMesh([[2, 3], [4, 5]], parent=mesh1)
self.assertEqual(MESH.parent, None)
self.assertEqual(mesh1.parent, MESH)
self.assertEqual(mesh1._desc.parent, MESH._id)
self.assertEqual(mesh3.parent, mesh1)
self.assertEqual(mesh4.parent, mesh1)
self.assertEqual(mesh1, mesh2)
self.assertNotEqual(mesh3, mesh4)
self.assertEqual(mesh2._id, mesh2._desc.id)
self.assertEqual(mesh3.topology, mesh3._desc.topology)
self.assertEqual(mesh3.topology, [2, 2])
self.assertEqual(mesh3.process_group, [0, 1, 2, 3])
self.assertEqual(mesh4.process_group, mesh4._desc.process_group)
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1)
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertTrue(data2_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2)
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertTrue(data3_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dist_add = dist.shard_op(paddle.add)
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertFalse(data2_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertFalse(data3_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
if __name__ == '__main__':
......
......@@ -28,15 +28,14 @@ from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
_global_process_mesh2 = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MLPLayer(nn.Layer):
......@@ -62,20 +61,43 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh2,
dim_mapping=[1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh2,
"dims_mapping": [1, -1]
})
out = self.norm(input)
out = self.linear0(out)
......@@ -99,10 +121,18 @@ def mlp_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -118,8 +148,7 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -127,18 +156,15 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
......@@ -147,18 +173,16 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_mlp_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
......@@ -167,61 +191,59 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_misc(self):
# import pdb
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1], [2, 3]], parent=ROOT_MESH)
global _global_process_mesh2
_global_process_mesh2 = auto.ProcessMesh(
mesh=[[4, 5], [6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
# pdb.set_trace()
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
self.assertTrue(dist_context.validate_dist_attr_for_program())
# def test_mlp_misc(self):
# # import pdb
# global _global_parallel_strategy
# _global_parallel_strategy = "pp"
# global _global_process_mesh
# _global_process_mesh = auto.ProcessMesh(
# mesh=[[0, 1], [2, 3]])
# global _global_process_mesh2
# _global_process_mesh2 = auto.ProcessMesh(
# mesh=[[4, 5], [6, 7]])
# train_program = static.Program()
# start_program = static.Program()
# dist_context = DistributedContext()
# train_program, start_program = mlp_pretrain_forward(train_program,
# start_program)
# # pdb.set_trace()
# complete_train_program = auto.complete_annotation(train_program,
# dist_context)
dist_context.finalize_distributed_attr_for_program(
complete_train_program)
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
for block in complete_train_program.blocks:
for tensor in block.vars.values():
desc = tensor.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
attr_name = append_distributed_attr_suffix("dim_mapping")
self.assertIsNotNone(desc.has_attr(attr_name))
for op in block.ops:
desc = op.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
set_default_distributed_context(dist_context)
self.assertTrue("dist_attr" in str(complete_train_program))
with unittest.mock.patch(
"sys.stdout", new_callable=StringIO) as mock_stdout:
print_program_with_distributed_attr(complete_train_program)
self.assertIsNotNone(mock_stdout.getvalue())
# # print_program_with_dist_attr(complete_train_program,
# # dist_context)
# dist_context.finalize_distributed_attr_for_program(
# complete_train_program)
# from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
# for block in complete_train_program.blocks:
# for tensor in block.vars.values():
# desc = tensor.desc
# attr_name = append_distributed_attr_suffix("mesh_id")
# self.assertIsNotNone(desc.has_attr(attr_name))
# attr_name = append_distributed_attr_suffix("dims_mapping")
# self.assertIsNotNone(desc.has_attr(attr_name))
# for op in block.ops:
# desc = op.desc
# attr_name = append_distributed_attr_suffix("mesh_id")
# self.assertIsNotNone(desc.has_attr(attr_name))
# for tensor_name in desc.input_arg_names():
# attr_name = append_distributed_attr_suffix("IN_" +
# tensor_name)
# self.assertIsNotNone(desc.has_attr(attr_name))
# for tensor_name in desc.output_arg_names():
# attr_name = append_distributed_attr_suffix("OUT_" +
# tensor_name)
# self.assertIsNotNone(desc.has_attr(attr_name))
# set_default_distributed_context(dist_context)
# self.assertTrue("dist_attr" in str(complete_train_program))
# with unittest.mock.patch(
# "sys.stdout", new_callable=StringIO) as mock_stdout:
# print_program_with_dist_attr(complete_train_program)
# self.assertIsNotNone(mock_stdout.getvalue())
class AttentionLayer(nn.Layer):
......@@ -262,10 +284,18 @@ class AttentionLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -276,18 +306,42 @@ class AttentionLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -320,12 +374,18 @@ class AttentionLayer(nn.Layer):
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
return out
......@@ -357,8 +417,7 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -366,18 +425,15 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
......@@ -386,18 +442,16 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
......@@ -406,11 +460,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
class DecoderLayer(nn.Layer):
......@@ -486,10 +538,18 @@ class DecoderLayer(nn.Layer):
def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
......@@ -497,13 +557,17 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
......@@ -521,18 +585,42 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -566,12 +654,18 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual
residual = embeddings + self.dropout2(out)
......@@ -586,14 +680,30 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual
final = residual + self.dropout3(out3)
......@@ -631,8 +741,7 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -640,18 +749,15 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
......@@ -660,18 +766,16 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_decoder_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
......@@ -680,11 +784,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
if __name__ == "__main__":
......
......@@ -32,13 +32,12 @@ from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MultiHeadAttention(nn.Layer):
......@@ -108,10 +107,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
......@@ -145,19 +152,35 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -238,12 +261,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out]
if self.need_weights:
......@@ -411,17 +440,33 @@ class TransformerDecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2(
# self.linear2(F.gelu(
......@@ -485,13 +530,17 @@ class GPTEmbeddings(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
......@@ -717,10 +766,18 @@ def gpt_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
gpt = GPTModel(
vocab_size=32768,
......@@ -753,8 +810,7 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
......@@ -763,18 +819,15 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
train_program = static.Program()
start_program = static.Program()
......@@ -783,18 +836,16 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_gpt_dp_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
......@@ -803,11 +854,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
self.assertTrue(dist_context.validate_dist_attr_for_program())
if __name__ == "__main__":
......
......@@ -23,21 +23,19 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
_global_process_mesh = auto.ProcessMesh(
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
NUM_RANKS = 8
STAGE_0_CNT = 5
STAGE_1_CNT = 10
......@@ -70,9 +68,13 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if self.is_distributed:
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1])
self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]})
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1])
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]})
out = self.norm(input)
out = self.linear0(out)
......@@ -120,8 +122,14 @@ def mlp_forward(train_program, start_program, is_distributed=True):
name="label", shape=[batch_size, 1], dtype='float32')
if is_distributed:
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1])
auto.shard_tensor(
input,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -137,8 +145,6 @@ def mlp_forward(train_program, start_program, is_distributed=True):
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
......
......@@ -29,19 +29,17 @@ from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group
from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
def get_programs(annotated_func):
......@@ -49,7 +47,7 @@ def get_programs(annotated_func):
start_program = static.Program()
dist_context = DistributedContext()
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
......@@ -95,9 +93,8 @@ def initialization_check(mode, dist_context, dist_startup_prog,
serial_startup_prog, var_need_broadcast, process_mesh,
mp_parallel_axis, dp_parallel_axis):
if 'mp' in mode:
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis,
3)
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
mp_ring_id = new_process_group(group_ranks).id
broadcast_ops = [
op for op in dist_startup_prog.global_block().ops
......@@ -110,9 +107,8 @@ def initialization_check(mode, dist_context, dist_startup_prog,
return False
if 'dp' in mode:
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis,
3)
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
dp_ring_id = new_process_group(group_ranks).id
nparam = len(serial_startup_prog.all_parameters())
nbroadcast_dp = len([
......@@ -137,22 +133,21 @@ def initialization_check(mode, dist_context, dist_startup_prog,
def get_input_var_dist_attr(op, main_program, dist_context):
varname = op.desc.input_arg_names()
var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var)
dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
return dist_attr
def get_output_var_dist_attr(op, main_program, dist_context):
varname = op.desc.output_arg_names()
var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var)
dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
return dist_attr
def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
equal = True
if serial_dist_attr.get_process_mesh() != dist_attr.get_process_mesh() or \
serial_dist_attr.is_parameter() != dist_attr.is_parameter() or \
serial_dist_attr.get_dims_mapping() != dist_attr.get_dims_mapping():
if serial_dist_attr.process_mesh != dist_attr.process_mesh or \
serial_dist_attr.dims_mapping != dist_attr.dims_mapping:
equal = False
return equal
......@@ -161,36 +156,33 @@ def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
dist_op_idx):
equal = True
# get serial op's process_mesh and impl_idx
serial_op_dist_attr = dist_context.get_op_distributed_attr_for_program(
serial_op)
serial_process_mesh = serial_op_dist_attr.get_process_mesh()
serial_impl_idx = serial_op_dist_attr.get_impl_idx()
serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
serial_process_mesh = serial_op_dist_attr.process_mesh
serial_impl_idx = serial_op_dist_attr.impl_idx
# check dist_attr between serial op and dist op
for i in dist_op_idx:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
dist_ops[i])
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
for in_varname in dist_ops[i].desc.input_arg_names():
in_var = dist_main_prog.global_block().var(in_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_dims_mapping = tensor_dist_attr.dims_mapping
in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
in_varname)
if tensor_dims_mapping != in_var_dims_mapping:
equal = False
for out_varname in dist_ops[i].desc.output_arg_names():
out_var = dist_main_prog.global_block().var(out_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
tensor_dims_mapping = tensor_dist_attr.dims_mapping
out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_varname)
if tensor_dims_mapping != out_var_dims_mapping:
equal = False
dist_op_process_mesh = op_dist_attr.get_process_mesh()
dist_op_impl_idx = op_dist_attr.get_impl_idx()
dist_op_process_mesh = op_dist_attr.process_mesh
dist_op_impl_idx = op_dist_attr.impl_idx
if serial_op.desc.id() == dist_ops[i].desc.id() or \
serial_process_mesh != dist_op_process_mesh or \
serial_impl_idx != dist_op_impl_idx:
......@@ -242,13 +234,13 @@ def distributed_attr_check_for_program(dist_main_prog, dist_context):
have_dist_attr = True
for block in dist_main_prog.blocks:
for tensor in block.vars.values():
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor)
if var_dist_attr is None:
have_dist_attr = False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is None:
have_dist_attr = False
......@@ -278,21 +270,43 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input)
out = self.linear0(out)
......@@ -316,10 +330,18 @@ def mlp_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -335,8 +357,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -372,8 +393,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -437,7 +457,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -535,10 +555,18 @@ class AttentionLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -549,18 +577,42 @@ class AttentionLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -593,12 +645,18 @@ class AttentionLayer(nn.Layer):
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
return out
......@@ -630,8 +688,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -666,8 +723,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -735,7 +791,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -871,10 +927,18 @@ class DecoderLayer(nn.Layer):
def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
......@@ -882,13 +946,17 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
......@@ -906,18 +974,42 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -951,17 +1043,25 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
else:
auto.shard_tensor(
self.out_proj.weight,
_global_process_mesh,
dim_mapping=[-1, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
# Add residual
residual = embeddings + self.dropout2(out)
......@@ -976,14 +1076,30 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual
final = residual + self.dropout3(out3)
......@@ -1022,7 +1138,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
......@@ -1105,7 +1221,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
_global_parallel_strategy = "None"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
......
......@@ -32,14 +32,13 @@ from paddle.distributed import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group
from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static()
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
_global_parallel_strategy = None
_global_process_mesh = None
......@@ -61,24 +60,27 @@ def is_valid_completed_program(dist_context, program):
ops = program.global_block().ops
vars_ = program.list_vars()
for op in ops:
op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op)
op_dist_attrs = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attrs == None:
return False
if op_dist_attrs.get_process_mesh == None:
if op_dist_attrs.process_mesh == None:
return False
if None in op_dist_attrs._dims_mapping.values():
for tensor_dist_attr in op_dist_attrs.inputs_dist_attrs.values():
if None == tensor_dist_attr.dims_mapping:
return False
for tensor_dist_attr in op_dist_attrs.outputs_dist_attrs.values():
if None == tensor_dist_attr.dims_mapping:
return False
for var in vars_:
var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program(
var)
var_dist_attrs = dist_context.get_tensor_dist_attr_for_program(var)
if var_dist_attrs == None:
return False
elif var_dist_attrs.get_process_mesh == None:
elif var_dist_attrs.process_mesh == None:
return False
elif var_dist_attrs.get_dims_mapping == None:
elif var_dist_attrs.dims_mapping == None:
return False
return True
......@@ -151,10 +153,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
......@@ -188,19 +198,35 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -281,12 +307,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out]
if self.need_weights:
......@@ -454,17 +486,33 @@ class TransformerDecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2(
# self.linear2(F.gelu(
......@@ -528,13 +576,17 @@ class GPTEmbeddings(nn.Layer):
if _global_parallel_strategy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
......@@ -760,10 +812,18 @@ def gpt_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
gpt = GPTModel(
vocab_size=32768,
......@@ -798,12 +858,12 @@ class TestGPTPartitioner(unittest.TestCase):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
dist_context.set_process_mesh(_global_process_mesh)
dist_context.process_mesh = _global_process_mesh
train_program, start_program, loss = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
......@@ -833,7 +893,7 @@ class TestGPTPartitioner(unittest.TestCase):
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog,
auto_parallel_startup_prog)
from paddle.distributed.auto_parallel.context import set_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
fw.write(str(auto_parallel_main_prog))
......@@ -877,14 +937,12 @@ class TestGPTPartitioner(unittest.TestCase):
mp_parallel_axis = 1
dp_parallel_axis = 0
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, mp_parallel_axis,
3)
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
mp_ring_id = new_process_group(group_ranks).id
group_ranks = _get_comm_group(process_mesh.process_group,
process_mesh.topology, dp_parallel_axis,
3)
group_ranks = _get_comm_group(
process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
dp_ring_id = new_process_group(group_ranks).id
tensor_parallel_allreduce_vars = sorted([
......
......@@ -22,16 +22,16 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP
from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
PP_MESH_0 = None
PP_MESH_1 = None
......@@ -57,16 +57,30 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input)
out = self.linear0(out)
......@@ -88,12 +102,32 @@ def mlp_forward(train_program, start_program):
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -108,8 +142,6 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
......@@ -136,22 +168,21 @@ def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check):
has_dist_attr = True
vars = dist_main_prog.global_block().vars
op_dist_attr = dist_context.get_op_distributed_attr_for_program(
op_need_check)
if not op_dist_attr or not op_dist_attr.get_process_mesh():
op_dist_attr = dist_context.get_op_dist_attr_for_program(op_need_check)
if not op_dist_attr or not op_dist_attr.process_mesh:
has_dist_attr = False
for var_name in op_need_check.input_arg_names:
if not op_dist_attr.get_input_dims_mapping(var_name) or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh():
not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \
not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh:
has_dist_attr = False
break
if has_dist_attr:
for var_name in op_need_check.output_arg_names:
if not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh():
if not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \
not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh:
has_dist_attr = False
break
......@@ -162,6 +193,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
send_result = False
recv_result = False
ops = dist_main_prog.global_block().ops
if rank_id == 0:
for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
......@@ -217,7 +249,7 @@ def check_initialization_for_dp(dist_startup_prog):
class TestMLPReshard(unittest.TestCase):
def test_complete_backward_annotation(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -231,6 +263,7 @@ class TestMLPReshard(unittest.TestCase):
if op.type == "gelu_grad":
op_need_check = op
break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr
self.assertTrue(
......@@ -241,11 +274,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -253,9 +286,10 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
for key in list(PROCESS_GROUP_MAP.keys()):
del PROCESS_GROUP_MAP[key]
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -267,7 +301,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
......@@ -22,18 +22,17 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
_global_process_mesh = auto.ProcessMesh(
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
class MLPLayer(nn.Layer):
......@@ -55,8 +54,14 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1])
auto.shard_tensor(
self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]})
out = self.norm(input)
out = self.linear0(out)
......@@ -77,8 +82,14 @@ def mlp_forward(train_program, start_program):
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1])
auto.shard_tensor(
input,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -94,7 +105,7 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
......@@ -156,10 +167,8 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id)
print(dist_main_prog)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
print(dist_main_prog)
print(dist_startup_prog)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
......@@ -22,17 +22,17 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
_global_parallel_strategy = "mp_pp"
ROOT_MESH = auto.ProcessMesh([[0, 1], [2, 3]])
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([0, 1], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([2, 3], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]])
PP_MESH_0 = auto.ProcessMesh([0, 1])
PP_MESH_1 = auto.ProcessMesh([2, 3])
class MLPLayer(nn.Layer):
......@@ -64,10 +64,21 @@ class MLPLayer(nn.Layer):
def forward(self, input):
auto.shard_tensor(
self.word_embeddings.weight, PP_MESH_0, dim_mapping=[0, -1])
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 0])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[0, -1])
auto.shard_tensor(self.linear2.weight, PP_MESH_1, dim_mapping=[0, -1])
self.word_embeddings.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
auto.shard_tensor(
self.linear2.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
w_out = self.word_embeddings(input)
out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True)
......@@ -88,8 +99,13 @@ def mlp_forward(train_program, start_program):
label = static.data(
name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
auto.shard_tensor(
input, dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -105,7 +121,7 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(train_program,
startup_program)
......@@ -198,19 +214,41 @@ class TestMLPReshard(unittest.TestCase):
def test_allgather(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 3], parent=ROOT_MESH)
process_mesh = auto.ProcessMesh(mesh=[0, 3])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, process_mesh, dim_mapping=[0, -1])
x = auto.shard_tensor(
x,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, process_mesh, dim_mapping=[-1, -1])
y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
x.name: [-1, -1],
w.name: [-1, -1]
}, **{"x": x,
"y": w})[0]
w = auto.shard_tensor(
w,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [-1, -1]
})
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
# x.name: [-1, -1],
# w.name: [-1, -1]
# }, **{"x": x,
# "y": w})[0]
y = paddle.distributed.shard_op(
paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)[0]
rank_id = 0
dist_context = DistributedContext()
......
......@@ -26,16 +26,15 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import get_default_distributed_context
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import new_process_group
from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0])
class MLPLayer(nn.Layer):
......@@ -59,16 +58,30 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else:
auto.shard_tensor(
self.linear0.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
self.linear1.weight, _global_process_mesh,
dim_mapping=[-1, -1])
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input)
out = self.linear0(out)
......@@ -90,12 +103,32 @@ def mlp_forward(train_program, start_program):
name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1])
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1])
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer(
hidden_size=hidden_size,
......@@ -168,7 +201,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = None
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH)
_global_process_mesh = auto.ProcessMesh(mesh=[0])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册