未验证 提交 a02532b5 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Improve the interface and the underlying mechanisms (#36617)

* default dist op

* add dist_attr for dist op

* add unitest

* update inputname

* update function name

* add unitest

* update CMakeLists.txt for CI

* fix dis_matmul

* fix compile error

* update matmul to matmul_v2

* unify api

* unify api

* todo

* update distop forward func

* update distop forward func

* auto parallel backward

* update dist op

* autoparallel backward

* add backward for embedding

* temp1

* temp2

* temp3

* temp4

* backward done1

* backward done2

* backward done3

* dist embedding remove mp mode

* dist matmul remove mp mode

* update dist embedding
『

* dist op init1

* dist op init 2

* update unitest

* context remove parallel mode

* partitioner remove parallel mode

* update unitest

* a more general method to support varying mesh in pipeline parallel

* support varying mesh in pipeline parallel

* embedding support varying mesh in pipeline parallel

* matmul support varying mesh in pipeline parallel

* default dist op support varying mesh in pipeline parallel

* dist attribute for startup program

* default dist op support varying mesh in pipeline parallel 2

* partitoner support varying mesh in pipeline parallel

* revise logic for auto compeletion

* revise framework.py

* revise reshard unitest

* revise unitest for parallelize

* chmod

* fixed bug for dist embedding name mapping

* Improve the interface and the underlying mechanisms of auto parallel

* revise completion for backward

* revise completion for update

* revise completion for update

* update unitest

* chmod

* bugfix for grad_op output var's mesh

* Modify codes for pr 36744

* Remove unnecessary comments in framework.py

* Remove unnecessary comments in completion.py
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: NJZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com>
上级 2e40cfb5
...@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401 ...@@ -43,10 +43,6 @@ from .collective import wait # noqa: F401
from .auto_parallel import shard_op # noqa: F401 from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401 from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401
from .fleet import BoxPSDataset # noqa: F401 from .fleet import BoxPSDataset # noqa: F401
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
from .interface import shard_tensor # noqa: F401 from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401 from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401 from .process_mesh import ProcessMesh
from .interface import set_offload_device # noqa: F401 # from .interface import set_shard_mask # noqa: F401
from .interface import set_pipeline_stage # noqa: F401 # from .interface import set_offload_device # noqa: F401
from .interface import ProcessMesh # noqa: F401 # from .interface import set_pipeline_stage # noqa: F401
# from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401 from .completion import complete_annotation # noqa: F401
from .completion import complete_backward_annotation # noqa: F401 from .completion import complete_backward_annotation # noqa: F401
from .reshard import reshard # noqa: F401 from .reshard import reshard # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import core
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
if self.get_owner_tensor().type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
if "read" in self.get_owner_op().type:
return True
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
...@@ -20,10 +20,13 @@ from paddle.fluid import framework ...@@ -20,10 +20,13 @@ from paddle.fluid import framework
from .utils import compute_compatible_process_mesh from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_distributed_attr from .utils import print_program_with_dist_attr
from .context import get_default_distributed_context
from .operators import find_best_compatible_distributed_operator_impl from .operators import find_best_compatible_distributed_operator_impl
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
...@@ -43,36 +46,35 @@ def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): ...@@ -43,36 +46,35 @@ def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
process meshes are compatible for now. process meshes are compatible for now.
""" """
changed = False changed = False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"): if tensor_dist_attr.is_annotated("process_mesh"):
return changed return changed
tensor_process_mesh = tensor_dist_attr.get_process_mesh() tensor_process_mesh = tensor_dist_attr.process_mesh
if fwd: if fwd:
inputs_process_meshes = [] inputs_process_meshes = []
for pred_op_node in tensor_node.inputs: for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None: if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node) pred_op_node)
op_process_mesh = op_dist_attr.get_process_mesh() op_process_mesh = op_dist_attr.process_mesh
inputs_process_meshes.append(op_process_mesh) inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh( compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes) inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None: if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh) tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True changed = True
else: else:
outputs_process_meshes = [] outputs_process_meshes = []
for succ_op_node in tensor_node.outputs: for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None: if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node) succ_op_node)
op_process_mesh = op_dist_attr.get_process_mesh() op_process_mesh = op_dist_attr.process_mesh
outputs_process_meshes.append(op_process_mesh) outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh( compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes) outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None: if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh) tensor_dist_attr.process_mesh = compatible_process_mesh
changed = True changed = True
return changed return changed
...@@ -84,43 +86,47 @@ def update_op_node_process_mesh(dist_context, op_node, fwd=True): ...@@ -84,43 +86,47 @@ def update_op_node_process_mesh(dist_context, op_node, fwd=True):
process meshes are compatible for now. process meshes are compatible for now.
""" """
changed = False changed = False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"): if op_dist_attr.is_annotated("process_mesh"):
return changed return changed
op_process_mesh = op_dist_attr.get_process_mesh() op_process_mesh = op_dist_attr.process_mesh
if fwd: if fwd:
inputs_process_meshes = [] inputs_process_meshes = []
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh() tensor_process_mesh = tensor_dist_attr.process_mesh
inputs_process_meshes.append(tensor_process_mesh) inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh( compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes) inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None: if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh) op_dist_attr.process_mesh = compatible_process_mesh
changed = True changed = True
else: else:
outputs_process_meshes = [] outputs_process_meshes = []
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh() tensor_process_mesh = tensor_dist_attr.process_mesh
outputs_process_meshes.append(tensor_process_mesh) outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh( compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes) outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None: if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh) op_dist_attr.process_mesh = compatible_process_mesh
changed = True changed = True
return changed return changed
def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): def update_op_dims_mapping_by_default_dist_impl(dist_context, op_node):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" """Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
# The following statement will be replaced by a more elegent way # The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice": if op_desc.type() == "shape" or op_desc.type() == "slice":
return False return False
...@@ -130,7 +136,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): ...@@ -130,7 +136,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
xshape_arg_names = op_desc.output("XShape") xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = [] batch_dim_mappings = []
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name): serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1: if len(dims_mapping) > 1:
...@@ -140,7 +147,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): ...@@ -140,7 +147,8 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
.format(op_desc.type(), idx, mapping) .format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0]) batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name): serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
...@@ -164,14 +172,16 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): ...@@ -164,14 +172,16 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping." assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names(): for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name): serial_tensor = dist_op.get_serial_input(arg_name)
if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]: if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping dims_mapping[0] = compatible_dim_mapping
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name): serial_tensor = dist_op.get_serial_output(arg_name)
if serial_tensor.is_parameter:
continue continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names: if arg_name not in xshape_arg_names:
...@@ -186,10 +196,13 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): ...@@ -186,10 +196,13 @@ def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
return changed return changed
def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr): def update_op_dims_mapping_by_elementwise_like_dist_impl(dist_context, op_node):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting).""" """Element-wise operator can be sharded in any way (but should take care of broadcasting)."""
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
op_dist_attr = dist_context.get_op_dist_attr_for_graph(op_node)
input_arg_names = op_desc.input_arg_names() input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {} input_dims_mapping_dict = {}
...@@ -258,12 +271,11 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -258,12 +271,11 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
# Skip reader tensor # Skip reader tensor
if tensor_desc.type() == core.VarDesc.VarType.READER: if tensor_desc.type() == core.VarDesc.VarType.READER:
return False return False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(tensor_node)
tensor_node)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"): if tensor_dist_attr.is_annotated("dims_mapping"):
return False return False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dims_mapping = tensor_dist_attr.dims_mapping
if fwd: if fwd:
dims_mapping_list = [] dims_mapping_list = []
for pred_op_node in tensor_node.inputs: for pred_op_node in tensor_node.inputs:
...@@ -272,7 +284,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -272,7 +284,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or pred_op_node.op().type() == "create_double_buffer_reader" \ or pred_op_node.op().type() == "create_double_buffer_reader" \
or pred_op_node.op().type() == "read": or pred_op_node.op().type() == "read":
continue continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
pred_op_node) pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping( op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()) tensor_desc.name())
...@@ -282,7 +294,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -282,7 +294,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list) dims_mapping_list)
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
else: else:
dims_mapping_list = [] dims_mapping_list = []
...@@ -292,7 +304,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -292,7 +304,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
or succ_op_node.op().type() == "create_double_buffer_reader" \ or succ_op_node.op().type() == "create_double_buffer_reader" \
or succ_op_node.op().type() == "read": or succ_op_node.op().type() == "read":
continue continue
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
succ_op_node) succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping( op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()) tensor_desc.name())
...@@ -302,7 +314,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): ...@@ -302,7 +314,7 @@ def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
dims_mapping_list) dims_mapping_list)
if (compatible_dims_mapping is not None) and \ if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping): (compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) tensor_dist_attr.dims_mapping = compatible_dims_mapping
changed = True changed = True
return changed return changed
...@@ -317,7 +329,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -317,7 +329,8 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
or op_desc.type() == "create_double_buffer_reader" \ or op_desc.type() == "create_double_buffer_reader" \
or op_desc.type() == "read": or op_desc.type() == "read":
return False return False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) dist_op = dist_context.get_dist_op_for_graph(op_node)
op_dist_attr = dist_op.dist_attr
if fwd: if fwd:
for tensor_node in op_node.inputs: for tensor_node in op_node.inputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
...@@ -327,9 +340,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -327,9 +340,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_input_dims_mapping( if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
continue continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_input_dims_mapping( op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
...@@ -341,26 +354,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -341,26 +354,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=True) op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None: if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
# This statement will be replaced by a good way # This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr): if op_dist_impl.is_compatible(dist_op):
op_dist_attr.set_impl_idx(op_dist_impl_idx) op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()): elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr) dist_context, op_node)
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_impl_idx(-1) op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else: else:
dim_changed = update_op_dims_mapping_by_default_dist_impl( dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr) dist_context, op_node)
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_impl_idx(-2) op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
else: else:
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
...@@ -370,9 +386,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -370,9 +386,9 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
if op_dist_attr.is_annotated_output_dims_mapping( if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()): tensor_desc.name()):
continue continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dims_mapping = tensor_dist_attr.dims_mapping
op_dims_mapping = op_dist_attr.get_output_dims_mapping( op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name()) tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping( compatible_dims_mapping = compute_compatible_dims_mapping(
...@@ -384,26 +400,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True): ...@@ -384,26 +400,29 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=False) op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None: if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
# This statement will be replaced by a good way # This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr): if op_dist_impl.is_compatible(dist_op):
op_dist_attr.set_impl_idx(op_dist_impl_idx) op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()): elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr) dist_context, op_node)
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_impl_idx(-1) op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else: else:
dim_changed = update_op_dims_mapping_by_default_dist_impl( dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr) dist_context, op_node)
if dim_changed: if dim_changed:
changed = True changed = True
op_dist_attr.set_impl_idx(-2) op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
return changed return changed
...@@ -421,18 +440,20 @@ def complete_annotation(program, dist_context=None): ...@@ -421,18 +440,20 @@ def complete_annotation(program, dist_context=None):
# Use the default distribted context for completeion if there is no one # Use the default distribted context for completeion if there is no one
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
dist_context.serial_program = program
else:
dist_context.serial_program = program
# Initialize distributed attributes for all var and op node in program # print_program_with_dist_attr(program, dist_context)
dist_context.initialize_distributed_attr_for_program(program)
# Convert program to graph # Initialize distributed attributes for all var and op node in program
graph = framework.IrGraph(core.Graph(program.desc)) dist_context.init_dist_attr_for_program()
# Initialize distributed attributes for all var and op node in graph # Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph) dist_context.init_dist_attr_for_graph()
# Complete process mesh for each node # Complete process mesh for each node
all_nodes = list(graph.all_nodes()) all_nodes = list(dist_context.serial_graph.all_nodes())
def sort_key_fun(node): def sort_key_fun(node):
first = -1 first = -1
...@@ -498,27 +519,27 @@ def complete_annotation(program, dist_context=None): ...@@ -498,27 +519,27 @@ def complete_annotation(program, dist_context=None):
is_wrong = False is_wrong = False
for node in all_nodes: for node in all_nodes:
if node.is_var() and node.var() is not None: if node.is_var() and node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
node) node)
if tensor_dist_attr.get_process_mesh() is None: if tensor_dist_attr.process_mesh is None:
msg_str = "" msg_str = ""
for op_node in node.inputs: for op_node in node.inputs:
if op_node.op() is not None: if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node) op_node)
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
op_node.op().type(), op_node.op().type(),
op_dist_attr.get_process_mesh()) op_dist_attr.process_mesh)
else: else:
msg_str += "{} [{}], ".format(op_node.name(), msg_str += "{} [{}], ".format(op_node.name(),
None) None)
for op_node in node.outputs: for op_node in node.outputs:
if op_node.op() is not None: if op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(
op_node) op_node)
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
op_node.op().type(), op_node.op().type(),
op_dist_attr.get_process_mesh()) op_dist_attr.process_mesh)
else: else:
msg_str += "{} [{}], ".format(op_node.name(), msg_str += "{} [{}], ".format(op_node.name(),
None) None)
...@@ -527,27 +548,26 @@ def complete_annotation(program, dist_context=None): ...@@ -527,27 +548,26 @@ def complete_annotation(program, dist_context=None):
is_wrong = True is_wrong = True
print(msg_str) print(msg_str)
if node.is_op() and node.op() is not None: if node.is_op() and node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph( op_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
node) if op_dist_attr.process_mesh is None:
if op_dist_attr.get_process_mesh() is None:
msg_str = "" msg_str = ""
for tensor_node in node.inputs: for tensor_node in node.inputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
tensor_node.var().name(), tensor_node.var().name(),
tensor_dist_attr.get_process_mesh()) tensor_dist_attr.process_mesh)
else: else:
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
tensor_node.name(), None) tensor_node.name(), None)
for tensor_node in node.outputs: for tensor_node in node.outputs:
if tensor_node.var() is not None: if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_graph(
tensor_node) tensor_node)
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
tensor_node.var().name(), tensor_node.var().name(),
tensor_dist_attr.get_process_mesh()) tensor_dist_attr.process_mesh)
else: else:
msg_str += "{} [{}], ".format( msg_str += "{} [{}], ".format(
tensor_node.name(), None) tensor_node.name(), None)
...@@ -592,11 +612,14 @@ def complete_annotation(program, dist_context=None): ...@@ -592,11 +612,14 @@ def complete_annotation(program, dist_context=None):
reach_fix_point = True reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program # Copy the corresponding distributed attribute from graph to program
dist_context.copy_distribute_attr_from_graph_to_program(graph, program) dist_context.copy_dist_attr_from_graph_to_program()
dist_context.clear_distributed_attr_for_graph() dist_context.clear_dist_info_for_graph()
# Do the validation check and amend some completion # Do the validation check and amend some completion
dist_context.amend_distributed_attr_for_program() dist_context.amend_dist_attr_for_program()
# print_program_with_dist_attr(program, dist_context)
dist_context.validate_dist_attr_for_program()
return program return program
...@@ -636,7 +659,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -636,7 +659,7 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ops = list(auto_parallel_main_prog.global_block().ops) ops = list(auto_parallel_main_prog.global_block().ops)
vars = auto_parallel_main_prog.global_block().vars vars = auto_parallel_main_prog.global_block().vars
dist_op_helper = dist_context.get_dist_op_helper() dist_op_context = dist_context.dist_op_context
for idx in range(first_backward_op_idx, len(ops)): for idx in range(first_backward_op_idx, len(ops)):
...@@ -658,45 +681,42 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -658,45 +681,42 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
forward_var = vars[forward_var_name] forward_var = vars[forward_var_name]
# TODO complete other attribte for grad var # TODO complete other attribte for grad var
tensor_attr = TensorDistributedAttribute(grad_var, dist_context) tensor_dist_attr = TensorDistributedAttribute()
process_mesh = dist_context.get_tensor_distributed_attr_for_program( process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).get_process_mesh() forward_var).process_mesh
dims_mapping = dist_context.get_tensor_distributed_attr_for_program( dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).get_dims_mapping() forward_var).dims_mapping
tensor_attr.set_dims_mapping(dims_mapping) tensor_dist_attr.dims_mapping = dims_mapping
tensor_attr.set_process_mesh(process_mesh) tensor_dist_attr.process_mesh = process_mesh
dist_context.set_tensor_distributed_attr_for_program(grad_var, dist_context.set_tensor_dist_attr_for_program(grad_var,
tensor_attr) tensor_dist_attr)
op_attr = OperatorDistributedAttribute(ops[idx], dist_context) op_dist_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_dist_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, dims_mapping) op_dist_attr.set_output_dims_mapping(grad_var.name, dims_mapping)
dist_context.set_op_distributed_attr_for_program(ops[idx], op_attr) dist_context.set_op_dist_attr_for_program(ops[idx], op_dist_attr)
continue continue
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in gradopidx2opidx # xxx_grad op will have a corresponding forward op in gradopidx2opidx
grad_op = ops[idx] grad_op = ops[idx]
if grad_op.desc.id() in dist_op_helper.gradopidx2opidx: if grad_op.desc.id() in dist_op_context.gradopidx2opidx:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id( forward_op = _get_op_by_id(
ops[:first_backward_op_idx], ops[:first_backward_op_idx],
dist_op_helper.gradopidx2opidx[grad_op.desc.id()]) dist_op_context.gradopidx2opidx[grad_op.desc.id()])
assert forward_op is not None assert forward_op is not None
# op dist attr # op dist attr
forward_op_attr = dist_context.get_op_distributed_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
forward_op_process_mesh = forward_op_attr.get_process_mesh() forward_op_process_mesh = forward_op_dist_attr.process_mesh
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_attr.set_process_mesh(forward_op_process_mesh) grad_op_dist_attr.process_mesh = forward_op_process_mesh
# var # var
for output_name in grad_op.desc.output_names(): for output_name in grad_op.desc.output_names():
assert len(grad_op.desc.output(output_name)) in [0, 1] assert len(grad_op.desc.output(output_name)) in [0, 1]
# if grad_op.type == "cast":
# input_name = "X"
# else:
if _is_grad_var_name(output_name): if _is_grad_var_name(output_name):
input_name = _get_forward_varname_from_grad_varname( input_name = _get_forward_varname_from_grad_varname(
output_name) output_name)
...@@ -711,39 +731,38 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -711,39 +731,38 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
if len(grad_op.desc.output(output_name)) == 1: if len(grad_op.desc.output(output_name)) == 1:
assert len(forward_op.desc.input(input_name)) == 1 assert len(forward_op.desc.input(input_name)) == 1
input_var = vars[forward_op.desc.input(input_name)[0]] input_var = vars[forward_op.desc.input(input_name)[0]]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var) input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name) input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping() ref_dims_mapping = input_var_dist_attr.dims_mapping
# tensor dist attr # tensor dist attr
output_var = vars[grad_op.desc.output(output_name)[0]] output_var = vars[grad_op.desc.output(output_name)[0]]
output_var_attr = TensorDistributedAttribute(output_var, output_var_dist_attr = TensorDistributedAttribute()
dist_context) output_var_dist_attr.dims_mapping = ref_dims_mapping
output_var_attr.set_dims_mapping(ref_dims_mapping) output_var_dist_attr.process_mesh = forward_op_process_mesh
output_var_attr.set_process_mesh(forward_op_process_mesh) dist_context.set_tensor_dist_attr_for_program(
dist_context.set_tensor_distributed_attr_for_program( output_var, output_var_dist_attr)
output_var, output_var_attr)
# op dist attr # op dist attr
grad_op_attr.set_output_dims_mapping(output_var.name, grad_op_dist_attr.set_output_dims_mapping(output_var.name,
ref_dims_mapping) ref_dims_mapping)
for input_name in grad_op.input_arg_names: for input_name in grad_op.input_arg_names:
input_var = vars[input_name] input_var = vars[input_name]
input_var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( input_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var) input_var)
assert input_var_dist_attr is not None, "[{}] has not dist attribute".format( assert input_var_dist_attr is not None, "[{}] has not dist attribute".format(
input_var.name) input_var.name)
ref_dims_mapping = input_var_dist_attr.get_dims_mapping() ref_dims_mapping = input_var_dist_attr.dims_mapping
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_var.name) input_var.name)
grad_op_attr.set_input_dims_mapping(input_name, grad_op_dist_attr.set_input_dims_mapping(input_name,
ref_dims_mapping) ref_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op, dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_attr) grad_op_dist_attr)
# only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx # only sum op for merge mutiple version grad has no a corresponding mapping in gradopidx2opidx
else: else:
...@@ -755,32 +774,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None): ...@@ -755,32 +774,31 @@ def complete_backward_annotation(auto_parallel_main_prog, dist_context=None):
ref_forward_var_name = _get_forward_varname_from_grad_varname( ref_forward_var_name = _get_forward_varname_from_grad_varname(
grad_op.output_arg_names[0]) grad_op.output_arg_names[0])
forward_var = vars[ref_forward_var_name] forward_var = vars[ref_forward_var_name]
ref_forward_var_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( ref_forward_var_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
forward_var).get_dims_mapping() forward_var).dims_mapping
ref_forward_var_process_mesh = dist_context.get_tensor_distributed_attr_for_program( ref_forward_var_process_mesh = dist_context.get_tensor_dist_attr_for_program(
forward_var).get_process_mesh() forward_var).process_mesh
# output # output
tensor_attr = TensorDistributedAttribute( tensor_dist_attr = TensorDistributedAttribute()
vars[grad_op.output_arg_names[0]], dist_context) tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
tensor_attr.set_dims_mapping(ref_forward_var_dims_mapping) tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
tensor_attr.set_process_mesh(ref_forward_var_process_mesh) dist_context.set_tensor_dist_attr_for_program(
dist_context.set_tensor_distributed_attr_for_program( vars[grad_op.output_arg_names[0]], tensor_dist_attr)
vars[grad_op.output_arg_names[0]], tensor_attr)
# op # op
grad_op_attr = OperatorDistributedAttribute(grad_op, dist_context) grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_attr.set_process_mesh(ref_forward_var_process_mesh) grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
for var_name in grad_op.input_arg_names: for var_name in grad_op.input_arg_names:
assert _get_forward_varname_from_grad_varname( assert _get_forward_varname_from_grad_varname(
var_name) == ref_forward_var_name var_name) == ref_forward_var_name
grad_op_attr.set_input_dims_mapping( grad_op_dist_attr.set_input_dims_mapping(
var_name, ref_forward_var_dims_mapping) var_name, ref_forward_var_dims_mapping)
grad_op_attr.set_output_dims_mapping(grad_op.output_arg_names[0], grad_op_dist_attr.set_output_dims_mapping(
ref_forward_var_dims_mapping) grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
dist_context.set_op_distributed_attr_for_program(grad_op, dist_context.set_op_dist_attr_for_program(grad_op,
grad_op_attr) grad_op_dist_attr)
def complete_update_annotation(auto_parallel_main_prog, dist_context): def complete_update_annotation(auto_parallel_main_prog, dist_context):
...@@ -808,39 +826,40 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -808,39 +826,40 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
param = vars[op.input("Param")[0]] param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]] grad_var = vars[op.input("Grad")[0]]
param_dist_attr = dist_context.get_tensor_distributed_attr_for_program( param_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param) param)
grad_dist_attr = dist_context.get_tensor_distributed_attr_for_program( grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
grad_var) grad_var)
assert param_dist_attr is not None assert param_dist_attr is not None
assert grad_dist_attr is not None assert grad_dist_attr is not None
assert param_dist_attr.get_dims_mapping( assert param_dist_attr.dims_mapping == grad_dist_attr.dims_mapping
) == grad_dist_attr.get_dims_mapping()
ref_process_mesh = dist_context.get_tensor_distributed_attr_for_program( ref_process_mesh = dist_context.get_tensor_dist_attr_for_program(
param).get_process_mesh() param).process_mesh
assert ref_process_mesh is not None assert ref_process_mesh is not None
ref_dims_mapping = dist_context.get_tensor_distributed_attr_for_program( ref_dims_mapping = dist_context.get_tensor_dist_attr_for_program(
param).get_dims_mapping() param).dims_mapping
assert ref_dims_mapping is not None assert ref_dims_mapping is not None
op_attr = OperatorDistributedAttribute(op, dist_context) op_dist_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(ref_process_mesh) op_dist_attr.process_mesh = ref_process_mesh
op_attr.set_input_dims_mapping(grad_var.name, ref_dims_mapping) op_dist_attr.set_input_dims_mapping(grad_var.name,
op_attr.set_input_dims_mapping(param.name, ref_dims_mapping) ref_dims_mapping)
op_attr.set_output_dims_mapping(param.name, ref_dims_mapping) op_dist_attr.set_input_dims_mapping(param.name,
ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(param.name,
ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]] learning_var = vars[op.input("LearningRate")[0]]
op_attr.set_input_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_attr.set_output_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_output_dims_mapping(learning_var.name, [-1])
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute(learning_var, var_dist_attr = TensorDistributedAttribute()
dist_context) var_dist_attr.process_mesh = ref_process_mesh
var_dist_attr.set_process_mesh(ref_process_mesh) var_dist_attr.dims_mapping = [-1]
var_dist_attr.set_dims_mapping([-1]) dist_context.set_tensor_dist_attr_for_program(learning_var,
dist_context.set_tensor_distributed_attr_for_program( var_dist_attr)
learning_var, var_dist_attr)
for input_name in op.desc.input_names(): for input_name in op.desc.input_names():
...@@ -853,24 +872,25 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -853,24 +872,25 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
assert len(op.desc.input(input_name)) == 1 assert len(op.desc.input(input_name)) == 1
input_var = vars[op.desc.input(input_name)[0]] input_var = vars[op.desc.input(input_name)[0]]
input_var_attr = TensorDistributedAttribute(input_var, input_var_attr = TensorDistributedAttribute()
dist_context)
if "Beta1Pow" in input_name or "Beta2Pow" in input_name: if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.set_dims_mapping([-1]) input_var_attr.dims_mapping = [-1]
op_attr.set_input_dims_mapping(input_var.name, [-1]) op_dist_attr.set_input_dims_mapping(input_var.name,
op_attr.set_output_dims_mapping(input_var.name, [-1]) [-1])
op_dist_attr.set_output_dims_mapping(input_var.name,
[-1])
else: else:
assert "Moment" in input_name assert "Moment" in input_name
input_var_attr.set_dims_mapping(ref_dims_mapping) input_var_attr.dims_mapping = ref_dims_mapping
op_attr.set_input_dims_mapping(input_var.name, op_dist_attr.set_input_dims_mapping(input_var.name,
ref_dims_mapping) ref_dims_mapping)
op_attr.set_output_dims_mapping(input_var.name, op_dist_attr.set_output_dims_mapping(input_var.name,
ref_dims_mapping) ref_dims_mapping)
input_var_attr.set_process_mesh(ref_process_mesh) input_var_attr.process_mesh = ref_process_mesh
dist_context.set_tensor_distributed_attr_for_program( dist_context.set_tensor_dist_attr_for_program(
input_var, input_var_attr) input_var, input_var_attr)
dist_context.set_op_distributed_attr_for_program(op, op_attr) dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
continue continue
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
from .interface import _g_process_mesh_map
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
def get_default_distributed_context():
global DEFAULT_DISTRIBUTED_CONTEXT
if DEFAULT_DISTRIBUTED_CONTEXT is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return DEFAULT_DISTRIBUTED_CONTEXT
def set_default_distributed_context(dist_context):
global DEFAULT_DISTRIBUTED_CONTEXT
DEFAULT_DISTRIBUTED_CONTEXT = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self):
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._tensor_distributed_attr_map_for_program = {}
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
self._get_dist_op_helper = DistOpHelper()
self._process_mesh = _g_process_mesh_map.get(0, None)
def is_initialized_for_program(self):
return self._is_initialized_for_program
def is_initialized_for_graph(self):
return self._is_initialized_for_graph
def get_tensor_distributed_attr_for_program(self, tensor):
tensor_id = tensor.desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get(
tensor_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr):
tensor_id = tensor.desc.id()
self._tensor_distributed_attr_map_for_program[
tensor_id] = tensor_dist_attr
def get_op_distributed_attr_for_program(self, op):
op_id = op.desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_program(self, op, op_dist_attr):
op_id = op.desc.id()
self._op_distributed_attr_map_for_program[op_id] = op_dist_attr
def get_tensor_distributed_attr_for_graph(self, tensor_node):
tensor_node_id = tensor_node.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get(
tensor_node_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_graph(self, tensor_node,
tensor_dist_attr):
tensor_node_id = tensor_node.id()
self._tensor_distributed_attr_map_for_graph[
tensor_node_id] = tensor_dist_attr
def get_op_distributed_attr_for_graph(self, op_node):
op_node_id = op_node.id()
op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def set_process_mesh(self, process_mesh):
self._process_mesh = process_mesh
def get_dist_op_helper(self):
return self._get_dist_op_helper
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
for block in program.blocks:
for tensor in block.vars.values():
# Since only tensors have distributed attributes, it's better to make sure var is a tensor
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is None:
tensor_dist_attr = TensorDistributedAttribute(tensor, self)
self._copy_distributed_attr_from_tensor_desc(
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
if tensor.type == core.VarDesc.VarType.READER:
tensor_dist_attr.set_shape([])
else:
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor_dist_attr.get_shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
tensor_dist_attr.mark_as_annotated("dims_mapping")
if isinstance(tensor, framework.Parameter):
tensor_dist_attr.mark_as_parameter()
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
op_dist_attr = OperatorDistributedAttribute(op, self)
self._copy_distributed_attr_from_op_desc(op.desc,
op_dist_attr)
self.set_op_distributed_attr_for_program(op, op_dist_attr)
# Default distributed implementation for all operators
# This will be updated during the completion prcess
op_dist_attr.set_impl_idx(-2)
if op_dist_attr.get_process_mesh() is not None:
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
if op.type == "create_py_reader" \
or tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_input_shape(tensor_name, [])
else:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1
for _ in range(
len(op_dist_attr.get_input_shape(tensor_name)))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_input_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
op_dist_attr.set_output_shape(tensor_name, [])
else:
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1
for _ in range(
len(
op_dist_attr.get_output_shape(tensor_name)))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_output_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
self._is_initialized_for_program = True
def finalize_distributed_attr_for_program(self, program):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before finalization."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is not None:
self._store_distributed_attr_to_tensor_desc(
tensor.desc, tensor_dist_attr)
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
self._store_distributed_attr_to_op_desc(op.desc,
op_dist_attr)
def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
attr_name = append_distributed_attr_suffix("dim_mapping")
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_dims_mapping(copied_dims_mapping)
attr_name = append_distributed_attr_suffix("mask")
if desc.has_attr(attr_name):
shard_mask = desc.attr(attr_name)
copied_shard_mask = copy.deepcopy(shard_mask)
dist_attr.set_shard_mask(copied_shard_mask)
attr_name = append_distributed_attr_suffix("offload_device")
if desc.has_attr(attr_name):
offload_device = desc.attr(attr_name)
copied_offload_device = copy.deepcopy(offload_device)
dist_attr.set_offload_device(copied_offload_device)
def _copy_distributed_attr_from_op_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
attr_name = append_distributed_attr_suffix("pipeline_stage")
if desc.has_attr(attr_name):
pipeline_stage = desc.attr(attr_name)
copied_pipeline_stage = copy.deepcopy(pipeline_stage)
dist_attr.set_pipeline_stage(copied_pipeline_stage)
def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
dims_mapping = dist_attr.get_dims_mapping()
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("dim_mapping")
desc._set_attr(attr_name, dims_mapping)
shard_mask = dist_attr.get_shard_mask()
if shard_mask is not None:
attr_name = append_distributed_attr_suffix("mask")
desc._set_attr(attr_name, shard_mask)
offload_device = dist_attr.get_offload_device()
if offload_device is not None:
attr_name = append_distributed_attr_suffix("offload_device")
desc._set_attr(attr_name, offload_device)
def _store_distributed_attr_to_op_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
for tensor_name in desc.input_arg_names():
dims_mapping = dist_attr.get_input_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
for tensor_name in desc.output_arg_names():
dims_mapping = dist_attr.get_output_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
pipeline_stage = dist_attr.get_pipeline_stage()
if pipeline_stage is not None:
attr_name = append_distributed_attr_suffix("pipeline_stage")
desc._set_attr(attr_name, pipeline_stage)
def initialize_distributed_attr_for_graph(self, graph):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before its graph."
if self._is_initialized_for_graph:
return
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program[
tensor_id]
assert tensor_dist_attr is not None, \
"Tensor must have a distributed attribute after the initialization for program."
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self.set_tensor_distributed_attr_for_graph(node,
new_tensor_dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program[op_id]
assert op_dist_attr is not None, \
"Operator must have a distributed attribute after the initialization for program."
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self.set_op_distributed_attr_for_graph(node, new_op_dist_attr)
self._is_initialized_for_graph = True
def clear_distributed_attr_for_program(self):
self._tensor_distributed_attr_map_for_program.clear()
self._op_distributed_attr_map_for_program.clear()
def clear_distributed_attr_for_graph(self):
self._tensor_distributed_attr_map_for_graph.clear()
self._op_distributed_attr_map_for_graph.clear()
def copy_distribute_attr_from_graph_to_program(self, graph, program):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"The distribute attributes must be initialized both in its program and graph"
updated_tensors = {}
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr = self.get_tensor_distributed_attr_for_graph(
node)
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self._tensor_distributed_attr_map_for_program[
tensor_id] = new_tensor_dist_attr
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self.get_op_distributed_attr_for_graph(node)
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self._op_distributed_attr_map_for_program[
op_id] = new_op_dist_attr
def amend_distributed_attr_for_program(self):
for attr in self._tensor_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Tensor's distributed attribute {} is not valid".format(attr)
tensor_shape = attr.get_shape()
dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Operator's distributed attribute {} is not valid".format(attr)
for arg_name in attr.get_owner_op().desc.input_arg_names():
tensor_shape = attr.get_input_shape(arg_name)
dims_mapping = attr.get_input_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names():
tensor_shape = attr.get_output_shape(arg_name)
dims_mapping = attr.get_output_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
class DistOpHelper:
"""
DistOpHelper is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
...@@ -131,7 +131,7 @@ class TensorCostNode(CostNode): ...@@ -131,7 +131,7 @@ class TensorCostNode(CostNode):
elif node.dtype == paddle.int64: elif node.dtype == paddle.int64:
self.dtype_factor *= 8 self.dtype_factor *= 8
else: else:
raise NotImplementedError("{} not counted".format(v.node.dtype)) raise NotImplementedError("{} not counted".format(node.dtype))
self.batch_size = None self.batch_size = None
if batch_size is not None: if batch_size is not None:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid.framework import Variable
from .process_mesh import ProcessMesh
_g_tensor_dist_attr_field_keys = [
"process_mesh", "dims_mapping", "shard_sizes", "device_placement"
]
_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"]
_g_op_input_suffix = "@input"
_g_op_output_suffix = "@output"
def get_tensor_dist_attr_field_keys():
global _g_tensor_dist_attr_field_keys
return _g_tensor_dist_attr_field_keys
def get_op_dist_attr_field_keys():
global _g_op_dist_attr_field_keys
return _g_op_dist_attr_field_keys
def append_op_input_suffix(name):
global _g_op_input_suffix
return name + _g_op_input_suffix
def append_op_output_suffix(name):
global _g_op_output_suffix
return name + _g_op_output_suffix
class TensorDistributedAttribute:
def __init__(self):
# The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed
self._process_mesh = None
self._dims_mapping = None
self._shard_sizes = None
self._device_placement = None
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
@property
def dims_mapping(self):
return self._dims_mapping
@dims_mapping.setter
def dims_mapping(self, dims_mapping):
if dims_mapping is not None:
assert isinstance(dims_mapping, list), \
"The type of dims_mapping must be list."
assert all(isinstance(x, int) for x in dims_mapping), \
("All elements of dims_mapping must be integer")
assert all(x >= -1 for x in dims_mapping), \
("All elements of dims_mapping must be greater than or equal to -1.")
self._dims_mapping = copy.deepcopy(dims_mapping)
@property
def shard_sizes(self):
return self._shard_sizes
@shard_sizes.setter
def shard_sizes(self, shard_sizes):
if shard_sizes is not None:
self._shard_sizes = copy.deepcopy(shard_sizes)
@property
def device_placement(self):
return self._device_placement
@device_placement.setter
def device_placement(self, device_placement):
if device_placement is not None:
self._device_placement = copy.deepcopy(device_placement)
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def is_annotated(self, dist_attr_field_name):
return self._is_annotated.get(dist_attr_field_name, False)
def mark_annotated(self, dist_attr_field_name):
self._is_annotated[dist_attr_field_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be dict or TensorDistributedAttribute."
if isinstance(dist_attr, dict):
for key in dist_attr.keys():
if key in get_tensor_dist_attr_field_keys():
self.mark_annotated(key)
elif isinstance(dist_attr, TensorDistributedAttribute):
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
def clear_annotated(self):
self._is_annotated.clear()
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tdims_mapping ({}): {}".format(annotated_str,
self.dims_mapping)
str += "\n\t}"
return str
class OperatorDistributedAttribute:
def __init__(self):
self._process_mesh = None
self._impl_type = None
self._impl_idx = None
self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {}
self._is_annotated = {}
@property
def process_mesh(self):
return self._process_mesh
@process_mesh.setter
def process_mesh(self, process_mesh):
if process_mesh is not None:
assert isinstance(process_mesh, (list, ProcessMesh)), \
"The type of process_mesh must be list or ProcessMesh."
if isinstance(process_mesh, list):
process_mesh = ProcessMesh(process_mesh)
self._process_mesh = copy.deepcopy(process_mesh)
for dist_attr in self._inputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
for dist_attr in self._outputs_dist_attrs.values():
dist_attr.process_mesh = process_mesh
@property
def impl_type(self):
return self._impl_type
@impl_type.setter
def impl_type(self, impl_type):
if impl_type is not None:
self._impl_type = impl_type
@property
def impl_idx(self):
return self._impl_idx
@impl_idx.setter
def impl_idx(self, impl_idx):
if impl_idx is not None:
self._impl_idx = impl_idx
@property
def inputs_dist_attrs(self):
return self._inputs_dist_attrs
@property
def outputs_dist_attrs(self):
return self._outputs_dist_attrs
def get_input_dist_attr(self, name):
return self._inputs_dist_attrs.get(name, None)
def set_input_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._inputs_dist_attrs[name] = dist_attr_object
def get_output_dist_attr(self, name):
return self._outputs_dist_attrs.get(name, None)
def set_output_dist_attr(self, name, dist_attr):
dist_attr_object = TensorDistributedAttribute()
dist_attr_object.init(dist_attr)
self._outputs_dist_attrs[name] = dist_attr_object
def get_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
dims_mapping = input_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_input_dims_mapping(self, name, dims_mapping):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
input_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._inputs_dist_attrs[name] = dist_attr
def get_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
dims_mapping = output_dist_attr.dims_mapping
else:
dims_mapping = None
return dims_mapping
def set_output_dims_mapping(self, name, dims_mapping):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
output_dist_attr.dims_mapping = dims_mapping
else:
dist_attr = TensorDistributedAttribute()
dist_attr.dims_mapping = dims_mapping
self._outputs_dist_attrs[name] = dist_attr
def init(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.init(value)
if dist_attr.get(append_op_input_suffix(key.name), False):
self.set_input_dist_attr(key.name, tensor_dist_attr)
if dist_attr.get(append_op_output_suffix(key.name), False):
self.set_output_dist_attr(key.name, tensor_dist_attr)
else:
if key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(
key, None)
if field_property:
field_property.fset(self, value)
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
elif isinstance(dist_attr, OperatorDistributedAttribute):
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
self.set_input_dist_attr(
tensor_name, dist_attr.get_input_dist_attr(tensor_name))
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
self.set_output_dist_attr(
tensor_name, dist_attr.get_output_dist_attr(tensor_name))
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
for key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(key,
None)
if field_property:
field_property.fset(self, field_property.fget(dist_attr))
else:
assert False, "No setter for {} in args {}.".format(
key, dist_attr)
# Make sure proscess_meshes in dist op be same
process_meshes = []
process_meshes.append(self.process_mesh)
for tensor_dist_attr in self.inputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
for tensor_dist_attr in self.outputs_dist_attrs.values():
process_meshes.append(tensor_dist_attr.process_mesh)
shared_process_mesh = None
for process_mesh in process_meshes:
if process_mesh is not None:
if shared_process_mesh is None:
shared_process_mesh = process_mesh
else:
assert process_mesh == shared_process_mesh, \
"ProcessMeshes in DistributedOperator must be the same."
self.process_mesh = shared_process_mesh
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_annotated(self, attr_name):
if attr_name == "process_mesh":
# Make sure proscess_mesh be annotated consistently
self._is_annotated[attr_name] = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.mark_annotated(attr_name)
else:
self._is_annotated[attr_name] = True
def mark_annotated_as(self, dist_attr):
if dist_attr is None:
return
assert isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
if isinstance(dist_attr, dict):
for key, value in dist_attr.items():
if isinstance(key, Variable):
input_dist_attr = self.get_input_dist_attr(key.name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(value)
output_dist_attr = self.get_output_dist_attr(key.name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(value)
else:
if key in get_op_dist_attr_field_keys():
self.mark_annotated(key)
process_mesh_annotated = False
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.inputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_dist_attr in self.outputs_dist_attrs.values():
if tensor_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
elif isinstance(dist_attr, OperatorDistributedAttribute):
process_mesh_annotated = False
self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
if self.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.inputs_dist_attrs.items(
):
input_dist_attr = self.get_input_dist_attr(tensor_name)
if input_dist_attr is not None:
input_dist_attr.mark_annotated_as(tensor_dist_attr)
if input_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
for tensor_name, tensor_dist_attr in dist_attr.outputs_dist_attrs.items(
):
output_dist_attr = self.get_output_dist_attr(tensor_name)
if output_dist_attr is not None:
output_dist_attr.mark_annotated_as(tensor_dist_attr)
if output_dist_attr.is_annotated("process_mesh"):
process_mesh_annotated = True
if process_mesh_annotated:
self.mark_annotated("process_mesh")
def clear_annotated(self):
self._is_annotated.clear()
for tensor_dist_attr in self.inputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
for tensor_dist_attr in self.outputs_dist_attrs.values():
tensor_dist_attr.clear_annotated()
def is_annotated_input_dims_mapping(self, name):
input_dist_attr = self.get_input_dist_attr(name)
if input_dist_attr:
return input_dist_attr.is_annotated("dims_mapping")
else:
return False
def is_annotated_output_dims_mapping(self, name):
output_dist_attr = self.get_output_dist_attr(name)
if output_dist_attr:
return output_dist_attr.is_annotated("dims_mapping")
else:
return False
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += "\n\t\tprocess_mesh ({}): {},".format(annotated_str,
self.process_mesh)
for arg_name, tensor_dist_attr in self.inputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
for arg_name, tensor_dist_attr in self.outputs_dist_attrs.items():
str += "\n\t\t{}'s: {},".format(arg_name, tensor_dist_attr)
str += "\n\t\timpl type: {}, ".format(self._impl_type)
str += "impl idx: {}".format(self._impl_idx)
str += "\n\t}"
return str
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None
def get_default_distributed_context():
global _g_default_distributed_context
if _g_default_distributed_context is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return _g_default_distributed_context
def set_default_distributed_context(dist_context):
global _g_default_distributed_context
_g_default_distributed_context = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self, program=None):
self._serial_program = program
self._serial_graph = None
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._dist_tensors_for_program = {}
self._dist_ops_for_program = {}
self._dist_tensors_for_graph = {}
self._dist_ops_for_graph = {}
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
@property
def serial_program(self):
return self._serial_program
@property
def serial_graph(self):
return self._serial_graph
@serial_program.setter
def serial_program(self, program):
assert self._serial_program is None, \
"This distributed context has already been realted to a serial program"
self._serial_program = program
@property
def process_meshes(self):
return self._process_meshes
@property
def dist_op_context(self):
return self._dist_op_context
def add_process_mesh(self, process_mesh):
assert isinstance(process_mesh, ProcessMesh), \
'The type of dim_mapping must be ProcessMesh.'
if process_mesh not in self.process_meshes:
self._process_meshes.append(process_mesh)
def add_dist_tensor_for_program(self, dist_tensor):
inner_serial_tensor = dist_tensor.serial_tensor
inner_serial_tensor_id = inner_serial_tensor.desc.id()
self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor
def add_dist_op_for_program(self, dist_op):
inner_serial_op = dist_op.serial_op
inner_serial_op_id = inner_serial_op.desc.id()
self._dist_ops_for_program[inner_serial_op_id] = dist_op
def get_dist_tensor_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_tensors_for_program.get(serial_tensor_id, None)
def get_dist_tensor_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)
def get_dist_op_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
return self._dist_ops_for_program.get(serial_tensor_id, None)
def get_dist_op_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
return self._dist_ops_for_graph.get(serial_tensor_node_id, None)
def get_tensor_dist_attr_for_program(self, serial_tensor):
serial_tensor_id = serial_tensor.desc.id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor)
def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
serial_tensor_node_id = serial_tensor_node.id()
dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
assert serial_tensor_node.is_var() and \
serial_tensor_node.var() is not None
serial_tensor_id = serial_tensor_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
assert dist_tensor is not None, \
"The distributed tensor of the program has not been added to this context."
serial_tensor_node_id = serial_tensor_node.id()
new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
dist_attr)
self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor
def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op)
def get_op_dist_attr_for_graph(self, serial_op_node):
serial_op_node_id = serial_op_node.id()
dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
assert serial_op_node.is_op() and \
serial_op_node.op() is not None
serial_op_id = serial_op_node.op().id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
assert dist_op is not None, \
"The distributed operator of the program has not been added to this context."
serial_op_node_id = serial_op_node.id()
new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
if self._is_initialized_for_program:
return
# Copy the dist tensors and dist ops annotated by users from the default context
default_ctx = get_default_distributed_context()
self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
for block in self._serial_program.blocks:
for tensor in block.vars.values():
# Copy the distributed tensors in the default context
default_dist_tensor = default_ctx.get_dist_tensor_for_program(
tensor)
if default_dist_tensor and default_ctx is not self:
self.add_dist_tensor_for_program(default_dist_tensor)
current_dist_tensor = self.get_dist_tensor_for_program(tensor)
if current_dist_tensor is None:
dist_tensor = DistributedTensor(tensor)
self.add_dist_tensor_for_program(dist_tensor)
for op in block.ops:
# Copy the distributed operators in the default context
default_dist_op = default_ctx.get_dist_op_for_program(op)
if default_dist_op and default_ctx is not self:
self.add_dist_op_for_program(default_dist_op)
current_dist_op = self.get_dist_op_for_program(op)
if current_dist_op is None:
dist_op = DistributedOperator(op)
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True
def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
if self._is_initialized_for_graph:
return
# Convert program to graph
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
dist_tensor = self._dist_tensors_for_program.get(tensor_id,
None)
assert dist_tensor is not None, \
"Tensor must have a distributed tensor after the initialization for program."
self.set_tensor_dist_attr_for_graph(node, dist_tensor.dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
dist_op = self._dist_ops_for_program.get(op_id, None)
assert dist_op is not None, \
"Operator must have a distributed operator after the initialization for program."
self.set_op_dist_attr_for_graph(node, dist_op.dist_attr)
self._is_initialized_for_graph = True
def clear_dist_info_for_program(self):
self._dist_tensors_for_program.clear()
self._dist_ops_for_program.clear()
def clear_dist_info_for_graph(self):
self._dist_tensors_for_graph.clear()
self._dist_ops_for_graph.clear()
def copy_dist_attr_from_graph_to_program(self):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"Both program and graph must be initialized."
updated_tensors = {}
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
node)
dist_tensor_for_program = self._dist_tensors_for_program[
tensor_id]
dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
dist_op_for_program = self._dist_ops_for_program[op_id]
dist_op_for_program.dist_attr = op_dist_attr_for_graph
def amend_dist_attr_for_program(self):
for dist_tensor in self._dist_tensors_for_program.values():
serial_tensor = dist_tensor.serial_tensor
dist_attr = dist_tensor.dist_attr
if serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
dims_mapping = dist_attr.dims_mapping
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for dist_op in self._dist_ops_for_program.values():
serial_op = dist_op.serial_op
dist_attr = dist_op.dist_attr
for arg_name in serial_op.input_arg_names:
if dist_op.get_serial_input(arg_name) is None:
tensor_shape = []
else:
if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
or dist_op.serial_op.type == "create_py_reader":
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_input(arg_name).shape
dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in serial_op.output_arg_names:
if dist_op.get_serial_output(
arg_name).type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = dist_op.get_serial_output(arg_name).shape
dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
process_mesh_shape = dist_attr.process_mesh.topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
def validate_dist_attr_for_program(self):
if not self._is_initialized_for_program:
assert False, \
"Program must be initialized before validating its distributed attributes"
for block in self.serial_program.blocks:
for tensor in block.vars.values():
dist_tensor = self.get_dist_tensor_for_program(tensor)
if (dist_tensor is not None) and (
not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.dist_attr)
for op in block.ops:
dist_op = self.get_dist_op_for_program(op)
if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} has a wrong distributed attributes {}.".format(
dist_op.serial_op.type, dist_tensor.dist_attr)
return True
class DistributedOperatorContext:
"""
DistributedOperatorContext is used to create a dist op desc in Program.
Every time to create a new dist op, the context should be updated for it accordingly.
"""
def __init__(self):
self._dst_main_program = None
self._dst_startup_program = None
self._varname_mapping = None
self._rank_id = None
self._cur_src_op = None
self._cur_dist_attr = None
self.gradopidx2opidx = {}
self.already_init_sync_vars = set()
def set_dst_main_program(self, prog):
self._dst_main_program = prog
def get_dst_main_program(self):
return self._dst_main_program
def set_dst_startup_program(self, prog):
self._dst_startup_program = prog
def get_dst_startup_program(self):
return self._dst_startup_program
def set_varname_mapping(self, mapping):
self._varname_mapping = mapping
def get_varname_mapping(self):
return self._varname_mapping
def set_rank_id(self, rank_id):
self._rank_id = rank_id
def get_rank_id(self):
return self._rank_id
def set_cur_src_op(self, cur_src_op):
self._cur_src_op = cur_src_op
def get_cur_src_op(self):
return self._cur_src_op
def prepare_forward_context(self, src_op):
self.set_cur_src_op(src_op)
# build input varname mapping
kinputs = {}
for input_name in src_op.desc.input_names():
varnames = []
for varname in src_op.desc.input(input_name):
varnames.append(self._varname_mapping[varname])
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in src_op.desc.output_names():
varnames = []
for varname in src_op.desc.output(output_name):
varnames.append(self._varname_mapping[varname])
koutputs[output_name] = varnames
return kinputs, koutputs
def prepare_backward_context(self, backward_op):
self.set_cur_src_op(backward_op)
# build input varname mapping
kinputs = {}
for input_name in backward_op.desc.input_names():
varnames = []
for varname in backward_op.desc.input(input_name):
varnames.append(varname)
kinputs[input_name] = varnames
# build output varname mapping
koutputs = {}
for output_name in backward_op.desc.output_names():
varnames = []
for varname in backward_op.desc.output(output_name):
varnames.append(varname)
koutputs[output_name] = varnames
return kinputs, koutputs
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
self._serial_inputs = {}
self._serial_outputs = {}
self._dist_attr = None
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_op(self):
return self._serial_op
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = OperatorDistributedAttribute()
# Create new dist_attr related to current serial_op
dist_attr = self._filter_dist_attr(dist_attr)
# Append suffix to mark the inputs or outputs
if isinstance(dist_attr, dict):
# Copy the keys since we may add new ones
for key in list(dist_attr.keys()):
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names:
dist_attr[append_op_input_suffix(key.name)] = True
if key.name in self._serial_op.output_arg_names:
dist_attr[append_op_output_suffix(key.name)] = True
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def get_serial_input(self, name):
return self._serial_inputs.get(name, None)
def get_serial_output(self, name):
return self._serial_outputs.get(name, None)
def _init_default_dist_attr(self):
for tensor_name in self._serial_op.input_arg_names:
if self._serial_op.type == "create_py_reader":
tensor = None
else:
tensor = self._serial_op.block._var_recursive(tensor_name)
self._serial_inputs[tensor_name] = tensor
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
for tensor_name in self._serial_op.output_arg_names:
tensor = self._serial_op.block._var_recursive(tensor_name)
if tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = tensor.shape
self._serial_outputs[tensor_name] = tensor
if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.set_output_dims_mapping(tensor_name,
tensor_dims_mapping)
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
def _filter_dist_attr(self, dist_attr):
if dist_attr is None:
return None
new_dist_attr = None
if isinstance(dist_attr, dict):
new_dist_attr = {}
for key, value in dist_attr.items():
if isinstance(key, Variable):
if key.name in self._serial_op.input_arg_names \
or key.name in self._serial_op.output_arg_names:
new_dist_attr[key] = value
else:
new_dist_attr[key] = value
elif isinstance(dist_attr, OperatorDistributedAttribute):
new_dist_attr = copy.deepcopy(dist_attr)
new_dist_attr._inputs_dist_attrs.clear()
new_dist_attr._outputs_dist_attrs.clear()
for tensor_name in self._serial_op.input_arg_names:
tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_input_dist_attr(tensor_name,
tensor_dist_attr)
for tensor_name in self._serial_op.output_arg_names:
tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
if tensor_dist_attr:
new_dist_attr.set_output_dist_attr(tensor_name,
tensor_dist_attr)
else:
assert False, "Cannot recognize the {} parameter.".format(dist_attr)
return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type:
return True
for name in self.serial_op.input_arg_names:
input_dist_attr = self.dist_attr.get_input_dist_attr(name)
dims_mapping = input_dist_attr.dims_mapping
shape = self.get_serial_input(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != input_dist_attr.process_mesh:
return False
for name in self.serial_op.output_arg_names:
output_dist_attr = self.dist_attr.get_output_dist_attr(name)
dims_mapping = output_dist_attr.dims_mapping
shape = self.get_serial_output(name).shape
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if dims_mapping.count(i) > 1:
return False
if self.dist_attr.process_mesh != output_dist_attr.process_mesh:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(),
self.serial_op.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_input(arg_name) is not None:
if self.get_serial_input(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.get_serial_output(arg_name) is not None:
if self.get_serial_output(arg_name).is_parameter:
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(None)
str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
return str
class DistributedModule:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = self._serial_module(*args, **kwargs)
new_op_size = len(main_block.ops)
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
default_dist_ctx.add_dist_op_for_program(dist_op)
if isinstance(output, Variable):
output = [output]
return list(output)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from paddle.fluid import core
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import get_tensor_dist_attr_field_keys
class DistributedTensor:
def __init__(self, serial_tensor, dist_attr=None):
self._serial_tensor = serial_tensor
self._dist_attr = None
self._batch_dim = 0
# Reuse the dist_attr setter to initialize _dist_attr
self.dist_attr = dist_attr
@property
def serial_tensor(self):
return self._serial_tensor
@property
def dist_attr(self):
return self._dist_attr
@dist_attr.setter
def dist_attr(self, dist_attr):
if self._dist_attr is None:
self._dist_attr = TensorDistributedAttribute()
self._dist_attr.init(dist_attr)
self._init_default_dist_attr()
def _init_default_dist_attr(self):
if self._dist_attr.dims_mapping is None:
if self.serial_tensor.type == core.VarDesc.VarType.READER:
tensor_shape = []
else:
tensor_shape = self._serial_tensor.shape
tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
self._dist_attr.dims_mapping = tensor_dims_mapping
def validate_dist_attr(self):
if self.serial_tensor.type == core.VarDesc.VarType.READER:
return True
tensor_shape = self.serial_tensor.shape
if len(tensor_shape) != len(self.dist_attr.dims_mapping):
return False
for i in range(len(self.dist_attr.dims_mapping)):
if self.dist_attr.dims_mapping[
i] < -1 or self.dist_attr.dims_mapping[i] >= len(
self.dist_attr.process_mesh.topology):
return False
for i in range(len(self.dist_attr.process_mesh.topology)):
if self.dist_attr.dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.serial_tensor.desc.name(), self.serial_tensor.desc.id())
# str += ", {}".format(self.dist_attr)
# return str
if self.dist_attr.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.dist_attr.process_mesh)
str += ", is_parameter: {}".format(self.serial_tensor.is_parameter)
if self.dist_attr.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.dist_attr.dims_mapping)
if self.dist_attr.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str, None)
if self.dist_attr.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str, None)
return str
...@@ -18,293 +18,34 @@ import paddle ...@@ -18,293 +18,34 @@ import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from .dist_context import get_default_distributed_context
__all__ = [] from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
# a map from ProcessMesh ids to the ProcessMesh instances from .dist_attribute import TensorDistributedAttribute
_g_process_mesh_map = dict() from .dist_attribute import OperatorDistributedAttribute
# user defined map from logical process ids to physical ones
_user_defined_physical_map = None
def _append_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def _remove_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def _static_mode_check(): def _static_mode_check():
if in_dygraph_mode(): if in_dygraph_mode():
raise RuntimeError("Auto-parallel only supports static mode, " raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static().") "please use paddle.enable_static() first.")
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
parent (ProcessMesh, optional): the parent ProcessMesh. None means
the ProcessMesh is the root one without parent ProcessMesh.
Default: None.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.parent is None
assert mesh.topology == [2, 3]
assert mesh.process_group == [2, 4, 5, 0, 1, 3]
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
def __init__(self, mesh, parent=None):
_static_mode_check()
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
self._topology = _get_nested_list_shape(mesh)
self._processes = _flatten_nested_list(mesh)
# Every element of mesh must be >= 0.
assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.')
unique_ids = set(self._processes)
assert len(unique_ids) == len(self._processes), (
'All elements of mesh must be unique.')
if parent is None:
# For root ProcessMesh, the ids of logical processes must be range
# from 0 to N-1, where N is the number of logical processes.
assert max(self._processes) == len(self._processes) - 1, (
'For root ProcessMesh, ids of logical processes must be range '
'from 0 to N-1, where N is the number of logical processes.')
parent_id = core.kNoneProcessMeshIndex()
assert len(_g_process_mesh_map.keys()) == 0, (
'The first ProcessMesh must be the root, which has no parent.')
else:
assert len(_g_process_mesh_map.keys()) > 0, (
'All ProcessMesh must have a parent except the root one.')
assert isinstance(parent, ProcessMesh), (
'parent must be an instance of ProcessMesh.')
parent_id = parent._desc.id
# All elements in mesh must belong to its parent
parent_ids = set(parent.process_group)
assert unique_ids <= parent_ids, (
'All elements in mesh must belong to its parent.')
self._desc = core.ProcessMeshDesc(self._topology, self._processes,
parent_id)
self._id = self._desc.id
self._parent_id = parent_id
assert self._id not in _g_process_mesh_map, (
"The ProcessMesh with id %d already exists." % self._id)
_g_process_mesh_map[self._id] = self
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def process_group(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def parent(self):
r"""
Get the parent ProcessMesh.
"""
if self._parent_id == core.kNoneProcessMeshIndex(): return None
assert self._parent_id in _g_process_mesh_map, (
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
user defined order.
Args:
order (list): order of the physical process ids.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
assert self.parent is None, (
"This function can only be called by the root ProcessMesh.")
unique_ids = set(order)
assert isinstance(order, list)
assert len(unique_ids) == len(order), (
"All elements in order must be unique.")
assert min(order) == 0
assert max(order) == len(order) - 1, (
"All elements in order must be from 0 to N - 1, where N "
"is the number of physical processes.")
logical_order = self.process_group
global _user_defined_physical_map
assert _user_defined_physical_map is None, (
"This function can only be called once.")
_user_defined_physical_map = dict()
assert len(logical_order) == len(order) def shard_tensor(x, dist_attr=None):
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def _reset_global_process_mesh_map(self):
"""
Remove all process mesh in _g_process_mesh_map, make it empty.
"""
_g_process_mesh_map = dict()
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert isinstance(mesh,
ProcessMesh), 'The type of mesh must be ProcessMesh.'
assert isinstance(dim_mapping,
list), 'The type of dim_mapping must be list.'
assert len(tensor.shape) == len(dim_mapping), (
'The number of dimensions '
'of tensor must be the same as the length of its corresponding '
'dim_mapping.')
mesh_dim = len(mesh.topology)
dim_set = set()
for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (
dim_mapping[i] < mesh_dim and dim_mapping[i] >= 0), (
'Each element '
'in dim_mapping must be greater than zero and less than the '
'length of its corresponding topology, or it must be -1.')
if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i])
def shard_tensor(x, mesh, dim_mapping):
""" """
Add distributed attributes for a tensors. Add distributed attributes for a tensors.
Args: Args:
x (Tensor): the tensor to process. x (Tensor): the tensor to be sharded.
mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes. dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
dim_mapping (list): a list to describe the mapping between `x` and `mesh`, "process_mesh": a nested list an to describe the mesh topology of logical processes.
the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means "dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
without parition along the corresponding dimension. `i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
Returns: Returns:
Tensor: the tensor `x` itself. Tensor: the tensor `x` annotated with distributed attributes.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping): ...@@ -314,87 +55,36 @@ def shard_tensor(x, mesh, dim_mapping):
paddle.enable_static() paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [0, -1]) dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
"dims_mapping": [0, -1]})
""" """
_static_mode_check() _static_mode_check()
_dim_mapping_checker(x, mesh, dim_mapping) assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
attr_name = _append_attr_suffix('mesh_id') "The type of dist_attr must be None, dict or TensorDistributedAttribute."
x._set_attr(attr_name, mesh._id) dist_tensor = DistributedTensor(x, dist_attr)
attr_name = _append_attr_suffix('dim_mapping') dist_tensor.dist_attr.mark_annotated_as(dist_attr)
x._set_attr(attr_name, dim_mapping) default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
return x return x
def set_shard_mask(x, mask): def shard_op(op_fn, dist_attr=None):
"""
Set the mask for a tensor which mask out the tensor from some processes in its mesh.
Args:
x (Tensor): the tensor to process.
mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to
the tensor `x`. Every value of `mask` must be one or zero, where one means
the tenor `x` will be put on the corresponding logical process and zero means the tensor `x`
will not be put on the corresponding logical process.
For example, for a ProcessMesh represented by the 2-dimensional
array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the
2-dimensional [[1, 0, 1], [0, 1, 0]],
then the tensor `x` will only be put on logical processes 2, 5 and 1.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [-1, 1])
dist.set_shard_mask(x, mask)
"""
_static_mode_check()
assert isinstance(mask, list)
np_mask = numpy.array(mask)
min_ele = numpy.min(np_mask)
max_ele = numpy.max(np_mask)
mesh_attr_name = _append_attr_suffix('mesh_id')
assert x._has_attr(mesh_attr_name), \
"Please set process mesh for the variable firstly."
assert min_ele >= 0 and max_ele <= 1, "Elements in mask must be 0 or 1."
x_mesh = x.process_mesh
assert x_mesh, "Please set process mesh for the variable firstly."
assert x_mesh.topology == list(np_mask.shape), (
"The shape of mask "
"must be the same as the shape of its Process Mesh.")
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
return x
def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
""" """
Call a functioin and add distributed attributes for ops added by the function. Call a functioin and add distributed attributes for ops added by the function.
Args: Args:
op_fn (callable): a callable object of an API. op_fn (callable): a callable operator or module to be sharded.
mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes. dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping. two categories. The first category decsribes the distributed attributes shared by all inputs and
The dim_mapping is a list to describe the mapping between a tensor and `mesh`, outputs, and only `process_mesh` can be specified now. The second category describes distributed
the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`, attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
where -1 means without parition along the corresponding dimension. optional and users can specify them as need. Note that `process_mesh` for operators must be the
kwargs (dict): a dict of parameter passed to the function `op_fn`. same as these process_meshes for inputs and outputs.
Returns: Returns:
list: the outputs of the function `op_fn`. list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs): ...@@ -404,100 +94,19 @@ def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
paddle.enable_static() paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6]) x = paddle.ones([4, 6])
y = paddle.zeros([4, 6]) y = paddle.zeros([4, 6])
kwargs = {'x': x, 'y': y} dist_add = dist.shard_op(paddle.add,
dist.shard_op(paddle.add, mesh, None, **kwargs) dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
""" x: {"dims_mapping": [-1, 0]},
_static_mode_check() y: {"dims_mapping": [0, -1]}
main_prog = paddle.fluid.default_main_program() })
main_block = main_prog.global_block() dist_add(x, y)
op_size = len(main_block.ops)
output = op_fn(**kwargs)
new_op_size = len(main_block.ops)
if dim_mapping_dict is None:
dim_mapping_dict = dict()
else:
assert isinstance(dim_mapping_dict,
dict), 'The type of dim_mapping_dict must be dict.'
for var_name in dim_mapping_dict.keys():
dim_mapping = dim_mapping_dict[var_name]
tensor = main_block.var(var_name)
_dim_mapping_checker(tensor, mesh, dim_mapping)
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id')
op._set_attr(attr_name, mesh._id)
for var_name in dim_mapping_dict.keys():
assert var_name in op.output_arg_names + op.input_arg_names
attr_name = _append_attr_suffix(var_name)
if var_name in op.input_arg_names:
# we use the prefix "IN_" to indicates an input argument name
attr_name = "IN_" + attr_name
else:
# we use the prefix "OUT_" to indicates an input argument name
attr_name = "OUT_" + attr_name
op._set_attr(attr_name, dim_mapping_dict[var_name])
if isinstance(output, Variable):
output = [output]
return list(output)
def set_offload_device(x, device):
"""
Set the device that the tensor `x` will be put on.
Args:
x (tensor): the tensor to process.
device (str): the device that the tensor `x` will be put on, e.g., 'cpu'.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
x = paddle.ones([4, 6])
dist.set_offload_device(x, 'cpu')
"""
_static_mode_check()
assert device == "cpu", "Only 'cpu' is supported for destination device."
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
def set_pipeline_stage(stage):
"""
Set the pipeline stage of the following ops.
Args:
stage (int): the pipeline stage the following ops belonging to.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
dist.set_pipeline_stage(0)
""" """
from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check() _static_mode_check()
assert isinstance(stage, int), 'The type of stage must be int.' assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
_set_pipeline_stage(stage) "The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding from . import dist_embedding
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
DISTRIBUTED_OPERATORS = {} _g_distributed_operator_impl_registries = {}
class DistributedOperator: class DistributedOperatorImplContainer:
def __init__(self): def __init__(self):
self._impls = [] self._impls = []
self._name = None self._name = None
...@@ -47,67 +47,60 @@ class DistributedOperatorImpl: ...@@ -47,67 +47,60 @@ class DistributedOperatorImpl:
def get_name(self): def get_name(self):
return self._name return self._name
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr): def is_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") return self.is_input_compatible(dist_op) and \
self.is_output_compatible(dist_op)
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method in Subclass.") raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op): def register_distributed_operator_impl_container(name, dist_op_impl_container):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
DISTRIBUTED_OPERATORS[name] = dist_op _g_distributed_operator_impl_registries[name] = dist_op_impl_container
def get_distributed_operator(name): def get_distributed_operator_impl_container(name):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
return DISTRIBUTED_OPERATORS.get(name, None) return _g_distributed_operator_impl_registries.get(name, None)
def register_distributed_operator_impl(name, dist_impl): def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name) dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op is not None: if dist_op_impl_container is not None:
dist_op.register_impl(dist_impl) dist_op_impl_container.register_impl(dist_impl)
else: else:
assert False, "Must register distributed operator first." assert False, "Must register distributed operator registry first."
def get_distributed_operator_impl(name, impl_idx): def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS global _g_distributed_operator_impl_registries
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx) return _g_distributed_operator_impl_registries[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr, def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
fwd=True):
""" """
Here just return the first compatible implemention. Here just return the first compatible implemention.
This will be improved by cost model in the future. This will be improved by cost model in the future.
""" """
dist_op = get_distributed_operator(name) dist_op_impl_container = get_distributed_operator_impl_container(name)
if dist_op is None: if dist_op_impl_container is None:
return None, -1 return None, -1
compatible_impls = [] compatible_impls = []
impls = dist_op.get_impls() impls = dist_op_impl_container.get_impls()
if fwd: if fwd:
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \ if impl.is_input_compatible(dist_op):
and impl.is_input_compatible(op_dist_attr):
compatible_impls.append((impl, idx)) compatible_impls.append((impl, idx))
else: else:
for idx, impl in enumerate(impls): for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \ if impl.is_output_compatible(dist_op):
and impl.is_output_compatible(op_dist_attr):
compatible_impls.append((impl, idx)) compatible_impls.append((impl, idx))
if compatible_impls: if compatible_impls:
...@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr, ...@@ -118,48 +111,84 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
return best_compatible_impl, idx return best_compatible_impl, idx
def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var): # def copy_distributed_attr_for_var(src_op_dist_attr, dst_var, src_var):
""" # """
copy src var's dist_attr to dst var # copy src var's dist_attr to dst var
""" # """
import copy # import copy
auto_paralle_context = src_op_dist_attr.get_owner_context() # auto_paralle_context = src_op_dist_attr.get_owner_context()
dist_attr = copy.deepcopy( # dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) # auto_paralle_context.get_tensor_distributed_attr_for_program(src_var))
dist_attr._owner_tensor = var # dist_attr._owner_tensor = var
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( # dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context # src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) # auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): def copy_distributed_attr_for_var(dist_context, dst_var, src_var):
"""
copy src var's dist_attr to dst var
"""
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
dist_context.set_tensor_dist_attr_for_program(dst_var, dist_attr)
# def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr):
# """
# copy src op's dist_attr to dst dist op
# """
# from ..attribute import OperatorDistributedAttribute
# auto_paralle_context = src_op_dist_attr.get_owner_context()
# op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context)
# auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc,
# op_dist_attr)
# auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
# op_dist_attr)
# op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
# op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
# for input_varname in dist_op.desc.input_arg_names():
# input_var = dst_block.var(input_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# input_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
# for output_varname in dist_op.desc.output_arg_names():
# output_var = dst_block.var(output_varname)
# tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program(
# output_var)
# tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
# op_dist_attr.set_output_dims_mapping(output_varname,
# tensor_dims_mapping)
def copy_distributed_attr_for_dist_op(dist_context, dist_op, dst_block,
src_op_dist_attr):
""" """
copy src op's dist_attr to dst dist op copy src op's dist_attr to dst dist op
""" """
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
# need check dist op attr and its inputs and outputs
auto_paralle_context = src_op_dist_attr.get_owner_context() op_dist_attr = OperatorDistributedAttribute()
op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) op_dist_attr.process_mesh = src_op_dist_attr.process_mesh
auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, op_dist_attr.impl_idx = src_op_dist_attr.impl_idx
op_dist_attr)
auto_paralle_context.set_op_distributed_attr_for_program(dist_op,
op_dist_attr)
op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh())
op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx())
for input_varname in dist_op.desc.input_arg_names(): for input_varname in dist_op.desc.input_arg_names():
input_var = dst_block.var(input_varname) input_var = dst_block.var(input_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
input_var) input_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr)
op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping)
for output_varname in dist_op.desc.output_arg_names(): for output_varname in dist_op.desc.output_arg_names():
output_var = dst_block.var(output_varname) output_var = dst_block.var(output_varname)
tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
output_var) output_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr)
op_dist_attr.set_output_dims_mapping(output_varname,
tensor_dims_mapping) dist_context.set_op_dist_attr_for_program(dist_op, op_dist_attr)
op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_op)
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -22,23 +22,24 @@ from ..utils import is_valid_list_index ...@@ -22,23 +22,24 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedDefault(DistributedOperator): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedDefault, self).__init__() super(DistributedDefault, self).__init__()
self._name = name self._name = name
register_distributed_operator("default", DistributedDefault("default")) register_distributed_operator_impl_container("default",
DistributedDefault("default"))
# Replicated Default # Replicated Default
...@@ -49,27 +50,24 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -49,27 +50,24 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
def is_output_compatible(self, op_dist_attr): def update_dims_mapping(self, dist_op):
raise NotImplementedError("Please Implement this method.")
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method.") raise NotImplementedError("Please Implement this method.")
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
varname_mapping = dist_op_helper.get_varname_mapping() varname_mapping = dist_op_context.get_varname_mapping()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -100,25 +98,25 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -100,25 +98,25 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in dist_op_desc.input_arg_names(): for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var( if startup_block.has_var(varname) and startup_block.var(
varname varname
).is_parameter and varname not in dist_op_helper.already_init_sync_vars: ).is_parameter and varname not in dist_op_context.already_init_sync_vars:
dist_op_helper.already_init_sync_vars.add(varname) dist_op_context.already_init_sync_vars.add(varname)
param = startup_block.var(varname) param = startup_block.var(varname)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program( param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
param) process_mesh = param_dist_attr.process_mesh
process_mesh = param_dist_attr.get_process_mesh() dims_mapping = param_dist_attr.dims_mapping
dims_mapping = param_dist_attr.get_dims_mapping()
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group: if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dims_mapping: if size <= 1 or axis in dims_mapping:
pass pass
else: else:
group_ranks = _get_comm_group( group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.process_group, process_mesh.topology, process_mesh.topology,
axis, rank_id) axis, rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
...@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -134,12 +132,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
}) })
# set distributed attribute # set distributed attribute
op_attr = OperatorDistributedAttribute(new_op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(param.name, op_attr.set_output_dims_mapping(param.name,
dims_mapping) dims_mapping)
op_attr.set_input_dims_mapping(param.name, dims_mapping) op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(new_op, op_attr) ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -147,13 +145,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -147,13 +145,13 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op() backward_op = dist_op_context.get_cur_src_op()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
# check if need gradient allreduce # check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited, # if there is a non-gradient & non-parameter input and its batch dimension is splited,
...@@ -165,19 +163,20 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -165,19 +163,20 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
varname).is_parameter: varname).is_parameter:
# NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op # NOTE input var's dim_mapping of backward op should be the same with input var instead of corresponding varname of forward op
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(varname) var_dim_mapping = dist_attr.get_input_dims_mapping(varname)
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.process_group: if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(process_mesh, rank_id) rank_id = _get_corresponding_rank(ctx, process_mesh,
rank_id)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group( group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.process_group, process_mesh.topology, process_mesh.topology,
batch_size_axis, rank_id) batch_size_axis, rank_id)
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
dp_group = new_process_group(group_ranks) dp_group = new_process_group(group_ranks)
...@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -228,17 +227,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
dims_mapping = ctx.get_tensor_distributed_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var).get_dims_mapping() grad_var).dims_mapping
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(grad_var.name, op_attr.set_output_dims_mapping(grad_var.name,
dims_mapping) dims_mapping)
op_attr.set_input_dims_mapping(grad_var.name, op_attr.set_input_dims_mapping(grad_var.name,
dims_mapping) dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp() main_block._sync_with_cpp()
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op from .common import copy_distributed_attr_for_dist_op
...@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index ...@@ -24,25 +24,26 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperator): class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbedding, self).__init__() super(DistributedEmbedding, self).__init__()
self._name = name self._name = name
register_distributed_operator("lookup_table_v2", register_distributed_operator_impl_container("lookup_table_v2",
DistributedEmbedding("embedding"))
register_distributed_operator_impl_container("c_embedding",
DistributedEmbedding("embedding")) DistributedEmbedding("embedding"))
register_distributed_operator("c_embedding", DistributedEmbedding("embedding"))
# RowParallel # RowParallel
...@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -53,12 +54,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0] ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0] w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
...@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -72,8 +70,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
...@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -82,9 +81,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
ids_name = op_desc.input('Ids')[0] ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0] w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
...@@ -111,12 +111,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -111,12 +111,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -147,12 +147,12 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
Weight_var.name)[0] Weight_var.name)[0]
assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format( assert embedding_row_dim_mapping >= 0, "row_parallel_embedding's row should be divided by a specific mesh axis, but got [{}]".format(
embedding_row_dim_mapping) embedding_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh_group: if rank_id not in process_mesh_group:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# A generalized method to caculate embedding offset using cartisian product # A generalized method to caculate embedding offset using cartisian product
...@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -182,7 +182,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
stop_gradient=Out_var.stop_gradient) stop_gradient=Out_var.stop_gradient)
# copy Out_var's dist_attr to intermediate_var_0's dist_attr # copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
check_variable_and_dtype( check_variable_and_dtype(
Out_var, 'tensor', Out_var, 'tensor',
...@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -208,25 +208,25 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
}) })
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_embedding_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_embedding_op, main_block,
op_dist_attr) op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr) op_dist_attr)
# param initialization sync # param initialization sync
assert Weight_var.name not in dist_op_helper.already_init_sync_vars assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_helper.already_init_sync_vars.add(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name) param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh() process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.get_dims_mapping() dim_mapping = param_dist_attr.dims_mapping
# NOTE all not splited axis should be presented in mesh # NOTE all not splited axis should be presented in mesh
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis, process_mesh.topology, axis,
rank_id) rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
...@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -247,17 +247,17 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op() backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group: if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh,
rank_id) rank_id)
# check if need gradient allreduce # check if need gradient allreduce
...@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -286,14 +286,14 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD']) kwargs['W@GRAD'])
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block.var(kwargs['Ids'][0])
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name) var_dim_mapping = dist_attr.get_input_dims_mapping(Ids_var.name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, process_mesh.topology,
batch_size_axis, rank_id) batch_size_axis, rank_id)
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
...@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -318,15 +318,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
W_Grad_var).get_dims_mapping() W_Grad_var).dims_mapping
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(W_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(W_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
register_distributed_operator_impl("lookup_table_v2", register_distributed_operator_impl("lookup_table_v2",
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from .common import copy_distributed_attr_for_var from .common import copy_distributed_attr_for_var
from .common import copy_distributed_attr_for_dist_op from .common import copy_distributed_attr_for_dist_op
...@@ -24,19 +24,20 @@ from ..utils import is_valid_list_index ...@@ -24,19 +24,20 @@ from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
from ..attribute import OperatorDistributedAttribute from ..dist_attribute import OperatorDistributedAttribute
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from ..process import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
def _update_dims_mapping_for_matmul(op_dist_attr): def _update_dims_mapping_for_matmul(dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
...@@ -112,7 +113,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr): ...@@ -112,7 +113,7 @@ def _update_dims_mapping_for_matmul(op_dist_attr):
if dim_changed: if dim_changed:
changed = True changed = True
# Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor # Remove unnecessary dim mapping to make sure the length of dims_mapping is same as its tensor
if x_dims_mapping_len == 1: if x_dims_mapping_len == 1:
x_dims_mapping.pop(0) x_dims_mapping.pop(0)
if y_dims_mapping_len == 1: if y_dims_mapping_len == 1:
...@@ -129,17 +130,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -129,17 +130,17 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
backward_op = dist_op_helper.get_cur_src_op() backward_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
dist_attr = ctx.get_op_distributed_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in dist_attr.get_process_mesh().process_group: if rank_id not in dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(dist_attr.get_process_mesh(), rank_id) rank_id = _get_corresponding_rank(ctx, dist_attr.process_mesh, rank_id)
# check if need gradient allreduce # check if need gradient allreduce
need_gradient_allreduce = False need_gradient_allreduce = False
...@@ -175,13 +176,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -175,13 +176,13 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format( assert not X_var.is_parameter, "left operand(X) [{}] of dist matmul should not be parameter".format(
X_var.name) X_var.name)
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name) var_dim_mapping = dist_attr.get_input_dims_mapping(X_var.name)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, batch_size_axis, process_mesh.topology, batch_size_axis,
rank_id) rank_id)
dp_degree = len(group_ranks) dp_degree = len(group_ranks)
...@@ -207,32 +208,32 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -207,32 +208,32 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_distributed_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
Y_Grad_var).get_dims_mapping() Y_Grad_var).dims_mapping
process_mesh = dist_attr.get_process_mesh() process_mesh = dist_attr.process_mesh
for op in [allreduce_op, scale_op]: for op in [allreduce_op, scale_op]:
op_attr = OperatorDistributedAttribute(op, ctx) op_attr = OperatorDistributedAttribute()
op_attr.set_process_mesh(process_mesh) op_attr.process_mesh = process_mesh
op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping) op_attr.set_output_dims_mapping(Y_Grad_var.name, dims_mapping)
op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping) op_attr.set_input_dims_mapping(Y_Grad_var.name, dims_mapping)
ctx.set_op_distributed_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id): def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
assert Weight_var.name not in dist_op_helper.already_init_sync_vars assert Weight_var.name not in dist_op_context.already_init_sync_vars
assert startup_block.has_var(Weight_var.name) assert startup_block.has_var(Weight_var.name)
dist_op_helper.already_init_sync_vars.add(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name) param = startup_block.var(Weight_var.name)
param_dist_attr = ctx.get_tensor_distributed_attr_for_program(param) param_dist_attr = ctx.get_tensor_dist_attr_for_program(param)
process_mesh = param_dist_attr.get_process_mesh() process_mesh = param_dist_attr.process_mesh
dim_mapping = param_dist_attr.get_dims_mapping() dim_mapping = param_dist_attr.dims_mapping
for axis, size in enumerate(process_mesh.topology): for axis, size in enumerate(process_mesh.topology):
if size <= 1 or axis in dim_mapping: if size <= 1 or axis in dim_mapping:
pass pass
else: else:
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(process_mesh.processes,
process_mesh.topology, axis, rank_id) process_mesh.topology, axis, rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
...@@ -249,13 +250,14 @@ def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id): ...@@ -249,13 +250,14 @@ def _init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, rank_id):
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
class DistributedMatmul(DistributedOperator): class DistributedMatmul(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmul, self).__init__() super(DistributedMatmul, self).__init__()
self._name = name self._name = name
register_distributed_operator("matmul", DistributedMatmul("matmul")) register_distributed_operator_impl_container("matmul",
DistributedMatmul("matmul"))
# ColumnParallel # ColumnParallel
...@@ -266,12 +268,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -266,12 +268,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -286,8 +285,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -286,8 +285,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]): if is_dim_replicate(out_dims_mapping[-1]):
...@@ -297,9 +297,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -297,9 +297,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
...@@ -310,18 +310,18 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -310,18 +310,18 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# check validation of inputs / outputs # check validation of inputs / outputs
...@@ -348,8 +348,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -348,8 +348,8 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
Weight_var.name)[1] Weight_var.name)[1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
...@@ -365,7 +365,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -365,7 +365,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
persistable=False, persistable=False,
stop_gradient=X_var.stop_gradient) stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr # copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'tensor', X_var, 'tensor',
...@@ -395,13 +395,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -395,13 +395,14 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs) type='matmul', inputs=inputs, outputs={'Out': Out_var}, attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block,
op_dist_attr)
copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block,
op_dist_attr) op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
@staticmethod @staticmethod
...@@ -417,12 +418,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -417,12 +418,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -438,8 +436,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -438,8 +436,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]): if is_dim_shard(out_dims_mapping[-1]):
...@@ -450,9 +449,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -450,9 +449,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
...@@ -463,18 +462,18 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -463,18 +462,18 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# check validation of inputs / outputs # check validation of inputs / outputs
...@@ -501,8 +500,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -501,8 +500,8 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
Weight_var.name)[0] Weight_var.name)[0]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
...@@ -528,7 +527,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -528,7 +527,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
is_data=False, is_data=False,
need_check_feed=Out_var.desc.need_check_feed()) need_check_feed=Out_var.desc.need_check_feed())
# copy Out_var's dist_attr to intermediate_var_0's dist_attr # copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
matmul_op = main_block.append_op( matmul_op = main_block.append_op(
type='matmul', type='matmul',
...@@ -547,13 +546,14 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -547,13 +546,14 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
}) })
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_op, main_block, op_dist_attr) copy_distributed_attr_for_dist_op(ctx, matmul_op, main_block,
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, op_dist_attr)
copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr) op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
@staticmethod @staticmethod
...@@ -567,12 +567,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -567,12 +567,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
super(DistributedMatmulImpl2, self).__init__() super(DistributedMatmulImpl2, self).__init__()
self._name = name self._name = name
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -592,8 +589,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -592,8 +589,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
...@@ -605,9 +603,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -605,9 +603,9 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
...@@ -625,13 +623,14 @@ register_distributed_operator_impl("matmul", ...@@ -625,13 +623,14 @@ register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel")) DistributedMatmulImpl2("replicate_parallel"))
class DistributedMatmulV2(DistributedOperator): class DistributedMatmulV2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedMatmulV2, self).__init__() super(DistributedMatmulV2, self).__init__()
self._name = name self._name = name
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2")) register_distributed_operator_impl_container("matmul_v2",
DistributedMatmulV2("matmul_v2"))
# ColumnParallel # ColumnParallel
...@@ -642,12 +641,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -642,12 +641,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -662,8 +658,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -662,8 +658,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]): if is_dim_replicate(out_dims_mapping[-1]):
...@@ -673,9 +670,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -673,9 +670,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
...@@ -686,18 +683,18 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -686,18 +683,18 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# check validation of inputs / outputs # check validation of inputs / outputs
...@@ -724,8 +721,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -724,8 +721,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
Weight_var.name)[1] Weight_var.name)[1]
assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_col_dim_mapping) matmul_col_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_col_dim_mapping parallel_axis = matmul_col_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
...@@ -741,7 +738,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -741,7 +738,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
persistable=False, persistable=False,
stop_gradient=X_var.stop_gradient) stop_gradient=X_var.stop_gradient)
# copy X_var's dist_attr to intermediate_var_0's dist_attr # copy X_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, X_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, X_var)
check_variable_and_dtype( check_variable_and_dtype(
X_var, 'tensor', X_var, 'tensor',
...@@ -770,14 +767,14 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -770,14 +767,14 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
attrs=attrs) attrs=attrs)
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(c_identity_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_identity_op, main_block,
op_dist_attr) op_dist_attr)
copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block,
op_dist_attr) op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
@staticmethod @staticmethod
...@@ -793,12 +790,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -793,12 +790,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -814,8 +808,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -814,8 +808,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False return False
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]): if is_dim_shard(out_dims_mapping[-1]):
...@@ -826,9 +821,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -826,9 +821,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
return False return False
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
...@@ -839,18 +834,18 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -839,18 +834,18 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
startup_block = dist_op_helper.get_dst_startup_program().global_block() startup_block = dist_op_context.get_dst_startup_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in op_dist_attr.get_process_mesh().process_group: if rank_id not in op_dist_attr.process_mesh.processes:
rank_id = _get_corresponding_rank(op_dist_attr.get_process_mesh(), rank_id = _get_corresponding_rank(ctx, op_dist_attr.process_mesh,
rank_id) rank_id)
# check validation of inputs / outputs # check validation of inputs / outputs
...@@ -877,8 +872,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -877,8 +872,8 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
Weight_var.name)[0] Weight_var.name)[0]
assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert matmul_row_dim_mapping >= 0, "row_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
matmul_row_dim_mapping) matmul_row_dim_mapping)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
process_mesh_group = op_dist_attr.get_process_mesh().process_group process_mesh_group = op_dist_attr.process_mesh.processes
parallel_axis = matmul_row_dim_mapping parallel_axis = matmul_row_dim_mapping
group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape, group_ranks = _get_comm_group(process_mesh_group, process_mesh_shape,
...@@ -900,7 +895,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -900,7 +895,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
is_data=False, is_data=False,
need_check_feed=Out_var.desc.need_check_feed()) need_check_feed=Out_var.desc.need_check_feed())
# copy Out_var's dist_attr to intermediate_var_0's dist_attr # copy Out_var's dist_attr to intermediate_var_0's dist_attr
copy_distributed_attr_for_var(op_dist_attr, intermediate_var_0, Out_var) copy_distributed_attr_for_var(ctx, intermediate_var_0, Out_var)
matmul_v2_op = main_block.append_op( matmul_v2_op = main_block.append_op(
type='matmul_v2', type='matmul_v2',
...@@ -919,14 +914,14 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -919,14 +914,14 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
}) })
# copy serial op's dist_attr to dist op's dist_attr # copy serial op's dist_attr to dist op's dist_attr
copy_distributed_attr_for_dist_op(matmul_v2_op, main_block, copy_distributed_attr_for_dist_op(ctx, matmul_v2_op, main_block,
op_dist_attr) op_dist_attr)
copy_distributed_attr_for_dist_op(c_allreduce_sum_op, main_block, copy_distributed_attr_for_dist_op(ctx, c_allreduce_sum_op, main_block,
op_dist_attr) op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter:
_init_param_sync(Weight_var, dist_op_helper, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
@staticmethod @staticmethod
...@@ -940,12 +935,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -940,12 +935,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
super(DistributedMatmulV2Impl2, self).__init__() super(DistributedMatmulV2Impl2, self).__init__()
self._name = name self._name = name
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0] y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -965,8 +957,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -965,8 +957,11 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
...@@ -978,9 +973,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -978,9 +973,9 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) dim_changed = _update_dims_mapping_for_matmul(dist_op)
if dim_changed: if dim_changed:
changed = True changed = True
return changed return changed
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard ...@@ -28,13 +28,14 @@ from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedReshape2(DistributedOperator): class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedReshape2, self).__init__() super(DistributedReshape2, self).__init__()
self._name = name self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2")) register_distributed_operator_impl_container("reshape2",
DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl): class DistributedReshapeImpl0(DistributedOperatorImpl):
...@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -44,12 +45,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -60,8 +58,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -75,9 +74,10 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
...@@ -103,11 +103,11 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -103,11 +103,11 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -139,7 +139,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
...@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -172,12 +172,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
self._forward_implemented = True self._forward_implemented = True
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -191,8 +188,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -203,9 +201,10 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
...@@ -231,11 +230,11 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -231,11 +230,11 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
kwargs: inputname_mapping & outputname_mapping kwargs: inputname_mapping & outputname_mapping
""" """
dist_op_helper = ctx.get_dist_op_helper() dist_op_context = ctx.dist_op_context
main_block = dist_op_helper.get_dst_main_program().global_block() main_block = dist_op_context.get_dst_main_program().global_block()
src_op = dist_op_helper.get_cur_src_op() src_op = dist_op_context.get_cur_src_op()
rank_id = dist_op_helper.get_rank_id() rank_id = dist_op_context.get_rank_id()
op_dist_attr = ctx.get_op_distributed_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -267,7 +266,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
# got dist attribute info # got dist attribute info
dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name)
process_mesh_shape = op_dist_attr.get_process_mesh().topology process_mesh_shape = op_dist_attr.process_mesh.topology
# modify target shape # modify target shape
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping ...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator): class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedSoftmax, self).__init__() super(DistributedSoftmax, self).__init__()
self._name = name self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax")) register_distributed_operator_impl_container("softmax",
DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl): class DistributedSoftmaxImpl(DistributedOperatorImpl):
...@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -40,12 +41,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """ op_desc = dist_op.serial_op.desc
return True op_dist_attr = dist_op.dist_attr
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
...@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -58,8 +56,9 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def is_output_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis') axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
...@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -72,9 +71,10 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
return True return True
def update_dims_mapping(self, op_dist_attr): def update_dims_mapping(self, dist_op):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from .common import DistributedOperator from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl from .common import DistributedOperatorImpl
from .common import register_distributed_operator from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl from .common import register_distributed_operator_impl
from ..utils import is_dim_shard from ..utils import is_dim_shard
from ..utils import is_dim_replicate from ..utils import is_dim_replicate
...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping ...@@ -24,13 +24,14 @@ from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator): class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, name): def __init__(self, name):
super(DistributedTranspose2, self).__init__() super(DistributedTranspose2, self).__init__()
self._name = name self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2")) register_distributed_operator_impl_container(
"transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl): class DistributedTranspose2Impl(DistributedOperatorImpl):
...@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -40,19 +41,16 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
self._forward_implemented = False self._forward_implemented = False
self._backward_implemented = True self._backward_implemented = True
def is_process_mesh_compatible(self, op_dist_attr): def is_input_compatible(self, dist_op):
""" No restriction for now. """
return True return True
def is_input_compatible(self, op_dist_attr): def is_output_compatible(self, dist_op):
return True return True
def is_output_compatible(self, op_dist_attr): def update_dims_mapping(self, dist_op):
return True
def update_dims_mapping(self, op_dist_attr):
changed = False changed = False
op_desc = op_dist_attr.get_owner_op().desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0] x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0] out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0] x_shape_name = op_desc.output('XShape')[0]
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
import paddle import paddle
from paddle.distributed.fleet import cloud_utils from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from .context import DistributedContext from .dist_context import DistributedContext
from .context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .completion import complete_annotation, complete_backward_annotation from .completion import complete_annotation, complete_backward_annotation
from .partitioner import Partitioner from .partitioner import Partitioner
from .process import get_all_process_groups from .process_group import get_all_process_groups
from .utils import make_data_unshard from .utils import make_data_unshard
from .reshard import reshard from .reshard import reshard
...@@ -70,7 +70,6 @@ class AutoParallelizer: ...@@ -70,7 +70,6 @@ class AutoParallelizer:
# Annotation completion # Annotation completion
completed_main_program = complete_annotation( completed_main_program = complete_annotation(
self._original_main_program, self._dist_context) self._original_main_program, self._dist_context)
# Logical partition # Logical partition
rank = paddle.distributed.get_rank() rank = paddle.distributed.get_rank()
partitioner = Partitioner(self._dist_strategy, self._dist_context, rank) partitioner = Partitioner(self._dist_strategy, self._dist_context, rank)
......
...@@ -22,15 +22,15 @@ from paddle.fluid import core, unique_name ...@@ -22,15 +22,15 @@ from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_ from paddle.fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm from paddle.fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops, ClipGradByGlobalNorm
from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy from paddle.distributed.fleet.base.distributed_strategy import DistributedStrategy
from paddle.distributed.auto_parallel.context import DistributedContext, DistOpHelper from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from .process import new_process_group from .dist_attribute import OperatorDistributedAttribute
from .interface import _g_process_mesh_map from .process_group import new_process_group
from .attribute import OperatorDistributedAttribute from .utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation from paddle.distributed.auto_parallel.completion import complete_backward_annotation, complete_update_annotation
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
...@@ -68,14 +68,14 @@ class Partitioner(object): ...@@ -68,14 +68,14 @@ class Partitioner(object):
# auto completion # auto completion
auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7]) auto.ProcessMesh(shape=[2, 4], process_group=[0, 1, 2, 3, 4, 5, 6, 7])
annotated_main_program = auto.complete_annotation(serial_main_program) annotated_main_program = auto.complete_annotation(serial_main_program)
auto_paralle_context = get_default_distributed_context() dist_context = get_default_distributed_context()
# distributed strategy & rank info # distributed strategy & rank info
rank_id = paddle.distributed.get_rank() rank_id = paddle.distributed.get_rank()
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
# create partitioner # create partitioner
Partitioner = Partitioner(dist_strategy, auto_paralle_context, rank_id) Partitioner = Partitioner(dist_strategy, dist_context, rank_id)
# create dist program with forward only # create dist program with forward only
# for distributed inference, using partitioned_main_prog from here # for distributed inference, using partitioned_main_prog from here
...@@ -93,11 +93,11 @@ class Partitioner(object): ...@@ -93,11 +93,11 @@ class Partitioner(object):
opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog) opt_ops = Partitioner.apply_optimize(optimizer, dist_params_grads, partitioned_main_prog, partitioned_startup_prog)
""" """
def __init__(self, dist_strategy, auto_parallel_context, rank_id=0): def __init__(self, dist_strategy, dist_context, rank_id=0):
""" """
Args: Args:
dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy. dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
auto_parallel_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario. dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
rank_id (int): global rank id to which the partitioned distributed program belong. rank_id (int): global rank id to which the partitioned distributed program belong.
""" """
...@@ -106,13 +106,13 @@ class Partitioner(object): ...@@ -106,13 +106,13 @@ class Partitioner(object):
"dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here" "dist_strategy be paddle.fleet.base.DistributedStrategy, got %s here"
% type(dist_strategy)) % type(dist_strategy))
if not isinstance(auto_parallel_context, DistributedContext): if not isinstance(dist_context, DistributedContext):
raise TypeError( raise TypeError(
"auto_parallel_context be paddle.fluid.DistributedContext, got %s here" "dist_context be paddle.fluid.DistributedContext, got %s here" %
% type(auto_parallel_context)) type(dist_context))
self._dist_strategy = dist_strategy self._dist_strategy = dist_strategy
self._auto_parallel_context = auto_parallel_context self._dist_context = dist_context
self._rank_id = rank_id self._rank_id = rank_id
self._serial2dist_varname_mapping = {} self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = "" self._dist_varname_suffix = ""
...@@ -218,8 +218,8 @@ class Partitioner(object): ...@@ -218,8 +218,8 @@ class Partitioner(object):
if not isinstance(startup_program, (Program)): if not isinstance(startup_program, (Program)):
raise TypeError( raise TypeError(
"auto_parallel_context be paddle.fluid.framework.program, got %s here" "dist_context be paddle.fluid.framework.program, got %s here" %
% type(startup_program)) type(startup_program))
# check if shard annotated serial program valid # check if shard annotated serial program valid
if not self._is_valid_annotated_program(main_program): if not self._is_valid_annotated_program(main_program):
...@@ -310,13 +310,12 @@ class Partitioner(object): ...@@ -310,13 +310,12 @@ class Partitioner(object):
if isinstance(var, Parameter): if isinstance(var, Parameter):
# TODO if var not belong to this rank, should be filtered # TODO if var not belong to this rank, should be filtered
serial_main_var = serial_main_block.var(var.name) serial_main_var = serial_main_block.var(var.name)
dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
serial_main_var) serial_main_var)
target_shape = _get_dist_shape(serial_main_var, dist_attr) target_shape = _get_dist_shape(serial_main_var, dist_attr)
new_name = var.name + self._dist_varname_suffix new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name temp_varname_map[var.name] = new_name
_partition_parameter(self._auto_parallel_context, _partition_parameter(self._dist_context, serial_main_var,
serial_main_var,
partitioned_startup_global_block, partitioned_startup_global_block,
new_name, target_shape) new_name, target_shape)
param2shape[new_name] = target_shape param2shape[new_name] = target_shape
...@@ -346,24 +345,22 @@ class Partitioner(object): ...@@ -346,24 +345,22 @@ class Partitioner(object):
assert new_op.desc == new_op_desc assert new_op.desc == new_op_desc
output_var = partitioned_startup_global_block.var(output_vars[ output_var = partitioned_startup_global_block.var(output_vars[
0]) 0])
output_var_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( output_var_attr = self._dist_context.get_tensor_dist_attr_for_program(
output_var) output_var)
op_attr = OperatorDistributedAttribute( op_attr = OperatorDistributedAttribute()
new_op, self._auto_parallel_context) op_attr.process_mesh = output_var_attr.process_mesh
op_attr.set_process_mesh(output_var_attr.get_process_mesh()) op_attr.set_output_dims_mapping(output_var.name,
op_attr.set_output_dims_mapping( output_var_attr.dims_mapping)
output_var.name, output_var_attr.get_dims_mapping()) op_attr.set_input_dims_mapping(output_var.name,
op_attr.set_input_dims_mapping( output_var_attr.dims_mapping)
output_var.name, output_var_attr.get_dims_mapping()) self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
self._auto_parallel_context.set_op_distributed_attr_for_program(
new_op, op_attr)
# TODO move helper init to a comm place # TODO move helper init to a comm place
dist_op_helper = self._auto_parallel_context.get_dist_op_helper() dist_op_context = self._dist_context.dist_op_context
dist_op_helper.set_dst_main_program(partitioned_main_prog) dist_op_context.set_dst_main_program(partitioned_main_prog)
dist_op_helper.set_dst_startup_program(partitioned_startup_prog) dist_op_context.set_dst_startup_program(partitioned_startup_prog)
dist_op_helper.set_varname_mapping(self._serial2dist_varname_mapping) dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping)
dist_op_helper.set_rank_id(self._rank_id) dist_op_context.set_rank_id(self._rank_id)
# transpile main program # transpile main program
for op in serial_ops: for op in serial_ops:
...@@ -373,8 +370,7 @@ class Partitioner(object): ...@@ -373,8 +370,7 @@ class Partitioner(object):
if serial_input_varname not in self._serial2dist_varname_mapping: if serial_input_varname not in self._serial2dist_varname_mapping:
new_varname = serial_input_varname + self._dist_varname_suffix new_varname = serial_input_varname + self._dist_varname_suffix
if serial_main_block.has_var(serial_input_varname): if serial_main_block.has_var(serial_input_varname):
_partition_var(self._auto_parallel_context, _partition_var(self._dist_context, serial_main_block,
serial_main_block,
partitioned_global_block, partitioned_global_block,
serial_input_varname, new_varname) serial_input_varname, new_varname)
else: else:
...@@ -387,28 +383,25 @@ class Partitioner(object): ...@@ -387,28 +383,25 @@ class Partitioner(object):
for serial_output_varname in op.desc.output_arg_names(): for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping: if serial_output_varname not in self._serial2dist_varname_mapping:
new_varname = serial_output_varname + self._dist_varname_suffix new_varname = serial_output_varname + self._dist_varname_suffix
_partition_var(self._auto_parallel_context, _partition_var(self._dist_context, serial_main_block,
serial_main_block, partitioned_global_block, partitioned_global_block,
serial_output_varname, new_varname) serial_output_varname, new_varname)
self._serial2dist_varname_mapping[ self._serial2dist_varname_mapping[
serial_output_varname] = new_varname serial_output_varname] = new_varname
# partition op # partition op
kinputs, koutputs = dist_op_helper.prepare_forward_context(op) kinputs, koutputs = dist_op_context.prepare_forward_context(op)
dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
op) if _is_dist_op_forward_implement(self._dist_context, op):
if _is_dist_op_forward_implement(self._auto_parallel_context, op): dist_ops = get_distributed_operator_impl_container(op.type)
dist_ops = get_distributed_operator(op.type) dist_op_impl = dist_ops.get_impl(dist_attr.impl_idx)
dist_op_impl = dist_ops.get_impl(dist_attr.get_impl_idx()) dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
dist_op_impl.forward(self._auto_parallel_context, **kinputs,
**koutputs)
else: else:
# replicate op # replicate op
dist_ops = get_distributed_operator("default") dist_ops = get_distributed_operator_impl_container("default")
dist_op_impl = dist_ops.get_impl(0) dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.forward(self._auto_parallel_context, **kinputs, dist_op_impl.forward(self._dist_context, **kinputs, **koutputs)
**koutputs)
return partitioned_main_prog, partitioned_startup_prog return partitioned_main_prog, partitioned_startup_prog
...@@ -453,18 +446,18 @@ class Partitioner(object): ...@@ -453,18 +446,18 @@ class Partitioner(object):
for param in no_grad_set for param in no_grad_set
] ]
dist_op_helper = self._auto_parallel_context.get_dist_op_helper() dist_op_context = self._dist_context.dist_op_context
params_and_grads = _auto_backward( params_and_grads = _auto_backward(
dist_loss, dist_loss,
dist_startup_program, dist_startup_program,
parameter_list=parameter_list, parameter_list=parameter_list,
no_grad_set=no_grad_set, no_grad_set=no_grad_set,
callbacks=callbacks, callbacks=callbacks,
distop_context=dist_op_helper) distop_context=dist_op_context)
# backward completion # backward completion
complete_backward_annotation( complete_backward_annotation(
dist_main_program, dist_context=self._auto_parallel_context) dist_main_program, dist_context=self._dist_context)
# transpiler backward for dist op # transpiler backward for dist op
# get backward ops # get backward ops
...@@ -485,31 +478,33 @@ class Partitioner(object): ...@@ -485,31 +478,33 @@ class Partitioner(object):
backward_ops = ops[first_backward_op_idx:] backward_ops = ops[first_backward_op_idx:]
for backward_op in backward_ops: for backward_op in backward_ops:
# if the backward op has a corresponding forward op # if the backward op has a corresponding forward op
if backward_op.desc.id() in dist_op_helper.gradopidx2opidx: if backward_op.desc.id() in dist_op_context.gradopidx2opidx:
forward_op_id = dist_op_helper.gradopidx2opidx[ forward_op_id = dist_op_context.gradopidx2opidx[
backward_op.desc.id()] backward_op.desc.id()]
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
# TODO backward attr should has _impl_idx # TODO backward attr should has _impl_idx
forward_op_dist_attr = self._auto_parallel_context.get_op_distributed_attr_for_program( forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
forward_op) forward_op)
# TODO use the backward op itself to find the dist op # TODO use the backward op itself to find the dist op
dist_ops = get_distributed_operator(forward_op.type) dist_ops = get_distributed_operator_impl_container(
kinputs, koutputs = dist_op_helper.prepare_backward_context( forward_op.type)
kinputs, koutputs = dist_op_context.prepare_backward_context(
backward_op) backward_op)
# TODO use backward op itself to determine impl idx # TODO use backward op itself to determine impl idx
if _is_dist_op_backward_implement( if _is_dist_op_backward_implement(self._dist_context,
self._auto_parallel_context, forward_op): forward_op):
dist_op_impl = dist_ops.get_impl( dist_op_impl = dist_ops.get_impl(
forward_op_dist_attr.get_impl_idx()) forward_op_dist_attr.impl_idx)
dist_op_impl.backward(self._auto_parallel_context, dist_op_impl.backward(self._dist_context, **kinputs,
**kinputs, **koutputs) **koutputs)
else: else:
# replicate op # replicate op
dist_ops = get_distributed_operator("default") dist_ops = get_distributed_operator_impl_container(
"default")
dist_op_impl = dist_ops.get_impl(0) dist_op_impl = dist_ops.get_impl(0)
dist_op_impl.backward(self._auto_parallel_context, dist_op_impl.backward(self._dist_context, **kinputs,
**kinputs, **koutputs) **koutputs)
return params_and_grads return params_and_grads
# replace dist grad ops # replace dist grad ops
...@@ -524,7 +519,7 @@ class Partitioner(object): ...@@ -524,7 +519,7 @@ class Partitioner(object):
# update completion # update completion
complete_update_annotation( complete_update_annotation(
main_program, dist_context=self._auto_parallel_context) main_program, dist_context=self._dist_context)
return optimize_ops return optimize_ops
...@@ -534,12 +529,11 @@ class Partitioner(object): ...@@ -534,12 +529,11 @@ class Partitioner(object):
ops = program.global_block().ops ops = program.global_block().ops
vars_ = program.list_vars() vars_ = program.list_vars()
op_dist_attrs = [ op_dist_attrs = [
self._auto_parallel_context.get_op_distributed_attr_for_program(op) self._dist_context.get_op_dist_attr_for_program(op) for op in ops
for op in ops
] ]
var_dist_attrs = [ var_dist_attrs = [
self._auto_parallel_context.get_tensor_distributed_attr_for_program( self._dist_context.get_tensor_dist_attr_for_program(var)
var) for var in vars_ for var in vars_
] ]
all_ops_annotated = all(dist_attr is not None all_ops_annotated = all(dist_attr is not None
...@@ -563,8 +557,7 @@ class Partitioner(object): ...@@ -563,8 +557,7 @@ class Partitioner(object):
def _is_var_distributed(self, var): def _is_var_distributed(self, var):
dist_attr = self._auto_parallel_context.get_tensor_distributed_attr_for_program( dist_attr = self._dist_context.get_tensor_dist_attr_for_program(var)
var)
assert dist_attr is not None, "dist_attr of var [{}] is None".format( assert dist_attr is not None, "dist_attr of var [{}] is None".format(
var.name) var.name)
return _is_distributed(dist_attr) return _is_distributed(dist_attr)
...@@ -637,20 +630,20 @@ def _get_no_grad_set(loss, no_grad_set=None): ...@@ -637,20 +630,20 @@ def _get_no_grad_set(loss, no_grad_set=None):
return no_grad_set return no_grad_set
def _is_dist_op_forward_implement(auto_paralle_context, op): def _is_dist_op_forward_implement(dist_context, op):
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator(op.type) dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.get_impl_idx())._forward_implemented dist_attr.impl_idx)._forward_implemented
def _is_dist_op_backward_implement(auto_paralle_context, op): def _is_dist_op_backward_implement(dist_context, op):
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(op) dist_attr = dist_context.get_op_dist_attr_for_program(op)
dist_ops = get_distributed_operator(op.type) dist_ops = get_distributed_operator_impl_container(op.type)
return dist_ops and dist_attr.get_impl_idx() >= 0 and dist_ops.get_impl( \ return dist_ops and dist_attr.impl_idx >= 0 and dist_ops.get_impl( \
dist_attr.get_impl_idx())._backward_implemented dist_attr.impl_idx)._backward_implemented
def _auto_backward(loss, def _auto_backward(loss,
...@@ -690,8 +683,8 @@ def _auto_backward(loss, ...@@ -690,8 +683,8 @@ def _auto_backward(loss,
def _is_distributed(dist_attr): def _is_distributed(dist_attr):
mapping = dist_attr.get_dims_mapping() mapping = dist_attr.dims_mapping
mesh = dist_attr.get_process_mesh().topology mesh = dist_attr.process_mesh.topology
for idx in range(len(mapping)): for idx in range(len(mapping)):
if mapping[idx] >= 0 and mesh[mapping[idx]] > 1: if mapping[idx] >= 0 and mesh[mapping[idx]] > 1:
return True return True
...@@ -702,8 +695,8 @@ def _is_distributed(dist_attr): ...@@ -702,8 +695,8 @@ def _is_distributed(dist_attr):
def _get_dist_shape(var, dist_attr): def _get_dist_shape(var, dist_attr):
var_shape = var.shape var_shape = var.shape
mapping = dist_attr.get_dims_mapping() mapping = dist_attr.dims_mapping
mesh = dist_attr.get_process_mesh().topology mesh = dist_attr.process_mesh.topology
assert len(var_shape) == len( assert len(var_shape) == len(
mapping mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
...@@ -721,7 +714,7 @@ def _get_dist_shape(var, dist_attr): ...@@ -721,7 +714,7 @@ def _get_dist_shape(var, dist_attr):
return new_shape return new_shape
def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape): dst_shape):
# NOTE hack to copied Parameter # NOTE hack to copied Parameter
# not initialized parameter, need to initialize it # not initialized parameter, need to initialize it
...@@ -749,17 +742,13 @@ def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname, ...@@ -749,17 +742,13 @@ def _partition_parameter(auto_paralle_context, src_var, dst_block, dst_varname,
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid() # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# param.desc.set_distributed_attr_uid(distributed_attr_uid) # param.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy( dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None assert dist_attr is not None
dist_attr._owner_tensor = param dist_context.set_tensor_dist_attr_for_program(param, dist_attr)
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(param,
dist_attr)
def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
dst_varname, dst_shape): dst_shape):
var = dst_block.create_var( var = dst_block.create_var(
type=src_var.type, type=src_var.type,
name=dst_varname, name=dst_varname,
...@@ -776,15 +765,12 @@ def _partition_intermediate_var(auto_paralle_context, src_var, dst_block, ...@@ -776,15 +765,12 @@ def _partition_intermediate_var(auto_paralle_context, src_var, dst_block,
# distributed_attr_uid = src_var.desc.get_distributed_attr_uid() # distributed_attr_uid = src_var.desc.get_distributed_attr_uid()
# var.desc.set_distributed_attr_uid(distributed_attr_uid) # var.desc.set_distributed_attr_uid(distributed_attr_uid)
dist_attr = copy.deepcopy( dist_attr = copy.deepcopy(
auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) dist_context.get_tensor_dist_attr_for_program(src_var))
assert dist_attr is not None assert dist_attr is not None
dist_attr._owner_tensor = var dist_context.set_tensor_dist_attr_for_program(var, dist_attr)
dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program(
src_var)._owner_context
auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr)
def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, def _partition_var(dist_context, src_block, dst_block, src_varname,
dst_varname): dst_varname):
""" """
partition include: split + replicate partition include: split + replicate
...@@ -798,16 +784,15 @@ def _partition_var(auto_paralle_context, src_block, dst_block, src_varname, ...@@ -798,16 +784,15 @@ def _partition_var(auto_paralle_context, src_block, dst_block, src_varname,
persistable=True, persistable=True,
stop_gradient=True) stop_gradient=True)
else: else:
dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
src_var)
target_shape = _get_dist_shape(src_var, dist_attr) target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter): if isinstance(src_var, Parameter):
_partition_parameter(auto_paralle_context, src_var, dst_block, _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_varname, target_shape) target_shape)
else: else:
_partition_intermediate_var(auto_paralle_context, src_var, _partition_intermediate_var(dist_context, src_var, dst_block,
dst_block, dst_varname, target_shape) dst_varname, target_shape)
def _insert_src_op(src_op, dst_block, varname_mapping): def _insert_src_op(src_op, dst_block, varname_mapping):
...@@ -822,8 +807,7 @@ def _insert_src_op(src_op, dst_block, varname_mapping): ...@@ -822,8 +807,7 @@ def _insert_src_op(src_op, dst_block, varname_mapping):
dst_block._sync_with_cpp() dst_block._sync_with_cpp()
def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, def _insert_dist_op(src_op, dst_block, varname_mapping, dist_context, rank_id):
rank_id):
# build input varname mapping # build input varname mapping
input_mapping = {} input_mapping = {}
...@@ -842,10 +826,9 @@ def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context, ...@@ -842,10 +826,9 @@ def _insert_dist_op(src_op, dst_block, varname_mapping, auto_paralle_context,
output_mapping[output_name] = varnames output_mapping[output_name] = varnames
# append dist op # append dist op
dist_attr = auto_paralle_context.get_op_distributed_attr_for_program(src_op) dist_attr = dist_context.get_op_dist_attr_for_program(src_op)
dist_ops = get_distributed_operator(src_op.type) dist_ops = get_distributed_operator_impl_container(src_op.type)
append_op_handle = dist_ops.get_impl(dist_attr.get_impl_idx()).forward( append_op_handle = dist_ops.get_impl(dist_attr.impl_idx).forward(src_op)
src_op)
append_op_handle( append_op_handle(
dst_block, dst_block,
src_op, src_op,
......
...@@ -19,62 +19,32 @@ from ..collective import _new_ring_id ...@@ -19,62 +19,32 @@ from ..collective import _new_ring_id
from ...fluid.framework import in_dygraph_mode from ...fluid.framework import in_dygraph_mode
from ...fluid.layers.tensor import fill_constant from ...fluid.layers.tensor import fill_constant
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = None _g_process_group_map = {}
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = None
def get_all_logical_process_set():
from .interface import _g_process_mesh_map
all_logical_process_set = set(_g_process_mesh_map[0].process_group)
return all_logical_process_set
def get_logical_process_to_physical_process_map():
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
return LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
def set_logical_process_to_physical_process_map(mapping):
global LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP
LOGICAL_PROCESS_TO_PHYSICAL_PROCESS_MAP = mapping
def get_processor_to_physical_process_map():
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
return PROCESSOR_TO_PHYSICAL_PROCESS_MAP
def set_processor_to_physical_process_map(mapping):
global PROCESSOR_TO_PHYSICAL_PROCESS_MAP
PROCESSOR_TO_PHYSICAL_PROCESS_MAP = mapping
PROCESS_GROUP_MAP = {}
def get_all_process_groups(): def get_all_process_groups():
global PROCESS_GROUP_MAP global _g_process_group_map
return PROCESS_GROUP_MAP.values() return _g_process_group_map.values()
def new_process_group(ranks): def new_process_group(ranks):
global PROCESS_GROUP_MAP global _g_process_group_map
if not PROCESS_GROUP_MAP: if not _g_process_group_map:
genv = _get_global_env() genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup( _g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size))) 0, list(range(genv.world_size)))
# A key constructed from ranks is used in the global process group map # A key constructed from ranks is used in the global process group map
key = ''.join(map(str, sorted(ranks))) key = ''.join(map(str, sorted(ranks)))
if key not in PROCESS_GROUP_MAP: if key not in _g_process_group_map:
num_groups = len(PROCESS_GROUP_MAP) num_groups = len(_g_process_group_map)
# Note: our process group may interfere with the original implementation # Note: our process group may interfere with the original implementation
# so the created group id should start from the original _new_ring_id() # so the created group id should start from the original _new_ring_id()
group_id = _new_ring_id() + num_groups + 1 group_id = _new_ring_id() + num_groups + 1
pg = ProcessGroup(group_id, ranks) pg = ProcessGroup(group_id, ranks)
PROCESS_GROUP_MAP[key] = pg _g_process_group_map[key] = pg
return pg return pg
else: else:
pg = PROCESS_GROUP_MAP[key] pg = _g_process_group_map[key]
return pg return pg
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import copy
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.topology == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3]
"""
def __init__(self, mesh):
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
processes = _flatten_nested_list(mesh)
assert all(isinstance(p, int) for p in processes), \
("All elements of mesh must be integer")
assert min(processes) >= 0, ('All elements of mesh must be >= 0.')
unique_processes = set(processes)
assert len(unique_processes) == len(processes), (
'All elements of mesh must be unique.')
self._topology = _get_nested_list_shape(mesh)
self._processes = processes
from .dist_context import get_default_distributed_context
default_dist_cxt = get_default_distributed_context()
default_dist_cxt.add_process_mesh(self)
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def processes(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
def __eq__(self, other):
if not isinstance(other, ProcessMesh):
return False
if self.topology != other.topology or self.processes != other.processes:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.processes)
return str
...@@ -22,9 +22,9 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -22,9 +22,9 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Program, OpProtoHolder from paddle.fluid.framework import Program, OpProtoHolder
import paddle.fluid.layers.utils as utils import paddle.fluid.layers.utils as utils
from ..collective import _get_global_env from ..collective import _get_global_env
from .context import DistributedContext from .dist_context import DistributedContext
from .attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process import new_process_group, ProcessGroup, PROCESS_GROUP_MAP from .process_group import new_process_group, ProcessGroup, _g_process_group_map
class AllGatherOpDesc: class AllGatherOpDesc:
...@@ -276,20 +276,22 @@ def _is_overlapped(shape_x, shape_y): ...@@ -276,20 +276,22 @@ def _is_overlapped(shape_x, shape_y):
return overlapped return overlapped
def _need_reshard(tensor_dist_attr, op_dist_attr): def _need_reshard(dist_tensor, dist_op):
"""Judge the tensor whether needs to be resharded.""" """Judge the tensor whether needs to be resharded."""
is_reshard = False is_reshard = False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dist_attr = dist_tensor.dist_attr
tensor_process_mesh = tensor_dist_attr.get_process_mesh() tensor_name = dist_tensor.serial_tensor.name
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( tensor_dims_mapping = tensor_dist_attr.dims_mapping
tensor_dist_attr.get_owner_tensor().name) tensor_process_mesh = tensor_dist_attr.process_mesh
op_process_mesh = op_dist_attr.get_process_mesh() op_dist_attr = dist_op.dist_attr
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
op_process_mesh = op_dist_attr.process_mesh
if all( if all(
map(lambda x: x is not None, [ map(lambda x: x is not None, [
tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping, tensor_dims_mapping, tensor_process_mesh, op_input_dims_mapping,
op_process_mesh op_process_mesh
])): ])):
if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh._id != op_process_mesh._id: if tensor_dims_mapping != op_input_dims_mapping or tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
return is_reshard return is_reshard
...@@ -305,28 +307,30 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping): ...@@ -305,28 +307,30 @@ def _compute_complete_shape(slice_shape, process_shape, dims_mapping):
return complete_shape return complete_shape
def find_op_desc_seq(source_tensor, tensor_dist_attr, op_dist_attr): def find_op_desc_seq(dist_tensor, dist_op):
""" """
Find the op description sequence to reshard the source tensor for matching the op requirement. Find the op description sequence to reshard the source tensor for matching the op requirement.
Args: Args:
source_tensor (Variable): A tensor with distributed attribute. dist_tensor (DistributedTensor): A distributed tensor.
tensor_dist_attr (TensorDistributedAttribute): The distributed attribute of tensor. dist_op (DistributedOperator): A distributed operator.
op_dist_attr (OperatorDistributedAttribute): The distributed attribute of operator.
Returns: Returns:
Dict, the dict represents the required op description sequence corresponding to process, The key of dict is Dict, the dict represents the required op description sequence corresponding to process, The key of dict is
process and value is a list containing op description. process and value is a list containing op description.
""" """
source_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dist_attr = dist_tensor.dist_attr
source_process_mesh = tensor_dist_attr.get_process_mesh() source_tensor = dist_tensor.serial_tensor
source_process_group = source_process_mesh.process_group tensor_name = source_tensor.name
source_dims_mapping = tensor_dist_attr.dims_mapping
source_process_mesh = tensor_dist_attr.process_mesh
source_process_group = source_process_mesh.processes
source_process_shape = source_process_mesh.topology source_process_shape = source_process_mesh.topology
target_process_mesh = op_dist_attr.get_process_mesh() op_dist_attr = dist_op.dist_attr
target_dims_mapping = op_dist_attr.get_input_dims_mapping( target_process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.get_owner_tensor().name) target_dims_mapping = op_dist_attr.get_input_dims_mapping(tensor_name)
target_process_group = target_process_mesh.process_group target_process_group = target_process_mesh.processes
target_process_shape = target_process_mesh.topology target_process_shape = target_process_mesh.topology
complete_shape = _compute_complete_shape( complete_shape = _compute_complete_shape(
...@@ -662,11 +666,11 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index, ...@@ -662,11 +666,11 @@ def _concat_partitions_with_op(partition_tensor_list, tensor, partition_index,
def _init_comm_for_send_recv(): def _init_comm_for_send_recv():
if not PROCESS_GROUP_MAP: if not _g_process_group_map:
genv = _get_global_env() genv = _get_global_env()
PROCESS_GROUP_MAP["global_group"] = ProcessGroup( _g_process_group_map["global_group"] = ProcessGroup(
0, list(range(genv.world_size))) 0, list(range(genv.world_size)))
PROCESS_GROUP_MAP["global_group"].instantiate() _g_process_group_map["global_group"].instantiate()
HAS_SENT = {} HAS_SENT = {}
...@@ -773,31 +777,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op, ...@@ -773,31 +777,29 @@ def parse_op_desc(program, rank_id, op_desc_seq, var_name, reshard_op,
axes=op_desc.axes, axes=op_desc.axes,
new_var_name=new_name) new_var_name=new_name)
tensor_attr = TensorDistributedAttribute(target_tensor, tensor_attr = TensorDistributedAttribute()
dist_context) process_mesh = dist_context.get_op_dist_attr_for_program(
process_mesh = dist_context.get_op_distributed_attr_for_program( matched_op).process_mesh
matched_op).get_process_mesh() dims_mapping = dist_context.get_op_dist_attr_for_program(
dims_mapping = dist_context.get_op_distributed_attr_for_program(
matched_op).get_input_dims_mapping(var_name) matched_op).get_input_dims_mapping(var_name)
tensor_attr.set_dims_mapping(dims_mapping) tensor_attr.dims_mapping = dims_mapping
tensor_attr.set_process_mesh(process_mesh) tensor_attr.process_mesh = process_mesh
dist_context.set_tensor_distributed_attr_for_program(target_tensor, dist_context.set_tensor_dist_attr_for_program(target_tensor,
tensor_attr) tensor_attr)
# rename op input name according to new name # rename op input name according to new name
for op in block.ops: for op in block.ops:
for name in op.input_arg_names: for name in op.input_arg_names:
op_dist_attr = dist_context.get_op_distributed_attr_for_program( op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
op)
if name == var_name and op_dist_attr is not None: if name == var_name and op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh() op_process_mesh = op_dist_attr.process_mesh
op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( op_input_dims_mapping = op_dist_attr.get_input_dims_mapping(
var_name) var_name)
if op_process_mesh._id == process_mesh._id and op_input_dims_mapping == dims_mapping: if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping:
op.desc._rename_input(name, target_tensor.name) op.desc._rename_input(name, target_tensor.name)
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
target_tensor.name, dims_mapping) target_tensor.name, dims_mapping)
op_dist_attr._dims_mapping.pop(name, None) op_dist_attr.set_input_dist_attr(name, None)
def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
...@@ -825,9 +827,9 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): ...@@ -825,9 +827,9 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
if op.type == "c_sync_comm_stream": if op.type == "c_sync_comm_stream":
need_save = [] need_save = []
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
process_mesh = dist_context.get_tensor_distributed_attr_for_program( process_mesh = dist_context.get_tensor_dist_attr_for_program(
vars[var_name]).get_process_mesh() vars[var_name]).process_mesh
if rank_id in process_mesh.process_group: if rank_id in process_mesh.processes:
need_save.append(var_name) need_save.append(var_name)
if not need_save: if not need_save:
remove_op_idx.append(idx) remove_op_idx.append(idx)
...@@ -839,10 +841,10 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): ...@@ -839,10 +841,10 @@ def _remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id):
continue continue
# judge the other op whether should be removed. # judge the other op whether should be removed.
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is not None: if op_dist_attr is not None:
op_process_mesh = op_dist_attr.get_process_mesh() op_process_mesh = op_dist_attr.process_mesh
if rank_id not in op_process_mesh.process_group and op.type not in not_remove_op_ref: if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref:
remove_op_idx.append(idx) remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]: for idx in remove_op_idx[::-1]:
...@@ -974,20 +976,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -974,20 +976,18 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if op_dist_attr is not None: if dist_op is not None:
idx_offset = 0 idx_offset = 0
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
# skip lod_tensor_blocking_queue_0 # skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0": if var_name == "lod_tensor_blocking_queue_0":
continue continue
var = block.vars[var_name] var = block.vars[var_name]
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( dist_tensor = dist_context.get_dist_tensor_for_program(var)
var) if dist_tensor is not None and _need_reshard(dist_tensor,
if tensor_dist_attr is not None and _need_reshard( dist_op):
tensor_dist_attr, op_dist_attr): reshard_op_desc = find_op_desc_seq(dist_tensor, dist_op)
reshard_op_desc = find_op_desc_seq(var, tensor_dist_attr,
op_dist_attr)
parse_op_desc(auto_parallel_main_prog, rank_id, parse_op_desc(auto_parallel_main_prog, rank_id,
reshard_op_desc, var_name, op, dist_context) reshard_op_desc, var_name, op, dist_context)
cur_op_count = len(block.ops) cur_op_count = len(block.ops)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import threading import threading
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
from .interface import _g_process_mesh_map
def is_valid_list_index(list, index): def is_valid_list_index(list, index):
...@@ -119,34 +118,35 @@ def remove_distributed_attr_suffix(name): ...@@ -119,34 +118,35 @@ def remove_distributed_attr_suffix(name):
def check_distributed_attr_for_program(program, dist_context=None): def check_distributed_attr_for_program(program, dist_context=None):
from .context import get_default_distributed_context from .dist_context import get_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \ assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check." "Distributed attributes must be initialized before check."
for block in program.blocks: for block in program.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( dist_tensor = dist_context.get_dist_tensor_for_graph(tensor)
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor) tensor)
if (tensor_dist_attr is not None) and ( if (tensor_dist_attr is not None) and (not dist_tensor.is_valid()):
not tensor_dist_attr.is_valid()):
return False return False
for op in block.ops: for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) dist_op = dist_context.get_dist_op_for_graph(tensor)
if (op_dist_attr is not None) and (not op_dist_attr.is_valid()): op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if (op_dist_attr is not None) and (not dist_op.is_valid()):
return False return False
return True return True
def print_program_with_distributed_attr(program, dist_context=None): def print_program_with_dist_attr(program, dist_context=None):
""" """
This function reuses the original program output ability with a distributed context. This function reuses the original program output ability with a distributed context.
Using lock can avoid multiple threads change the default distributed context simultaneously. Using lock can avoid multiple threads change the default distributed context simultaneously.
""" """
lock = threading.Lock() lock = threading.Lock()
lock.acquire() lock.acquire()
from .context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .context import set_default_distributed_context from .dist_context import set_default_distributed_context
if dist_context is None: if dist_context is None:
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
print(program) print(program)
...@@ -301,31 +301,29 @@ def _linear_idx2coordinate(mesh_shape, linear_idx): ...@@ -301,31 +301,29 @@ def _linear_idx2coordinate(mesh_shape, linear_idx):
return coordinate return coordinate
def _get_corresponding_rank(target_mesh, rank): def _get_corresponding_rank(dist_context, target_mesh, rank):
# TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case. # TODO(JZ-LIANG) a hack method to support varying mesh in Pipeline parallelism case.
# we assume that all mesh are evenly divide from a parent mesh and should have same size. # we assume that all mesh are evenly divide from a parent mesh and should have same size.
# to revise this in future. # to revise this in future.
coordinate = None coordinate = None
for key, mesh in _g_process_mesh_map.items(): for mesh in dist_context.process_meshes:
if key == 0: if rank in mesh.processes and mesh.topology == target_mesh.topology:
continue
if rank in mesh.process_group and mesh.topology == target_mesh.topology:
coordinate = _linear_idx2coordinate(mesh.topology, coordinate = _linear_idx2coordinate(mesh.topology,
mesh.process_group.index(rank)) mesh.processes.index(rank))
break break
assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format( assert coordinate is not None, "could NOT found rank [{}] in any registered mesh".format(
rank) rank)
return target_mesh.process_group[_coordinate2linear_idx(mesh.topology, return target_mesh.processes[_coordinate2linear_idx(mesh.topology,
coordinate)] coordinate)]
def _get_unshard_dist_shape(var, dist_attr): def _get_unshard_dist_shape(var, dist_attr):
var_shape = var.shape var_shape = var.shape
mapping = dist_attr.get_dims_mapping() mapping = dist_attr.dims_mapping
mesh = dist_attr.get_process_mesh().topology mesh = dist_attr.process_mesh.topology
assert len(var_shape) == len( assert len(var_shape) == len(
mapping mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
...@@ -341,19 +339,16 @@ def _get_unshard_dist_shape(var, dist_attr): ...@@ -341,19 +339,16 @@ def _get_unshard_dist_shape(var, dist_attr):
def make_data_unshard(dist_main_prog, dist_startup_prog): def make_data_unshard(dist_main_prog, dist_startup_prog):
from .context import get_default_distributed_context from .dist_context import get_default_distributed_context
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
for var in dist_main_prog.list_vars(): for var in dist_main_prog.list_vars():
if var.is_data: if var.is_data:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
var) var)
inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr) inverse_shape = _get_unshard_dist_shape(var, tensor_dist_attr)
var.desc.set_shape(inverse_shape) var.desc.set_shape(inverse_shape)
dim_mapping = tensor_dist_attr.get_dims_mapping() dim_mapping = tensor_dist_attr.dims_mapping
dim_mapping = [-1] * len(dim_mapping) dim_mapping = [-1] * len(dim_mapping)
tensor_dist_attr.set_dims_mapping(dim_mapping) tensor_dist_attr.dims_mapping = dim_mapping
dist_context.set_tensor_distributed_attr_for_program( dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr)
var, tensor_dist_attr)
var._set_attr('dim_mapping' + core.kAutoParallelSuffix(),
dim_mapping)
...@@ -1308,13 +1308,12 @@ class Variable(object): ...@@ -1308,13 +1308,12 @@ class Variable(object):
if self.persistable: if self.persistable:
var_str = "persist " + var_str var_str = "persist " + var_str
from paddle.distributed.auto_parallel.context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( dist_tensor = dist_context.get_dist_tensor_for_program(self)
self) if dist_tensor is not None:
if var_dist_attr is not None:
var_str += ", {name} = {value}".format( var_str += ", {name} = {value}".format(
name="dist_attr", value=var_dist_attr) name="dist_attr", value=dist_tensor)
return var_str return var_str
...@@ -2529,12 +2528,12 @@ class Operator(object): ...@@ -2529,12 +2528,12 @@ class Operator(object):
if i != len(attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
from paddle.distributed.auto_parallel.context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
dist_context = get_default_distributed_context() dist_context = get_default_distributed_context()
op_dist_attr = dist_context.get_op_distributed_attr_for_program(self) dist_op = dist_context.get_dist_op_for_program(self)
if op_dist_attr is not None: if dist_op is not None:
attrs_str += ", {name} = {value}".format( attrs_str += ", {name} = {value}".format(
name="dist_attr", value=op_dist_attr) name="dist_attr", value=dist_op)
if outputs_str != "{}": if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
......
...@@ -36,8 +36,7 @@ class TestDataUnshard(unittest.TestCase): ...@@ -36,8 +36,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program): def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program): with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1]) MESH_0 = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
input = paddle.static.data(name='input', shape=[2, 8]) input = paddle.static.data(name='input', shape=[2, 8])
label = paddle.static.data(name='label', shape=[2, 8]) label = paddle.static.data(name='label', shape=[2, 8])
...@@ -47,10 +46,30 @@ class TestDataUnshard(unittest.TestCase): ...@@ -47,10 +46,30 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr) linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[0, -1]) auto.shard_tensor(
auto.shard_tensor(label, MESH_0, dim_mapping=[0, -1]) input,
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, -1]) dist_attr={
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[-1, -1]) "process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(
linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
linear0_out = linear0(input) linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out) gelu_out = F.gelu(linear0_out)
...@@ -105,8 +124,7 @@ class TestDataUnshard(unittest.TestCase): ...@@ -105,8 +124,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program): def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program): with paddle.static.program_guard(train_program, start_program):
ROOT_MESH = auto.ProcessMesh([0, 1]) MESH_0 = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], ROOT_MESH)
input = paddle.static.data(name='input', shape=[8, 8]) input = paddle.static.data(name='input', shape=[8, 8])
label = paddle.static.data(name='label', shape=[8, 8]) label = paddle.static.data(name='label', shape=[8, 8])
...@@ -116,11 +134,31 @@ class TestDataUnshard(unittest.TestCase): ...@@ -116,11 +134,31 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr) linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input, MESH_0, dim_mapping=[-1, -1]) auto.shard_tensor(
auto.shard_tensor(label, MESH_0, dim_mapping=[-1, -1]) input,
dist_attr={
auto.shard_tensor(linear0.weight, MESH_0, dim_mapping=[-1, 0]) "process_mesh": MESH_0,
auto.shard_tensor(linear1.weight, MESH_0, dim_mapping=[0, -1]) "dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
linear0_out = linear0(input) linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out) gelu_out = F.gelu(linear0_out)
......
...@@ -24,13 +24,12 @@ import paddle.utils as utils ...@@ -24,13 +24,12 @@ import paddle.utils as utils
from paddle.fluid import layers from paddle.fluid import layers
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
import paddle.fluid.core as core import paddle.fluid.core as core
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -78,8 +77,12 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -78,8 +77,12 @@ def mlp_pretrain_forward(train_program, start_program):
label = static.data( label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32') name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1, -1]) auto.shard_tensor(
auto.set_pipeline_stage(1) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -99,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -99,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self): def test_mlp_serial(self):
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False dist_strategy.amp = False
...@@ -131,7 +134,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): ...@@ -131,7 +134,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
for op in block.ops: for op in block.ops:
for attr_name in op.attr_names: for attr_name in op.attr_names:
self.assertTrue(suffix not in attr_name) self.assertTrue(suffix not in attr_name)
# print_program_with_distributed_attr(distributed_main_program) # print_program_with_dist_attr(distributed_main_program)
self.assertIsNotNone(distributed_startup_program) self.assertIsNotNone(distributed_startup_program)
self.assertIsNotNone(distributed_main_program) self.assertIsNotNone(distributed_main_program)
......
...@@ -15,128 +15,153 @@ ...@@ -15,128 +15,153 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import functools
import operator
import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn as nn import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
paddle.enable_static() paddle.enable_static()
process_mesh1 = [0, 1, 2, 3]
def _flatten_nested_list(nested_list): process_mesh2 = [[0, 1, 2], [3, 4, 5]]
result = functools.reduce(operator.iconcat, nested_list, [])
return result
def _append_attr_suffix(name):
return name + core.kAutoParallelSuffix()
LAST_PP_STAGE = 3
MASK = [[0, 1, 1], [0, 1, 1]]
MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]])
class SimpleNet(nn.Layer): class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4): def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__() super(SimpleNet, self).__init__()
self.mesh = MESH
self.mesh.set_placement([5, 4, 3, 2, 1, 0])
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size) self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2) self.dense2 = nn.Linear(hidden_size, hidden_size // 2)
def forward(self, x, y): def forward(self, x, y):
x = dist.shard_tensor(x, self.mesh, dim_mapping=[0, -1]) # Test shard_tensor interface with dist_attr arg
x = dist.set_shard_mask(x, MASK) x = dist.shard_tensor(
x,
dist_attr={"process_mesh": process_mesh1,
"dims_mapping": [0, -1]})
emb_out = self.word_embeddings(x) emb_out = self.word_embeddings(x)
# Test shard_tensor interface with no dist_attr arg
dist.set_pipeline_stage(LAST_PP_STAGE) y = dist.shard_tensor(y)
y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1])
dist.set_offload_device(y, "cpu")
linear1 = self.dense1(y) linear1 = self.dense1(y)
out = self.dense2(linear1) out = self.dense2(linear1)
return x, y, self.mesh return x, y
class TestAutoParallelAPI(unittest.TestCase): class TestAutoParallelAPI(unittest.TestCase):
def test_api(self): def test_api(self):
dist_context = get_default_distributed_context()
net = SimpleNet() net = SimpleNet()
data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64") data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
data2 = fluid.layers.fill_constant( data2 = fluid.layers.fill_constant(
shape=[2, 4], value=2, dtype="float32") shape=[2, 4], value=2, dtype="float32")
data3 = fluid.layers.fill_constant( data3 = fluid.layers.fill_constant(
shape=[2, 4], value=4, dtype="float32") shape=[2, 4], value=4, dtype="float32")
x, y, mesh = net.forward(data1, data2)
mesh_attr = _append_attr_suffix('mesh_id')
x_mesh_id = x._get_attr(mesh_attr)
self.assertEqual(x_mesh_id, mesh._id)
x_mesh = x.process_mesh
allatts = x.attr_names
self.assertEqual(x_mesh, mesh)
shard_mask_attr = _append_attr_suffix('mask')
self.assertEqual(
x._get_attr(shard_mask_attr), _flatten_nested_list(MASK))
self.assertEqual(x.shard_mask, _flatten_nested_list(MASK))
offload_attr = _append_attr_suffix('offload_device')
self.assertEqual(y._get_attr(offload_attr), "cpu")
self.assertEqual(y.desc.has_attr(offload_attr), True)
self.assertEqual(y.offload_device, "cpu")
y._remove_attr(offload_attr)
self.assertEqual(y._has_attr(offload_attr), False)
ops = paddle.static.default_main_program().block(0).ops
first_op = ops[0]
last_op = ops[-1]
self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE) x, y = net.forward(data1, data2)
DIMS_MAPPING1 = [0, 1] dist_x = dist_context.get_dist_tensor_for_program(x)
DIMS_MAPPING2 = [-1, 0] self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1)
kwargs = {'x': data2, 'y': data3} self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1])
dist.shard_op( self.assertEqual(dist_x.dist_attr.shard_sizes, None)
self.assertEqual(dist_x.dist_attr.device_placement, None)
self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_x.dist_attr.is_annotated("device_placement"))
dist_y = dist_context.get_dist_tensor_for_program(y)
self.assertEqual(dist_y.dist_attr.process_mesh, None)
self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1])
self.assertEqual(dist_y.dist_attr.shard_sizes, None)
self.assertEqual(dist_y.dist_attr.device_placement, None)
self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh"))
self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_y.dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dims_mapping1 = [0, 1]
dims_mapping2 = [-1, 0]
dist_add = dist.shard_op(
paddle.add, paddle.add,
mesh=mesh, dist_attr={
dim_mapping_dict={ data2: {
data2.name: DIMS_MAPPING1, "process_mesh": process_mesh2,
data3.name: DIMS_MAPPING2 "dims_mapping": dims_mapping1
}, },
**kwargs) data3: {
"dims_mapping": dims_mapping2
}
})
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1] last_op = ops[-1]
self.assertEqual(last_op.process_mesh, mesh) dist_op = dist_context.get_dist_op_for_program(last_op)
attr_name = "IN_" + data2.name self.assertEqual(dist_op.dist_attr.process_mesh,
attr_name = _append_attr_suffix(attr_name) ProcessMesh(process_mesh2))
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING1) self.assertEqual(dist_op.dist_attr.impl_type, "default")
attr_name = "IN_" + data3.name self.assertEqual(dist_op.dist_attr.impl_idx, -2)
attr_name = _append_attr_suffix(attr_name) self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING2)
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
def test_process_mesh(self): self.assertEqual(data2_dist_attr.process_mesh,
mesh1 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=MESH) dist_op.dist_attr.process_mesh)
mesh2 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=mesh1) self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1)
mesh3 = dist.ProcessMesh([[0, 1], [2, 3]], parent=mesh1) self.assertEqual(data2_dist_attr.shard_sizes, None)
mesh4 = dist.ProcessMesh([[2, 3], [4, 5]], parent=mesh1) self.assertEqual(data2_dist_attr.device_placement, None)
self.assertTrue(data2_dist_attr.is_annotated("process_mesh"))
self.assertEqual(MESH.parent, None) self.assertTrue(data2_dist_attr.is_annotated("dims_mapping"))
self.assertEqual(mesh1.parent, MESH) self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertEqual(mesh1._desc.parent, MESH._id) self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
self.assertEqual(mesh3.parent, mesh1)
self.assertEqual(mesh4.parent, mesh1) data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(mesh1, mesh2) self.assertEqual(data3_dist_attr.process_mesh,
self.assertNotEqual(mesh3, mesh4) dist_op.dist_attr.process_mesh)
self.assertEqual(mesh2._id, mesh2._desc.id) self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2)
self.assertEqual(mesh3.topology, mesh3._desc.topology) self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(mesh3.topology, [2, 2]) self.assertEqual(data3_dist_attr.device_placement, None)
self.assertEqual(mesh3.process_group, [0, 1, 2, 3]) self.assertTrue(data3_dist_attr.is_annotated("process_mesh"))
self.assertEqual(mesh4.process_group, mesh4._desc.process_group) self.assertTrue(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dist_add = dist.shard_op(paddle.add)
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, -2)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertFalse(data2_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertFalse(data3_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,15 +28,14 @@ from paddle.fluid import layers ...@@ -28,15 +28,14 @@ from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
_global_process_mesh2 = None _global_process_mesh2 = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -62,20 +61,43 @@ class MLPLayer(nn.Layer): ...@@ -62,20 +61,43 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight,
auto.shard_tensor( dist_attr={
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight,
auto.shard_tensor( dist_attr={
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "pp": elif _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh2, self.linear1.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh2,
"dims_mapping": [1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -99,10 +121,18 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -99,10 +121,18 @@ def mlp_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -118,8 +148,7 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -118,8 +148,7 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -127,18 +156,15 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -127,18 +156,15 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_mp(self): def test_mlp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -147,18 +173,16 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -147,18 +173,16 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_dp_mp(self): def test_mlp_dp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -167,61 +191,59 @@ class TestMLPAutoCompletion(unittest.TestCase): ...@@ -167,61 +191,59 @@ class TestMLPAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context)) # def test_mlp_misc(self):
# # import pdb
def test_mlp_misc(self): # global _global_parallel_strategy
# import pdb # _global_parallel_strategy = "pp"
global _global_parallel_strategy # global _global_process_mesh
_global_parallel_strategy = "pp" # _global_process_mesh = auto.ProcessMesh(
global _global_process_mesh # mesh=[[0, 1], [2, 3]])
_global_process_mesh = auto.ProcessMesh( # global _global_process_mesh2
mesh=[[0, 1], [2, 3]], parent=ROOT_MESH) # _global_process_mesh2 = auto.ProcessMesh(
global _global_process_mesh2 # mesh=[[4, 5], [6, 7]])
_global_process_mesh2 = auto.ProcessMesh(
mesh=[[4, 5], [6, 7]], parent=ROOT_MESH) # train_program = static.Program()
# start_program = static.Program()
train_program = static.Program() # dist_context = DistributedContext()
start_program = static.Program() # train_program, start_program = mlp_pretrain_forward(train_program,
dist_context = DistributedContext() # start_program)
train_program, start_program = mlp_pretrain_forward(train_program, # # pdb.set_trace()
start_program) # complete_train_program = auto.complete_annotation(train_program,
# pdb.set_trace()
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context) # dist_context)
dist_context.finalize_distributed_attr_for_program( # # print_program_with_dist_attr(complete_train_program,
complete_train_program) # # dist_context)
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map # dist_context.finalize_distributed_attr_for_program(
for block in complete_train_program.blocks: # complete_train_program)
for tensor in block.vars.values(): # from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
desc = tensor.desc # for block in complete_train_program.blocks:
attr_name = append_distributed_attr_suffix("mesh_id") # for tensor in block.vars.values():
self.assertIsNotNone(desc.has_attr(attr_name)) # desc = tensor.desc
attr_name = append_distributed_attr_suffix("dim_mapping") # attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name)) # self.assertIsNotNone(desc.has_attr(attr_name))
for op in block.ops: # attr_name = append_distributed_attr_suffix("dims_mapping")
desc = op.desc # self.assertIsNotNone(desc.has_attr(attr_name))
attr_name = append_distributed_attr_suffix("mesh_id") # for op in block.ops:
self.assertIsNotNone(desc.has_attr(attr_name)) # desc = op.desc
for tensor_name in desc.input_arg_names(): # attr_name = append_distributed_attr_suffix("mesh_id")
attr_name = append_distributed_attr_suffix("IN_" + # self.assertIsNotNone(desc.has_attr(attr_name))
tensor_name) # for tensor_name in desc.input_arg_names():
self.assertIsNotNone(desc.has_attr(attr_name)) # attr_name = append_distributed_attr_suffix("IN_" +
for tensor_name in desc.output_arg_names(): # tensor_name)
attr_name = append_distributed_attr_suffix("OUT_" + # self.assertIsNotNone(desc.has_attr(attr_name))
tensor_name) # for tensor_name in desc.output_arg_names():
self.assertIsNotNone(desc.has_attr(attr_name)) # attr_name = append_distributed_attr_suffix("OUT_" +
set_default_distributed_context(dist_context) # tensor_name)
self.assertTrue("dist_attr" in str(complete_train_program)) # self.assertIsNotNone(desc.has_attr(attr_name))
with unittest.mock.patch( # set_default_distributed_context(dist_context)
"sys.stdout", new_callable=StringIO) as mock_stdout: # self.assertTrue("dist_attr" in str(complete_train_program))
print_program_with_distributed_attr(complete_train_program) # with unittest.mock.patch(
self.assertIsNotNone(mock_stdout.getvalue()) # "sys.stdout", new_callable=StringIO) as mock_stdout:
# print_program_with_dist_attr(complete_train_program)
# self.assertIsNotNone(mock_stdout.getvalue())
class AttentionLayer(nn.Layer): class AttentionLayer(nn.Layer):
...@@ -262,10 +284,18 @@ class AttentionLayer(nn.Layer): ...@@ -262,10 +284,18 @@ class AttentionLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input) q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -276,18 +306,42 @@ class AttentionLayer(nn.Layer): ...@@ -276,18 +306,42 @@ class AttentionLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
auto.shard_tensor( dist_attr={
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) "process_mesh": _global_process_mesh,
auto.shard_tensor( "dims_mapping": [-1, 0]
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) })
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -320,12 +374,18 @@ class AttentionLayer(nn.Layer): ...@@ -320,12 +374,18 @@ class AttentionLayer(nn.Layer):
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
return out return out
...@@ -357,8 +417,7 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -357,8 +417,7 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -366,18 +425,15 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -366,18 +425,15 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_mp(self): def test_attn_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -386,18 +442,16 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -386,18 +442,16 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_dp_mp(self): def test_attn_dp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -406,11 +460,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): ...@@ -406,11 +460,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
class DecoderLayer(nn.Layer): class DecoderLayer(nn.Layer):
...@@ -486,10 +538,18 @@ class DecoderLayer(nn.Layer): ...@@ -486,10 +538,18 @@ class DecoderLayer(nn.Layer):
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
...@@ -497,13 +557,17 @@ class DecoderLayer(nn.Layer): ...@@ -497,13 +557,17 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings) embeddings = self.dropout1(embeddings)
...@@ -521,18 +585,42 @@ class DecoderLayer(nn.Layer): ...@@ -521,18 +585,42 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
auto.shard_tensor( dist_attr={
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) "process_mesh": _global_process_mesh,
auto.shard_tensor( "dims_mapping": [-1, 0]
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) })
auto.shard_tensor(
self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -566,12 +654,18 @@ class DecoderLayer(nn.Layer): ...@@ -566,12 +654,18 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual # Add residual
residual = embeddings + self.dropout2(out) residual = embeddings + self.dropout2(out)
...@@ -586,14 +680,30 @@ class DecoderLayer(nn.Layer): ...@@ -586,14 +680,30 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight,
auto.shard_tensor( dist_attr={
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual # Add residual
final = residual + self.dropout3(out3) final = residual + self.dropout3(out3)
...@@ -631,8 +741,7 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -631,8 +741,7 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
...@@ -640,18 +749,15 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -640,18 +749,15 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_mp(self): def test_decoder_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -660,18 +766,16 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -660,18 +766,16 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_dp_mp(self): def test_decoder_dp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -680,11 +784,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): ...@@ -680,11 +784,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,13 +32,12 @@ from paddle.distributed.fleet import fleet ...@@ -32,13 +32,12 @@ from paddle.distributed.fleet import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MultiHeadAttention(nn.Layer): class MultiHeadAttention(nn.Layer):
...@@ -108,10 +107,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -108,10 +107,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
...@@ -145,19 +152,35 @@ class MultiHeadAttention(nn.Layer): ...@@ -145,19 +152,35 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -238,12 +261,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -238,12 +261,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
...@@ -411,17 +440,33 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -411,17 +440,33 @@ class TransformerDecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2( # tgt = self.dropout2(
# self.linear2(F.gelu( # self.linear2(F.gelu(
...@@ -485,13 +530,17 @@ class GPTEmbeddings(nn.Layer): ...@@ -485,13 +530,17 @@ class GPTEmbeddings(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
...@@ -717,10 +766,18 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -717,10 +766,18 @@ def gpt_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
gpt = GPTModel( gpt = GPTModel(
vocab_size=32768, vocab_size=32768,
...@@ -753,8 +810,7 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -753,8 +810,7 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -763,18 +819,15 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -763,18 +819,15 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_mp(self): def test_gpt_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -783,18 +836,16 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -783,18 +836,16 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_dp_mp(self): def test_gpt_dp_mp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
...@@ -803,11 +854,9 @@ class TestGPTAutoCompletion(unittest.TestCase): ...@@ -803,11 +854,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
# print_program_with_distributed_attr(complete_train_program, # print_program_with_dist_attr(complete_train_program,
# dist_context) # dist_context)
self.assertTrue( self.assertTrue(dist_context.validate_dist_attr_for_program())
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,21 +23,19 @@ import paddle.static as static ...@@ -23,21 +23,19 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.completion import complete_backward_annotation from paddle.distributed.auto_parallel.completion import complete_backward_annotation
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
_global_process_mesh = auto.ProcessMesh( PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH)
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
NUM_RANKS = 8 NUM_RANKS = 8
STAGE_0_CNT = 5 STAGE_0_CNT = 5
STAGE_1_CNT = 10 STAGE_1_CNT = 10
...@@ -70,9 +68,13 @@ class MLPLayer(nn.Layer): ...@@ -70,9 +68,13 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if self.is_distributed: if self.is_distributed:
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -120,8 +122,14 @@ def mlp_forward(train_program, start_program, is_distributed=True): ...@@ -120,8 +122,14 @@ def mlp_forward(train_program, start_program, is_distributed=True):
name="label", shape=[batch_size, 1], dtype='float32') name="label", shape=[batch_size, 1], dtype='float32')
if is_distributed: if is_distributed:
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) auto.shard_tensor(
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) input,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -137,8 +145,6 @@ def mlp_forward(train_program, start_program, is_distributed=True): ...@@ -137,8 +145,6 @@ def mlp_forward(train_program, start_program, is_distributed=True):
def get_dist_prog(train_program, startup_program, dist_context, rank_id): def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
......
...@@ -29,19 +29,17 @@ from paddle.fluid import layers ...@@ -29,19 +29,17 @@ from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
def get_programs(annotated_func): def get_programs(annotated_func):
...@@ -49,7 +47,7 @@ def get_programs(annotated_func): ...@@ -49,7 +47,7 @@ def get_programs(annotated_func):
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
global _global_process_mesh global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh) dist_context.process_mesh = _global_process_mesh
train_program, start_program = annotated_func(train_program, start_program) train_program, start_program = annotated_func(train_program, start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
dist_context) dist_context)
...@@ -95,9 +93,8 @@ def initialization_check(mode, dist_context, dist_startup_prog, ...@@ -95,9 +93,8 @@ def initialization_check(mode, dist_context, dist_startup_prog,
serial_startup_prog, var_need_broadcast, process_mesh, serial_startup_prog, var_need_broadcast, process_mesh,
mp_parallel_axis, dp_parallel_axis): mp_parallel_axis, dp_parallel_axis):
if 'mp' in mode: if 'mp' in mode:
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(
process_mesh.topology, mp_parallel_axis, process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
3)
mp_ring_id = new_process_group(group_ranks).id mp_ring_id = new_process_group(group_ranks).id
broadcast_ops = [ broadcast_ops = [
op for op in dist_startup_prog.global_block().ops op for op in dist_startup_prog.global_block().ops
...@@ -110,9 +107,8 @@ def initialization_check(mode, dist_context, dist_startup_prog, ...@@ -110,9 +107,8 @@ def initialization_check(mode, dist_context, dist_startup_prog,
return False return False
if 'dp' in mode: if 'dp' in mode:
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(
process_mesh.topology, dp_parallel_axis, process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
3)
dp_ring_id = new_process_group(group_ranks).id dp_ring_id = new_process_group(group_ranks).id
nparam = len(serial_startup_prog.all_parameters()) nparam = len(serial_startup_prog.all_parameters())
nbroadcast_dp = len([ nbroadcast_dp = len([
...@@ -137,22 +133,21 @@ def initialization_check(mode, dist_context, dist_startup_prog, ...@@ -137,22 +133,21 @@ def initialization_check(mode, dist_context, dist_startup_prog,
def get_input_var_dist_attr(op, main_program, dist_context): def get_input_var_dist_attr(op, main_program, dist_context):
varname = op.desc.input_arg_names() varname = op.desc.input_arg_names()
var = main_program.global_block().var(varname[0]) var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
return dist_attr return dist_attr
def get_output_var_dist_attr(op, main_program, dist_context): def get_output_var_dist_attr(op, main_program, dist_context):
varname = op.desc.output_arg_names() varname = op.desc.output_arg_names()
var = main_program.global_block().var(varname[0]) var = main_program.global_block().var(varname[0])
dist_attr = dist_context.get_tensor_distributed_attr_for_program(var) dist_attr = dist_context.get_tensor_dist_attr_for_program(var)
return dist_attr return dist_attr
def check_equal_var_dist_attr(serial_dist_attr, dist_attr): def check_equal_var_dist_attr(serial_dist_attr, dist_attr):
equal = True equal = True
if serial_dist_attr.get_process_mesh() != dist_attr.get_process_mesh() or \ if serial_dist_attr.process_mesh != dist_attr.process_mesh or \
serial_dist_attr.is_parameter() != dist_attr.is_parameter() or \ serial_dist_attr.dims_mapping != dist_attr.dims_mapping:
serial_dist_attr.get_dims_mapping() != dist_attr.get_dims_mapping():
equal = False equal = False
return equal return equal
...@@ -161,36 +156,33 @@ def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops, ...@@ -161,36 +156,33 @@ def check_equal_dist_op_attr(dist_context, dist_main_prog, serial_op, dist_ops,
dist_op_idx): dist_op_idx):
equal = True equal = True
# get serial op's process_mesh and impl_idx # get serial op's process_mesh and impl_idx
serial_op_dist_attr = dist_context.get_op_distributed_attr_for_program( serial_op_dist_attr = dist_context.get_op_dist_attr_for_program(serial_op)
serial_op) serial_process_mesh = serial_op_dist_attr.process_mesh
serial_process_mesh = serial_op_dist_attr.get_process_mesh() serial_impl_idx = serial_op_dist_attr.impl_idx
serial_impl_idx = serial_op_dist_attr.get_impl_idx()
# check dist_attr between serial op and dist op # check dist_attr between serial op and dist op
for i in dist_op_idx: for i in dist_op_idx:
op_dist_attr = dist_context.get_op_distributed_attr_for_program( op_dist_attr = dist_context.get_op_dist_attr_for_program(dist_ops[i])
dist_ops[i])
for in_varname in dist_ops[i].desc.input_arg_names(): for in_varname in dist_ops[i].desc.input_arg_names():
in_var = dist_main_prog.global_block().var(in_varname) in_var = dist_main_prog.global_block().var(in_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var) in_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dims_mapping = tensor_dist_attr.dims_mapping
in_var_dims_mapping = op_dist_attr.get_input_dims_mapping( in_var_dims_mapping = op_dist_attr.get_input_dims_mapping(
in_varname) in_varname)
if tensor_dims_mapping != in_var_dims_mapping: if tensor_dims_mapping != in_var_dims_mapping:
equal = False equal = False
for out_varname in dist_ops[i].desc.output_arg_names(): for out_varname in dist_ops[i].desc.output_arg_names():
out_var = dist_main_prog.global_block().var(out_varname) out_var = dist_main_prog.global_block().var(out_varname)
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var) out_var)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() tensor_dims_mapping = tensor_dist_attr.dims_mapping
out_var_dims_mapping = op_dist_attr.get_output_dims_mapping( out_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
out_varname) out_varname)
if tensor_dims_mapping != out_var_dims_mapping: if tensor_dims_mapping != out_var_dims_mapping:
equal = False equal = False
dist_op_process_mesh = op_dist_attr.process_mesh
dist_op_process_mesh = op_dist_attr.get_process_mesh() dist_op_impl_idx = op_dist_attr.impl_idx
dist_op_impl_idx = op_dist_attr.get_impl_idx()
if serial_op.desc.id() == dist_ops[i].desc.id() or \ if serial_op.desc.id() == dist_ops[i].desc.id() or \
serial_process_mesh != dist_op_process_mesh or \ serial_process_mesh != dist_op_process_mesh or \
serial_impl_idx != dist_op_impl_idx: serial_impl_idx != dist_op_impl_idx:
...@@ -242,13 +234,13 @@ def distributed_attr_check_for_program(dist_main_prog, dist_context): ...@@ -242,13 +234,13 @@ def distributed_attr_check_for_program(dist_main_prog, dist_context):
have_dist_attr = True have_dist_attr = True
for block in dist_main_prog.blocks: for block in dist_main_prog.blocks:
for tensor in block.vars.values(): for tensor in block.vars.values():
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
tensor) tensor)
if var_dist_attr is None: if var_dist_attr is None:
have_dist_attr = False have_dist_attr = False
for op in block.ops: for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) op_dist_attr = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attr is None: if op_dist_attr is None:
have_dist_attr = False have_dist_attr = False
...@@ -278,21 +270,43 @@ class MLPLayer(nn.Layer): ...@@ -278,21 +270,43 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
else: else:
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, self.linear0.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, self.linear1.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -316,10 +330,18 @@ def mlp_pretrain_forward(train_program, start_program): ...@@ -316,10 +330,18 @@ def mlp_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -335,8 +357,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -335,8 +357,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -372,8 +393,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -372,8 +393,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -437,7 +457,7 @@ class TestMLPAutoPartitioner(unittest.TestCase): ...@@ -437,7 +457,7 @@ class TestMLPAutoPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward) mlp_pretrain_forward)
...@@ -535,10 +555,18 @@ class AttentionLayer(nn.Layer): ...@@ -535,10 +555,18 @@ class AttentionLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1]) input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
q = self.q_proj(input) q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
...@@ -549,18 +577,42 @@ class AttentionLayer(nn.Layer): ...@@ -549,18 +577,42 @@ class AttentionLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -593,12 +645,18 @@ class AttentionLayer(nn.Layer): ...@@ -593,12 +645,18 @@ class AttentionLayer(nn.Layer):
out = self.out_proj(out) out = self.out_proj(out)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
return out return out
...@@ -630,8 +688,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -630,8 +688,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -666,8 +723,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -666,8 +723,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "mp" _global_parallel_strategy = "mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -735,7 +791,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase): ...@@ -735,7 +791,7 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward) attn_pretrain_forward)
...@@ -871,10 +927,18 @@ class DecoderLayer(nn.Layer): ...@@ -871,10 +927,18 @@ class DecoderLayer(nn.Layer):
def forward(self, input_ids, position_ids): def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
input_embeddings = self.word_embeddings(input_ids) input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
...@@ -882,13 +946,17 @@ class DecoderLayer(nn.Layer): ...@@ -882,13 +946,17 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
embeddings = input_embeddings + position_embeddings embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings) embeddings = self.dropout1(embeddings)
...@@ -906,18 +974,42 @@ class DecoderLayer(nn.Layer): ...@@ -906,18 +974,42 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -951,17 +1043,25 @@ class DecoderLayer(nn.Layer): ...@@ -951,17 +1043,25 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
else: else:
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, self.out_proj.weight,
_global_process_mesh, dist_attr={
dim_mapping=[-1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
# Add residual # Add residual
residual = embeddings + self.dropout2(out) residual = embeddings + self.dropout2(out)
...@@ -976,14 +1076,30 @@ class DecoderLayer(nn.Layer): ...@@ -976,14 +1076,30 @@ class DecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# Add residual # Add residual
final = residual + self.dropout3(out3) final = residual + self.dropout3(out3)
...@@ -1022,7 +1138,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1022,7 +1138,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp" _global_parallel_strategy = "dp_mp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward) decoder_pretrain_forward)
...@@ -1105,7 +1221,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase): ...@@ -1105,7 +1221,7 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
_global_parallel_strategy = "None" _global_parallel_strategy = "None"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward) decoder_pretrain_forward)
......
...@@ -32,14 +32,13 @@ from paddle.distributed import fleet ...@@ -32,14 +32,13 @@ from paddle.distributed import fleet
import paddle.static as static import paddle.static as static
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import _get_comm_group from paddle.distributed.auto_parallel.utils import _get_comm_group
from paddle.distributed.auto_parallel.process import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
...@@ -61,24 +60,27 @@ def is_valid_completed_program(dist_context, program): ...@@ -61,24 +60,27 @@ def is_valid_completed_program(dist_context, program):
ops = program.global_block().ops ops = program.global_block().ops
vars_ = program.list_vars() vars_ = program.list_vars()
for op in ops: for op in ops:
op_dist_attrs = dist_context.get_op_distributed_attr_for_program(op) op_dist_attrs = dist_context.get_op_dist_attr_for_program(op)
if op_dist_attrs == None: if op_dist_attrs == None:
return False return False
if op_dist_attrs.get_process_mesh == None: if op_dist_attrs.process_mesh == None:
return False return False
if None in op_dist_attrs._dims_mapping.values(): for tensor_dist_attr in op_dist_attrs.inputs_dist_attrs.values():
if None == tensor_dist_attr.dims_mapping:
return False
for tensor_dist_attr in op_dist_attrs.outputs_dist_attrs.values():
if None == tensor_dist_attr.dims_mapping:
return False return False
for var in vars_: for var in vars_:
var_dist_attrs = dist_context.get_tensor_distributed_attr_for_program( var_dist_attrs = dist_context.get_tensor_dist_attr_for_program(var)
var)
if var_dist_attrs == None: if var_dist_attrs == None:
return False return False
elif var_dist_attrs.get_process_mesh == None: elif var_dist_attrs.process_mesh == None:
return False return False
elif var_dist_attrs.get_dims_mapping == None: elif var_dist_attrs.dims_mapping == None:
return False return False
return True return True
...@@ -151,10 +153,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -151,10 +153,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
...@@ -188,19 +198,35 @@ class MultiHeadAttention(nn.Layer): ...@@ -188,19 +198,35 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value) v = self.v_proj(value)
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
...@@ -281,12 +307,18 @@ class MultiHeadAttention(nn.Layer): ...@@ -281,12 +307,18 @@ class MultiHeadAttention(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[0, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.out_proj.weight, _global_process_mesh, self.out_proj.weight,
dim_mapping=[1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
outs = [out] outs = [out]
if self.need_weights: if self.need_weights:
...@@ -454,17 +486,33 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -454,17 +486,33 @@ class TransformerDecoderLayer(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
# tgt = self.dropout2( # tgt = self.dropout2(
# self.linear2(F.gelu( # self.linear2(F.gelu(
...@@ -528,13 +576,17 @@ class GPTEmbeddings(nn.Layer): ...@@ -528,13 +576,17 @@ class GPTEmbeddings(nn.Layer):
if _global_parallel_strategy == "mp": if _global_parallel_strategy == "mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[0, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, self.word_embeddings.weight,
_global_process_mesh, dist_attr={
dim_mapping=[1, -1]) "process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings embeddings = input_embedings + position_embeddings
...@@ -760,10 +812,18 @@ def gpt_pretrain_forward(train_program, start_program): ...@@ -760,10 +812,18 @@ def gpt_pretrain_forward(train_program, start_program):
if _global_parallel_strategy == "dp": if _global_parallel_strategy == "dp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp": elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor( auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1]) input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
gpt = GPTModel( gpt = GPTModel(
vocab_size=32768, vocab_size=32768,
...@@ -798,12 +858,12 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -798,12 +858,12 @@ class TestGPTPartitioner(unittest.TestCase):
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh( _global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
train_program = static.Program() train_program = static.Program()
start_program = static.Program() start_program = static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
dist_context.set_process_mesh(_global_process_mesh) dist_context.process_mesh = _global_process_mesh
train_program, start_program, loss = gpt_pretrain_forward(train_program, train_program, start_program, loss = gpt_pretrain_forward(train_program,
start_program) start_program)
complete_train_program = auto.complete_annotation(train_program, complete_train_program = auto.complete_annotation(train_program,
...@@ -833,7 +893,7 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -833,7 +893,7 @@ class TestGPTPartitioner(unittest.TestCase):
opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads, opt_ops = partitioner.apply_optimize(optimizer, dist_params_grads,
auto_parallel_main_prog, auto_parallel_main_prog,
auto_parallel_startup_prog) auto_parallel_startup_prog)
from paddle.distributed.auto_parallel.context import set_default_distributed_context from paddle.distributed.auto_parallel.dist_context import set_default_distributed_context
set_default_distributed_context(dist_context) set_default_distributed_context(dist_context)
with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw: with open("./test_auto_parallel_partitioner_main_new.txt1", "w") as fw:
fw.write(str(auto_parallel_main_prog)) fw.write(str(auto_parallel_main_prog))
...@@ -877,14 +937,12 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -877,14 +937,12 @@ class TestGPTPartitioner(unittest.TestCase):
mp_parallel_axis = 1 mp_parallel_axis = 1
dp_parallel_axis = 0 dp_parallel_axis = 0
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(
process_mesh.topology, mp_parallel_axis, process_mesh.processes, process_mesh.topology, mp_parallel_axis, 3)
3)
mp_ring_id = new_process_group(group_ranks).id mp_ring_id = new_process_group(group_ranks).id
group_ranks = _get_comm_group(process_mesh.process_group, group_ranks = _get_comm_group(
process_mesh.topology, dp_parallel_axis, process_mesh.processes, process_mesh.topology, dp_parallel_axis, 3)
3)
dp_ring_id = new_process_group(group_ranks).id dp_ring_id = new_process_group(group_ranks).id
tensor_parallel_allreduce_vars = sorted([ tensor_parallel_allreduce_vars = sorted([
......
...@@ -22,16 +22,16 @@ import paddle.static as static ...@@ -22,16 +22,16 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import PROCESS_GROUP_MAP from paddle.distributed.auto_parallel.process_group import _g_process_group_map
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0, 1])
PP_MESH_0 = None PP_MESH_0 = None
PP_MESH_1 = None PP_MESH_1 = None
...@@ -57,16 +57,30 @@ class MLPLayer(nn.Layer): ...@@ -57,16 +57,30 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else: else:
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, self.linear0.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, self.linear1.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -88,12 +102,32 @@ def mlp_forward(train_program, start_program): ...@@ -88,12 +102,32 @@ def mlp_forward(train_program, start_program):
name="label", shape=[batch_size, 1], dtype='float32') name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) auto.shard_tensor(
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else: else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -108,8 +142,6 @@ def mlp_forward(train_program, start_program): ...@@ -108,8 +142,6 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id): def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh)
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
...@@ -136,22 +168,21 @@ def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): ...@@ -136,22 +168,21 @@ def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check):
has_dist_attr = True has_dist_attr = True
vars = dist_main_prog.global_block().vars vars = dist_main_prog.global_block().vars
op_dist_attr = dist_context.get_op_distributed_attr_for_program( op_dist_attr = dist_context.get_op_dist_attr_for_program(op_need_check)
op_need_check) if not op_dist_attr or not op_dist_attr.process_mesh:
if not op_dist_attr or not op_dist_attr.get_process_mesh():
has_dist_attr = False has_dist_attr = False
for var_name in op_need_check.input_arg_names: for var_name in op_need_check.input_arg_names:
if not op_dist_attr.get_input_dims_mapping(var_name) or \ if not op_dist_attr.get_input_dims_mapping(var_name) or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh:
has_dist_attr = False has_dist_attr = False
break break
if has_dist_attr: if has_dist_attr:
for var_name in op_need_check.output_arg_names: for var_name in op_need_check.output_arg_names:
if not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_dims_mapping() or \ if not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \
not dist_context.get_tensor_distributed_attr_for_program(vars[var_name]).get_process_mesh(): not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh:
has_dist_attr = False has_dist_attr = False
break break
...@@ -162,6 +193,7 @@ def check_send_recv_result(dist_main_prog, rank_id): ...@@ -162,6 +193,7 @@ def check_send_recv_result(dist_main_prog, rank_id):
send_result = False send_result = False
recv_result = False recv_result = False
ops = dist_main_prog.global_block().ops ops = dist_main_prog.global_block().ops
if rank_id == 0: if rank_id == 0:
for idx, op in enumerate(ops): for idx, op in enumerate(ops):
if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names:
...@@ -217,7 +249,7 @@ def check_initialization_for_dp(dist_startup_prog): ...@@ -217,7 +249,7 @@ def check_initialization_for_dp(dist_startup_prog):
class TestMLPReshard(unittest.TestCase): class TestMLPReshard(unittest.TestCase):
def test_complete_backward_annotation(self): def test_complete_backward_annotation(self):
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -231,6 +263,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -231,6 +263,7 @@ class TestMLPReshard(unittest.TestCase):
if op.type == "gelu_grad": if op.type == "gelu_grad":
op_need_check = op op_need_check = op
break break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr # grad op should have dist attr
self.assertTrue( self.assertTrue(
...@@ -241,11 +274,11 @@ class TestMLPReshard(unittest.TestCase): ...@@ -241,11 +274,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
global PP_MESH_0 global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1 global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1], parent=ROOT_MESH) PP_MESH_1 = auto.ProcessMesh(mesh=[1])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -253,9 +286,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -253,9 +286,10 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(PROCESS_GROUP_MAP.keys()): for key in list(_g_process_group_map.keys()):
del PROCESS_GROUP_MAP[key] del _g_process_group_map[key]
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
# print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
...@@ -267,7 +301,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -267,7 +301,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], parent=ROOT_MESH) _global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
...@@ -22,18 +22,17 @@ import paddle.static as static ...@@ -22,18 +22,17 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp" _global_parallel_strategy = "dp_mp_pp"
ROOT_MESH = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) _global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
_global_process_mesh = auto.ProcessMesh( PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
[[[0, 1], [4, 5]], [[2, 3], [6, 7]]], parent=ROOT_MESH) PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], parent=ROOT_MESH)
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], parent=ROOT_MESH)
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -55,8 +54,14 @@ class MLPLayer(nn.Layer): ...@@ -55,8 +54,14 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input): def forward(self, input):
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 1]) auto.shard_tensor(
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[1, -1]) self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -77,8 +82,14 @@ def mlp_forward(train_program, start_program): ...@@ -77,8 +82,14 @@ def mlp_forward(train_program, start_program):
label = static.data( label = static.data(
name="label", shape=[batch_size, 1], dtype='float32') name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[0, -1]) auto.shard_tensor(
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[0, -1]) input,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -94,7 +105,7 @@ def mlp_forward(train_program, start_program): ...@@ -94,7 +105,7 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id): def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh) dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
...@@ -156,10 +167,8 @@ class TestMLPReshard(unittest.TestCase): ...@@ -156,10 +167,8 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 2 rank_id = 2
dist_main_prog, dist_startup_prog = get_dist_prog( dist_main_prog, dist_startup_prog = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
print(dist_main_prog)
reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context) reshard(dist_main_prog, dist_startup_prog, rank_id, dist_context)
print(dist_main_prog) # print_program_with_dist_attr(dist_main_prog, dist_context)
print(dist_startup_prog)
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......
...@@ -22,17 +22,17 @@ import paddle.static as static ...@@ -22,17 +22,17 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import DistributedContext from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = "mp_pp" _global_parallel_strategy = "mp_pp"
ROOT_MESH = auto.ProcessMesh([[0, 1], [2, 3]]) _global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]])
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], parent=ROOT_MESH) PP_MESH_0 = auto.ProcessMesh([0, 1])
PP_MESH_0 = auto.ProcessMesh([0, 1], parent=ROOT_MESH) PP_MESH_1 = auto.ProcessMesh([2, 3])
PP_MESH_1 = auto.ProcessMesh([2, 3], parent=ROOT_MESH)
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -64,10 +64,21 @@ class MLPLayer(nn.Layer): ...@@ -64,10 +64,21 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
auto.shard_tensor( auto.shard_tensor(
self.word_embeddings.weight, PP_MESH_0, dim_mapping=[0, -1]) self.word_embeddings.weight,
auto.shard_tensor(self.linear0.weight, PP_MESH_0, dim_mapping=[-1, 0]) dist_attr={"process_mesh": PP_MESH_0,
auto.shard_tensor(self.linear1.weight, PP_MESH_1, dim_mapping=[0, -1]) "dims_mapping": [0, -1]})
auto.shard_tensor(self.linear2.weight, PP_MESH_1, dim_mapping=[0, -1]) auto.shard_tensor(
self.linear0.weight,
dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
auto.shard_tensor(
self.linear2.weight,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]})
w_out = self.word_embeddings(input) w_out = self.word_embeddings(input)
out = self.linear0(w_out) out = self.linear0(w_out)
gelu_out = F.gelu(out, approximate=True) gelu_out = F.gelu(out, approximate=True)
...@@ -88,8 +99,13 @@ def mlp_forward(train_program, start_program): ...@@ -88,8 +99,13 @@ def mlp_forward(train_program, start_program):
label = static.data( label = static.data(
name="label", shape=[batch_size, 1], dtype='float32') name="label", shape=[batch_size, 1], dtype='float32')
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1]) auto.shard_tensor(
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) input, dist_attr={"process_mesh": PP_MESH_0,
"dims_mapping": [-1]})
auto.shard_tensor(
label,
dist_attr={"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -105,7 +121,7 @@ def mlp_forward(train_program, start_program): ...@@ -105,7 +121,7 @@ def mlp_forward(train_program, start_program):
def get_dist_prog(train_program, startup_program, dist_context, rank_id): def get_dist_prog(train_program, startup_program, dist_context, rank_id):
global _global_process_mesh global _global_process_mesh
dist_context.set_process_mesh(_global_process_mesh) dist_context.process_mesh = _global_process_mesh
loss, train_program, startup_program = mlp_forward(train_program, loss, train_program, startup_program = mlp_forward(train_program,
startup_program) startup_program)
...@@ -198,19 +214,41 @@ class TestMLPReshard(unittest.TestCase): ...@@ -198,19 +214,41 @@ class TestMLPReshard(unittest.TestCase):
def test_allgather(self): def test_allgather(self):
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 3], parent=ROOT_MESH) process_mesh = auto.ProcessMesh(mesh=[0, 3])
with static.program_guard(train_program, startup_program): with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x, process_mesh, dim_mapping=[0, -1]) x = auto.shard_tensor(
x,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w, process_mesh, dim_mapping=[-1, -1]) w = auto.shard_tensor(
w,
y = paddle.distributed.shard_op(paddle.matmul, process_mesh, { dist_attr={
x.name: [-1, -1], "process_mesh": process_mesh,
w.name: [-1, -1] "dims_mapping": [-1, -1]
}, **{"x": x, })
"y": w})[0]
# y = paddle.distributed.shard_op(paddle.matmul, process_mesh, {
# x.name: [-1, -1],
# w.name: [-1, -1]
# }, **{"x": x,
# "y": w})[0]
y = paddle.distributed.shard_op(
paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)[0]
rank_id = 0 rank_id = 0
dist_context = DistributedContext() dist_context = DistributedContext()
......
...@@ -26,16 +26,15 @@ import paddle.static as static ...@@ -26,16 +26,15 @@ import paddle.static as static
import paddle.nn.functional as F import paddle.nn.functional as F
import paddle.utils as utils import paddle.utils as utils
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.context import get_default_distributed_context from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.reshard import reshard from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.process import new_process_group from paddle.distributed.auto_parallel.process_group import new_process_group
paddle.enable_static() paddle.enable_static()
_global_parallel_strategy = None _global_parallel_strategy = None
_global_process_mesh = None _global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([0])
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
...@@ -59,16 +58,30 @@ class MLPLayer(nn.Layer): ...@@ -59,16 +58,30 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, PP_MESH_0, dim_mapping=[-1, -1]) self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, PP_MESH_1, dim_mapping=[-1, -1]) self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
else: else:
auto.shard_tensor( auto.shard_tensor(
self.linear0.weight, _global_process_mesh, self.linear0.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor( auto.shard_tensor(
self.linear1.weight, _global_process_mesh, self.linear1.weight,
dim_mapping=[-1, -1]) dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
out = self.norm(input) out = self.norm(input)
out = self.linear0(out) out = self.linear0(out)
...@@ -90,12 +103,32 @@ def mlp_forward(train_program, start_program): ...@@ -90,12 +103,32 @@ def mlp_forward(train_program, start_program):
name="label", shape=[batch_size, 1], dtype='float32') name="label", shape=[batch_size, 1], dtype='float32')
if _global_parallel_strategy == "pp": if _global_parallel_strategy == "pp":
auto.shard_tensor(input, PP_MESH_0, dim_mapping=[-1, -1]) auto.shard_tensor(
auto.shard_tensor(label, PP_MESH_1, dim_mapping=[-1, -1]) input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
elif _global_parallel_strategy == "dp": elif _global_parallel_strategy == "dp":
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[0, -1]) auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
else: else:
auto.shard_tensor(input, _global_process_mesh, dim_mapping=[-1, -1]) auto.shard_tensor(
input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
mlp = MLPLayer( mlp = MLPLayer(
hidden_size=hidden_size, hidden_size=hidden_size,
...@@ -168,7 +201,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -168,7 +201,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = None _global_parallel_strategy = None
global _global_process_mesh global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0], parent=ROOT_MESH) _global_process_mesh = auto.ProcessMesh(mesh=[0])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册