未验证 提交 93d862b0 编写于 作者: Y Yulong Ao 提交者: GitHub

Add auto completion module for auto parallel (#34813)

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* add dist

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update

* update

* delete unused proto

* resotre op_desc

* restore type_defs

* update var_desc

* remove dimss_mapping for proto_pybind

* update interface.py

* update framework.py

* update

* update

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* [WIP] Add the auto completion feature and related codes

* [WIP] Improve the auto completion and related codes

* [WIP] Make the auto completion to support data-parallel

* [WIP] Make the completion support mp and dp+mp

* [WIP] Refactor auto completion unit test for MLP

* [WIP] Refactor the implementation of DistributedOperatorImpl

* [WIP] Improve dims_mapping update rule and fix a bug

* [WIP] Support auto completion for one transformer decoder layer

* [WIP] Add a minor change

* [WIP] Fix a bug within the uint test

* Shard XShape tensor, add embedding completion and refactor code

* Add the distributed_operators dir to setup.py.in

* Improve the completion process and add the unittest for gpt

* fix process_mesh ut

* fix process_mesh ut

* update

* update, test=develop

* Add support for automatically completing distributed attrs of special ops

* update

* update

* update

* fix doc sample codes, test=develop

* improve coverage, test=develop

* add static_mode check, test=develop

* Model the cluster for cost model and physical mapping

* update, test=develop

* add set_placement, test=develop

* Add the check to make sure the candidate tensors' size is great than zero

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update, test=develop

* Auto mark dist attrs annotated by user

* update ndarray to nested list, test=develop

* update, test=develop

* Add auto-completion module for auto-parallel (based on PR#33804)

* Remove unnecessary files

* Remove unrelated files for the auto completion pr

* Update the unit test to improve the coverage

* Modify codes based on reviews

* Minor changes for CI

* Improve some codes based on new comments

* Fix bugs caused by shallow copy in attributes.py
* Imporve amend_distributed_attr_for_program in context.py
* Other changes for weihang's comments
Co-authored-by: Nsandyhouse <lilong12@baidu.com>
上级 e8f146a9
...@@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
outputs_ = op_desc.outputs_; outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_; attrs_ = op_desc.attrs_;
need_update_ = true; need_update_ = true;
// When creating graph from program, the creation of op node will create a new
// OpDesc instead of
// referring to the original one. To find the original OpDesc of the op node,
// the id have to be
// copied to the new OpDesc. The var node has the same situation, but the
// default copy constructor
// can copy the id automatically.
id_ = op_desc.id_;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -151,6 +152,18 @@ class OpDesc { ...@@ -151,6 +152,18 @@ class OpDesc {
const BlockDesc *Block() const { return this->block_; } const BlockDesc *Block() const { return this->block_; }
// This thread-safe implementation seems to be redudent since the neural
// networks
// are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> id{0};
return ++id;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
...@@ -173,6 +186,8 @@ class OpDesc { ...@@ -173,6 +186,8 @@ class OpDesc {
// need_update_ indicate there some local changes not be synchronized. If // need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true. // local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false}; bool need_update_{false};
uint64_t id_ = GenerateId();
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <atomic>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -150,6 +151,17 @@ class VarDesc { ...@@ -150,6 +151,17 @@ class VarDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private: private:
const proto::VarType::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
...@@ -158,6 +170,7 @@ class VarDesc { ...@@ -158,6 +170,7 @@ class VarDesc {
proto::VarDesc desc_; proto::VarDesc desc_;
AttributeMap attrs_; AttributeMap attrs_;
uint64_t id_ = GenerateId();
}; };
bool operator==(const VarDesc &left, const VarDesc &right); bool operator==(const VarDesc &left, const VarDesc &right);
......
...@@ -24,7 +24,6 @@ limitations under the License. */ ...@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle { namespace paddle {
...@@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) { ...@@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) {
.def("attr_names", &pd::VarDesc::AttrNames) .def("attr_names", &pd::VarDesc::AttrNames)
.def("_set_attr", &pd::VarDesc::SetAttr) .def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr) .def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("id", &pd::VarDesc::Id)
.def("attr", &pd::VarDesc::GetAttr); .def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", ""); pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
...@@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) {
.def("serialize_to_string", SerializeMessage<pd::OpDesc>) .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", [](pd::OpDesc &self) { return self.Block(); }, .def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("id", &pd::OpDesc::Id)
.def("inputs", &pd::OpDesc::Inputs) .def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs); .def("outputs", &pd::OpDesc::Outputs);
} }
......
...@@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 ...@@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401
from . import cloud_utils # noqa: F401 from . import cloud_utils # noqa: F401
from . import utils # noqa: F401 from . import utils # noqa: F401
__all__ = [ #noqa
__all__ = [ # noqa
"spawn", "spawn",
"scatter", "scatter",
"broadcast", "broadcast",
......
...@@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401 ...@@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401 from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401 from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401 from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
__all__ = [] __all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from paddle.fluid import core
from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_distributed_attr
from .context import get_default_distributed_context
from .operators import find_best_compatible_distributed_operator_impl
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
def is_elementwise_like_op(op_type):
if op_type in ELEMENTWISE_LIKE_OP_LIST:
return True
else:
return False
def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
"""
Update tensor's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
return changed
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
if fwd:
inputs_process_meshes = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
else:
outputs_process_meshes = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
return changed
def update_op_node_process_mesh(dist_context, op_node, fwd=True):
"""
Update op's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"):
return changed
op_process_mesh = op_dist_attr.get_process_mesh()
if fwd:
inputs_process_meshes = []
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
return changed
def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting)."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
changed = False
if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False
tensor_desc = tensor_node.var()
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"):
return False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
if fwd:
dims_mapping_list = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
changed = True
else:
dims_mapping_list = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
changed = True
return changed
def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(tensor_desc.name(),
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
return changed
def complete_annotation(program, dist_context=None):
""" Complete annotation for the partial annotated program.
Arguments:
program: partial annotated program.
dist_context: the distributed context is used to store distributed attributes for program.
If not provided, the default one will be used.
Returns:
program: completed annotated program.
"""
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
# Initialize distributed attributes for all var and op node in program
dist_context.initialize_distributed_attr_for_program(program)
# print_program_with_distributed_attr(program, dist_context)
# Convert program to graph
graph = framework.IrGraph(core.Graph(program.desc))
# Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph)
# # Complete process mesh for each node
all_nodes = list(graph.all_nodes())
reach_fix_point = False
while not reach_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
# Complete dims_mapping for each node
reach_fix_point = False
while not reach_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program
dist_context.copy_distribute_attr_from_graph_to_program(graph, program)
dist_context.clear_distributed_attr_for_graph()
# Do the validation check and amend some completion
dist_context.amend_distributed_attr_for_program()
return program
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
def get_default_distributed_context():
global DEFAULT_DISTRIBUTED_CONTEXT
if DEFAULT_DISTRIBUTED_CONTEXT is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return DEFAULT_DISTRIBUTED_CONTEXT
def set_default_distributed_context(dist_context):
global DEFAULT_DISTRIBUTED_CONTEXT
DEFAULT_DISTRIBUTED_CONTEXT = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self):
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._tensor_distributed_attr_map_for_program = {}
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
def is_initialized_for_program(self):
return self._is_initialized_for_program
def is_initialized_for_graph(self):
return self._is_initialized_for_graph
def get_tensor_distributed_attr_for_program(self, tensor):
tensor_id = tensor.desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get(
tensor_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr):
tensor_id = tensor.desc.id()
self._tensor_distributed_attr_map_for_program[
tensor_id] = tensor_dist_attr
def get_op_distributed_attr_for_program(self, op):
op_id = op.desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_program(self, op, op_dist_attr):
op_id = op.desc.id()
self._op_distributed_attr_map_for_program[op_id] = op_dist_attr
def get_tensor_distributed_attr_for_graph(self, tensor_node):
tensor_node_id = tensor_node.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get(
tensor_node_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_graph(self, tensor_node,
tensor_dist_attr):
tensor_node_id = tensor_node.id()
self._tensor_distributed_attr_map_for_graph[
tensor_node_id] = tensor_dist_attr
def get_op_distributed_attr_for_graph(self, op_node):
op_node_id = op_node.id()
op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
for block in program.blocks:
for tensor in block.vars.values():
# Since only tensors have distributed attributes, it's better to make sure var is a tensor
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is None:
tensor_dist_attr = TensorDistributedAttribute(tensor, self)
self._copy_distributed_attr_from_tensor_desc(
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
tensor_dist_attr.mark_as_annotated("dims_mapping")
if isinstance(tensor, framework.Parameter):
tensor_dist_attr.mark_as_parameter()
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
op_dist_attr = OperatorDistributedAttribute(op, self)
self._copy_distributed_attr_from_op_desc(op.desc,
op_dist_attr)
self.set_op_distributed_attr_for_program(op, op_dist_attr)
# Default distributed implementation for all operators
# This will be updated during the completion prcess
op_dist_attr.set_impl_idx(-2)
if op_dist_attr.get_process_mesh() is not None:
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_input_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_output_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
self._is_initialized_for_program = True
def finalize_distributed_attr_for_program(self, program):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before finalization."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is not None:
self._store_distributed_attr_to_tensor_desc(
tensor.desc, tensor_dist_attr)
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
self._store_distributed_attr_to_op_desc(op.desc,
op_dist_attr)
def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
attr_name = append_distributed_attr_suffix("dim_mapping")
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_dims_mapping(copied_dims_mapping)
attr_name = append_distributed_attr_suffix("mask")
if desc.has_attr(attr_name):
shard_mask = desc.attr(attr_name)
copied_shard_mask = copy.deepcopy(shard_mask)
dist_attr.set_shard_mask(copied_shard_mask)
attr_name = append_distributed_attr_suffix("offload_device")
if desc.has_attr(attr_name):
offload_device = desc.attr(attr_name)
copied_offload_device = copy.deepcopy(offload_device)
dist_attr.set_offload_device(copied_offload_device)
def _copy_distributed_attr_from_op_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
attr_name = append_distributed_attr_suffix("pipeline_stage")
if desc.has_attr(attr_name):
pipeline_stage = desc.attr(attr_name)
copied_pipeline_stage = copy.deepcopy(pipeline_stage)
dist_attr.set_pipeline_stage(copied_pipeline_stage)
def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
dims_mapping = dist_attr.get_dims_mapping()
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("dim_mapping")
desc._set_attr(attr_name, dims_mapping)
shard_mask = dist_attr.get_shard_mask()
if shard_mask is not None:
attr_name = append_distributed_attr_suffix("mask")
desc._set_attr(attr_name, shard_mask)
offload_device = dist_attr.get_offload_device()
if offload_device is not None:
attr_name = append_distributed_attr_suffix("offload_device")
desc._set_attr(attr_name, offload_device)
def _store_distributed_attr_to_op_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
for tensor_name in desc.input_arg_names():
dims_mapping = dist_attr.get_input_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
for tensor_name in desc.output_arg_names():
dims_mapping = dist_attr.get_output_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
pipeline_stage = dist_attr.get_pipeline_stage()
if pipeline_stage is not None:
attr_name = append_distributed_attr_suffix("pipeline_stage")
desc._set_attr(attr_name, pipeline_stage)
def initialize_distributed_attr_for_graph(self, graph):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before its graph."
if self._is_initialized_for_graph:
return
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program[
tensor_id]
assert tensor_dist_attr is not None, \
"Tensor must have a distributed attribute after the initialization for program."
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self.set_tensor_distributed_attr_for_graph(node,
new_tensor_dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program[op_id]
assert op_dist_attr is not None, \
"Operator must have a distributed attribute after the initialization for program."
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self.set_op_distributed_attr_for_graph(node, new_op_dist_attr)
self._is_initialized_for_graph = True
def clear_distributed_attr_for_program(self):
self._tensor_distributed_attr_map_for_program.clear()
self._op_distributed_attr_map_for_program.clear()
def clear_distributed_attr_for_graph(self):
self._tensor_distributed_attr_map_for_graph.clear()
self._op_distributed_attr_map_for_graph.clear()
def copy_distribute_attr_from_graph_to_program(self, graph, program):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"The distribute attributes must be initialized both in its program and graph"
updated_tensors = {}
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr = self.get_tensor_distributed_attr_for_graph(
node)
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self._tensor_distributed_attr_map_for_program[
tensor_id] = new_tensor_dist_attr
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self.get_op_distributed_attr_for_graph(node)
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self._op_distributed_attr_map_for_program[
op_id] = new_op_dist_attr
def amend_distributed_attr_for_program(self):
for attr in self._tensor_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Tensor's distributed attribute {} is not valid".format(attr)
tensor_shape = attr.get_shape()
dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[
i]] > tensor_shape[i]:
dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Operator's distributed attribute {} is not valid".format(attr)
for arg_name in attr.get_owner_op().desc.input_arg_names():
tensor_shape = attr.get_input_shape(arg_name)
dims_mapping = attr.get_input_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names():
tensor_shape = attr.get_output_shape(arg_name)
dims_mapping = attr.get_output_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import numpy import numpy
import paddle.fluid.core as core import copy
import paddle import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
...@@ -237,6 +238,23 @@ class ProcessMesh(object): ...@@ -237,6 +238,23 @@ class ProcessMesh(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping): def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert len(tensor.shape) == len(dim_mapping) assert len(tensor.shape) == len(dim_mapping)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
DISTRIBUTED_OPERATORS = {}
class DistributedOperator:
def __init__(self):
self._impls = []
self._name = None
def register_impl(self, dist_impl):
self._impls.append(dist_impl)
def get_impl(self, impl_idx):
return self._impls[impl_idx]
def get_impls(self):
return self._impls
class DistributedOperatorImpl:
def __init__(self):
self._name = None
def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def backward(self, dist_ctx, *grad_outputs):
raise NotImplementedError("Please Implement this method in Subclass.")
def get_name(self):
return self._name
def is_process_mesh_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op):
global DISTRIBUTED_OPERATORS
DISTRIBUTED_OPERATORS[name] = dist_op
def get_distributed_operator(name):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS.get(name, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name)
if dist_op is not None:
dist_op.register_impl(dist_impl)
else:
assert False, "Must register distributed operator first."
def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op = get_distributed_operator(name)
if dist_op is None:
return None, -1
compatible_impls = []
impls = dist_op.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_input_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
else:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_output_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
if compatible_impls:
best_compatible_impl, idx = compatible_impls[0]
else:
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedEmbedding(DistributedOperator):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
register_distributed_operator("lookup_table_v2",
DistributedEmbedding("embedding"))
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(ids_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[ids_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[w_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
def _update_dims_mapping_for_matmul(op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1)
if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1)
# Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2:
broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
x_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - y_dims_mapping_len)
if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
y_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(out_dims_mapping_len - 2):
if out_dims_mapping[i] != compatible_dims_mapping[i]:
out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, y_dims_mapping], [-1, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [-2, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[y_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
# Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor
if x_dims_mapping_len == 1:
x_dims_mapping.pop(0)
if y_dims_mapping_len == 1:
y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len
return changed
class DistributedMatmul(DistributedOperator):
def __init__(self, name):
super(DistributedMatmul, self).__init__()
self._name = name
register_distributed_operator("matmul", DistributedMatmul("matmul"))
# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("matmul",
DistributedMatmulImpl0("column_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl1("row_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel"))
class DistributedMatmulV2(DistributedOperator):
def __init__(self, name):
super(DistributedMatmulV2, self).__init__()
self._name = name
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2"))
# ReplicateParallel
class DistributedMatmulV2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl("replicate_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedReshape2(DistributedOperator):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back"))
register_distributed_operator_impl(
"reshape2", DistributedReshapeImpl1("remove_one_dim_back"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[axis]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"softmax", DistributedSoftmaxImpl("replicate_last_axis"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
return True
def is_output_compatible(self, op_dist_attr):
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
perm = op_desc.attr('axis')
assert len(x_dims_mapping) == len(perm)
new_dims_mapping = [-1 for i in range(len(x_dims_mapping))]
for i in range(len(x_dims_mapping)):
new_dims_mapping[i] = x_dims_mapping[perm[i]]
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[new_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
if x_dims_mapping[perm[i]] != new_dims_mapping[i]:
x_dims_mapping[perm[i]] = new_dims_mapping[i]
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl(
"transpose2", DistributedTranspose2Impl("same_mapping_transpose"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import threading
import paddle.fluid.core as core
def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list):
return True
else:
return False
def is_dim_shard(mapping):
if mapping != -1:
return True
else:
return False
def is_dim_replicate(mapping):
if mapping == -1:
return True
else:
return False
def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings:
return None
compatible_mapping = dim_mappings[0]
for mapping in dim_mappings:
if compatible_mapping == -1:
compatible_mapping = mapping
elif mapping == -1:
continue
elif compatible_mapping == mapping:
continue
else:
return None
return compatible_mapping
def compute_compatible_dims_mapping(dims_mapping_list):
if not dims_mapping_list:
return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
assert dims_mapping is not None, \
"Dims mapping must not be None for compatible computation"
assert len(dims_mapping) == length, \
"The length of dims_mapping in list must be same for compatible computation."
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
return compatible_result
def compute_compatible_process_mesh(process_mesh_list):
compatible_process_mesh = None
if not process_mesh_list:
return compatible_process_mesh
for process_mesh in process_mesh_list:
if process_mesh is not None:
if compatible_process_mesh is None:
compatible_process_mesh = process_mesh
else:
assert process_mesh == compatible_process_mesh, \
"There is no compatible process mesh."
return compatible_process_mesh
def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list):
assert len(dims_mapping_list) == len(index_list)
changed = False
dim_mappings = []
for i in range(len(dims_mapping_list)):
assert is_valid_list_index(dims_mapping_list[i], index_list[i])
dim_mappings.append(dims_mapping_list[i][index_list[i]])
compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings)
if compatible_dim_mapping is None:
return False
for i in range(len(dims_mapping_list)):
if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]:
dims_mapping_list[i][index_list[i]] = compatible_dim_mapping
changed = True
return changed
def append_distributed_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def remove_distributed_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def check_distributed_attr_for_program(program, dist_context=None):
from .context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor)
if (tensor_dist_attr is not None) and (
not tensor_dist_attr.is_valid()):
return False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if (op_dist_attr is not None) and (not op_dist_attr.is_valid()):
return False
return True
def print_program_with_distributed_attr(program, dist_context=None):
"""
This function reuses the original program output ability with a distributed context.
Using lock can avoid multiple threads change the default distributed context simultaneously.
"""
lock = threading.Lock()
lock.acquire()
from .context import get_default_distributed_context
from .context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
else:
original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context)
print(program)
set_default_distributed_context(original_default_context)
lock.release()
...@@ -1224,6 +1224,14 @@ class Variable(object): ...@@ -1224,6 +1224,14 @@ class Variable(object):
if self.persistable: if self.persistable:
var_str = "persist " + var_str var_str = "persist " + var_str
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
self)
if var_dist_attr is not None:
var_str += ", {name} = {value}".format(
name="dist_attr", value=var_dist_attr)
return var_str return var_str
def to_string(self, throw_on_error, with_details=False): def to_string(self, throw_on_error, with_details=False):
...@@ -2384,6 +2392,13 @@ class Operator(object): ...@@ -2384,6 +2392,13 @@ class Operator(object):
if i != len(attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
op_dist_attr = dist_context.get_op_distributed_attr_for_program(self)
if op_dist_attr is not None:
attrs_str += ", {name} = {value}".format(
name="dist_attr", value=op_dist_attr)
if outputs_str != "{}": if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
format(outputs=outputs_str, op_type=self.type, format(outputs=outputs_str, op_type=self.type,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import unittest.mock
from io import StringIO
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
paddle.enable_static()
_global_parallel_stratergy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
out = mlp(input)
return train_program, start_program
class TestMLPAutoCompletion(unittest.TestCase):
def test_mlp_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_misc(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_context.finalize_distributed_attr_for_program(
complete_train_program)
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
for block in complete_train_program.blocks:
for tensor in block.vars.values():
desc = tensor.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
attr_name = append_distributed_attr_suffix("dim_mapping")
self.assertIsNotNone(desc.has_attr(attr_name))
for op in block.ops:
desc = op.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
set_default_distributed_context(dist_context)
self.assertTrue("dist_attr" in str(complete_train_program))
with unittest.mock.patch(
"sys.stdout", new_callable=StringIO) as mock_stdout:
print_program_with_distributed_attr(complete_train_program)
self.assertIsNotNone(mock_stdout.getvalue())
class AttentionLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
sequence_len=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(AttentionLayer, self).__init__()
self.hidden_size = hidden_size
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(input)
v = self.v_proj(input)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
return out
def attn_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="query",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
attn = AttentionLayer(
hidden_size=hidden_size,
sequence_len=sequence_len,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = attn(input)
return train_program, start_program
class TestAttentionAutoCompletion(unittest.TestCase):
def test_attn_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
class DecoderLayer(nn.Layer):
def __init__(self,
vocab_size=32768,
hidden_size=1024,
sequence_len=512,
max_position_embeddings=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(DecoderLayer, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.word_embeddings = nn.Embedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
self.position_embeddings = nn.Embedding(
self.max_position_embeddings,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
intermediate_size = 4 * self.hidden_size
d_model = self.hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(self.dropout_ratio)
self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
# Pre-norm
target = self.norm(embeddings)
# The following is the attention part
q = self.q_proj(target)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(target)
v = self.v_proj(target)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
# Add residual
residual = embeddings + self.dropout2(out)
# Pre-norm
out0 = self.norm(residual)
# The following is the MLP part
out1 = self.linear0(out0)
out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
# Add residual
final = residual + self.dropout3(out3)
return final
def decoder_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
decoder = DecoderLayer(
vocab_size=32768,
hidden_size=hidden_size,
sequence_len=sequence_len,
max_position_embeddings=512,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = decoder(input_ids, position_ids)
return train_program, start_program
class TestDecoderLayerAutoCompletion(unittest.TestCase):
def test_decoder_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import collections
import math
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
import paddle.utils as utils
from paddle.fluid import layers
from paddle.fluid.framework import in_dygraph_mode
from paddle.nn.layer.transformer import _convert_param_attr_to_list
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.context import DistributedContext
paddle.enable_static()
_global_parallel_stratergy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces.
"""
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
kdim=None,
vdim=None,
need_weights=False,
weight_attr=None,
bias_attr=None,
topo=None,
fuse=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if topo is None or topo.mp_info.size == 1:
if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
return q, k, v
def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
r"""
Prapares linear projected queries, keys and values for usage of subsequnt
multiple parallel attention. If `cache` is not None, using cached results
to reduce redundant calculations.
"""
q = self.q_proj(query)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache):
# for encoder-decoder attention in inference and has cached
k, v = cache.k, cache.v
else:
k, v = self.compute_kv(key, value)
if isinstance(cache, self.Cache):
# for decoder self-attention in inference
k = tensor.concat([cache.k, k], axis=2)
v = tensor.concat([cache.v, v], axis=2)
if use_cache is True:
cache = self.Cache(k, v)
return (q, k, v) if use_cache is False else (q, k, v, cache)
def compute_kv(self, key, value):
r"""
Applies linear projection on input keys and values, then splits heads
(reshape and transpose) to get keys and values from different representation
subspaces. The results are used as key-values pairs for subsequent multiple
parallel attention.
It is part of calculations in multi-head attention, and is provided as
a method to pre-compute and prefetch these results, thus we can use them
to construct cache for inference.
"""
k = self.k_proj(key)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
v = self.v_proj(value)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
return k, v
def gen_cache(self, key, value=None, type=Cache):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiHeadAttention.Cache` or an
instance of `MultiHeadAttention.StaticCache`.
"""
if type == MultiHeadAttention.StaticCache: # static_kv
k, v = self.compute_kv(key, value)
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
return self.Cache(k, v)
else:
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
def forward(self,
query,
key,
value,
attn_mask=None,
use_cache=False,
cache=None):
r"""
Applies multi-head attention to map queries and a set of key-value pairs
to outputs.
"""
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if use_cache is False:
if self.fuse:
q, k, v = self._fuse_prepare_qkv(query)
else:
q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
outs = [out]
if self.need_weights:
outs.append(weights)
if use_cache:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)
class TransformerDecoder(nn.Layer):
"""
TransformerDecoder is a stack of N decoder layers.
"""
def __init__(self,
decoder_layers,
num_layers,
norm=None,
hidden_size=None,
topo=None):
super(TransformerDecoder, self).__init__()
self.topo = topo
self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
if norm is "LayerNorm":
self.norm = nn.LayerNorm(hidden_size)
elif norm is not None:
raise ValueError("Only support LayerNorm")
self.checkpoints = []
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
use_cache=False,
cache=None):
r"""
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last decoder
layer.
"""
output = tgt
new_caches = []
self.checkpoints = []
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
new_caches.append(new_cache)
else:
output = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
else:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i])
new_caches.append(new_cache)
self.checkpoints.append(output.name)
if self.norm is not None:
output = self.norm(output)
return output if use_cache is False else (output, new_caches)
def gen_cache(self, memory, do_zip=False):
r"""
Generates cache for `forward` usage. The generated cache is a list, and
each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
for more details. If `do_zip` is True, apply `zip` on these tuples to get
a list with two elements.
"""
cache = [layer.gen_cache(memory) for layer in self.layers]
if do_zip:
cache = list(zip(*cache))
return cache
class TransformerDecoderLayer(nn.Layer):
"""
The transformer decoder layer.
It contains multiheadattention and some linear layers.
"""
def __init__(self,
d_model,
nhead,
dim_feedforward,
dropout=0.1,
activation="gelu",
attn_dropout=None,
act_dropout=None,
normalize_before=True,
weight_attr=None,
bias_attr=None,
topo=None):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
super(TransformerDecoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
topo=topo)
if topo is None or topo.mp_info.size == 1:
self.linear1 = nn.Linear(
d_model,
dim_feedforward,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
# tgt = self.dropout2(
# self.linear2(F.gelu(
# self.linear1(tgt), approximate=True)))
tgt = self.linear1(tgt)
tgt = F.gelu(tgt, approximate=True)
tgt = self.dropout2(self.linear2(tgt))
tgt = residual + tgt
if not self.normalize_before:
tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache(
memory, type=self.self_attn.Cache)
return incremental_cache
class GPTEmbeddings(nn.Layer):
"""
Include embeddings from word, position and token_type embeddings
"""
def __init__(self,
vocab_size,
hidden_size=768,
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
topo=None):
super(GPTEmbeddings, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.word_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding(
max_position_embeddings,
hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, position_ids=None):
if position_ids is None:
ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=-1)
position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class GPTModel(nn.Layer):
"""
The base model of gpt.
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None):
super(GPTModel, self).__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.topo = topo
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.pipline_mode = topo is not None and topo.pp_info.size > 1
if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size
self.embeddings = GPTEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range,
topo)
decoder_layers = nn.LayerList()
for i in range(num_hidden_layers):
DecoderLayer = TransformerDecoderLayer
decoder_layers.append(
DecoderLayer(
d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=hidden_dropout_prob,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
topo=topo))
Decoder = TransformerDecoder
self.decoder = Decoder(
decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size,
topo=topo)
self.checkpoints = []
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
use_cache=False,
cache=None):
self.checkpoints = []
if attention_mask is None:
length = paddle.shape(input_ids)[1]
# Use bool mask
attention_mask = paddle.tensor.tril(
paddle.ones(
(length, length),
dtype=self.embeddings.word_embeddings.weight.dtype))
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(
past_length,
paddle.shape(input_ids)[-1] + past_length,
dtype='int64')
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1],
paddle.shape(input_ids)[-1])) * -1e9,
diagonal=1)
if attention_mask is not None:
attention_mask = attention_mask + causal_mask
else:
attention_mask = causal_mask
# The tensor returned by triu not in static graph.
attention_mask.stop_gradient = True
encoder_outputs = self.decoder(
embedding_output,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
cache=cache)
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs
class GPTForPretraining(nn.Layer):
"""
The pretraining model of GPT.
It returns some logits and cached_kvs.
"""
def __init__(self, gpt):
super(GPTForPretraining, self).__init__()
self.gpt = gpt
self.share_param = False
self.weight = self.gpt.embeddings.word_embeddings.weight
if not self.share_param:
self.weight = self.create_parameter(shape=self.weight.shape)
def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo):
if topo is not None and topo.mp_info.size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=None)
logits = paddle.matmul(
input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(logits, group=None)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
masked_positions=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
else:
encoder_outputs = outputs
logits = self.parallel_matmul(encoder_outputs, self.weight, True,
self.gpt.topo)
if use_cache:
return logits, cached_kvs
else:
return logits
class GPTPretrainingCriterion(nn.Layer):
"""
Criterion for GPT.
It calculates the final loss.
"""
def __init__(self, topo=None):
super(GPTPretrainingCriterion, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
else:
self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy
def forward(self, prediction_scores, masked_lm_labels, loss_mask):
masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2))
loss_mask = loss_mask.reshape([-1])
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
loss = masked_lm_loss / loss_mask.sum()
return loss
def gpt_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 16
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float64')
labels = static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float64')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
gpt = GPTModel(
vocab_size=32768,
hidden_size=1024,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None)
model = GPTForPretraining(gpt)
preds = model(input_ids, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
return train_program, start_program
class TestGPTAutoCompletion(unittest.TestCase):
def test_gpt_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__":
unittest.main()
...@@ -165,6 +165,7 @@ packages=['paddle', ...@@ -165,6 +165,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.pp_utils', 'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.framework', 'paddle.framework',
'paddle.jit', 'paddle.jit',
'paddle.jit.dy2static', 'paddle.jit.dy2static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册