未验证 提交 3f962e77 编写于 作者: L lilong12 提交者: GitHub

add the basic apis for auto_parallel (#33804)

* add auto_parallel apis
上级 88f2f4a4
......@@ -202,7 +202,7 @@ cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op
cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute shape_inference op_info operator glog version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc process_mesh_desc.cc DEPS attribute shape_inference op_info operator glog version)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
......
......@@ -38,6 +38,13 @@ enum AttrType {
FLOAT64S = 12;
}
message ProcessMeshDesc {
required int32 id = 1;
required int32 parent_id = 2;
repeated int32 topology = 3;
repeated int32 process_group = 4;
};
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
......@@ -167,6 +174,15 @@ message VarType {
}
message VarDesc {
message Attr {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional string s = 4;
repeated int32 ints = 5;
};
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
......@@ -175,6 +191,7 @@ message VarDesc {
optional bool need_check_feed = 4 [ default = false ];
optional bool is_parameter = 5 [ default = false ];
optional bool stop_gradient = 6 [ default = false ];
repeated Attr attrs = 7;
}
message BlockDesc {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/process_mesh_desc.h"
namespace paddle {
namespace framework {
int32_t ProcessMeshDesc::next_id = -1;
ProcessMeshDesc::ProcessMeshDesc(const std::vector<int32_t> &topo,
const std::vector<int32_t> &process_group,
int32_t parent_id) {
int32_t cur_id = ++next_id;
desc_.set_id(cur_id);
desc_.set_parent_id(parent_id);
for (size_t i = 0; i != topo.size(); ++i) {
desc_.add_topology(topo[i]);
}
for (size_t i = 0; i != process_group.size(); ++i) {
desc_.add_process_group(process_group[i]);
}
ProcessMeshDescMap::GetInstance().Insert(cur_id, this);
}
std::vector<int32_t> ProcessMeshDesc::Topology() const {
size_t size = desc_.topology_size();
std::vector<int32_t> ret(size);
for (auto i = 0; i != desc_.topology_size(); ++i) {
ret[i] = desc_.topology(i);
}
return ret;
}
std::vector<int32_t> ProcessMeshDesc::ProcessGroup() const {
size_t size = desc_.process_group_size();
std::vector<int32_t> ret(size);
for (auto i = 0; i != desc_.process_group_size(); ++i) {
ret[i] = desc_.process_group(i);
}
return ret;
}
ProcessMeshDescMap &ProcessMeshDescMap::GetInstance() {
static ProcessMeshDescMap g_process_mesh_desc_map;
return g_process_mesh_desc_map;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
class ProcessMeshDesc {
public:
ProcessMeshDesc(const std::vector<int32_t>& topo,
const std::vector<int32_t>& process_group, int32_t parent_id);
int32_t ID() const { return desc_.id(); }
int32_t Parent() const { return desc_.parent_id(); }
std::vector<int32_t> Topology() const;
std::vector<int32_t> ProcessGroup() const;
static int32_t next_id;
private:
proto::ProcessMeshDesc desc_; // not_own
};
class ProcessMeshDescMap {
public:
static ProcessMeshDescMap& GetInstance();
bool Has(int32_t index) const { return map_.find(index) != map_.end(); }
void Insert(int32_t index, ProcessMeshDesc* mesh) {
PADDLE_ENFORCE_NE(
Has(index), true,
platform::errors::AlreadyExists("Index (%d) has been used.", index));
map_.insert(std::make_pair(index, mesh));
}
private:
ProcessMeshDescMap() = default;
// Use raw pointer to avoid double free
std::unordered_map<int32_t, ProcessMeshDesc*> map_;
DISABLE_COPY_AND_ASSIGN(ProcessMeshDescMap);
};
} // namespace framework
} // namespace paddle
......@@ -22,5 +22,13 @@ constexpr int kRootBlockIndex = 0;
// The Parent Index of root Block, this block does not exist.
constexpr int kNoneBlockIndex = -1;
// The Parent Index of root ProcessMesh, this ProcessMesh does not exist.
constexpr int kNoneProcessMeshIndex = -1;
// If a attribute name has a certain suffix, it means that the
// atrribute is a distributed-related attribute for auto parallel.
// e.g., "mesh_id@PARALLEL".
constexpr char kAutoParallelSuffix[] = "@PARALLEL";
} // namespace framework
} // namespace paddle
......@@ -280,6 +280,46 @@ std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
}
}
std::vector<std::string> VarDesc::AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}
void VarDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); }
void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(sandyhouse): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
this->attrs_[name] = std::vector<int>();
return;
}
bool valid = attr_type == proto::AttrType::INT ||
attr_type == proto::AttrType::STRING ||
attr_type == proto::AttrType::INTS;
PADDLE_ENFORCE_EQ(valid, true, platform::errors::InvalidArgument(
"The value for attr (%s) must be "
"one of list or int or string.",
name));
this->attrs_[name] = v;
}
Attribute VarDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_NE(it, attrs_.end(), platform::errors::NotFound(
"Attribute %s is not found.", name));
return it->second;
}
bool operator==(const VarDesc &left, const VarDesc &right) {
return left.Proto()->SerializeAsString() ==
right.Proto()->SerializeAsString();
......
......@@ -19,7 +19,9 @@ limitations under the License. */
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/type_defs.h"
namespace paddle {
namespace framework {
......@@ -137,6 +139,17 @@ class VarDesc {
desc_.set_need_check_feed(need_check_feed);
}
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
std::vector<std::string> AttrNames() const;
void SetAttr(const std::string &name, const Attribute &v);
void RemoveAttr(const std::string &name);
Attribute GetAttr(const std::string &name) const;
private:
const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const;
......@@ -144,6 +157,7 @@ class VarDesc {
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
proto::VarDesc desc_;
AttributeMap attrs_;
};
bool operator==(const VarDesc &left, const VarDesc &right);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/proto_desc.h"
#if defined(PADDLE_WITH_DGC)
#include "paddle/fluid/framework/details/dgc_const_values.h"
......@@ -33,6 +34,9 @@ void BindConstValue(pybind11::module* m) {
m->def("kControlDepVarName",
[] { return framework::ir::Node::kControlDepVarName; });
m->def("kNewGradSuffix", [] { return framework::kNewGradSuffix; });
m->def("kAutoParallelSuffix", [] { return framework::kAutoParallelSuffix; });
m->def("kNoneProcessMeshIndex",
[] { return framework::kNoneProcessMeshIndex; });
auto op_proto_and_checker_maker =
m->def_submodule("op_proto_and_checker_maker");
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/process_mesh_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h"
......@@ -84,6 +85,17 @@ void BindProgramDesc(pybind11::module *m) {
[](pd::ProgramDesc &self) -> int64_t { return self.Version(); });
}
void BindProcessMeshDesc(pybind11::module *m) {
pybind11::class_<pd::ProcessMeshDesc>(*m, "ProcessMeshDesc", "")
.def(pybind11::init<const std::vector<int32_t> &,
const std::vector<int32_t> &, int32_t>())
.def_property_readonly("id", &pd::ProcessMeshDesc::ID)
.def_property_readonly("parent", &pd::ProcessMeshDesc::Parent)
.def_property_readonly("topology", &pd::ProcessMeshDesc::Topology)
.def_property_readonly("process_group",
&pd::ProcessMeshDesc::ProcessGroup);
}
void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
......@@ -184,7 +196,12 @@ void BindVarDsec(pybind11::module *m) {
.def("clear_stop_gradient", &pd::VarDesc::ClearStopGradient)
.def("has_stop_gradient", &pd::VarDesc::HasStopGradient)
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed)
.def("has_attr", &pd::VarDesc::HasAttr)
.def("attr_names", &pd::VarDesc::AttrNames)
.def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
......
......@@ -30,6 +30,7 @@ void BindProgramDesc(pybind11::module* m);
void BindBlockDesc(pybind11::module* m);
void BindVarDsec(pybind11::module* m);
void BindOpDesc(pybind11::module* m);
void BindProcessMeshDesc(pybind11::module* m);
} // namespace pybind
} // namespace paddle
......@@ -2054,6 +2054,7 @@ All parameter, weight, gradient are variables in Paddle.
BindOpDesc(&m);
BindConstValue(&m);
BindGlobalValueGetterSetter(&m);
BindProcessMeshDesc(&m);
py::class_<framework::LoDRankTable>(m, "LodRankTable")
.def("items", [](framework::LoDRankTable &table) {
......
......@@ -36,6 +36,13 @@ from .collective import get_group # noqa: F401
from .collective import send # noqa: F401
from .collective import wait # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import set_shard_mask # noqa: F401
from .auto_parallel import set_offload_device # noqa: F401
from .auto_parallel import set_pipeline_stage # noqa: F401
from .auto_parallel import ProcessMesh # noqa: F401
from .fleet import BoxPSDataset # noqa: F401
from .entry_attr import ProbabilityEntry # noqa: F401
......@@ -69,5 +76,11 @@ __all__ = [ #noqa
"ReduceOp",
"wait",
"get_rank",
"ProbabilityEntry"
"ProbabilityEntry",
"shard_tensor",
"shard_op",
"set_shard_mask",
"set_offload_device",
"set_pipeline_stage",
"ProcessMesh",
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401
__all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import paddle.fluid.core as core
import paddle
from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode
__all__ = []
# a map from ProcessMesh ids to the ProcessMesh instances
_g_process_mesh_map = dict()
# user defined map from logical process ids to physical ones
_user_defined_physical_map = None
def _append_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def _remove_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def _static_mode_check():
if in_dygraph_mode():
raise RuntimeError("Auto-parallel only supports static mode, "
"please use paddle.enable_static().")
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
parent (ProcessMesh, optional): the parent ProcessMesh. None means
the ProcessMesh is the root one without parent ProcessMesh.
Default: None.
Returns:
None
Raises:
ValueError: If `mesh` is not an instance of list.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.parent is None
assert mesh.topology == [2, 3]
assert mesh.process_group == [2, 4, 5, 0, 1, 3]
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
def __init__(self, mesh, parent=None):
_static_mode_check()
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
self._topology = _get_nested_list_shape(mesh)
self._processes = _flatten_nested_list(mesh)
# Every element of mesh must be >= 0.
assert min(self._processes) >= 0, ('All elements of mesh must be >= 0.')
unique_ids = set(self._processes)
assert len(unique_ids) == len(self._processes), (
'All elements of mesh must be unique.')
if parent is None:
# For root ProcessMesh, the ids of logical processes must be range
# from 0 to N-1, where N is the number of logical processes.
assert max(self._processes) == len(self._processes) - 1, (
'For root ProcessMesh, ids of logical processes must be range '
'from 0 to N-1, where N is the number of logical processes.')
parent_id = core.kNoneProcessMeshIndex()
assert len(_g_process_mesh_map.keys()) == 0, (
'The first ProcessMesh must be the root, which has no parent.')
else:
assert len(_g_process_mesh_map.keys()) > 0, (
'All ProcessMesh must have a parent except the root one.')
assert isinstance(parent, ProcessMesh), (
'parent must be an instance of ProcessMesh.')
parent_id = parent._desc.id
# All elements in mesh must belong to its parent
parent_ids = set(parent.process_group)
assert unique_ids <= parent_ids, (
'All elements in mesh must belong to its parent.')
self._desc = core.ProcessMeshDesc(self._topology, self._processes,
parent_id)
self._id = self._desc.id
self._parent_id = parent_id
assert self._id not in _g_process_mesh_map, (
"The ProcessMesh with id %d already exists." % self._id)
_g_process_mesh_map[self._id] = self
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
"""
return self._topology
@property
def process_group(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
"""
return self._processes
@property
def parent(self):
r"""
Get the parent ProcessMesh.
"""
if self._parent_id == core.kNoneProcessMeshIndex(): return None
assert self._parent_id in _g_process_mesh_map, (
"parent with id %d does not exist." % self._parent_id)
return _g_process_mesh_map[self._parent_id]
def set_placement(self, order):
"""
Set the map from logical processes to physical ones using the
user defined order.
Args:
order (list): order of the physical process ids.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mesh.set_placement([0, 1, 2, 3, 4, 5])
"""
assert self.parent is None, (
"This function can only be called by the root ProcessMesh.")
unique_ids = set(order)
assert isinstance(order, list)
assert len(unique_ids) == len(order), (
"All elements in order must be unique.")
assert min(order) == 0
assert max(order) == len(order) - 1, (
"All elements in order must be from 0 to N - 1, where N "
"is the number of physical processes.")
logical_order = self.process_group
global _user_defined_physical_map
assert _user_defined_physical_map is None, (
"This function can only be called once.")
_user_defined_physical_map = dict()
assert len(logical_order) == len(order)
for idx, l_id in enumerate(logical_order):
_user_defined_physical_map[l_id] = order[idx]
def __eq__(self, other):
assert other and isinstance(other, ProcessMesh)
if self.topology != other.topology or self.process_group != other.process_group:
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert len(tensor.shape) == len(dim_mapping)
mesh_dim = len(mesh.topology)
dim_set = set()
for i in range(len(dim_mapping)):
assert dim_mapping[i] == -1 or (dim_mapping[i] < mesh_dim and
dim_mapping[i] >= 0)
if dim_mapping[i] >= 0:
assert dim_mapping[i] not in dim_set
dim_set.add(dim_mapping[i])
def shard_tensor(x, mesh, dim_mapping):
"""
Add distributed attributes for a tensors.
Args:
x (Tensor): the tensor to process.
mesh (ProcessMesh): an instance of ProcessMesh to describe the topology of logical processes.
dim_mapping (list): a list to describe the mapping between `x` and `mesh`,
the dimension `i` of `x` is split across the dimension `dims_mapping[i]`, where -1 means
without parition along the corresponding dimension.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
dist.shard_tensor(x, mesh, [0, -1])
"""
_static_mode_check()
_dim_mapping_checker(x, mesh, dim_mapping)
attr_name = _append_attr_suffix('mesh_id')
x._set_attr(attr_name, mesh._id)
attr_name = _append_attr_suffix('dim_mapping')
x._set_attr(attr_name, dim_mapping)
return x
def set_shard_mask(x, mask):
"""
Set the mask for a tensor which mask out the tensor from some processes in its mesh.
Args:
x (Tensor): the tensor to process.
mask (list): a nested list. The shape of `mask` must be the same as the ProcessMesh belonging to
the tensor `x`. Every value of `mask` must be one or zero, where one means
the tenor `x` will be put on the corresponding logical process and zero means the tensor `x`
will not be put on the corresponding logical process.
For example, for a ProcessMesh represented by the 2-dimensional
array [[2, 4, 5], [0, 1, 3]], and a `mask` given by the
2-dimensional [[1, 0, 1], [0, 1, 0]],
then the tensor `x` will only be put on logical processes 2, 5 and 1.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
mask = [[1, 0, 1], [0, 1, 0]]
x = paddle.ones([4, 6])
dist.set_shard_mask(x, mask)
"""
_static_mode_check()
assert isinstance(mask, list)
attr_name = _append_attr_suffix('mask')
x._set_attr(attr_name, _flatten_nested_list(mask))
return x
def shard_op(op_fn, mesh, dim_mapping_dict, **kwargs):
"""
Call a functioin and add distributed attributes for ops added by the function.
Args:
op_fn (callable): a callable object of an API.
mesh (ProcessMesh): an instance of ProcessMesh specifies the topology of logical processes.
dim_mapping_dict (dict): a mapping from tensor's name to its dims_mapping.
The dim_mapping is a list to describe the mapping between a tensor and `mesh`,
the dimension `i` of the tensor is split across the dimension `dim_mapping[i]`,
where -1 means without parition along the corresponding dimension.
kwargs (dict): a dict of parameter passed to the function `op_fn`.
Returns:
list: the outputs of the function `op_fn`.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
kwargs = {'x': x, 'y': y}
dist.shard_op(paddle.add, mesh, None, **kwargs)
"""
_static_mode_check()
main_prog = paddle.fluid.default_main_program()
main_block = main_prog.global_block()
op_size = len(main_block.ops)
output = op_fn(**kwargs)
new_op_size = len(main_block.ops)
if dim_mapping_dict is None: dim_mapping_dict = dict()
for idx in range(op_size, new_op_size):
op = main_block.ops[idx]
attr_name = _append_attr_suffix('mesh_id')
op._set_attr(attr_name, mesh._id)
for var_name in dim_mapping_dict.keys():
assert var_name in op.output_arg_names + op.input_arg_names
attr_name = _append_attr_suffix(var_name)
if var_name in op.input_arg_names:
# we use the prefix "IN_" to indicates an input argument name
attr_name = "IN_" + attr_name
else:
# we use the prefix "OUT_" to indicates an input argument name
attr_name = "OUT_" + attr_name
op._set_attr(attr_name, dim_mapping_dict[var_name])
if isinstance(output, Variable):
output = [output]
return list(output)
def set_offload_device(x, device):
"""
Set the device that the tensor `x` will be put on.
Args:
x (tensor): the tensor to process.
device (str): the device that the tensor `x` will be put on, e.g., 'cpu'.
Returns:
Tensor: the tensor `x` itself.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
x = paddle.ones([4, 6])
dist.set_offload_device(x, 'cpu')
"""
_static_mode_check()
attr_name = _append_attr_suffix("offload_device")
x._set_attr(attr_name, device)
return x
def set_pipeline_stage(stage):
"""
Set the pipeline stage of the following ops.
Args:
stage (int): the pipeline stage the following ops belonging to.
Returns:
None.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
dist.set_pipeline_stage(0)
"""
from paddle.fluid.framework import _set_pipeline_stage
_static_mode_check()
_set_pipeline_stage(stage)
......@@ -72,6 +72,7 @@ _dygraph_tracer_ = None
_global_expected_place_ = None
_current_device = None
global_prog_seed = 0
_current_pipeline_stage = None
_global_flags_ = core.globals()
......@@ -239,6 +240,11 @@ def _static_only_(func):
return __impl__
def _set_pipeline_stage(stage):
global _current_pipeline_stage
_current_pipeline_stage = stage
# NOTE(zhiqiu): This decorator is used for the APIs of Variable which is only
# used to make Variable and VarBase has same interfaces, like numpy. Since VarBase is not exposed in our
# official docments, logically, we want to keep VarBase and logically consistent. While, actually,
......@@ -1873,6 +1879,86 @@ class Variable(object):
type='size', inputs={'Input': [self]}, outputs={'Out': [output]})
return output
def _set_attr(self, name, val):
"""
Set the value of attribute by attribute's name.
Args:
name(str): the attribute name.
val(int|str|list): the value of the attribute.
"""
self._update_desc_attr(name, val)
def _has_attr(self, name):
"""
Whether this Variable has the attribute with the name `name` or not.
Args:
name(str): the attribute name.
Returns:
bool: True if has this attribute.
"""
return self.desc.has_attr(name)
def _remove_attr(self, name):
self.desc.remove_attr(name)
def _update_desc_attr(self, name, val):
"""
Update the value of desc's attribute by attribute's name.
Args:
name(str): the attribute name.
val(int|str|list): the value of the attribute.
"""
self.desc._set_attr(name, val)
@property
def attr_names(self):
"""Get the names of all attributes defined."""
return self.desc.attr_names()
def _get_attr(self, name):
"""
Get the attribute by name.
Args:
name(str): the attribute name.
Returns:
int|str|list: The attribute value. The return value
can be any valid attribute type.
"""
return self.desc.attr(name)
@property
def process_mesh(self):
"""
Get the process mesh belonging to this Variable.
"""
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
from paddle.distributed.auto_parallel.interface import ProcessMesh
mesh_attr_name = 'mesh_id' + core.kAutoParallelSuffix()
mesh_id = self.desc.attr(mesh_attr_name)
return _g_process_mesh_map[mesh_id]
@property
def shard_mask(self):
"""
Get shard_mask belonging to this Variable.
"""
mask_attr_name = 'mask' + core.kAutoParallelSuffix()
return self.desc.attr(mask_attr_name)
@property
def offload_device(self):
"""
Get the offload device of this Variable.
"""
offload_attr_name = 'offload_device' + core.kAutoParallelSuffix()
return self.desc.attr(offload_attr_name)
def get_all_op_protos():
"""
......@@ -2077,6 +2163,11 @@ class Operator(object):
"The Attr(force_cpu) of Op(%s) will be deprecated in the future, "
"please use 'device_guard' instead. 'device_guard' has higher priority when they are "
"used at the same time." % type)
if _current_pipeline_stage is not None:
pipeline_attr_name = 'pipeline_stage' + core.kAutoParallelSuffix(
)
self._update_desc_attr(pipeline_attr_name,
_current_pipeline_stage)
def find_name(var_list, name):
for var_name in var_list:
......@@ -2548,6 +2639,31 @@ class Operator(object):
return False
@property
def process_mesh(self):
"""
Get the process mesh belonging to this Operator.
"""
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
mesh_attr_name = 'mesh_id' + core.kAutoParallelSuffix()
mesh_id = self.attr(mesh_attr_name)
return _g_process_mesh_map[mesh_id]
def dims_mapping(self, name):
"""
Get the dims_mapping for the op's var named `name`.
"""
dims_mapping_attr_name = name + core.kAutoParallelSuffix()
return self.attr(dims_mapping_attr_name)
@property
def pipeline_stage(self):
"""
Get pipeline stage of the Operator.
"""
pipeline_stage_attr_name = 'pipeline_stage' + core.kAutoParallelSuffix()
return self.desc.attr(pipeline_stage_attr_name)
class Block(object):
"""
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import functools
import operator
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.nn as nn
import paddle.distributed as dist
paddle.enable_static()
def _flatten_nested_list(nested_list):
result = functools.reduce(operator.iconcat, nested_list, [])
return result
def _append_attr_suffix(name):
return name + core.kAutoParallelSuffix()
LAST_PP_STAGE = 3
MASK = [[0, 1], [1, 0], [1, 1]]
MESH = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]])
class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__()
self.mesh = MESH
self.mesh.set_placement([5, 4, 3, 2, 1, 0])
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2)
def forward(self, x, y):
x = dist.shard_tensor(x, self.mesh, dim_mapping=[0, -1])
x = dist.set_shard_mask(x, MASK)
emb_out = self.word_embeddings(x)
dist.set_pipeline_stage(LAST_PP_STAGE)
y = dist.shard_tensor(y, self.mesh, dim_mapping=[0, -1])
dist.set_offload_device(y, "gpu:3")
linear1 = self.dense1(y)
out = self.dense2(linear1)
return x, y, self.mesh
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
net = SimpleNet()
data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
data2 = fluid.layers.fill_constant(
shape=[2, 4], value=2, dtype="float32")
data3 = fluid.layers.fill_constant(
shape=[2, 4], value=4, dtype="float32")
x, y, mesh = net.forward(data1, data2)
mesh_attr = _append_attr_suffix('mesh_id')
x_mesh_id = x._get_attr(mesh_attr)
self.assertEqual(x_mesh_id, mesh._id)
x_mesh = x.process_mesh
allatts = x.attr_names
self.assertEqual(x_mesh, mesh)
shard_mask_attr = _append_attr_suffix('mask')
self.assertEqual(
x._get_attr(shard_mask_attr), _flatten_nested_list(MASK))
self.assertEqual(x.shard_mask, _flatten_nested_list(MASK))
offload_attr = _append_attr_suffix('offload_device')
self.assertEqual(y._get_attr(offload_attr), "gpu:3")
self.assertEqual(y.desc.has_attr(offload_attr), True)
self.assertEqual(y.offload_device, "gpu:3")
y._remove_attr(offload_attr)
self.assertEqual(y._has_attr(offload_attr), False)
ops = paddle.static.default_main_program().block(0).ops
first_op = ops[0]
last_op = ops[-1]
self.assertEqual(last_op.pipeline_stage, LAST_PP_STAGE)
DIMS_MAPPING1 = [0, 1, -1]
DIMS_MAPPING2 = [-1, 2, 0]
kwargs = {'x': data2, 'y': data3}
dist.shard_op(
paddle.add,
mesh=mesh,
dim_mapping_dict={
data2.name: DIMS_MAPPING1,
data3.name: DIMS_MAPPING2
},
**kwargs)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
self.assertEqual(last_op.process_mesh, mesh)
attr_name = "IN_" + data2.name
attr_name = _append_attr_suffix(attr_name)
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING1)
attr_name = "IN_" + data3.name
attr_name = _append_attr_suffix(attr_name)
self.assertEqual(last_op.attr(attr_name), DIMS_MAPPING2)
def test_process_mesh(self):
mesh1 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=MESH)
mesh2 = dist.ProcessMesh([[0, 1, 2], [3, 4, 5]], parent=mesh1)
mesh3 = dist.ProcessMesh([[0, 1], [2, 3]], parent=mesh1)
mesh4 = dist.ProcessMesh([[2, 3], [4, 5]], parent=mesh1)
self.assertEqual(MESH.parent, None)
self.assertEqual(mesh1.parent, MESH)
self.assertEqual(mesh1._desc.parent, MESH._id)
self.assertEqual(mesh3.parent, mesh1)
self.assertEqual(mesh4.parent, mesh1)
self.assertEqual(mesh1, mesh2)
self.assertNotEqual(mesh3, mesh4)
self.assertEqual(mesh2._id, mesh2._desc.id)
self.assertEqual(mesh3.topology, mesh3._desc.topology)
self.assertEqual(mesh3.topology, [2, 2])
self.assertEqual(mesh3.process_group, [0, 1, 2, 3])
self.assertEqual(mesh4.process_group, mesh4._desc.process_group)
if __name__ == '__main__':
unittest.main()
......@@ -164,6 +164,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel',
'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.framework',
'paddle.jit',
'paddle.jit.dy2static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册