diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 519bf8c633a013fedab4f529dad014a71ad2d594..1b4d8adeb574fd11040c519e30037b322377392c 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { outputs_ = op_desc.outputs_; attrs_ = op_desc.attrs_; need_update_ = true; + // When creating graph from program, the creation of op node will create a new + // OpDesc instead of + // referring to the original one. To find the original OpDesc of the op node, + // the id have to be + // copied to the new OpDesc. The var node has the same situation, but the + // default copy constructor + // can copy the id automatically. + id_ = op_desc.id_; } OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 31339b4d620b5164e1b2eeac78cdbf6c935f77d1..6b5969f412218463bdb566ae175ebfb9b7fdc35f 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -151,6 +152,18 @@ class OpDesc { const BlockDesc *Block() const { return this->block_; } + // This thread-safe implementation seems to be redudent since the neural + // networks + // are usually constructed in a single thread + static uint64_t GenerateId() { + static std::atomic id{0}; + return ++id; + } + + // Note: the identity only used as a key for referring to its + // distributed attribute now. + uint64_t Id() { return id_; } + private: template static std::vector MapKeys(const MapType &map) { @@ -173,6 +186,8 @@ class OpDesc { // need_update_ indicate there some local changes not be synchronized. If // local changes should be synchronized, need_update_ should be set to true. bool need_update_{false}; + + uint64_t id_ = GenerateId(); }; } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index 6821165692d2a4606824c5fc61e0013a065ac532..d1a1757d5309b6f28991f4116fc5a9bb6264077f 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include #include @@ -150,6 +151,17 @@ class VarDesc { Attribute GetAttr(const std::string &name) const; + // This thread-safe implementation seems to be redudent since the neural + // networks are usually constructed in a single thread. + static uint64_t GenerateId() { + static std::atomic uid{0}; + return ++uid; + } + + // Note: the identity only used as a key for referring to its + // distributed attribute now. + uint64_t Id() { return id_; } + private: const proto::VarType::TensorDesc &tensor_desc() const; std::vector tensor_descs() const; @@ -158,6 +170,7 @@ class VarDesc { proto::VarDesc desc_; AttributeMap attrs_; + uint64_t id_ = GenerateId(); }; bool operator==(const VarDesc &left, const VarDesc &right); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 7caa2494dc014487b2be96624b817ec95f559b20..596bd004e1387a6849f6963f1198ea47da8805a5 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/version.h" - #include "paddle/fluid/pybind/pybind_boost_headers.h" namespace paddle { @@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) { .def("attr_names", &pd::VarDesc::AttrNames) .def("_set_attr", &pd::VarDesc::SetAttr) .def("remove_attr", &pd::VarDesc::RemoveAttr) + .def("id", &pd::VarDesc::Id) .def("attr", &pd::VarDesc::GetAttr); pybind11::enum_ vartype(var_desc, "VarType", ""); @@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) { .def("serialize_to_string", SerializeMessage) .def("block", [](pd::OpDesc &self) { return self.Block(); }, pybind11::return_value_policy::reference) + .def("id", &pd::OpDesc::Id) .def("inputs", &pd::OpDesc::Inputs) .def("outputs", &pd::OpDesc::Outputs); } diff --git a/python/paddle/distributed/__init__.py b/python/paddle/distributed/__init__.py index 02a65e846df263b114a7ae813d3a5b678eedd204..e3d1d3c597f9b07b87173f794b1a2479fb0e7cd5 100644 --- a/python/paddle/distributed/__init__.py +++ b/python/paddle/distributed/__init__.py @@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 from . import cloud_utils # noqa: F401 from . import utils # noqa: F401 -__all__ = [ #noqa + +__all__ = [ # noqa "spawn", "scatter", "broadcast", diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index afe8d5652cfa733d0d527e38f207e353e91a85c8..5b0fdc1f1f166540358c82c231730a90a08fab48 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401 from .interface import set_offload_device # noqa: F401 from .interface import set_pipeline_stage # noqa: F401 from .interface import ProcessMesh # noqa: F401 +from .completion import complete_annotation # noqa: F401 __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/attribute.py b/python/paddle/distributed/auto_parallel/attribute.py new file mode 100644 index 0000000000000000000000000000000000000000..0ca1b7e9444d0f36bfad88a03183acbe359df84d --- /dev/null +++ b/python/paddle/distributed/auto_parallel/attribute.py @@ -0,0 +1,304 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict + + +class TensorDistributedAttribute: + def __init__(self, owner_tensor, owner_context): + self._owner_tensor = owner_tensor + self._owner_context = owner_context + self._process_mesh = None + self._dims_mapping = None + self._shard_mask = None + self._offload_device = None + self._shape = None + self._is_annotated = {} + self._is_parameter = False + + def get_owner_tensor(self): + return self._owner_tensor + + def get_owner_context(self): + return self._owner_context + + def get_process_mesh(self): + return self._process_mesh + + def set_process_mesh(self, process_mesh): + self._process_mesh = copy.deepcopy(process_mesh) + + def get_dims_mapping(self): + return self._dims_mapping + + def set_dims_mapping(self, dims_mapping): + self._dims_mapping = copy.deepcopy(dims_mapping) + + def get_shard_mask(self): + return self._shard_mask + + def set_shard_mask(self, shard_mask): + self._shard_mask = copy.deepcopy(shard_mask) + + def get_offload_device(self): + return self._offload_device + + def set_offload_device(self, offload_device): + self._offload_device = copy.deepcopy(offload_device) + + def get_shape(self): + return self._shape + + def set_shape(self, shape): + self._shape = copy.deepcopy(shape) + + def is_annotated(self, dist_attr_name): + return self._is_annotated.get(dist_attr_name, False) + + def mark_as_annotated(self, dist_attr_name): + self._is_annotated[dist_attr_name] = True + + def is_parameter(self): + return self._is_parameter + + def mark_as_parameter(self): + self._is_parameter = True + + def is_valid(self): + tensor_shape = self.get_owner_tensor().desc.shape() + if len(tensor_shape) != len(self.get_dims_mapping()): + return False + for i in range(len(self.get_dims_mapping())): + if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[ + i] >= len(self.get_process_mesh().topology): + return False + for i in range(len(self.get_process_mesh().topology)): + if self.get_dims_mapping().count(i) > 1: + return False + return True + + def __str__(self): + str = "{{tensor name: {}, tensor id: {}".format( + self.get_owner_tensor().desc.name(), + self.get_owner_tensor().desc.id()) + if self.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", process_mesh ({}): {}".format(annotated_str, + self.get_process_mesh()) + + str += ", is_parameter: {}".format(self._is_parameter) + + if self.is_annotated("dims_mapping"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", dims_mapping ({}): {}".format(annotated_str, + self.get_dims_mapping()) + + if self.is_annotated("shard_mask"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", shard_mask ({}): {}".format(annotated_str, + self.get_shard_mask()) + + if self.is_annotated("offload_device"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", offload_device ({}): {} }}".format(annotated_str, + self.get_offload_device()) + return str + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + # No need to copy the owner tensor and context + if k == "_owner_tensor" or k == "_owner_context": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + + +class OperatorDistributedAttribute: + def __init__(self, owner_op, owner_context): + self._owner_op = owner_op + self._owner_context = owner_context + self._process_mesh = None + self._dims_mapping = {} + self._shapes = {} + self._is_annotated = {} + self._is_parameters = {} + self._pipeline_stage = None + self._impl_idx = None + + def get_owner_op(self): + return self._owner_op + + def get_owner_context(self): + return self._owner_context + + def get_process_mesh(self): + return self._process_mesh + + def set_process_mesh(self, process_mesh): + self._process_mesh = copy.deepcopy(process_mesh) + + def get_input_dims_mapping(self, name): + return self._dims_mapping.get("IN_" + name, None) + + def set_input_dims_mapping(self, name, dims_mapping): + self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping) + + def get_output_dims_mapping(self, name): + return self._dims_mapping.get("OUT_" + name, None) + + def set_output_dims_mapping(self, name, dims_mapping): + self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping) + + def get_impl_idx(self): + return self._impl_idx + + def set_impl_idx(self, impl_idx): + self._impl_idx = impl_idx + + def get_pipeline_stage(self): + return self._pipeline_stage + + def set_pipeline_stage(self, pipeline_stage): + self._pipeline_stage = copy.deepcopy(pipeline_stage) + + def get_input_shape(self, name): + return self._shapes.get("IN_" + name, None) + + def set_input_shape(self, name, shape): + self._shapes["IN_" + name] = copy.deepcopy(shape) + + def get_output_shape(self, name): + return self._shapes.get("OUT_" + name, None) + + def set_output_shape(self, name, shape): + self._shapes["OUT_" + name] = copy.deepcopy(shape) + + def is_annotated(self, attr_name): + return self._is_annotated.get(attr_name, False) + + def mark_as_annotated(self, attr_name): + self._is_annotated[attr_name] = True + + def is_annotated_input_dims_mapping(self, name): + return self._is_annotated.get("IN_" + name, False) + + def mark_as_annotated_input_dims_mapping(self, name): + self._is_annotated["IN_" + name] = True + + def is_annotated_output_dims_mapping(self, name): + return self._is_annotated.get("OUT_" + name, False) + + def mark_as_annotated_output_dims_mapping(self, name): + self._is_annotated["OUT_" + name] = True + + def is_parameter(self, name): + return self._is_parameters.get(name, False) + + def mark_as_parameter(self, name): + self._is_parameters[name] = True + + def is_valid(self): + for name in self.get_owner_op().desc.input_arg_names(): + dims_mapping = self.get_input_dims_mapping(name) + shape = self.get_input_shape(name) + if len(shape) != len(dims_mapping): + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len( + self.get_process_mesh().topology): + return False + for i in range(len(self.get_process_mesh().topology)): + if dims_mapping.count(i) > 1: + return False + for name in self.get_owner_op().desc.output_arg_names(): + dims_mapping = self.get_output_dims_mapping(name) + shape = self.get_output_shape(name) + if len(shape) != len(dims_mapping): + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len( + self.get_process_mesh().topology): + return False + for i in range(len(self.get_process_mesh().topology)): + if dims_mapping.count(i) > 1: + return False + return True + + def __str__(self): + str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(), + self.get_owner_op().desc.id()) + + if self.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", process_mesh ({}): {}".format(annotated_str, + self.get_process_mesh()) + + for arg_name in self.get_owner_op().desc.input_arg_names(): + dims_mapping = self.get_input_dims_mapping(arg_name) + if self.is_annotated_input_dims_mapping(arg_name): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + if self.is_parameter(arg_name): + is_parameter_str = "parameter" + else: + is_parameter_str = "non-parameter" + str += ", {}'s dims_mapping (input, {}, {}): {}".format( + arg_name, annotated_str, is_parameter_str, dims_mapping) + + for arg_name in self.get_owner_op().desc.output_arg_names(): + dims_mapping = self.get_output_dims_mapping(arg_name) + if self.is_annotated_output_dims_mapping(arg_name): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + if self.is_parameter(arg_name): + is_parameter_str = "parameter" + else: + is_parameter_str = "non-parameter" + str += ", {}'s dims_mapping (output, {}, {}): {}".format( + arg_name, annotated_str, is_parameter_str, dims_mapping) + + str += ", pipeline stage: {}".format(self._pipeline_stage) + + str += ", dist_impl idx: {} }}".format(self._impl_idx) + + return str + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + # No need to copy the owner op and context + if k == "_owner_op" or k == "_owner_context": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py new file mode 100644 index 0000000000000000000000000000000000000000..72af14af2c3958f7eb952163711133f42ec32238 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -0,0 +1,483 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +from paddle.fluid import core +from paddle.fluid import framework + +from .utils import compute_compatible_process_mesh +from .utils import compute_compatible_dim_mapping +from .utils import compute_compatible_dims_mapping +from .utils import print_program_with_distributed_attr +from .context import get_default_distributed_context +from .operators import find_best_compatible_distributed_operator_impl + +ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"] + + +def is_elementwise_like_op(op_type): + if op_type in ELEMENTWISE_LIKE_OP_LIST: + return True + else: + return False + + +def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True): + """ + Update tensor's process mesh by using its predecessor's process mesh if in the forward direction, + and by using its successor's process mesh if in the backward direction. Note: only the equal + process meshes are compatible for now. + """ + changed = False + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + if tensor_dist_attr.is_annotated("process_mesh"): + return changed + tensor_process_mesh = tensor_dist_attr.get_process_mesh() + if fwd: + inputs_process_meshes = [] + for pred_op_node in tensor_node.inputs: + if pred_op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + pred_op_node) + op_process_mesh = op_dist_attr.get_process_mesh() + inputs_process_meshes.append(op_process_mesh) + compatible_process_mesh = compute_compatible_process_mesh( + inputs_process_meshes) + if compatible_process_mesh is not None and tensor_process_mesh is None: + tensor_dist_attr.set_process_mesh(compatible_process_mesh) + changed = True + else: + outputs_process_meshes = [] + for succ_op_node in tensor_node.outputs: + if succ_op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + succ_op_node) + op_process_mesh = op_dist_attr.get_process_mesh() + outputs_process_meshes.append(op_process_mesh) + compatible_process_mesh = compute_compatible_process_mesh( + outputs_process_meshes) + if compatible_process_mesh is not None and tensor_process_mesh is None: + tensor_dist_attr.set_process_mesh(compatible_process_mesh) + changed = True + return changed + + +def update_op_node_process_mesh(dist_context, op_node, fwd=True): + """ + Update op's process mesh by using its predecessor's process mesh if in the forward direction, + and by using its successor's process mesh if in the backward direction. Note: only the equal + process meshes are compatible for now. + """ + changed = False + op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) + if op_dist_attr.is_annotated("process_mesh"): + return changed + op_process_mesh = op_dist_attr.get_process_mesh() + if fwd: + inputs_process_meshes = [] + for tensor_node in op_node.inputs: + if tensor_node.var() is not None: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + tensor_process_mesh = tensor_dist_attr.get_process_mesh() + inputs_process_meshes.append(tensor_process_mesh) + compatible_process_mesh = compute_compatible_process_mesh( + inputs_process_meshes) + if compatible_process_mesh is not None and op_process_mesh is None: + op_dist_attr.set_process_mesh(compatible_process_mesh) + changed = True + else: + outputs_process_meshes = [] + for tensor_node in op_node.outputs: + if tensor_node.var() is not None: + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + tensor_process_mesh = tensor_dist_attr.get_process_mesh() + outputs_process_meshes.append(tensor_process_mesh) + compatible_process_mesh = compute_compatible_process_mesh( + outputs_process_meshes) + if compatible_process_mesh is not None and op_process_mesh is None: + op_dist_attr.set_process_mesh(compatible_process_mesh) + changed = True + return changed + + +def update_op_dims_mapping_by_default_dist_impl(op_dist_attr): + """Each operator has a default distributed operator, only allowed to be sharded in batch dimension.""" + changed = False + op_desc = op_dist_attr.get_owner_op().desc + # The following statement will be replaced by a more elegent way + if op_desc.type() == "shape" or op_desc.type() == "slice": + return False + output_names = op_desc.output_names() + xshape_arg_names = [] + if "XShape" in output_names: + xshape_arg_names = op_desc.output("XShape") + batch_dim_mappings = [] + for arg_name in op_desc.input_arg_names(): + if op_dist_attr.is_parameter(arg_name): + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + for arg_name in op_desc.output_arg_names(): + if op_dist_attr.is_parameter(arg_name): + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if len(dims_mapping) > 1: + for idx, mapping in enumerate(dims_mapping[1:]): + assert mapping == -1, \ + "{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[0]) + else: + assert dims_mapping[0] == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\ + .format(op_desc.type(), mapping) + if len(dims_mapping) > 2: + for idx, mapping in enumerate(dims_mapping[2:]): + assert mapping == -1, \ + "{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\ + .format(op_desc.type(), idx, mapping) + batch_dim_mappings.append(dims_mapping[1]) + + compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings) + assert compatible_dim_mapping is not None, "There is no compatible dim mapping." + for arg_name in op_desc.input_arg_names(): + if op_dist_attr.is_parameter(arg_name): + continue + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + for arg_name in op_desc.output_arg_names(): + if op_dist_attr.is_parameter(arg_name): + continue + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if arg_name not in xshape_arg_names: + if compatible_dim_mapping != dims_mapping[0]: + dims_mapping[0] = compatible_dim_mapping + changed = True + else: + if compatible_dim_mapping != dims_mapping[1]: + dims_mapping[1] = compatible_dim_mapping + changed = True + + return changed + + +def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr): + """Element-wise operator can be sharded in any way (but should take care of broadcasting).""" + changed = False + op_desc = op_dist_attr.get_owner_op().desc + + input_arg_names = op_desc.input_arg_names() + input_dims_mapping_dict = {} + input_dims_mapping_lens = {} + max_dims_mapping_len = -1 + for arg_name in input_arg_names: + dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) + if max_dims_mapping_len < len(dims_mapping): + max_dims_mapping_len = len(dims_mapping) + input_dims_mapping_dict[arg_name] = dims_mapping + input_dims_mapping_lens[arg_name] = len(dims_mapping) + + dims_mapping_list = [] + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i] + dims_mapping_list.append(new_dims_mapping) + else: + dims_mapping_list.append(input_dims_mapping_dict[arg_name]) + output_arg_names = op_desc.output_arg_names() + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + assert len(dims_mapping) == max_dims_mapping_len + dims_mapping_list.append(dims_mapping) + + compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for arg_name in input_arg_names: + if input_dims_mapping_lens[arg_name] < max_dims_mapping_len: + new_dims_mapping = [ + -1 for _ in range(input_dims_mapping_lens[arg_name]) + ] + for i in range(input_dims_mapping_lens[arg_name]): + new_idx = (max_dims_mapping_len - + input_dims_mapping_lens[arg_name]) + i + new_dims_mapping[i] = compatible_dims_mapping[new_idx] + if new_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping) + changed = True + else: + if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: + op_dist_attr.set_input_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + for arg_name in output_arg_names: + dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) + if compatible_dims_mapping != dims_mapping: + op_dist_attr.set_output_dims_mapping(arg_name, + compatible_dims_mapping) + changed = True + + return changed + + +def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True): + changed = False + if (not tensor_node.is_var()) or (tensor_node.var() is None): + return False + tensor_desc = tensor_node.var() + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + assert tensor_dist_attr is not None + if tensor_dist_attr.is_annotated("dims_mapping"): + return False + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + if fwd: + dims_mapping_list = [] + for pred_op_node in tensor_node.inputs: + if pred_op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + pred_op_node) + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) + changed = True + else: + dims_mapping_list = [] + for succ_op_node in tensor_node.outputs: + if succ_op_node.op() is not None: + op_dist_attr = dist_context.get_op_distributed_attr_for_graph( + succ_op_node) + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + dims_mapping_list.append(op_dims_mapping) + dims_mapping_list.append(tensor_dims_mapping) + compatible_dims_mapping = compute_compatible_dims_mapping( + dims_mapping_list) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != tensor_dims_mapping): + tensor_dist_attr.set_dims_mapping(compatible_dims_mapping) + changed = True + return changed + + +def update_op_node_dims_mapping(dist_context, op_node, fwd=True): + changed = False + if (not op_node.is_op()) or (op_node.op() is None): + return False + op_desc = op_node.op() + op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node) + if fwd: + for tensor_node in op_node.inputs: + if tensor_node.var() is not None: + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_input_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + op_dims_mapping = op_dist_attr.get_input_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_input_dims_mapping(tensor_desc.name(), + compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( + op_desc.type(), op_dist_attr, fwd=True) + if op_dist_impl is not None: + dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) + if dim_changed: + changed = True + # This statement will be replaced by a good way + if op_dist_impl.is_compatible(op_dist_attr): + op_dist_attr.set_impl_idx(op_dist_impl_idx) + elif is_elementwise_like_op(op_desc.type()): + dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( + op_dist_attr) + if dim_changed: + changed = True + op_dist_attr.set_impl_idx(-1) + else: + dim_changed = update_op_dims_mapping_by_default_dist_impl( + op_dist_attr) + if dim_changed: + changed = True + op_dist_attr.set_impl_idx(-2) + else: + for tensor_node in op_node.outputs: + if tensor_node.var() is not None: + tensor_desc = tensor_node.var() + if op_dist_attr.is_annotated_output_dims_mapping( + tensor_desc.name()): + continue + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph( + tensor_node) + tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() + op_dims_mapping = op_dist_attr.get_output_dims_mapping( + tensor_desc.name()) + compatible_dims_mapping = compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping]) + if (compatible_dims_mapping is not None) and \ + (compatible_dims_mapping != op_dims_mapping): + op_dist_attr.set_output_dims_mapping( + tensor_desc.name(), compatible_dims_mapping) + changed = True + # Find the most compatible implemenetations from the distributed operator + op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl( + op_desc.type(), op_dist_attr, fwd=False) + if op_dist_impl is not None: + dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr) + if dim_changed: + changed = True + # This statement will be replaced by a good way + if op_dist_impl.is_compatible(op_dist_attr): + op_dist_attr.set_impl_idx(op_dist_impl_idx) + elif is_elementwise_like_op(op_desc.type()): + dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl( + op_dist_attr) + if dim_changed: + changed = True + op_dist_attr.set_impl_idx(-1) + else: + dim_changed = update_op_dims_mapping_by_default_dist_impl( + op_dist_attr) + if dim_changed: + changed = True + op_dist_attr.set_impl_idx(-2) + return changed + + +def complete_annotation(program, dist_context=None): + """ Complete annotation for the partial annotated program. + + Arguments: + program: partial annotated program. + dist_context: the distributed context is used to store distributed attributes for program. + If not provided, the default one will be used. + Returns: + program: completed annotated program. + """ + + # Use the default distribted context for completeion if there is no one + if dist_context is None: + dist_context = get_default_distributed_context() + + # Initialize distributed attributes for all var and op node in program + dist_context.initialize_distributed_attr_for_program(program) + # print_program_with_distributed_attr(program, dist_context) + + # Convert program to graph + graph = framework.IrGraph(core.Graph(program.desc)) + + # Initialize distributed attributes for all var and op node in graph + dist_context.initialize_distributed_attr_for_graph(graph) + + # # Complete process mesh for each node + all_nodes = list(graph.all_nodes()) + reach_fix_point = False + while not reach_fix_point: + changed = False + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_process_mesh( + dist_context, node, fwd=True) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_process_mesh( + dist_context, node, fwd=True) + if op_changed: + changed = True + for node in reversed(all_nodes): + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_process_mesh( + dist_context, node, fwd=False) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_process_mesh( + dist_context, node, fwd=False) + if op_changed: + changed = True + if changed: + reach_fix_point = False + else: + reach_fix_point = True + + # Complete dims_mapping for each node + reach_fix_point = False + while not reach_fix_point: + changed = False + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_dims_mapping( + dist_context, node, fwd=True) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_dims_mapping( + dist_context, node, fwd=True) + if op_changed: + changed = True + for node in reversed(all_nodes): + if node.is_var() and node.var() is not None: + tensor_changed = update_tensor_node_dims_mapping( + dist_context, node, fwd=False) + if tensor_changed: + changed = True + if node.is_op() and node.op() is not None: + op_changed = update_op_node_dims_mapping( + dist_context, node, fwd=False) + if op_changed: + changed = True + if changed: + reach_fix_point = False + else: + reach_fix_point = True + + # Copy the corresponding distributed attribute from graph to program + dist_context.copy_distribute_attr_from_graph_to_program(graph, program) + dist_context.clear_distributed_attr_for_graph() + + # Do the validation check and amend some completion + dist_context.amend_distributed_attr_for_program() + + return program diff --git a/python/paddle/distributed/auto_parallel/context.py b/python/paddle/distributed/auto_parallel/context.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2adc7eacf91957532b4ba6b24c7dc79400eb1f --- /dev/null +++ b/python/paddle/distributed/auto_parallel/context.py @@ -0,0 +1,379 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict +from paddle.fluid import framework +from .attribute import TensorDistributedAttribute +from .attribute import OperatorDistributedAttribute +from .utils import append_distributed_attr_suffix + +# There always exists a default context for user. And user can set it to another one. +DEFAULT_DISTRIBUTED_CONTEXT = None + + +def get_default_distributed_context(): + global DEFAULT_DISTRIBUTED_CONTEXT + if DEFAULT_DISTRIBUTED_CONTEXT is None: + dist_context = DistributedContext() + set_default_distributed_context(dist_context) + return DEFAULT_DISTRIBUTED_CONTEXT + + +def set_default_distributed_context(dist_context): + global DEFAULT_DISTRIBUTED_CONTEXT + DEFAULT_DISTRIBUTED_CONTEXT = dist_context + + +class DistributedContext: + """ + DistributedContext is used to collect related distributed information for program and graph. + One auto-parallel run should use its own DistributedContext to avoid interfering other run. + """ + + def __init__(self): + self._is_initialized_for_program = False + self._is_initialized_for_graph = False + self._tensor_distributed_attr_map_for_program = {} + self._op_distributed_attr_map_for_program = {} + self._tensor_distributed_attr_map_for_graph = {} + self._op_distributed_attr_map_for_graph = {} + + def is_initialized_for_program(self): + return self._is_initialized_for_program + + def is_initialized_for_graph(self): + return self._is_initialized_for_graph + + def get_tensor_distributed_attr_for_program(self, tensor): + tensor_id = tensor.desc.id() + tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get( + tensor_id, None) + return tensor_dist_attr + + def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr): + tensor_id = tensor.desc.id() + self._tensor_distributed_attr_map_for_program[ + tensor_id] = tensor_dist_attr + + def get_op_distributed_attr_for_program(self, op): + op_id = op.desc.id() + op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id, + None) + return op_dist_attr + + def set_op_distributed_attr_for_program(self, op, op_dist_attr): + op_id = op.desc.id() + self._op_distributed_attr_map_for_program[op_id] = op_dist_attr + + def get_tensor_distributed_attr_for_graph(self, tensor_node): + tensor_node_id = tensor_node.id() + tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get( + tensor_node_id, None) + return tensor_dist_attr + + def set_tensor_distributed_attr_for_graph(self, tensor_node, + tensor_dist_attr): + tensor_node_id = tensor_node.id() + self._tensor_distributed_attr_map_for_graph[ + tensor_node_id] = tensor_dist_attr + + def get_op_distributed_attr_for_graph(self, op_node): + op_node_id = op_node.id() + op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id, + None) + return op_dist_attr + + def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr): + op_node_id = op_node.id() + self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr + + def initialize_distributed_attr_for_program(self, program): + if self._is_initialized_for_program: + return + for block in program.blocks: + for tensor in block.vars.values(): + # Since only tensors have distributed attributes, it's better to make sure var is a tensor + tensor_dist_attr = self.get_tensor_distributed_attr_for_program( + tensor) + if tensor_dist_attr is None: + tensor_dist_attr = TensorDistributedAttribute(tensor, self) + self._copy_distributed_attr_from_tensor_desc( + tensor.desc, tensor_dist_attr) + self.set_tensor_distributed_attr_for_program( + tensor, tensor_dist_attr) + tensor_dist_attr.set_shape(tensor.desc.shape()) + if tensor_dist_attr.get_process_mesh() is not None: + tensor_dist_attr.mark_as_annotated("process_mesh") + if tensor_dist_attr.get_dims_mapping() is None: + tensor_dims_mapping = [ + -1 for _ in range(len(tensor.desc.shape())) + ] + tensor_dist_attr.set_dims_mapping(tensor_dims_mapping) + else: + tensor_dist_attr.mark_as_annotated("dims_mapping") + if isinstance(tensor, framework.Parameter): + tensor_dist_attr.mark_as_parameter() + for op in block.ops: + op_dist_attr = self.get_op_distributed_attr_for_program(op) + if op_dist_attr is None: + op_dist_attr = OperatorDistributedAttribute(op, self) + self._copy_distributed_attr_from_op_desc(op.desc, + op_dist_attr) + self.set_op_distributed_attr_for_program(op, op_dist_attr) + # Default distributed implementation for all operators + # This will be updated during the completion prcess + op_dist_attr.set_impl_idx(-2) + if op_dist_attr.get_process_mesh() is not None: + op_dist_attr.mark_as_annotated("process_mesh") + for tensor_name in op.input_arg_names: + # There may be a better way to find the tensor by name + tensor = op.block._var_recursive(tensor_name) + op_dist_attr.set_input_shape(tensor_name, + tensor.desc.shape()) + if op_dist_attr.get_input_dims_mapping(tensor_name) is None: + tensor_dims_mapping = [ + -1 for _ in range(len(tensor.desc.shape())) + ] + op_dist_attr.set_input_dims_mapping(tensor_name, + tensor_dims_mapping) + else: + op_dist_attr.mark_as_annotated_input_dims_mapping( + tensor_name) + if isinstance(tensor, framework.Parameter): + op_dist_attr.mark_as_parameter(tensor_name) + for tensor_name in op.output_arg_names: + tensor = op.block._var_recursive(tensor_name) + op_dist_attr.set_output_shape(tensor_name, + tensor.desc.shape()) + if op_dist_attr.get_output_dims_mapping( + tensor_name) is None: + tensor_dims_mapping = [ + -1 for _ in range(len(tensor.desc.shape())) + ] + op_dist_attr.set_output_dims_mapping( + tensor_name, tensor_dims_mapping) + else: + op_dist_attr.mark_as_annotated_output_dims_mapping( + tensor_name) + if isinstance(tensor, framework.Parameter): + op_dist_attr.mark_as_parameter(tensor_name) + self._is_initialized_for_program = True + + def finalize_distributed_attr_for_program(self, program): + assert self._is_initialized_for_program, \ + "The program must initialize its distributed attribute before finalization." + for block in program.blocks: + for tensor in block.vars.values(): + tensor_dist_attr = self.get_tensor_distributed_attr_for_program( + tensor) + if tensor_dist_attr is not None: + self._store_distributed_attr_to_tensor_desc( + tensor.desc, tensor_dist_attr) + for op in block.ops: + op_dist_attr = self.get_op_distributed_attr_for_program(op) + if op_dist_attr is not None: + self._store_distributed_attr_to_op_desc(op.desc, + op_dist_attr) + + def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr): + from paddle.distributed.auto_parallel.interface import _g_process_mesh_map + attr_name = append_distributed_attr_suffix("mesh_id") + if desc.has_attr(attr_name): + mesh_id = desc.attr(attr_name) + process_mesh = _g_process_mesh_map[mesh_id] + copied_process_mesh = copy.deepcopy(process_mesh) + dist_attr.set_process_mesh(copied_process_mesh) + attr_name = append_distributed_attr_suffix("dim_mapping") + if desc.has_attr(attr_name): + dims_mapping = desc.attr(attr_name) + copied_dims_mapping = copy.deepcopy(dims_mapping) + dist_attr.set_dims_mapping(copied_dims_mapping) + attr_name = append_distributed_attr_suffix("mask") + if desc.has_attr(attr_name): + shard_mask = desc.attr(attr_name) + copied_shard_mask = copy.deepcopy(shard_mask) + dist_attr.set_shard_mask(copied_shard_mask) + attr_name = append_distributed_attr_suffix("offload_device") + if desc.has_attr(attr_name): + offload_device = desc.attr(attr_name) + copied_offload_device = copy.deepcopy(offload_device) + dist_attr.set_offload_device(copied_offload_device) + + def _copy_distributed_attr_from_op_desc(self, desc, dist_attr): + from paddle.distributed.auto_parallel.interface import _g_process_mesh_map + attr_name = append_distributed_attr_suffix("mesh_id") + if desc.has_attr(attr_name): + mesh_id = desc.attr(attr_name) + process_mesh = _g_process_mesh_map[mesh_id] + copied_process_mesh = copy.deepcopy(process_mesh) + dist_attr.set_process_mesh(copied_process_mesh) + for tensor_name in desc.input_arg_names(): + attr_name = append_distributed_attr_suffix("IN_" + tensor_name) + if desc.has_attr(attr_name): + dims_mapping = desc.attr(attr_name) + copied_dims_mapping = copy.deepcopy(dims_mapping) + dist_attr.set_input_dims_mapping(tensor_name, + copied_dims_mapping) + for tensor_name in desc.output_arg_names(): + attr_name = append_distributed_attr_suffix("OUT_" + tensor_name) + if desc.has_attr(attr_name): + dims_mapping = desc.attr(attr_name) + copied_dims_mapping = copy.deepcopy(dims_mapping) + dist_attr.set_input_dims_mapping(tensor_name, + copied_dims_mapping) + attr_name = append_distributed_attr_suffix("pipeline_stage") + if desc.has_attr(attr_name): + pipeline_stage = desc.attr(attr_name) + copied_pipeline_stage = copy.deepcopy(pipeline_stage) + dist_attr.set_pipeline_stage(copied_pipeline_stage) + + def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr): + process_mesh = dist_attr.get_process_mesh() + if process_mesh is not None: + attr_name = append_distributed_attr_suffix("mesh_id") + desc._set_attr(attr_name, process_mesh._id) + dims_mapping = dist_attr.get_dims_mapping() + if dims_mapping is not None: + attr_name = append_distributed_attr_suffix("dim_mapping") + desc._set_attr(attr_name, dims_mapping) + shard_mask = dist_attr.get_shard_mask() + if shard_mask is not None: + attr_name = append_distributed_attr_suffix("mask") + desc._set_attr(attr_name, shard_mask) + offload_device = dist_attr.get_offload_device() + if offload_device is not None: + attr_name = append_distributed_attr_suffix("offload_device") + desc._set_attr(attr_name, offload_device) + + def _store_distributed_attr_to_op_desc(self, desc, dist_attr): + process_mesh = dist_attr.get_process_mesh() + if process_mesh is not None: + attr_name = append_distributed_attr_suffix("mesh_id") + desc._set_attr(attr_name, process_mesh._id) + for tensor_name in desc.input_arg_names(): + dims_mapping = dist_attr.get_input_dims_mapping(tensor_name) + if dims_mapping is not None: + attr_name = append_distributed_attr_suffix("IN_" + tensor_name) + desc._set_attr(attr_name, dims_mapping) + for tensor_name in desc.output_arg_names(): + dims_mapping = dist_attr.get_output_dims_mapping(tensor_name) + if dims_mapping is not None: + attr_name = append_distributed_attr_suffix("OUT_" + tensor_name) + desc._set_attr(attr_name, dims_mapping) + pipeline_stage = dist_attr.get_pipeline_stage() + if pipeline_stage is not None: + attr_name = append_distributed_attr_suffix("pipeline_stage") + desc._set_attr(attr_name, pipeline_stage) + + def initialize_distributed_attr_for_graph(self, graph): + assert self._is_initialized_for_program, \ + "The program must initialize its distributed attribute before its graph." + if self._is_initialized_for_graph: + return + all_nodes = graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_desc = node.var() + tensor_id = tensor_desc.id() + tensor_dist_attr = self._tensor_distributed_attr_map_for_program[ + tensor_id] + assert tensor_dist_attr is not None, \ + "Tensor must have a distributed attribute after the initialization for program." + new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr) + self.set_tensor_distributed_attr_for_graph(node, + new_tensor_dist_attr) + + if node.is_op() and node.op() is not None: + op_desc = node.op() + op_id = op_desc.id() + op_dist_attr = self._op_distributed_attr_map_for_program[op_id] + assert op_dist_attr is not None, \ + "Operator must have a distributed attribute after the initialization for program." + new_op_dist_attr = copy.deepcopy(op_dist_attr) + self.set_op_distributed_attr_for_graph(node, new_op_dist_attr) + self._is_initialized_for_graph = True + + def clear_distributed_attr_for_program(self): + self._tensor_distributed_attr_map_for_program.clear() + self._op_distributed_attr_map_for_program.clear() + + def clear_distributed_attr_for_graph(self): + self._tensor_distributed_attr_map_for_graph.clear() + self._op_distributed_attr_map_for_graph.clear() + + def copy_distribute_attr_from_graph_to_program(self, graph, program): + assert self._is_initialized_for_program and self._is_initialized_for_graph, \ + "The distribute attributes must be initialized both in its program and graph" + updated_tensors = {} + all_nodes = graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_desc = node.var() + tensor_id = tensor_desc.id() + updated = updated_tensors.get(tensor_desc.name(), False) + # If a var has multiples var nodes in graph, only use the first one for now + if not updated: + tensor_dist_attr = self.get_tensor_distributed_attr_for_graph( + node) + new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr) + self._tensor_distributed_attr_map_for_program[ + tensor_id] = new_tensor_dist_attr + updated_tensors[tensor_desc.name()] = True + if node.is_op() and node.op() is not None: + op_desc = node.op() + op_id = op_desc.id() + op_dist_attr = self.get_op_distributed_attr_for_graph(node) + new_op_dist_attr = copy.deepcopy(op_dist_attr) + self._op_distributed_attr_map_for_program[ + op_id] = new_op_dist_attr + + def amend_distributed_attr_for_program(self): + for attr in self._tensor_distributed_attr_map_for_program.values(): + assert attr.is_valid(), \ + "Tensor's distributed attribute {} is not valid".format(attr) + tensor_shape = attr.get_shape() + dims_mapping = attr.get_dims_mapping() + process_mesh_shape = attr.get_process_mesh().topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[ + i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + for attr in self._op_distributed_attr_map_for_program.values(): + assert attr.is_valid(), \ + "Operator's distributed attribute {} is not valid".format(attr) + for arg_name in attr.get_owner_op().desc.input_arg_names(): + tensor_shape = attr.get_input_shape(arg_name) + dims_mapping = attr.get_input_dims_mapping(arg_name) + process_mesh_shape = attr.get_process_mesh().topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and process_mesh_shape[ + dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + for arg_name in attr.get_owner_op().desc.output_arg_names(): + tensor_shape = attr.get_output_shape(arg_name) + dims_mapping = attr.get_output_dims_mapping(arg_name) + process_mesh_shape = attr.get_process_mesh().topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and process_mesh_shape[ + dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index f98cc30131457c31a986d10fdc041dc3f6c17820..1d5b94afbaa4d2c3b4274d67965b5b68592db49d 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -13,8 +13,9 @@ # limitations under the License. import numpy -import paddle.fluid.core as core +import copy import paddle +import paddle.fluid.core as core from paddle.fluid.framework import Variable from paddle.fluid.framework import in_dygraph_mode @@ -237,6 +238,23 @@ class ProcessMesh(object): def __ne__(self, other): return not self.__eq__(other) + def __str__(self): + str = "shape {} and process group {}".format(self.topology, + self.process_group) + return str + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + # No need to copy the owner tensor and context + if k == "_desc": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + def _dim_mapping_checker(tensor, mesh, dim_mapping): assert len(tensor.shape) == len(dim_mapping) diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..14ded477cb70925c44d64cb019ba4b58e49a9d76 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from .common import find_best_compatible_distributed_operator_impl +from . import dist_embedding +from . import dist_matmul +from . import dist_reshape +from . import dist_softmax +from . import dist_transpose diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e253c0e0b178aab63ed2dac783dc2b4d402baa --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +DISTRIBUTED_OPERATORS = {} + + +class DistributedOperator: + def __init__(self): + self._impls = [] + self._name = None + + def register_impl(self, dist_impl): + self._impls.append(dist_impl) + + def get_impl(self, impl_idx): + return self._impls[impl_idx] + + def get_impls(self): + return self._impls + + +class DistributedOperatorImpl: + def __init__(self): + self._name = None + + def forward(self, dist_ctx, *args, **kwargs): + raise NotImplementedError("Please Implement this method in Subclass.") + + def backward(self, dist_ctx, *grad_outputs): + raise NotImplementedError("Please Implement this method in Subclass.") + + def get_name(self): + return self._name + + def is_process_mesh_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method in Subclass.") + + def is_input_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method in Subclass.") + + def is_output_compatible(self, op_dist_attr): + raise NotImplementedError("Please Implement this method in Subclass.") + + def is_compatible(self, op_dist_attr): + return self.is_process_mesh_compatible(op_dist_attr) \ + and self.is_input_compatible(op_dist_attr) \ + and self.is_output_compatible(op_dist_attr) + + def update_dims_mapping(self, op_dist_attr): + raise NotImplementedError("Please Implement this method in Subclass.") + + +def register_distributed_operator(name, dist_op): + global DISTRIBUTED_OPERATORS + DISTRIBUTED_OPERATORS[name] = dist_op + + +def get_distributed_operator(name): + global DISTRIBUTED_OPERATORS + return DISTRIBUTED_OPERATORS.get(name, None) + + +def register_distributed_operator_impl(name, dist_impl): + dist_op = get_distributed_operator(name) + if dist_op is not None: + dist_op.register_impl(dist_impl) + else: + assert False, "Must register distributed operator first." + + +def get_distributed_operator_impl(name, impl_idx): + global DISTRIBUTED_OPERATORS + return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx) + + +def find_best_compatible_distributed_operator_impl(name, op_dist_attr, + fwd=True): + """ + Here just return the first compatible implemention. + This will be improved by cost model in the future. + """ + dist_op = get_distributed_operator(name) + if dist_op is None: + return None, -1 + compatible_impls = [] + impls = dist_op.get_impls() + if fwd: + for idx, impl in enumerate(impls): + if impl.is_process_mesh_compatible(op_dist_attr) \ + and impl.is_input_compatible(op_dist_attr): + compatible_impls.append((impl, idx)) + else: + for idx, impl in enumerate(impls): + if impl.is_process_mesh_compatible(op_dist_attr) \ + and impl.is_output_compatible(op_dist_attr): + compatible_impls.append((impl, idx)) + + if compatible_impls: + best_compatible_impl, idx = compatible_impls[0] + else: + best_compatible_impl, idx = None, -1 + + return best_compatible_impl, idx diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..81d3925bb5dcc1be0a2deee091caf560d4ba4522 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -0,0 +1,97 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping + + +class DistributedEmbedding(DistributedOperator): + def __init__(self, name): + super(DistributedEmbedding, self).__init__() + self._name = name + + +register_distributed_operator("lookup_table_v2", + DistributedEmbedding("embedding")) + + +# RowParallel +class DistributedEmbeddingImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedEmbeddingImpl, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + ids_name = op_desc.input('Ids')[0] + w_name = op_desc.input('W')[0] + ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) + w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) + if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[ + -1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in ids_dims_mapping[1:]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + # Other dimensions must be replicate except the batch dimension + for mapping in out_dims_mapping[1:]: + if is_dim_shard(mapping): + return False + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + ids_name = op_desc.input('Ids')[0] + w_name = op_desc.input('W')[0] + out_name = op_desc.output('Out')[0] + ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) + w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + for i in range(len(ids_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [ids_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + dim_changed = compute_compatible_and_update_dim_mapping( + [w_dims_mapping, out_dims_mapping], [-1, -1]) + if dim_changed: + changed = True + + return changed + + +register_distributed_operator_impl("lookup_table_v2", + DistributedEmbeddingImpl("row_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..fbeb0edd41897b78410d366564a7d4aef0205c83 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -0,0 +1,343 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping + + +def _update_dims_mapping_for_matmul(op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + x_dims_mapping_len = len(x_dims_mapping) + y_dims_mapping_len = len(y_dims_mapping) + out_dims_mapping_len = len(out_dims_mapping) + + # print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping) + # Add dim mapping to Make sure the length dims_mapping be at least 2 + if x_dims_mapping_len == 1: + x_dims_mapping.insert(0, -1) + if y_dims_mapping_len == 1: + y_dims_mapping.insert(1, -1) + + # Deal with dim > 2 and take care of broadcasting + if out_dims_mapping_len > 2: + broadcast_x_dims_mapping = [] + broadcast_y_dims_mapping = [] + broadcast_out_dims_mapping = [] + + for i in range(out_dims_mapping_len - x_dims_mapping_len): + broadcast_x_dims_mapping.append(out_dims_mapping[i]) + for i in range(x_dims_mapping_len - 2): + broadcast_x_dims_mapping.append(x_dims_mapping[i]) + + for i in range(out_dims_mapping_len - y_dims_mapping_len): + broadcast_y_dims_mapping.append(out_dims_mapping[i]) + for i in range(y_dims_mapping_len - 2): + broadcast_y_dims_mapping.append(y_dims_mapping[i]) + + for i in range(out_dims_mapping_len - 2): + broadcast_out_dims_mapping.append(out_dims_mapping[i]) + + compatible_dims_mapping = compute_compatible_dims_mapping([ + broadcast_x_dims_mapping, broadcast_y_dims_mapping, + broadcast_out_dims_mapping + ]) + assert compatible_dims_mapping is not None, "There is no compatible dim mapping." + + for i in range(x_dims_mapping_len - 2): + new_idx = i + (out_dims_mapping_len - x_dims_mapping_len) + if x_dims_mapping[i] != compatible_dims_mapping[new_idx]: + x_dims_mapping[i] = compatible_dims_mapping[new_idx] + changed = True + + for i in range(y_dims_mapping_len - 2): + new_idx = i + (out_dims_mapping_len - y_dims_mapping_len) + if y_dims_mapping[i] != compatible_dims_mapping[new_idx]: + y_dims_mapping[i] = compatible_dims_mapping[new_idx] + changed = True + + for i in range(out_dims_mapping_len - 2): + if out_dims_mapping[i] != compatible_dims_mapping[i]: + out_dims_mapping[i] = compatible_dims_mapping[i] + changed = True + + # The following which uses negative index can be work + # when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2 + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, y_dims_mapping], [-1, -2]) + if dim_changed: + changed = True + + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [-2, -2]) + if dim_changed: + changed = True + + dim_changed = compute_compatible_and_update_dim_mapping( + [y_dims_mapping, out_dims_mapping], [-1, -1]) + if dim_changed: + changed = True + + # Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor + if x_dims_mapping_len == 1: + x_dims_mapping.pop(0) + if y_dims_mapping_len == 1: + y_dims_mapping.pop(1) + + # print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping) + assert len(x_dims_mapping) == x_dims_mapping_len + assert len(y_dims_mapping) == y_dims_mapping_len + assert len(out_dims_mapping) == out_dims_mapping_len + + return changed + + +class DistributedMatmul(DistributedOperator): + def __init__(self, name): + super(DistributedMatmul, self).__init__() + self._name = name + + +register_distributed_operator("matmul", DistributedMatmul("matmul")) + + +# ColumnParallel +class DistributedMatmulImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulImpl0, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[ + 1]): + return False + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_replicate(out_dims_mapping[-1]): + return False + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + +# RowParallel +class DistributedMatmulImpl1(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulImpl1, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + if is_dim_replicate(x_dims_mapping[-1]): + return False + if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[ + -1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in x_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + if is_dim_shard(out_dims_mapping[-1]): + return False + # Other dimensions must be replicate except the batch dimension + for mapping in out_dims_mapping[1:-1]: + if is_dim_shard(mapping): + return False + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + +# ReplicateParallel +class DistributedMatmulImpl2(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulImpl2, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_valid_list_index(x_dims_mapping, + -2) and is_dim_shard(x_dims_mapping[-2]): + return False + + if is_dim_shard(y_dims_mapping[-1]): + return False + if is_valid_list_index(y_dims_mapping, + -2) and is_dim_shard(y_dims_mapping[-2]): + return False + + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if is_dim_shard(out_dims_mapping[-1]): + return False + if is_valid_list_index(out_dims_mapping, + -2) and is_dim_shard(out_dims_mapping[-2]): + return False + + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + +register_distributed_operator_impl("matmul", + DistributedMatmulImpl0("column_parallel")) +register_distributed_operator_impl("matmul", + DistributedMatmulImpl1("row_parallel")) +register_distributed_operator_impl("matmul", + DistributedMatmulImpl2("replicate_parallel")) + + +class DistributedMatmulV2(DistributedOperator): + def __init__(self, name): + super(DistributedMatmulV2, self).__init__() + self._name = name + + +register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2")) + + +# ReplicateParallel +class DistributedMatmulV2Impl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedMatmulV2Impl, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + y_name = op_desc.input('Y')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name) + + if is_dim_shard(x_dims_mapping[-1]): + return False + if is_valid_list_index(x_dims_mapping, + -2) and is_dim_shard(x_dims_mapping[-2]): + return False + + if is_dim_shard(y_dims_mapping[-1]): + return False + if is_valid_list_index(y_dims_mapping, + -2) and is_dim_shard(y_dims_mapping[-2]): + return False + + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if is_dim_shard(out_dims_mapping[-1]): + return False + if is_valid_list_index(out_dims_mapping, + -2) and is_dim_shard(out_dims_mapping[-2]): + return False + + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + dim_changed = _update_dims_mapping_for_matmul(op_dist_attr) + if dim_changed: + changed = True + return changed + + +register_distributed_operator_impl( + "matmul_v2", DistributedMatmulV2Impl("replicate_parallel")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..40da0e2f6093f2964cf66ad55e7e5fb1cac99a26 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping + + +class DistributedReshape2(DistributedOperator): + def __init__(self, name): + super(DistributedReshape2, self).__init__() + self._name = name + + +register_distributed_operator("reshape2", DistributedReshape2("reshape2")) + + +class DistributedReshapeImpl0(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedReshapeImpl0, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if len(x_dims_mapping) != len(out_dims_mapping) - 1: + return False + + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if len(x_dims_mapping) != len(out_dims_mapping) - 1: + return False + + if is_dim_shard(out_dims_mapping[-1]): + return False + + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + for i in range(len(x_dims_mapping)): + x_shape_dims_mapping[i + 1] = x_dims_mapping[i] + + return changed + + +class DistributedReshapeImpl1(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedReshapeImpl1, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if len(x_dims_mapping) != len(out_dims_mapping) + 1: + return False + + if is_dim_shard(x_dims_mapping[-1]): + return False + + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if len(x_dims_mapping) != len(out_dims_mapping) + 1: + return False + + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + + for i in range(len(out_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + for i in range(len(x_dims_mapping)): + x_shape_dims_mapping[i + 1] = x_dims_mapping[i] + + return changed + + +register_distributed_operator_impl("reshape2", + DistributedReshapeImpl0("add_one_dim_back")) +register_distributed_operator_impl( + "reshape2", DistributedReshapeImpl1("remove_one_dim_back")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..fad11aadf8020f290487874a506c6f2d3384fd99 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping + + +class DistributedSoftmax(DistributedOperator): + def __init__(self, name): + super(DistributedSoftmax, self).__init__() + self._name = name + + +register_distributed_operator("softmax", DistributedSoftmax("softmax")) + + +class DistributedSoftmaxImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedSoftmaxImpl, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + axis = op_desc.attr('axis') + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + # print("softmax axis", axis) + + if axis != -1 and axis != len(x_dims_mapping) - 1: + return False + + if is_dim_shard(x_dims_mapping[axis]): + return False + + return True + + def is_output_compatible(self, op_dist_attr): + op_desc = op_dist_attr.get_owner_op().desc + out_name = op_desc.output('Out')[0] + axis = op_desc.attr('axis') + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if axis != -1 and axis != len(out_dims_mapping) - 1: + return False + + if is_dim_shard(out_dims_mapping[axis]): + return False + + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + return changed + + +register_distributed_operator_impl( + "softmax", DistributedSoftmaxImpl("replicate_last_axis")) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..c2ca4d85fdf106760a5c16fe33bfbe5c26a7a8f2 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperator +from .common import DistributedOperatorImpl +from .common import register_distributed_operator +from .common import register_distributed_operator_impl +from ..utils import is_dim_shard +from ..utils import is_dim_replicate +from ..utils import is_valid_list_index +from ..utils import compute_compatible_dim_mapping +from ..utils import compute_compatible_dims_mapping +from ..utils import compute_compatible_and_update_dim_mapping + + +class DistributedTranspose2(DistributedOperator): + def __init__(self, name): + super(DistributedTranspose2, self).__init__() + self._name = name + + +register_distributed_operator("transpose2", DistributedTranspose2("transpose2")) + + +class DistributedTranspose2Impl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedTranspose2Impl, self).__init__() + self._name = name + + def is_process_mesh_compatible(self, op_dist_attr): + """ No restriction for now. """ + return True + + def is_input_compatible(self, op_dist_attr): + return True + + def is_output_compatible(self, op_dist_attr): + return True + + def update_dims_mapping(self, op_dist_attr): + changed = False + op_desc = op_dist_attr.get_owner_op().desc + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_shape_name = op_desc.output('XShape')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping( + x_shape_name) + perm = op_desc.attr('axis') + + assert len(x_dims_mapping) == len(perm) + + new_dims_mapping = [-1 for i in range(len(x_dims_mapping))] + for i in range(len(x_dims_mapping)): + new_dims_mapping[i] = x_dims_mapping[perm[i]] + + for i in range(len(out_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [new_dims_mapping, out_dims_mapping], [i, i]) + if dim_changed: + changed = True + + for i in range(len(x_dims_mapping)): + if x_dims_mapping[perm[i]] != new_dims_mapping[i]: + x_dims_mapping[perm[i]] = new_dims_mapping[i] + changed = True + + for i in range(len(x_dims_mapping)): + x_shape_dims_mapping[i + 1] = x_dims_mapping[i] + + return changed + + +register_distributed_operator_impl( + "transpose2", DistributedTranspose2Impl("same_mapping_transpose")) diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a73ae5c0a64d67f6f7caf395c3d12bb425b5c5 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import threading +import paddle.fluid.core as core + + +def is_valid_list_index(list, index): + if index >= -len(list) and index < len(list): + return True + else: + return False + + +def is_dim_shard(mapping): + if mapping != -1: + return True + else: + return False + + +def is_dim_replicate(mapping): + if mapping == -1: + return True + else: + return False + + +def compute_compatible_dim_mapping(dim_mappings): + if not dim_mappings: + return None + compatible_mapping = dim_mappings[0] + for mapping in dim_mappings: + if compatible_mapping == -1: + compatible_mapping = mapping + elif mapping == -1: + continue + elif compatible_mapping == mapping: + continue + else: + return None + return compatible_mapping + + +def compute_compatible_dims_mapping(dims_mapping_list): + if not dims_mapping_list: + return None + length = len(dims_mapping_list[0]) + for dims_mapping in dims_mapping_list: + assert dims_mapping is not None, \ + "Dims mapping must not be None for compatible computation" + assert len(dims_mapping) == length, \ + "The length of dims_mapping in list must be same for compatible computation." + compatible_result = [] + for dim_mappings in zip(*dims_mapping_list): + compatible_dim_mapping = compute_compatible_dim_mapping( + list(dim_mappings)) + if compatible_dim_mapping is None: + return None + compatible_result.append(compatible_dim_mapping) + return compatible_result + + +def compute_compatible_process_mesh(process_mesh_list): + compatible_process_mesh = None + if not process_mesh_list: + return compatible_process_mesh + for process_mesh in process_mesh_list: + if process_mesh is not None: + if compatible_process_mesh is None: + compatible_process_mesh = process_mesh + else: + assert process_mesh == compatible_process_mesh, \ + "There is no compatible process mesh." + return compatible_process_mesh + + +def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list): + assert len(dims_mapping_list) == len(index_list) + changed = False + dim_mappings = [] + for i in range(len(dims_mapping_list)): + assert is_valid_list_index(dims_mapping_list[i], index_list[i]) + dim_mappings.append(dims_mapping_list[i][index_list[i]]) + compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings) + if compatible_dim_mapping is None: + return False + for i in range(len(dims_mapping_list)): + if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]: + dims_mapping_list[i][index_list[i]] = compatible_dim_mapping + changed = True + return changed + + +def append_distributed_attr_suffix(name): + """ + Append auto parallel suffix for distributed attribute name. + """ + return name + core.kAutoParallelSuffix() + + +def remove_distributed_attr_suffix(name): + """ + Remove auto parallel suffix from distributed attribute name. + """ + return name.strip(core.kAutoParallelSuffix()) + + +def check_distributed_attr_for_program(program, dist_context=None): + from .context import get_default_distributed_context + if dist_context is None: + dist_context = get_default_distributed_context() + assert dist_context.is_initialized_for_program(), \ + "Distributed attributes must be initialized before check." + for block in program.blocks: + for tensor in block.vars.values(): + tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + tensor) + if (tensor_dist_attr is not None) and ( + not tensor_dist_attr.is_valid()): + return False + for op in block.ops: + op_dist_attr = dist_context.get_op_distributed_attr_for_program(op) + if (op_dist_attr is not None) and (not op_dist_attr.is_valid()): + return False + return True + + +def print_program_with_distributed_attr(program, dist_context=None): + """ + This function reuses the original program output ability with a distributed context. + Using lock can avoid multiple threads change the default distributed context simultaneously. + """ + lock = threading.Lock() + lock.acquire() + from .context import get_default_distributed_context + from .context import set_default_distributed_context + if dist_context is None: + dist_context = get_default_distributed_context() + print(program) + else: + original_default_context = get_default_distributed_context() + set_default_distributed_context(dist_context) + print(program) + set_default_distributed_context(original_default_context) + lock.release() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 12aa0c9391019ad7b5a2bcb0ed617769f7f389f6..13477fd3422007565856dffc1a511579045adfcd 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1224,6 +1224,14 @@ class Variable(object): if self.persistable: var_str = "persist " + var_str + from paddle.distributed.auto_parallel.context import get_default_distributed_context + dist_context = get_default_distributed_context() + var_dist_attr = dist_context.get_tensor_distributed_attr_for_program( + self) + if var_dist_attr is not None: + var_str += ", {name} = {value}".format( + name="dist_attr", value=var_dist_attr) + return var_str def to_string(self, throw_on_error, with_details=False): @@ -2384,6 +2392,13 @@ class Operator(object): if i != len(attr_names) - 1: attrs_str += ", " + from paddle.distributed.auto_parallel.context import get_default_distributed_context + dist_context = get_default_distributed_context() + op_dist_attr = dist_context.get_op_distributed_attr_for_program(self) + if op_dist_attr is not None: + attrs_str += ", {name} = {value}".format( + name="dist_attr", value=op_dist_attr) + if outputs_str != "{}": op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ format(outputs=outputs_str, op_type=self.type, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1943ce6c60f20474518fdda0847ed4a1b7659b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -0,0 +1,676 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import unittest.mock +from io import StringIO + +import paddle +import paddle.nn as nn +import paddle.static as static +import paddle.nn.functional as F +import paddle.utils as utils +import paddle.tensor as tensor +from paddle.fluid import layers +from paddle.nn.layer.transformer import _convert_param_attr_to_list +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix +from paddle.distributed.auto_parallel.context import DistributedContext +from paddle.distributed.auto_parallel.context import set_default_distributed_context +paddle.enable_static() +_global_parallel_stratergy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + out = self.dropout(out) + + return out + + +def mlp_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + out = mlp(input) + return train_program, start_program + + +class TestMLPAutoCompletion(unittest.TestCase): + def test_mlp_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_mlp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_mlp_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_mlp_misc(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = mlp_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + dist_context.finalize_distributed_attr_for_program( + complete_train_program) + from paddle.distributed.auto_parallel.interface import _g_process_mesh_map + for block in complete_train_program.blocks: + for tensor in block.vars.values(): + desc = tensor.desc + attr_name = append_distributed_attr_suffix("mesh_id") + self.assertIsNotNone(desc.has_attr(attr_name)) + attr_name = append_distributed_attr_suffix("dim_mapping") + self.assertIsNotNone(desc.has_attr(attr_name)) + for op in block.ops: + desc = op.desc + attr_name = append_distributed_attr_suffix("mesh_id") + self.assertIsNotNone(desc.has_attr(attr_name)) + for tensor_name in desc.input_arg_names(): + attr_name = append_distributed_attr_suffix("IN_" + + tensor_name) + self.assertIsNotNone(desc.has_attr(attr_name)) + for tensor_name in desc.output_arg_names(): + attr_name = append_distributed_attr_suffix("OUT_" + + tensor_name) + self.assertIsNotNone(desc.has_attr(attr_name)) + set_default_distributed_context(dist_context) + self.assertTrue("dist_attr" in str(complete_train_program)) + with unittest.mock.patch( + "sys.stdout", new_callable=StringIO) as mock_stdout: + print_program_with_distributed_attr(complete_train_program) + self.assertIsNotNone(mock_stdout.getvalue()) + + +class AttentionLayer(nn.Layer): + def __init__(self, + hidden_size=1024, + sequence_len=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(AttentionLayer, self).__init__() + self.hidden_size = hidden_size + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range)) + bias_attr = None + + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + def forward(self, input): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input, _global_process_mesh, dim_mapping=[0, -1, -1]) + + q = self.q_proj(input) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(input) + v = self.v_proj(input) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + return out + + +def attn_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input = static.data( + name="query", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + attn = AttentionLayer( + hidden_size=hidden_size, + sequence_len=sequence_len, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = attn(input) + + return train_program, start_program + + +class TestAttentionAutoCompletion(unittest.TestCase): + def test_attn_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = attn_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_attn_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = attn_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_attn_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = attn_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + +class DecoderLayer(nn.Layer): + def __init__(self, + vocab_size=32768, + hidden_size=1024, + sequence_len=512, + max_position_embeddings=512, + intermediate_size=4 * 1024, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02): + super(DecoderLayer, self).__init__() + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.max_position_embeddings = max_position_embeddings + self.sequence_len = sequence_len + self.embed_dim = self.hidden_size + self.kdim = self.embed_dim + self.vdim = self.embed_dim + self.num_heads = num_heads + self.dropout_ratio = dropout_ratio + self.initializer_range = initializer_range + self.training = True + self.attn_mask = None + + self.head_dim = self.embed_dim // self.num_heads + assert self.head_dim * self.num_heads == self.embed_dim, \ + "embed_dim must be divisible by num_heads" + self.word_embeddings = nn.Embedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + self.position_embeddings = nn.Embedding( + self.max_position_embeddings, + self.hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range))) + + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.q_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr) + + intermediate_size = 4 * self.hidden_size + d_model = self.hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)) + bias_attr = None + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(self.dropout_ratio) + self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") + + def forward(self, input_ids, position_ids): + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + input_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + embeddings = input_embeddings + position_embeddings + embeddings = self.dropout1(embeddings) + + # Pre-norm + target = self.norm(embeddings) + + # The following is the attention part + q = self.q_proj(target) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + k = self.k_proj(target) + v = self.v_proj(target) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if self.attn_mask is not None: + product = product + self.attn_mask + + weights = F.softmax(product) + + if self.dropout_ratio: + weights = F.dropout( + weights, + self.dropout_ratio, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + # Add residual + residual = embeddings + self.dropout2(out) + + # Pre-norm + out0 = self.norm(residual) + + # The following is the MLP part + out1 = self.linear0(out0) + out2 = F.gelu(out1, approximate=True) + out3 = self.linear1(out2) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1]) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # Add residual + final = residual + self.dropout3(out3) + return final + + +def decoder_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 4 + hidden_size = 1024 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + decoder = DecoderLayer( + vocab_size=32768, + hidden_size=hidden_size, + sequence_len=sequence_len, + max_position_embeddings=512, + intermediate_size=4 * hidden_size, + num_heads=16, + dropout_ratio=0.1, + initializer_range=0.02) + out = decoder(input_ids, position_ids) + + return train_program, start_program + + +class TestDecoderLayerAutoCompletion(unittest.TestCase): + def test_decoder_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = decoder_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_decoder_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = decoder_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_decoder_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = decoder_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py new file mode 100644 index 0000000000000000000000000000000000000000..204e8910e05104dad379531928a10b45e871682c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -0,0 +1,814 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import collections +import math +import unittest + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.tensor as tensor +import paddle.utils as utils +from paddle.fluid import layers +from paddle.fluid.framework import in_dygraph_mode +from paddle.nn.layer.transformer import _convert_param_attr_to_list +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.distributed.fleet import fleet +import paddle.static as static +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program +from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr +from paddle.distributed.auto_parallel.context import DistributedContext + +paddle.enable_static() +_global_parallel_stratergy = None +_global_process_mesh = None +ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]]) + + +class MultiHeadAttention(nn.Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + """ + + Cache = collections.namedtuple("Cache", ["k", "v"]) + StaticCache = collections.namedtuple("StaticCache", ["k", "v"]) + + def __init__(self, + embed_dim, + num_heads, + dropout=0., + kdim=None, + vdim=None, + need_weights=False, + weight_attr=None, + bias_attr=None, + topo=None, + fuse=False): + super(MultiHeadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.need_weights = need_weights + self.fuse = fuse + + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if topo is None or topo.mp_info.size == 1: + if self.fuse: + assert self.kdim == embed_dim + assert self.vdim == embed_dim + self.qkv_proj = nn.Linear( + embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr) + else: + self.q_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + self.k_proj = nn.Linear( + self.kdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.v_proj = nn.Linear( + self.vdim, embed_dim, weight_attr, bias_attr=bias_attr) + self.out_proj = nn.Linear( + embed_dim, embed_dim, weight_attr, bias_attr=bias_attr) + + def _fuse_prepare_qkv(self, query): + mix_layer = self.qkv_proj(query) + mix_layer = paddle.reshape_(mix_layer, + [0, 0, self.num_heads, 3 * self.head_dim]) + mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3]) + q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1) + return q, k, v + + def _prepare_qkv(self, query, key, value, use_cache=False, cache=None): + r""" + Prapares linear projected queries, keys and values for usage of subsequnt + multiple parallel attention. If `cache` is not None, using cached results + to reduce redundant calculations. + """ + q = self.q_proj(query) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + + if isinstance(cache, self.StaticCache): + # for encoder-decoder attention in inference and has cached + k, v = cache.k, cache.v + else: + k, v = self.compute_kv(key, value) + + if isinstance(cache, self.Cache): + # for decoder self-attention in inference + k = tensor.concat([cache.k, k], axis=2) + v = tensor.concat([cache.v, v], axis=2) + if use_cache is True: + cache = self.Cache(k, v) + + return (q, k, v) if use_cache is False else (q, k, v, cache) + + def compute_kv(self, key, value): + r""" + Applies linear projection on input keys and values, then splits heads + (reshape and transpose) to get keys and values from different representation + subspaces. The results are used as key-values pairs for subsequent multiple + parallel attention. + It is part of calculations in multi-head attention, and is provided as + a method to pre-compute and prefetch these results, thus we can use them + to construct cache for inference. + """ + k = self.k_proj(key) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + v = self.v_proj(value) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + return k, v + + def gen_cache(self, key, value=None, type=Cache): + """ + Generates cache for `forward` usage in inference accroding to arguments. + The generated cache is an instance of `MultiHeadAttention.Cache` or an + instance of `MultiHeadAttention.StaticCache`. + """ + if type == MultiHeadAttention.StaticCache: # static_kv + k, v = self.compute_kv(key, value) + return self.StaticCache(k, v) + elif value is None: # incremental_state + k = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + v = layers.fill_constant_batch_size_like( + input=key, + shape=[-1, self.num_heads, 0, self.head_dim], + dtype=key.dtype, + value=0) + return self.Cache(k, v) + else: + # incremental_state with initial value, mainly for usage like UniLM + return self.Cache(key, value) + + def forward(self, + query, + key, + value, + attn_mask=None, + use_cache=False, + cache=None): + r""" + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + """ + key = query if key is None else key + value = query if value is None else value + # compute q ,k ,v + if use_cache is False: + if self.fuse: + q, k, v = self._fuse_prepare_qkv(query) + else: + q, k, v = self._prepare_qkv(query, key, value, use_cache, cache) + else: + q, k, v, cache = self._prepare_qkv(query, key, value, use_cache, + cache) + # scale dot product attention + product = layers.matmul( + x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5) + + if attn_mask is not None: + product = product + attn_mask + + weights = F.softmax(product) + if self.dropout: + weights = F.dropout( + weights, + self.dropout, + training=self.training, + mode="upscale_in_train") + + out = tensor.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.out_proj(out) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.out_proj.weight, _global_process_mesh, + dim_mapping=[1, -1]) + + outs = [out] + if self.need_weights: + outs.append(weights) + if use_cache: + outs.append(cache) + return out if len(outs) == 1 else tuple(outs) + + +class TransformerDecoder(nn.Layer): + """ + TransformerDecoder is a stack of N decoder layers. + """ + + def __init__(self, + decoder_layers, + num_layers, + norm=None, + hidden_size=None, + topo=None): + super(TransformerDecoder, self).__init__() + + self.topo = topo + self.num_layers = num_layers + self.layers = decoder_layers + self.norm = norm + if norm is "LayerNorm": + self.norm = nn.LayerNorm(hidden_size) + elif norm is not None: + raise ValueError("Only support LayerNorm") + self.checkpoints = [] + + def forward(self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + use_cache=False, + cache=None): + r""" + Applies a stack of N Transformer decoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last decoder + layer. + """ + output = tgt + new_caches = [] + self.checkpoints = [] + + for i, mod in enumerate(self.layers): + if cache is None: + if use_cache: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + new_caches.append(new_cache) + else: + output = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache) + + else: + output, new_cache = mod(output, + memory, + tgt_mask=tgt_mask, + use_cache=use_cache, + cache=cache[i]) + new_caches.append(new_cache) + self.checkpoints.append(output.name) + + if self.norm is not None: + output = self.norm(output) + return output if use_cache is False else (output, new_caches) + + def gen_cache(self, memory, do_zip=False): + r""" + Generates cache for `forward` usage. The generated cache is a list, and + each element in it is a tuple( :code:`(incremental_cache, static_cache)` ) + produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache` + for more details. If `do_zip` is True, apply `zip` on these tuples to get + a list with two elements. + """ + cache = [layer.gen_cache(memory) for layer in self.layers] + if do_zip: + cache = list(zip(*cache)) + return cache + + +class TransformerDecoderLayer(nn.Layer): + """ + The transformer decoder layer. + It contains multiheadattention and some linear layers. + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout=0.1, + activation="gelu", + attn_dropout=None, + act_dropout=None, + normalize_before=True, + weight_attr=None, + bias_attr=None, + topo=None): + self._config = locals() + self._config.pop("self") + self._config.pop("__class__", None) # py3 + + super(TransformerDecoderLayer, self).__init__() + attn_dropout = dropout if attn_dropout is None else attn_dropout + act_dropout = dropout if act_dropout is None else act_dropout + self.normalize_before = normalize_before + + weight_attrs = _convert_param_attr_to_list(weight_attr, 3) + bias_attrs = _convert_param_attr_to_list(bias_attr, 3) + + self.self_attn = MultiHeadAttention( + d_model, + nhead, + dropout=attn_dropout, + weight_attr=weight_attrs[0], + bias_attr=bias_attrs[0], + topo=topo) + if topo is None or topo.mp_info.size == 1: + self.linear1 = nn.Linear( + d_model, + dim_feedforward, + weight_attrs[2], + bias_attr=bias_attrs[2]) + self.linear2 = nn.Linear( + dim_feedforward, + d_model, + weight_attrs[2], + bias_attr=bias_attrs[2]) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train") + self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train") + self.activation = getattr(F, activation) + + def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None): + residual = tgt + + if self.normalize_before: + tgt = self.norm1(tgt) + + if use_cache is False: + tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache) + else: + tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask, + use_cache, cache) + tgt = residual + self.dropout1(tgt) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1]) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1]) + + # tgt = self.dropout2( + # self.linear2(F.gelu( + # self.linear1(tgt), approximate=True))) + tgt = self.linear1(tgt) + tgt = F.gelu(tgt, approximate=True) + tgt = self.dropout2(self.linear2(tgt)) + tgt = residual + tgt + + if not self.normalize_before: + tgt = self.norm2(tgt) + + return tgt if use_cache is False else (tgt, incremental_cache) + + def gen_cache(self, memory): + incremental_cache = self.self_attn.gen_cache( + memory, type=self.self_attn.Cache) + return incremental_cache + + +class GPTEmbeddings(nn.Layer): + """ + Include embeddings from word, position and token_type embeddings + """ + + def __init__(self, + vocab_size, + hidden_size=768, + hidden_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + topo=None): + super(GPTEmbeddings, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.word_embeddings = nn.Embedding( + vocab_size, + hidden_size, + weight_attr=paddle.ParamAttr( + name="word_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + self.position_embeddings = nn.Embedding( + max_position_embeddings, + hidden_size, + weight_attr=paddle.ParamAttr( + name="pos_embeddings", + initializer=nn.initializer.Normal( + mean=0.0, std=initializer_range))) + + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, position_ids=None): + if position_ids is None: + ones = paddle.ones_like(input_ids, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + position_ids = seq_length - ones + + input_embedings = self.word_embeddings(input_ids) + + if _global_parallel_stratergy == "mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + self.word_embeddings.weight, + _global_process_mesh, + dim_mapping=[1, -1]) + + position_embeddings = self.position_embeddings(position_ids) + embeddings = input_embedings + position_embeddings + embeddings = self.dropout(embeddings) + return embeddings + + +class GPTModel(nn.Layer): + """ + The base model of gpt. + """ + + def __init__(self, + vocab_size, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None): + super(GPTModel, self).__init__() + + self.pad_token_id = pad_token_id + self.initializer_range = initializer_range + self.topo = topo + self.hidden_size = hidden_size + self.vocab_size = vocab_size + + self.pipline_mode = topo is not None and topo.pp_info.size > 1 + if self.pipline_mode: + self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size + + self.embeddings = GPTEmbeddings( + vocab_size, hidden_size, hidden_dropout_prob, + max_position_embeddings, type_vocab_size, self.initializer_range, + topo) + + decoder_layers = nn.LayerList() + for i in range(num_hidden_layers): + DecoderLayer = TransformerDecoderLayer + decoder_layers.append( + DecoderLayer( + d_model=hidden_size, + nhead=num_attention_heads, + dim_feedforward=intermediate_size, + dropout=hidden_dropout_prob, + activation=hidden_act, + attn_dropout=attention_probs_dropout_prob, + act_dropout=hidden_dropout_prob, + weight_attr=paddle.ParamAttr( + initializer=nn.initializer.Normal( + mean=0.0, std=self.initializer_range)), + bias_attr=None, + topo=topo)) + + Decoder = TransformerDecoder + + self.decoder = Decoder( + decoder_layers, + num_hidden_layers, + norm="LayerNorm", + hidden_size=hidden_size, + topo=topo) + + self.checkpoints = [] + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + use_cache=False, + cache=None): + self.checkpoints = [] + if attention_mask is None: + length = paddle.shape(input_ids)[1] + # Use bool mask + attention_mask = paddle.tensor.tril( + paddle.ones( + (length, length), + dtype=self.embeddings.word_embeddings.weight.dtype)) + if position_ids is None: + past_length = 0 + if cache is not None: + past_length = paddle.shape(cache[0].k)[-2] + position_ids = paddle.arange( + past_length, + paddle.shape(input_ids)[-1] + past_length, + dtype='int64') + position_ids = position_ids.unsqueeze(0) + # .expand_as(input_ids) + position_ids = paddle.fluid.layers.expand_as(position_ids, + input_ids) + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids) + + # TODO, use registered buffer + causal_mask = paddle.tensor.triu( + paddle.ones((paddle.shape(input_ids)[-1], + paddle.shape(input_ids)[-1])) * -1e9, + diagonal=1) + + if attention_mask is not None: + attention_mask = attention_mask + causal_mask + else: + attention_mask = causal_mask + + # The tensor returned by triu not in static graph. + attention_mask.stop_gradient = True + + encoder_outputs = self.decoder( + embedding_output, + memory=None, + tgt_mask=attention_mask, + use_cache=use_cache, + cache=cache) + self.checkpoints.extend(self.decoder.checkpoints) + return encoder_outputs + + +class GPTForPretraining(nn.Layer): + """ + The pretraining model of GPT. + It returns some logits and cached_kvs. + """ + + def __init__(self, gpt): + super(GPTForPretraining, self).__init__() + self.gpt = gpt + self.share_param = False + self.weight = self.gpt.embeddings.word_embeddings.weight + if not self.share_param: + self.weight = self.create_parameter(shape=self.weight.shape) + + def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo): + if topo is not None and topo.mp_info.size > 1: + input_parallel = paddle.distributed.collective._c_identity( + lm_output, group=None) + + logits = paddle.matmul( + input_parallel, logit_weights, transpose_y=True) + + if parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=None) + else: + logits = paddle.matmul(lm_output, logit_weights, transpose_y=True) + return logits + + def forward(self, + input_ids, + position_ids=None, + attention_mask=None, + masked_positions=None, + use_cache=False, + cache=None): + outputs = self.gpt(input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + use_cache=use_cache, + cache=cache) + if use_cache: + encoder_outputs, cached_kvs = outputs[:2] + else: + encoder_outputs = outputs + logits = self.parallel_matmul(encoder_outputs, self.weight, True, + self.gpt.topo) + + if use_cache: + return logits, cached_kvs + else: + return logits + + +class GPTPretrainingCriterion(nn.Layer): + """ + Criterion for GPT. + It calculates the final loss. + """ + + def __init__(self, topo=None): + super(GPTPretrainingCriterion, self).__init__() + if topo is None or topo.mp_info.size == 1: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") + else: + self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy + + def forward(self, prediction_scores, masked_lm_labels, loss_mask): + masked_lm_loss = self.loss_func(prediction_scores, + masked_lm_labels.unsqueeze(2)) + + loss_mask = loss_mask.reshape([-1]) + masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) + loss = masked_lm_loss / loss_mask.sum() + return loss + + +def gpt_pretrain_forward(train_program, start_program): + with static.program_guard(train_program, + start_program), utils.unique_name.guard(): + batch_size = 16 + sequence_len = 512 + input_ids = static.data( + name="input_ids", shape=[batch_size, sequence_len], dtype='int64') + position_ids = static.data( + name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float64') + labels = static.data( + name="labels", shape=[batch_size, sequence_len], dtype='int64') + loss_mask = static.data( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float64') + + if _global_parallel_stratergy == "dp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + elif _global_parallel_stratergy == "dp_mp": + auto.shard_tensor( + input_ids, _global_process_mesh, dim_mapping=[0, -1]) + + gpt = GPTModel( + vocab_size=32768, + hidden_size=1024, + num_hidden_layers=2, + num_attention_heads=16, + intermediate_size=4096, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=1024, + type_vocab_size=16, + initializer_range=0.02, + pad_token_id=0, + topo=None) + + model = GPTForPretraining(gpt) + + preds = model(input_ids, position_ids, attention_mask) + + criterion = GPTPretrainingCriterion() + + loss = criterion(preds, labels, loss_mask) + + return train_program, start_program + + +class TestGPTAutoCompletion(unittest.TestCase): + def test_gpt_dp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = gpt_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_gpt_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[0, 1, 2, 3], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = gpt_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + def test_gpt_dp_mp(self): + global _global_parallel_stratergy + _global_parallel_stratergy = "dp_mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH) + + train_program = static.Program() + start_program = static.Program() + dist_context = DistributedContext() + train_program, start_program = gpt_pretrain_forward(train_program, + start_program) + complete_train_program = auto.complete_annotation(train_program, + dist_context) + # print_program_with_distributed_attr(complete_train_program, + # dist_context) + self.assertTrue( + check_distributed_attr_for_program(complete_train_program, + dist_context)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 07cf4c3a252df8ebf35d4645bab19aea78c5d3e4..499054492694881237dff771b7ad49310931a48b 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -165,6 +165,7 @@ packages=['paddle', 'paddle.distributed.fleet.meta_parallel.pp_utils', 'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.auto_parallel', + 'paddle.distributed.auto_parallel.operators', 'paddle.framework', 'paddle.jit', 'paddle.jit.dy2static',