未验证 提交 93d862b0 编写于 作者: Y Yulong Ao 提交者: GitHub

Add auto completion module for auto parallel (#34813)

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* add dist

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update

* update

* delete unused proto

* resotre op_desc

* restore type_defs

* update var_desc

* remove dimss_mapping for proto_pybind

* update interface.py

* update framework.py

* update

* update

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* [WIP] Add the auto completion feature and related codes

* [WIP] Improve the auto completion and related codes

* [WIP] Make the auto completion to support data-parallel

* [WIP] Make the completion support mp and dp+mp

* [WIP] Refactor auto completion unit test for MLP

* [WIP] Refactor the implementation of DistributedOperatorImpl

* [WIP] Improve dims_mapping update rule and fix a bug

* [WIP] Support auto completion for one transformer decoder layer

* [WIP] Add a minor change

* [WIP] Fix a bug within the uint test

* Shard XShape tensor, add embedding completion and refactor code

* Add the distributed_operators dir to setup.py.in

* Improve the completion process and add the unittest for gpt

* fix process_mesh ut

* fix process_mesh ut

* update

* update, test=develop

* Add support for automatically completing distributed attrs of special ops

* update

* update

* update

* fix doc sample codes, test=develop

* improve coverage, test=develop

* add static_mode check, test=develop

* Model the cluster for cost model and physical mapping

* update, test=develop

* add set_placement, test=develop

* Add the check to make sure the candidate tensors' size is great than zero

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update, test=develop

* Auto mark dist attrs annotated by user

* update ndarray to nested list, test=develop

* update, test=develop

* Add auto-completion module for auto-parallel (based on PR#33804)

* Remove unnecessary files

* Remove unrelated files for the auto completion pr

* Update the unit test to improve the coverage

* Modify codes based on reviews

* Minor changes for CI

* Improve some codes based on new comments

* Fix bugs caused by shallow copy in attributes.py
* Imporve amend_distributed_attr_for_program in context.py
* Other changes for weihang's comments
Co-authored-by: Nsandyhouse <lilong12@baidu.com>
上级 e8f146a9
......@@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_;
need_update_ = true;
// When creating graph from program, the creation of op node will create a new
// OpDesc instead of
// referring to the original one. To find the original OpDesc of the op node,
// the id have to be
// copied to the new OpDesc. The var node has the same situation, but the
// default copy constructor
// can copy the id automatically.
id_ = op_desc.id_;
}
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <atomic>
#include <string>
#include <unordered_map>
#include <utility>
......@@ -151,6 +152,18 @@ class OpDesc {
const BlockDesc *Block() const { return this->block_; }
// This thread-safe implementation seems to be redudent since the neural
// networks
// are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> id{0};
return ++id;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......@@ -173,6 +186,8 @@ class OpDesc {
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false};
uint64_t id_ = GenerateId();
};
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <atomic>
#include <string>
#include <vector>
......@@ -150,6 +151,17 @@ class VarDesc {
Attribute GetAttr(const std::string &name) const;
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private:
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
......@@ -158,6 +170,7 @@ class VarDesc {
proto::VarDesc desc_;
AttributeMap attrs_;
uint64_t id_ = GenerateId();
};
bool operator==(const VarDesc &left, const VarDesc &right);
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle {
......@@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) {
.def("attr_names", &pd::VarDesc::AttrNames)
.def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("id", &pd::VarDesc::Id)
.def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
......@@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) {
.def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference)
.def("id", &pd::OpDesc::Id)
.def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs);
}
......
......@@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401
from . import cloud_utils # noqa: F401
from . import utils # noqa: F401
__all__ = [ #noqa
__all__ = [ # noqa
"spawn",
"scatter",
"broadcast",
......
......@@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
__all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from paddle.fluid import core
from paddle.fluid import framework
from .utils import compute_compatible_process_mesh
from .utils import compute_compatible_dim_mapping
from .utils import compute_compatible_dims_mapping
from .utils import print_program_with_distributed_attr
from .context import get_default_distributed_context
from .operators import find_best_compatible_distributed_operator_impl
ELEMENTWISE_LIKE_OP_LIST = ["elementwise_add", "gelu", "dropout", "cast"]
def is_elementwise_like_op(op_type):
if op_type in ELEMENTWISE_LIKE_OP_LIST:
return True
else:
return False
def update_tensor_node_process_mesh(dist_context, tensor_node, fwd=True):
"""
Update tensor's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
if tensor_dist_attr.is_annotated("process_mesh"):
return changed
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
if fwd:
inputs_process_meshes = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
inputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
else:
outputs_process_meshes = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node)
op_process_mesh = op_dist_attr.get_process_mesh()
outputs_process_meshes.append(op_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and tensor_process_mesh is None:
tensor_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
return changed
def update_op_node_process_mesh(dist_context, op_node, fwd=True):
"""
Update op's process mesh by using its predecessor's process mesh if in the forward direction,
and by using its successor's process mesh if in the backward direction. Note: only the equal
process meshes are compatible for now.
"""
changed = False
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if op_dist_attr.is_annotated("process_mesh"):
return changed
op_process_mesh = op_dist_attr.get_process_mesh()
if fwd:
inputs_process_meshes = []
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
inputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
inputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
else:
outputs_process_meshes = []
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_process_mesh = tensor_dist_attr.get_process_mesh()
outputs_process_meshes.append(tensor_process_mesh)
compatible_process_mesh = compute_compatible_process_mesh(
outputs_process_meshes)
if compatible_process_mesh is not None and op_process_mesh is None:
op_dist_attr.set_process_mesh(compatible_process_mesh)
changed = True
return changed
def update_op_dims_mapping_by_default_dist_impl(op_dist_attr):
"""Each operator has a default distributed operator, only allowed to be sharded in batch dimension."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
# The following statement will be replaced by a more elegent way
if op_desc.type() == "shape" or op_desc.type() == "slice":
return False
output_names = op_desc.output_names()
xshape_arg_names = []
if "XShape" in output_names:
xshape_arg_names = op_desc.output("XShape")
batch_dim_mappings = []
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if len(dims_mapping) > 1:
for idx, mapping in enumerate(dims_mapping[1:]):
assert mapping == -1, \
"{} only the batch dimension (0-dim) can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[0])
else:
assert dims_mapping[0] == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension 0 is sharded by {} part."\
.format(op_desc.type(), mapping)
if len(dims_mapping) > 2:
for idx, mapping in enumerate(dims_mapping[2:]):
assert mapping == -1, \
"{} only the batch dimension (1-dim) of XShape can be sharded, but the dimension {} is sharded by {} part."\
.format(op_desc.type(), idx, mapping)
batch_dim_mappings.append(dims_mapping[1])
compatible_dim_mapping = compute_compatible_dim_mapping(batch_dim_mappings)
assert compatible_dim_mapping is not None, "There is no compatible dim mapping."
for arg_name in op_desc.input_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
for arg_name in op_desc.output_arg_names():
if op_dist_attr.is_parameter(arg_name):
continue
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if arg_name not in xshape_arg_names:
if compatible_dim_mapping != dims_mapping[0]:
dims_mapping[0] = compatible_dim_mapping
changed = True
else:
if compatible_dim_mapping != dims_mapping[1]:
dims_mapping[1] = compatible_dim_mapping
changed = True
return changed
def update_op_dims_mapping_by_elementwise_like_dist_impl(op_dist_attr):
"""Element-wise operator can be sharded in any way (but should take care of broadcasting)."""
changed = False
op_desc = op_dist_attr.get_owner_op().desc
input_arg_names = op_desc.input_arg_names()
input_dims_mapping_dict = {}
input_dims_mapping_lens = {}
max_dims_mapping_len = -1
for arg_name in input_arg_names:
dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name)
if max_dims_mapping_len < len(dims_mapping):
max_dims_mapping_len = len(dims_mapping)
input_dims_mapping_dict[arg_name] = dims_mapping
input_dims_mapping_lens[arg_name] = len(dims_mapping)
dims_mapping_list = []
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [-1 for _ in range(max_dims_mapping_len)]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[new_idx] = input_dims_mapping_dict[arg_name][i]
dims_mapping_list.append(new_dims_mapping)
else:
dims_mapping_list.append(input_dims_mapping_dict[arg_name])
output_arg_names = op_desc.output_arg_names()
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
assert len(dims_mapping) == max_dims_mapping_len
dims_mapping_list.append(dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(dims_mapping_list)
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for arg_name in input_arg_names:
if input_dims_mapping_lens[arg_name] < max_dims_mapping_len:
new_dims_mapping = [
-1 for _ in range(input_dims_mapping_lens[arg_name])
]
for i in range(input_dims_mapping_lens[arg_name]):
new_idx = (max_dims_mapping_len -
input_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)
changed = True
else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
for arg_name in output_arg_names:
dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name)
if compatible_dims_mapping != dims_mapping:
op_dist_attr.set_output_dims_mapping(arg_name,
compatible_dims_mapping)
changed = True
return changed
def update_tensor_node_dims_mapping(dist_context, tensor_node, fwd=True):
changed = False
if (not tensor_node.is_var()) or (tensor_node.var() is None):
return False
tensor_desc = tensor_node.var()
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
assert tensor_dist_attr is not None
if tensor_dist_attr.is_annotated("dims_mapping"):
return False
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
if fwd:
dims_mapping_list = []
for pred_op_node in tensor_node.inputs:
if pred_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
pred_op_node)
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
changed = True
else:
dims_mapping_list = []
for succ_op_node in tensor_node.outputs:
if succ_op_node.op() is not None:
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(
succ_op_node)
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
dims_mapping_list.append(op_dims_mapping)
dims_mapping_list.append(tensor_dims_mapping)
compatible_dims_mapping = compute_compatible_dims_mapping(
dims_mapping_list)
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != tensor_dims_mapping):
tensor_dist_attr.set_dims_mapping(compatible_dims_mapping)
changed = True
return changed
def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
changed = False
if (not op_node.is_op()) or (op_node.op() is None):
return False
op_desc = op_node.op()
op_dist_attr = dist_context.get_op_distributed_attr_for_graph(op_node)
if fwd:
for tensor_node in op_node.inputs:
if tensor_node.var() is not None:
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_input_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_input_dims_mapping(tensor_desc.name(),
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
tensor_desc = tensor_node.var()
if op_dist_attr.is_annotated_output_dims_mapping(
tensor_desc.name()):
continue
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_graph(
tensor_node)
tensor_dims_mapping = tensor_dist_attr.get_dims_mapping()
op_dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_desc.name())
compatible_dims_mapping = compute_compatible_dims_mapping(
[op_dims_mapping, tensor_dims_mapping])
if (compatible_dims_mapping is not None) and \
(compatible_dims_mapping != op_dims_mapping):
op_dist_attr.set_output_dims_mapping(
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), op_dist_attr, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(op_dist_attr)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(op_dist_attr):
op_dist_attr.set_impl_idx(op_dist_impl_idx)
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-1)
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
op_dist_attr)
if dim_changed:
changed = True
op_dist_attr.set_impl_idx(-2)
return changed
def complete_annotation(program, dist_context=None):
""" Complete annotation for the partial annotated program.
Arguments:
program: partial annotated program.
dist_context: the distributed context is used to store distributed attributes for program.
If not provided, the default one will be used.
Returns:
program: completed annotated program.
"""
# Use the default distribted context for completeion if there is no one
if dist_context is None:
dist_context = get_default_distributed_context()
# Initialize distributed attributes for all var and op node in program
dist_context.initialize_distributed_attr_for_program(program)
# print_program_with_distributed_attr(program, dist_context)
# Convert program to graph
graph = framework.IrGraph(core.Graph(program.desc))
# Initialize distributed attributes for all var and op node in graph
dist_context.initialize_distributed_attr_for_graph(graph)
# # Complete process mesh for each node
all_nodes = list(graph.all_nodes())
reach_fix_point = False
while not reach_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_process_mesh(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_process_mesh(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
# Complete dims_mapping for each node
reach_fix_point = False
while not reach_fix_point:
changed = False
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=True)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=True)
if op_changed:
changed = True
for node in reversed(all_nodes):
if node.is_var() and node.var() is not None:
tensor_changed = update_tensor_node_dims_mapping(
dist_context, node, fwd=False)
if tensor_changed:
changed = True
if node.is_op() and node.op() is not None:
op_changed = update_op_node_dims_mapping(
dist_context, node, fwd=False)
if op_changed:
changed = True
if changed:
reach_fix_point = False
else:
reach_fix_point = True
# Copy the corresponding distributed attribute from graph to program
dist_context.copy_distribute_attr_from_graph_to_program(graph, program)
dist_context.clear_distributed_attr_for_graph()
# Do the validation check and amend some completion
dist_context.amend_distributed_attr_for_program()
return program
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
def get_default_distributed_context():
global DEFAULT_DISTRIBUTED_CONTEXT
if DEFAULT_DISTRIBUTED_CONTEXT is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return DEFAULT_DISTRIBUTED_CONTEXT
def set_default_distributed_context(dist_context):
global DEFAULT_DISTRIBUTED_CONTEXT
DEFAULT_DISTRIBUTED_CONTEXT = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self):
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._tensor_distributed_attr_map_for_program = {}
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
def is_initialized_for_program(self):
return self._is_initialized_for_program
def is_initialized_for_graph(self):
return self._is_initialized_for_graph
def get_tensor_distributed_attr_for_program(self, tensor):
tensor_id = tensor.desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get(
tensor_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr):
tensor_id = tensor.desc.id()
self._tensor_distributed_attr_map_for_program[
tensor_id] = tensor_dist_attr
def get_op_distributed_attr_for_program(self, op):
op_id = op.desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_program(self, op, op_dist_attr):
op_id = op.desc.id()
self._op_distributed_attr_map_for_program[op_id] = op_dist_attr
def get_tensor_distributed_attr_for_graph(self, tensor_node):
tensor_node_id = tensor_node.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get(
tensor_node_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_graph(self, tensor_node,
tensor_dist_attr):
tensor_node_id = tensor_node.id()
self._tensor_distributed_attr_map_for_graph[
tensor_node_id] = tensor_dist_attr
def get_op_distributed_attr_for_graph(self, op_node):
op_node_id = op_node.id()
op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
for block in program.blocks:
for tensor in block.vars.values():
# Since only tensors have distributed attributes, it's better to make sure var is a tensor
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is None:
tensor_dist_attr = TensorDistributedAttribute(tensor, self)
self._copy_distributed_attr_from_tensor_desc(
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
tensor_dist_attr.mark_as_annotated("dims_mapping")
if isinstance(tensor, framework.Parameter):
tensor_dist_attr.mark_as_parameter()
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
op_dist_attr = OperatorDistributedAttribute(op, self)
self._copy_distributed_attr_from_op_desc(op.desc,
op_dist_attr)
self.set_op_distributed_attr_for_program(op, op_dist_attr)
# Default distributed implementation for all operators
# This will be updated during the completion prcess
op_dist_attr.set_impl_idx(-2)
if op_dist_attr.get_process_mesh() is not None:
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_input_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_output_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
self._is_initialized_for_program = True
def finalize_distributed_attr_for_program(self, program):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before finalization."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is not None:
self._store_distributed_attr_to_tensor_desc(
tensor.desc, tensor_dist_attr)
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
self._store_distributed_attr_to_op_desc(op.desc,
op_dist_attr)
def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
attr_name = append_distributed_attr_suffix("dim_mapping")
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_dims_mapping(copied_dims_mapping)
attr_name = append_distributed_attr_suffix("mask")
if desc.has_attr(attr_name):
shard_mask = desc.attr(attr_name)
copied_shard_mask = copy.deepcopy(shard_mask)
dist_attr.set_shard_mask(copied_shard_mask)
attr_name = append_distributed_attr_suffix("offload_device")
if desc.has_attr(attr_name):
offload_device = desc.attr(attr_name)
copied_offload_device = copy.deepcopy(offload_device)
dist_attr.set_offload_device(copied_offload_device)
def _copy_distributed_attr_from_op_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
attr_name = append_distributed_attr_suffix("pipeline_stage")
if desc.has_attr(attr_name):
pipeline_stage = desc.attr(attr_name)
copied_pipeline_stage = copy.deepcopy(pipeline_stage)
dist_attr.set_pipeline_stage(copied_pipeline_stage)
def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
dims_mapping = dist_attr.get_dims_mapping()
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("dim_mapping")
desc._set_attr(attr_name, dims_mapping)
shard_mask = dist_attr.get_shard_mask()
if shard_mask is not None:
attr_name = append_distributed_attr_suffix("mask")
desc._set_attr(attr_name, shard_mask)
offload_device = dist_attr.get_offload_device()
if offload_device is not None:
attr_name = append_distributed_attr_suffix("offload_device")
desc._set_attr(attr_name, offload_device)
def _store_distributed_attr_to_op_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
for tensor_name in desc.input_arg_names():
dims_mapping = dist_attr.get_input_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
for tensor_name in desc.output_arg_names():
dims_mapping = dist_attr.get_output_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
pipeline_stage = dist_attr.get_pipeline_stage()
if pipeline_stage is not None:
attr_name = append_distributed_attr_suffix("pipeline_stage")
desc._set_attr(attr_name, pipeline_stage)
def initialize_distributed_attr_for_graph(self, graph):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before its graph."
if self._is_initialized_for_graph:
return
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program[
tensor_id]
assert tensor_dist_attr is not None, \
"Tensor must have a distributed attribute after the initialization for program."
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self.set_tensor_distributed_attr_for_graph(node,
new_tensor_dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program[op_id]
assert op_dist_attr is not None, \
"Operator must have a distributed attribute after the initialization for program."
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self.set_op_distributed_attr_for_graph(node, new_op_dist_attr)
self._is_initialized_for_graph = True
def clear_distributed_attr_for_program(self):
self._tensor_distributed_attr_map_for_program.clear()
self._op_distributed_attr_map_for_program.clear()
def clear_distributed_attr_for_graph(self):
self._tensor_distributed_attr_map_for_graph.clear()
self._op_distributed_attr_map_for_graph.clear()
def copy_distribute_attr_from_graph_to_program(self, graph, program):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"The distribute attributes must be initialized both in its program and graph"
updated_tensors = {}
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr = self.get_tensor_distributed_attr_for_graph(
node)
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self._tensor_distributed_attr_map_for_program[
tensor_id] = new_tensor_dist_attr
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self.get_op_distributed_attr_for_graph(node)
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self._op_distributed_attr_map_for_program[
op_id] = new_op_dist_attr
def amend_distributed_attr_for_program(self):
for attr in self._tensor_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Tensor's distributed attribute {} is not valid".format(attr)
tensor_shape = attr.get_shape()
dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[
i]] > tensor_shape[i]:
dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Operator's distributed attribute {} is not valid".format(attr)
for arg_name in attr.get_owner_op().desc.input_arg_names():
tensor_shape = attr.get_input_shape(arg_name)
dims_mapping = attr.get_input_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names():
tensor_shape = attr.get_output_shape(arg_name)
dims_mapping = attr.get_output_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
......@@ -13,8 +13,9 @@
# limitations under the License.
import numpy
import paddle.fluid.core as core
import copy
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode
......@@ -237,6 +238,23 @@ class ProcessMesh(object):
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert len(tensor.shape) == len(dim_mapping)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
DISTRIBUTED_OPERATORS = {}
class DistributedOperator:
def __init__(self):
self._impls = []
self._name = None
def register_impl(self, dist_impl):
self._impls.append(dist_impl)
def get_impl(self, impl_idx):
return self._impls[impl_idx]
def get_impls(self):
return self._impls
class DistributedOperatorImpl:
def __init__(self):
self._name = None
def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def backward(self, dist_ctx, *grad_outputs):
raise NotImplementedError("Please Implement this method in Subclass.")
def get_name(self):
return self._name
def is_process_mesh_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op):
global DISTRIBUTED_OPERATORS
DISTRIBUTED_OPERATORS[name] = dist_op
def get_distributed_operator(name):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS.get(name, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name)
if dist_op is not None:
dist_op.register_impl(dist_impl)
else:
assert False, "Must register distributed operator first."
def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op = get_distributed_operator(name)
if dist_op is None:
return None, -1
compatible_impls = []
impls = dist_op.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_input_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
else:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_output_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
if compatible_impls:
best_compatible_impl, idx = compatible_impls[0]
else:
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedEmbedding(DistributedOperator):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
register_distributed_operator("lookup_table_v2",
DistributedEmbedding("embedding"))
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(ids_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[ids_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[w_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
def _update_dims_mapping_for_matmul(op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1)
if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1)
# Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2:
broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
x_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - y_dims_mapping_len)
if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
y_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(out_dims_mapping_len - 2):
if out_dims_mapping[i] != compatible_dims_mapping[i]:
out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, y_dims_mapping], [-1, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [-2, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[y_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
# Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor
if x_dims_mapping_len == 1:
x_dims_mapping.pop(0)
if y_dims_mapping_len == 1:
y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len
return changed
class DistributedMatmul(DistributedOperator):
def __init__(self, name):
super(DistributedMatmul, self).__init__()
self._name = name
register_distributed_operator("matmul", DistributedMatmul("matmul"))
# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("matmul",
DistributedMatmulImpl0("column_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl1("row_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel"))
class DistributedMatmulV2(DistributedOperator):
def __init__(self, name):
super(DistributedMatmulV2, self).__init__()
self._name = name
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2"))
# ReplicateParallel
class DistributedMatmulV2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl("replicate_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedReshape2(DistributedOperator):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back"))
register_distributed_operator_impl(
"reshape2", DistributedReshapeImpl1("remove_one_dim_back"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[axis]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"softmax", DistributedSoftmaxImpl("replicate_last_axis"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
return True
def is_output_compatible(self, op_dist_attr):
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
perm = op_desc.attr('axis')
assert len(x_dims_mapping) == len(perm)
new_dims_mapping = [-1 for i in range(len(x_dims_mapping))]
for i in range(len(x_dims_mapping)):
new_dims_mapping[i] = x_dims_mapping[perm[i]]
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[new_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
if x_dims_mapping[perm[i]] != new_dims_mapping[i]:
x_dims_mapping[perm[i]] = new_dims_mapping[i]
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl(
"transpose2", DistributedTranspose2Impl("same_mapping_transpose"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import threading
import paddle.fluid.core as core
def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list):
return True
else:
return False
def is_dim_shard(mapping):
if mapping != -1:
return True
else:
return False
def is_dim_replicate(mapping):
if mapping == -1:
return True
else:
return False
def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings:
return None
compatible_mapping = dim_mappings[0]
for mapping in dim_mappings:
if compatible_mapping == -1:
compatible_mapping = mapping
elif mapping == -1:
continue
elif compatible_mapping == mapping:
continue
else:
return None
return compatible_mapping
def compute_compatible_dims_mapping(dims_mapping_list):
if not dims_mapping_list:
return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
assert dims_mapping is not None, \
"Dims mapping must not be None for compatible computation"
assert len(dims_mapping) == length, \
"The length of dims_mapping in list must be same for compatible computation."
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
return compatible_result
def compute_compatible_process_mesh(process_mesh_list):
compatible_process_mesh = None
if not process_mesh_list:
return compatible_process_mesh
for process_mesh in process_mesh_list:
if process_mesh is not None:
if compatible_process_mesh is None:
compatible_process_mesh = process_mesh
else:
assert process_mesh == compatible_process_mesh, \
"There is no compatible process mesh."
return compatible_process_mesh
def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list):
assert len(dims_mapping_list) == len(index_list)
changed = False
dim_mappings = []
for i in range(len(dims_mapping_list)):
assert is_valid_list_index(dims_mapping_list[i], index_list[i])
dim_mappings.append(dims_mapping_list[i][index_list[i]])
compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings)
if compatible_dim_mapping is None:
return False
for i in range(len(dims_mapping_list)):
if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]:
dims_mapping_list[i][index_list[i]] = compatible_dim_mapping
changed = True
return changed
def append_distributed_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def remove_distributed_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def check_distributed_attr_for_program(program, dist_context=None):
from .context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor)
if (tensor_dist_attr is not None) and (
not tensor_dist_attr.is_valid()):
return False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if (op_dist_attr is not None) and (not op_dist_attr.is_valid()):
return False
return True
def print_program_with_distributed_attr(program, dist_context=None):
"""
This function reuses the original program output ability with a distributed context.
Using lock can avoid multiple threads change the default distributed context simultaneously.
"""
lock = threading.Lock()
lock.acquire()
from .context import get_default_distributed_context
from .context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
else:
original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context)
print(program)
set_default_distributed_context(original_default_context)
lock.release()
......@@ -1224,6 +1224,14 @@ class Variable(object):
if self.persistable:
var_str = "persist " + var_str
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
self)
if var_dist_attr is not None:
var_str += ", {name} = {value}".format(
name="dist_attr", value=var_dist_attr)
return var_str
def to_string(self, throw_on_error, with_details=False):
......@@ -2384,6 +2392,13 @@ class Operator(object):
if i != len(attr_names) - 1:
attrs_str += ", "
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
op_dist_attr = dist_context.get_op_distributed_attr_for_program(self)
if op_dist_attr is not None:
attrs_str += ", {name} = {value}".format(
name="dist_attr", value=op_dist_attr)
if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
format(outputs=outputs_str, op_type=self.type,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import unittest.mock
from io import StringIO
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.tensor as tensor
from paddle.fluid import layers
from paddle.nn.layer.transformer import _convert_param_attr_to_list
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.utils import append_distributed_attr_suffix
from paddle.distributed.auto_parallel.context import DistributedContext
from paddle.distributed.auto_parallel.context import set_default_distributed_context
paddle.enable_static()
_global_parallel_stratergy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
out = self.dropout(out)
return out
def mlp_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
out = mlp(input)
return train_program, start_program
class TestMLPAutoCompletion(unittest.TestCase):
def test_mlp_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_mlp_misc(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = mlp_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
dist_context.finalize_distributed_attr_for_program(
complete_train_program)
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
for block in complete_train_program.blocks:
for tensor in block.vars.values():
desc = tensor.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
attr_name = append_distributed_attr_suffix("dim_mapping")
self.assertIsNotNone(desc.has_attr(attr_name))
for op in block.ops:
desc = op.desc
attr_name = append_distributed_attr_suffix("mesh_id")
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" +
tensor_name)
self.assertIsNotNone(desc.has_attr(attr_name))
set_default_distributed_context(dist_context)
self.assertTrue("dist_attr" in str(complete_train_program))
with unittest.mock.patch(
"sys.stdout", new_callable=StringIO) as mock_stdout:
print_program_with_distributed_attr(complete_train_program)
self.assertIsNotNone(mock_stdout.getvalue())
class AttentionLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
sequence_len=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(AttentionLayer, self).__init__()
self.hidden_size = hidden_size
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input, _global_process_mesh, dim_mapping=[0, -1, -1])
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(input)
v = self.v_proj(input)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
return out
def attn_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input = static.data(
name="query",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
attn = AttentionLayer(
hidden_size=hidden_size,
sequence_len=sequence_len,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = attn(input)
return train_program, start_program
class TestAttentionAutoCompletion(unittest.TestCase):
def test_attn_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_attn_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = attn_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
class DecoderLayer(nn.Layer):
def __init__(self,
vocab_size=32768,
hidden_size=1024,
sequence_len=512,
max_position_embeddings=512,
intermediate_size=4 * 1024,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02):
super(DecoderLayer, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.max_position_embeddings = max_position_embeddings
self.sequence_len = sequence_len
self.embed_dim = self.hidden_size
self.kdim = self.embed_dim
self.vdim = self.embed_dim
self.num_heads = num_heads
self.dropout_ratio = dropout_ratio
self.initializer_range = initializer_range
self.training = True
self.attn_mask = None
self.head_dim = self.embed_dim // self.num_heads
assert self.head_dim * self.num_heads == self.embed_dim, \
"embed_dim must be divisible by num_heads"
self.word_embeddings = nn.Embedding(
self.vocab_size,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
self.position_embeddings = nn.Embedding(
self.max_position_embeddings,
self.hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)))
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.q_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, self.embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
self.embed_dim, self.embed_dim, weight_attr, bias_attr=bias_attr)
intermediate_size = 4 * self.hidden_size
d_model = self.hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range))
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(self.dropout_ratio)
self.dropout2 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids):
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
# Pre-norm
target = self.norm(embeddings)
# The following is the attention part
q = self.q_proj(target)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(target)
v = self.v_proj(target)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if self.attn_mask is not None:
product = product + self.attn_mask
weights = F.softmax(product)
if self.dropout_ratio:
weights = F.dropout(
weights,
self.dropout_ratio,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
# Add residual
residual = embeddings + self.dropout2(out)
# Pre-norm
out0 = self.norm(residual)
# The following is the MLP part
out1 = self.linear0(out0)
out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 0])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear0.weight, _global_process_mesh, dim_mapping=[-1, 1])
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[1, -1])
# Add residual
final = residual + self.dropout3(out3)
return final
def decoder_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 4
hidden_size = 1024
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
decoder = DecoderLayer(
vocab_size=32768,
hidden_size=hidden_size,
sequence_len=sequence_len,
max_position_embeddings=512,
intermediate_size=4 * hidden_size,
num_heads=16,
dropout_ratio=0.1,
initializer_range=0.02)
out = decoder(input_ids, position_ids)
return train_program, start_program
class TestDecoderLayerAutoCompletion(unittest.TestCase):
def test_decoder_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_decoder_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = decoder_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import collections
import math
import unittest
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.tensor as tensor
import paddle.utils as utils
from paddle.fluid import layers
from paddle.fluid.framework import in_dygraph_mode
from paddle.nn.layer.transformer import _convert_param_attr_to_list
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.distributed.fleet import fleet
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.utils import check_distributed_attr_for_program
from paddle.distributed.auto_parallel.utils import print_program_with_distributed_attr
from paddle.distributed.auto_parallel.context import DistributedContext
paddle.enable_static()
_global_parallel_stratergy = None
_global_process_mesh = None
ROOT_MESH = auto.ProcessMesh([[0, 1, 2, 3], [4, 5, 6, 7]])
class MultiHeadAttention(nn.Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces.
"""
Cache = collections.namedtuple("Cache", ["k", "v"])
StaticCache = collections.namedtuple("StaticCache", ["k", "v"])
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
kdim=None,
vdim=None,
need_weights=False,
weight_attr=None,
bias_attr=None,
topo=None,
fuse=False):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.need_weights = need_weights
self.fuse = fuse
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if topo is None or topo.mp_info.size == 1:
if self.fuse:
assert self.kdim == embed_dim
assert self.vdim == embed_dim
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, weight_attr, bias_attr=bias_attr)
else:
self.q_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
self.k_proj = nn.Linear(
self.kdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.v_proj = nn.Linear(
self.vdim, embed_dim, weight_attr, bias_attr=bias_attr)
self.out_proj = nn.Linear(
embed_dim, embed_dim, weight_attr, bias_attr=bias_attr)
def _fuse_prepare_qkv(self, query):
mix_layer = self.qkv_proj(query)
mix_layer = paddle.reshape_(mix_layer,
[0, 0, self.num_heads, 3 * self.head_dim])
mix_layer = paddle.transpose(mix_layer, [0, 2, 1, 3])
q, k, v = paddle.split(mix_layer, num_or_sections=3, axis=-1)
return q, k, v
def _prepare_qkv(self, query, key, value, use_cache=False, cache=None):
r"""
Prapares linear projected queries, keys and values for usage of subsequnt
multiple parallel attention. If `cache` is not None, using cached results
to reduce redundant calculations.
"""
q = self.q_proj(query)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.q_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache):
# for encoder-decoder attention in inference and has cached
k, v = cache.k, cache.v
else:
k, v = self.compute_kv(key, value)
if isinstance(cache, self.Cache):
# for decoder self-attention in inference
k = tensor.concat([cache.k, k], axis=2)
v = tensor.concat([cache.v, v], axis=2)
if use_cache is True:
cache = self.Cache(k, v)
return (q, k, v) if use_cache is False else (q, k, v, cache)
def compute_kv(self, key, value):
r"""
Applies linear projection on input keys and values, then splits heads
(reshape and transpose) to get keys and values from different representation
subspaces. The results are used as key-values pairs for subsequent multiple
parallel attention.
It is part of calculations in multi-head attention, and is provided as
a method to pre-compute and prefetch these results, thus we can use them
to construct cache for inference.
"""
k = self.k_proj(key)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.k_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
v = self.v_proj(value)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.v_proj.weight, _global_process_mesh, dim_mapping=[-1, 1])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v = tensor.transpose(x=v, perm=[0, 2, 1, 3])
return k, v
def gen_cache(self, key, value=None, type=Cache):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiHeadAttention.Cache` or an
instance of `MultiHeadAttention.StaticCache`.
"""
if type == MultiHeadAttention.StaticCache: # static_kv
k, v = self.compute_kv(key, value)
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
v = layers.fill_constant_batch_size_like(
input=key,
shape=[-1, self.num_heads, 0, self.head_dim],
dtype=key.dtype,
value=0)
return self.Cache(k, v)
else:
# incremental_state with initial value, mainly for usage like UniLM
return self.Cache(key, value)
def forward(self,
query,
key,
value,
attn_mask=None,
use_cache=False,
cache=None):
r"""
Applies multi-head attention to map queries and a set of key-value pairs
to outputs.
"""
key = query if key is None else key
value = query if value is None else value
# compute q ,k ,v
if use_cache is False:
if self.fuse:
q, k, v = self._fuse_prepare_qkv(query)
else:
q, k, v = self._prepare_qkv(query, key, value, use_cache, cache)
else:
q, k, v, cache = self._prepare_qkv(query, key, value, use_cache,
cache)
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
if self.dropout:
weights = F.dropout(
weights,
self.dropout,
training=self.training,
mode="upscale_in_train")
out = tensor.matmul(weights, v)
# combine heads
out = tensor.transpose(out, perm=[0, 2, 1, 3])
out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
# project to output
out = self.out_proj(out)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.out_proj.weight, _global_process_mesh,
dim_mapping=[1, -1])
outs = [out]
if self.need_weights:
outs.append(weights)
if use_cache:
outs.append(cache)
return out if len(outs) == 1 else tuple(outs)
class TransformerDecoder(nn.Layer):
"""
TransformerDecoder is a stack of N decoder layers.
"""
def __init__(self,
decoder_layers,
num_layers,
norm=None,
hidden_size=None,
topo=None):
super(TransformerDecoder, self).__init__()
self.topo = topo
self.num_layers = num_layers
self.layers = decoder_layers
self.norm = norm
if norm is "LayerNorm":
self.norm = nn.LayerNorm(hidden_size)
elif norm is not None:
raise ValueError("Only support LayerNorm")
self.checkpoints = []
def forward(self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
use_cache=False,
cache=None):
r"""
Applies a stack of N Transformer decoder layers on inputs. If `norm` is
provided, also applies layer normalization on the output of last decoder
layer.
"""
output = tgt
new_caches = []
self.checkpoints = []
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
new_caches.append(new_cache)
else:
output = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
else:
output, new_cache = mod(output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i])
new_caches.append(new_cache)
self.checkpoints.append(output.name)
if self.norm is not None:
output = self.norm(output)
return output if use_cache is False else (output, new_caches)
def gen_cache(self, memory, do_zip=False):
r"""
Generates cache for `forward` usage. The generated cache is a list, and
each element in it is a tuple( :code:`(incremental_cache, static_cache)` )
produced by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache`
for more details. If `do_zip` is True, apply `zip` on these tuples to get
a list with two elements.
"""
cache = [layer.gen_cache(memory) for layer in self.layers]
if do_zip:
cache = list(zip(*cache))
return cache
class TransformerDecoderLayer(nn.Layer):
"""
The transformer decoder layer.
It contains multiheadattention and some linear layers.
"""
def __init__(self,
d_model,
nhead,
dim_feedforward,
dropout=0.1,
activation="gelu",
attn_dropout=None,
act_dropout=None,
normalize_before=True,
weight_attr=None,
bias_attr=None,
topo=None):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3
super(TransformerDecoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
act_dropout = dropout if act_dropout is None else act_dropout
self.normalize_before = normalize_before
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0],
topo=topo)
if topo is None or topo.mp_info.size == 1:
self.linear1 = nn.Linear(
d_model,
dim_feedforward,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.linear2 = nn.Linear(
dim_feedforward,
d_model,
weight_attrs[2],
bias_attr=bias_attrs[2])
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
self.norm2 = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout1 = nn.Dropout(dropout, mode="upscale_in_train")
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)
def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 0])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear1.weight, _global_process_mesh, dim_mapping=[-1, 1])
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.linear2.weight, _global_process_mesh, dim_mapping=[1, -1])
# tgt = self.dropout2(
# self.linear2(F.gelu(
# self.linear1(tgt), approximate=True)))
tgt = self.linear1(tgt)
tgt = F.gelu(tgt, approximate=True)
tgt = self.dropout2(self.linear2(tgt))
tgt = residual + tgt
if not self.normalize_before:
tgt = self.norm2(tgt)
return tgt if use_cache is False else (tgt, incremental_cache)
def gen_cache(self, memory):
incremental_cache = self.self_attn.gen_cache(
memory, type=self.self_attn.Cache)
return incremental_cache
class GPTEmbeddings(nn.Layer):
"""
Include embeddings from word, position and token_type embeddings
"""
def __init__(self,
vocab_size,
hidden_size=768,
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
topo=None):
super(GPTEmbeddings, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.word_embeddings = nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.ParamAttr(
name="word_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.position_embeddings = nn.Embedding(
max_position_embeddings,
hidden_size,
weight_attr=paddle.ParamAttr(
name="pos_embeddings",
initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)))
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids, position_ids=None):
if position_ids is None:
ones = paddle.ones_like(input_ids, dtype="int64")
seq_length = paddle.cumsum(ones, axis=-1)
position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_stratergy == "mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
self.word_embeddings.weight,
_global_process_mesh,
dim_mapping=[1, -1])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
return embeddings
class GPTModel(nn.Layer):
"""
The base model of gpt.
"""
def __init__(self,
vocab_size,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None):
super(GPTModel, self).__init__()
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.topo = topo
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.pipline_mode = topo is not None and topo.pp_info.size > 1
if self.pipline_mode:
self.layer_per_stage = num_hidden_layers // self.topo.pp_info.size
self.embeddings = GPTEmbeddings(
vocab_size, hidden_size, hidden_dropout_prob,
max_position_embeddings, type_vocab_size, self.initializer_range,
topo)
decoder_layers = nn.LayerList()
for i in range(num_hidden_layers):
DecoderLayer = TransformerDecoderLayer
decoder_layers.append(
DecoderLayer(
d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
dropout=hidden_dropout_prob,
activation=hidden_act,
attn_dropout=attention_probs_dropout_prob,
act_dropout=hidden_dropout_prob,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
mean=0.0, std=self.initializer_range)),
bias_attr=None,
topo=topo))
Decoder = TransformerDecoder
self.decoder = Decoder(
decoder_layers,
num_hidden_layers,
norm="LayerNorm",
hidden_size=hidden_size,
topo=topo)
self.checkpoints = []
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
use_cache=False,
cache=None):
self.checkpoints = []
if attention_mask is None:
length = paddle.shape(input_ids)[1]
# Use bool mask
attention_mask = paddle.tensor.tril(
paddle.ones(
(length, length),
dtype=self.embeddings.word_embeddings.weight.dtype))
if position_ids is None:
past_length = 0
if cache is not None:
past_length = paddle.shape(cache[0].k)[-2]
position_ids = paddle.arange(
past_length,
paddle.shape(input_ids)[-1] + past_length,
dtype='int64')
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)
# TODO, use registered buffer
causal_mask = paddle.tensor.triu(
paddle.ones((paddle.shape(input_ids)[-1],
paddle.shape(input_ids)[-1])) * -1e9,
diagonal=1)
if attention_mask is not None:
attention_mask = attention_mask + causal_mask
else:
attention_mask = causal_mask
# The tensor returned by triu not in static graph.
attention_mask.stop_gradient = True
encoder_outputs = self.decoder(
embedding_output,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
cache=cache)
self.checkpoints.extend(self.decoder.checkpoints)
return encoder_outputs
class GPTForPretraining(nn.Layer):
"""
The pretraining model of GPT.
It returns some logits and cached_kvs.
"""
def __init__(self, gpt):
super(GPTForPretraining, self).__init__()
self.gpt = gpt
self.share_param = False
self.weight = self.gpt.embeddings.word_embeddings.weight
if not self.share_param:
self.weight = self.create_parameter(shape=self.weight.shape)
def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo):
if topo is not None and topo.mp_info.size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=None)
logits = paddle.matmul(
input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(logits, group=None)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
def forward(self,
input_ids,
position_ids=None,
attention_mask=None,
masked_positions=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
else:
encoder_outputs = outputs
logits = self.parallel_matmul(encoder_outputs, self.weight, True,
self.gpt.topo)
if use_cache:
return logits, cached_kvs
else:
return logits
class GPTPretrainingCriterion(nn.Layer):
"""
Criterion for GPT.
It calculates the final loss.
"""
def __init__(self, topo=None):
super(GPTPretrainingCriterion, self).__init__()
if topo is None or topo.mp_info.size == 1:
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
else:
self.loss_func = paddle.distributed.collective._c_softmax_with_cross_entropy
def forward(self, prediction_scores, masked_lm_labels, loss_mask):
masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2))
loss_mask = loss_mask.reshape([-1])
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask)
loss = masked_lm_loss / loss_mask.sum()
return loss
def gpt_pretrain_forward(train_program, start_program):
with static.program_guard(train_program,
start_program), utils.unique_name.guard():
batch_size = 16
sequence_len = 512
input_ids = static.data(
name="input_ids", shape=[batch_size, sequence_len], dtype='int64')
position_ids = static.data(
name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float64')
labels = static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64')
loss_mask = static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float64')
if _global_parallel_stratergy == "dp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
elif _global_parallel_stratergy == "dp_mp":
auto.shard_tensor(
input_ids, _global_process_mesh, dim_mapping=[0, -1])
gpt = GPTModel(
vocab_size=32768,
hidden_size=1024,
num_hidden_layers=2,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=16,
initializer_range=0.02,
pad_token_id=0,
topo=None)
model = GPTForPretraining(gpt)
preds = model(input_ids, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
return train_program, start_program
class TestGPTAutoCompletion(unittest.TestCase):
def test_gpt_dp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[0, 1, 2, 3], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
def test_gpt_dp_mp(self):
global _global_parallel_stratergy
_global_parallel_stratergy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], parent=ROOT_MESH)
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
train_program, start_program = gpt_pretrain_forward(train_program,
start_program)
complete_train_program = auto.complete_annotation(train_program,
dist_context)
# print_program_with_distributed_attr(complete_train_program,
# dist_context)
self.assertTrue(
check_distributed_attr_for_program(complete_train_program,
dist_context))
if __name__ == "__main__":
unittest.main()
......@@ -165,6 +165,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.framework',
'paddle.jit',
'paddle.jit.dy2static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册