未验证 提交 93d862b0 编写于 作者: Y Yulong Ao 提交者: GitHub

Add auto completion module for auto parallel (#34813)

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* add dist

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update, test=develop

* update

* update

* update

* update

* update

* update, test=develop

* update, test=develop

* update

* update

* delete unused proto

* resotre op_desc

* restore type_defs

* update var_desc

* remove dimss_mapping for proto_pybind

* update interface.py

* update framework.py

* update

* update

* add auto_parallel dir

* mv to paddle.distributed

* add shard_xx api

* add distributed attrs for var

* add ut, test=develop

* [WIP] Add the auto completion feature and related codes

* [WIP] Improve the auto completion and related codes

* [WIP] Make the auto completion to support data-parallel

* [WIP] Make the completion support mp and dp+mp

* [WIP] Refactor auto completion unit test for MLP

* [WIP] Refactor the implementation of DistributedOperatorImpl

* [WIP] Improve dims_mapping update rule and fix a bug

* [WIP] Support auto completion for one transformer decoder layer

* [WIP] Add a minor change

* [WIP] Fix a bug within the uint test

* Shard XShape tensor, add embedding completion and refactor code

* Add the distributed_operators dir to setup.py.in

* Improve the completion process and add the unittest for gpt

* fix process_mesh ut

* fix process_mesh ut

* update

* update, test=develop

* Add support for automatically completing distributed attrs of special ops

* update

* update

* update

* fix doc sample codes, test=develop

* improve coverage, test=develop

* add static_mode check, test=develop

* Model the cluster for cost model and physical mapping

* update, test=develop

* add set_placement, test=develop

* Add the check to make sure the candidate tensors' size is great than zero

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update doc, test=develop

* update, test=develop

* Auto mark dist attrs annotated by user

* update ndarray to nested list, test=develop

* update, test=develop

* Add auto-completion module for auto-parallel (based on PR#33804)

* Remove unnecessary files

* Remove unrelated files for the auto completion pr

* Update the unit test to improve the coverage

* Modify codes based on reviews

* Minor changes for CI

* Improve some codes based on new comments

* Fix bugs caused by shallow copy in attributes.py
* Imporve amend_distributed_attr_for_program in context.py
* Other changes for weihang's comments
Co-authored-by: Nsandyhouse <lilong12@baidu.com>
上级 e8f146a9
...@@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) { ...@@ -353,6 +353,14 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
outputs_ = op_desc.outputs_; outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_; attrs_ = op_desc.attrs_;
need_update_ = true; need_update_ = true;
// When creating graph from program, the creation of op node will create a new
// OpDesc instead of
// referring to the original one. To find the original OpDesc of the op node,
// the id have to be
// copied to the new OpDesc. The var node has the same situation, but the
// default copy constructor
// can copy the id automatically.
id_ = op_desc.id_;
} }
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block) OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <atomic>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
...@@ -151,6 +152,18 @@ class OpDesc { ...@@ -151,6 +152,18 @@ class OpDesc {
const BlockDesc *Block() const { return this->block_; } const BlockDesc *Block() const { return this->block_; }
// This thread-safe implementation seems to be redudent since the neural
// networks
// are usually constructed in a single thread
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> id{0};
return ++id;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
...@@ -173,6 +186,8 @@ class OpDesc { ...@@ -173,6 +186,8 @@ class OpDesc {
// need_update_ indicate there some local changes not be synchronized. If // need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true. // local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false}; bool need_update_{false};
uint64_t id_ = GenerateId();
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <atomic>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -150,6 +151,17 @@ class VarDesc { ...@@ -150,6 +151,17 @@ class VarDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
// This thread-safe implementation seems to be redudent since the neural
// networks are usually constructed in a single thread.
static uint64_t GenerateId() {
static std::atomic<std::uint64_t> uid{0};
return ++uid;
}
// Note: the identity only used as a key for referring to its
// distributed attribute now.
uint64_t Id() { return id_; }
private: private:
const proto::VarType::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
...@@ -158,6 +170,7 @@ class VarDesc { ...@@ -158,6 +170,7 @@ class VarDesc {
proto::VarDesc desc_; proto::VarDesc desc_;
AttributeMap attrs_; AttributeMap attrs_;
uint64_t id_ = GenerateId();
}; };
bool operator==(const VarDesc &left, const VarDesc &right); bool operator==(const VarDesc &left, const VarDesc &right);
......
...@@ -24,7 +24,6 @@ limitations under the License. */ ...@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
#include "paddle/fluid/pybind/pybind_boost_headers.h" #include "paddle/fluid/pybind/pybind_boost_headers.h"
namespace paddle { namespace paddle {
...@@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) { ...@@ -202,6 +201,7 @@ void BindVarDsec(pybind11::module *m) {
.def("attr_names", &pd::VarDesc::AttrNames) .def("attr_names", &pd::VarDesc::AttrNames)
.def("_set_attr", &pd::VarDesc::SetAttr) .def("_set_attr", &pd::VarDesc::SetAttr)
.def("remove_attr", &pd::VarDesc::RemoveAttr) .def("remove_attr", &pd::VarDesc::RemoveAttr)
.def("id", &pd::VarDesc::Id)
.def("attr", &pd::VarDesc::GetAttr); .def("attr", &pd::VarDesc::GetAttr);
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", ""); pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
...@@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -294,6 +294,7 @@ void BindOpDesc(pybind11::module *m) {
.def("serialize_to_string", SerializeMessage<pd::OpDesc>) .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", [](pd::OpDesc &self) { return self.Block(); }, .def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("id", &pd::OpDesc::Id)
.def("inputs", &pd::OpDesc::Inputs) .def("inputs", &pd::OpDesc::Inputs)
.def("outputs", &pd::OpDesc::Outputs); .def("outputs", &pd::OpDesc::Outputs);
} }
......
...@@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401 ...@@ -57,7 +57,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv # noqa: F401
from . import cloud_utils # noqa: F401 from . import cloud_utils # noqa: F401
from . import utils # noqa: F401 from . import utils # noqa: F401
__all__ = [ #noqa
__all__ = [ # noqa
"spawn", "spawn",
"scatter", "scatter",
"broadcast", "broadcast",
......
...@@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401 ...@@ -18,5 +18,6 @@ from .interface import set_shard_mask # noqa: F401
from .interface import set_offload_device # noqa: F401 from .interface import set_offload_device # noqa: F401
from .interface import set_pipeline_stage # noqa: F401 from .interface import set_pipeline_stage # noqa: F401
from .interface import ProcessMesh # noqa: F401 from .interface import ProcessMesh # noqa: F401
from .completion import complete_annotation # noqa: F401
__all__ = [] __all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
class TensorDistributedAttribute:
def __init__(self, owner_tensor, owner_context):
self._owner_tensor = owner_tensor
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = None
self._shard_mask = None
self._offload_device = None
self._shape = None
self._is_annotated = {}
self._is_parameter = False
def get_owner_tensor(self):
return self._owner_tensor
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_dims_mapping(self):
return self._dims_mapping
def set_dims_mapping(self, dims_mapping):
self._dims_mapping = copy.deepcopy(dims_mapping)
def get_shard_mask(self):
return self._shard_mask
def set_shard_mask(self, shard_mask):
self._shard_mask = copy.deepcopy(shard_mask)
def get_offload_device(self):
return self._offload_device
def set_offload_device(self, offload_device):
self._offload_device = copy.deepcopy(offload_device)
def get_shape(self):
return self._shape
def set_shape(self, shape):
self._shape = copy.deepcopy(shape)
def is_annotated(self, dist_attr_name):
return self._is_annotated.get(dist_attr_name, False)
def mark_as_annotated(self, dist_attr_name):
self._is_annotated[dist_attr_name] = True
def is_parameter(self):
return self._is_parameter
def mark_as_parameter(self):
self._is_parameter = True
def is_valid(self):
tensor_shape = self.get_owner_tensor().desc.shape()
if len(tensor_shape) != len(self.get_dims_mapping()):
return False
for i in range(len(self.get_dims_mapping())):
if self.get_dims_mapping()[i] < -1 or self.get_dims_mapping()[
i] >= len(self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if self.get_dims_mapping().count(i) > 1:
return False
return True
def __str__(self):
str = "{{tensor name: {}, tensor id: {}".format(
self.get_owner_tensor().desc.name(),
self.get_owner_tensor().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
str += ", is_parameter: {}".format(self._is_parameter)
if self.is_annotated("dims_mapping"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", dims_mapping ({}): {}".format(annotated_str,
self.get_dims_mapping())
if self.is_annotated("shard_mask"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", shard_mask ({}): {}".format(annotated_str,
self.get_shard_mask())
if self.is_annotated("offload_device"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", offload_device ({}): {} }}".format(annotated_str,
self.get_offload_device())
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_owner_tensor" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
class OperatorDistributedAttribute:
def __init__(self, owner_op, owner_context):
self._owner_op = owner_op
self._owner_context = owner_context
self._process_mesh = None
self._dims_mapping = {}
self._shapes = {}
self._is_annotated = {}
self._is_parameters = {}
self._pipeline_stage = None
self._impl_idx = None
def get_owner_op(self):
return self._owner_op
def get_owner_context(self):
return self._owner_context
def get_process_mesh(self):
return self._process_mesh
def set_process_mesh(self, process_mesh):
self._process_mesh = copy.deepcopy(process_mesh)
def get_input_dims_mapping(self, name):
return self._dims_mapping.get("IN_" + name, None)
def set_input_dims_mapping(self, name, dims_mapping):
self._dims_mapping["IN_" + name] = copy.deepcopy(dims_mapping)
def get_output_dims_mapping(self, name):
return self._dims_mapping.get("OUT_" + name, None)
def set_output_dims_mapping(self, name, dims_mapping):
self._dims_mapping["OUT_" + name] = copy.deepcopy(dims_mapping)
def get_impl_idx(self):
return self._impl_idx
def set_impl_idx(self, impl_idx):
self._impl_idx = impl_idx
def get_pipeline_stage(self):
return self._pipeline_stage
def set_pipeline_stage(self, pipeline_stage):
self._pipeline_stage = copy.deepcopy(pipeline_stage)
def get_input_shape(self, name):
return self._shapes.get("IN_" + name, None)
def set_input_shape(self, name, shape):
self._shapes["IN_" + name] = copy.deepcopy(shape)
def get_output_shape(self, name):
return self._shapes.get("OUT_" + name, None)
def set_output_shape(self, name, shape):
self._shapes["OUT_" + name] = copy.deepcopy(shape)
def is_annotated(self, attr_name):
return self._is_annotated.get(attr_name, False)
def mark_as_annotated(self, attr_name):
self._is_annotated[attr_name] = True
def is_annotated_input_dims_mapping(self, name):
return self._is_annotated.get("IN_" + name, False)
def mark_as_annotated_input_dims_mapping(self, name):
self._is_annotated["IN_" + name] = True
def is_annotated_output_dims_mapping(self, name):
return self._is_annotated.get("OUT_" + name, False)
def mark_as_annotated_output_dims_mapping(self, name):
self._is_annotated["OUT_" + name] = True
def is_parameter(self, name):
return self._is_parameters.get(name, False)
def mark_as_parameter(self, name):
self._is_parameters[name] = True
def is_valid(self):
for name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(name)
shape = self.get_input_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
for name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(name)
shape = self.get_output_shape(name)
if len(shape) != len(dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(
self.get_process_mesh().topology):
return False
for i in range(len(self.get_process_mesh().topology)):
if dims_mapping.count(i) > 1:
return False
return True
def __str__(self):
str = "{{op type: {}, op id: {}".format(self.get_owner_op().desc.type(),
self.get_owner_op().desc.id())
if self.is_annotated("process_mesh"):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
str += ", process_mesh ({}): {}".format(annotated_str,
self.get_process_mesh())
for arg_name in self.get_owner_op().desc.input_arg_names():
dims_mapping = self.get_input_dims_mapping(arg_name)
if self.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (input, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
for arg_name in self.get_owner_op().desc.output_arg_names():
dims_mapping = self.get_output_dims_mapping(arg_name)
if self.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated"
else:
annotated_str = "non-annotated"
if self.is_parameter(arg_name):
is_parameter_str = "parameter"
else:
is_parameter_str = "non-parameter"
str += ", {}'s dims_mapping (output, {}, {}): {}".format(
arg_name, annotated_str, is_parameter_str, dims_mapping)
str += ", pipeline stage: {}".format(self._pipeline_stage)
str += ", dist_impl idx: {} }}".format(self._impl_idx)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner op and context
if k == "_owner_op" or k == "_owner_context":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
此差异已折叠。
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import copy
from collections import defaultdict
from paddle.fluid import framework
from .attribute import TensorDistributedAttribute
from .attribute import OperatorDistributedAttribute
from .utils import append_distributed_attr_suffix
# There always exists a default context for user. And user can set it to another one.
DEFAULT_DISTRIBUTED_CONTEXT = None
def get_default_distributed_context():
global DEFAULT_DISTRIBUTED_CONTEXT
if DEFAULT_DISTRIBUTED_CONTEXT is None:
dist_context = DistributedContext()
set_default_distributed_context(dist_context)
return DEFAULT_DISTRIBUTED_CONTEXT
def set_default_distributed_context(dist_context):
global DEFAULT_DISTRIBUTED_CONTEXT
DEFAULT_DISTRIBUTED_CONTEXT = dist_context
class DistributedContext:
"""
DistributedContext is used to collect related distributed information for program and graph.
One auto-parallel run should use its own DistributedContext to avoid interfering other run.
"""
def __init__(self):
self._is_initialized_for_program = False
self._is_initialized_for_graph = False
self._tensor_distributed_attr_map_for_program = {}
self._op_distributed_attr_map_for_program = {}
self._tensor_distributed_attr_map_for_graph = {}
self._op_distributed_attr_map_for_graph = {}
def is_initialized_for_program(self):
return self._is_initialized_for_program
def is_initialized_for_graph(self):
return self._is_initialized_for_graph
def get_tensor_distributed_attr_for_program(self, tensor):
tensor_id = tensor.desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program.get(
tensor_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_program(self, tensor, tensor_dist_attr):
tensor_id = tensor.desc.id()
self._tensor_distributed_attr_map_for_program[
tensor_id] = tensor_dist_attr
def get_op_distributed_attr_for_program(self, op):
op_id = op.desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program.get(op_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_program(self, op, op_dist_attr):
op_id = op.desc.id()
self._op_distributed_attr_map_for_program[op_id] = op_dist_attr
def get_tensor_distributed_attr_for_graph(self, tensor_node):
tensor_node_id = tensor_node.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_graph.get(
tensor_node_id, None)
return tensor_dist_attr
def set_tensor_distributed_attr_for_graph(self, tensor_node,
tensor_dist_attr):
tensor_node_id = tensor_node.id()
self._tensor_distributed_attr_map_for_graph[
tensor_node_id] = tensor_dist_attr
def get_op_distributed_attr_for_graph(self, op_node):
op_node_id = op_node.id()
op_dist_attr = self._op_distributed_attr_map_for_graph.get(op_node_id,
None)
return op_dist_attr
def set_op_distributed_attr_for_graph(self, op_node, op_dist_attr):
op_node_id = op_node.id()
self._op_distributed_attr_map_for_graph[op_node_id] = op_dist_attr
def initialize_distributed_attr_for_program(self, program):
if self._is_initialized_for_program:
return
for block in program.blocks:
for tensor in block.vars.values():
# Since only tensors have distributed attributes, it's better to make sure var is a tensor
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is None:
tensor_dist_attr = TensorDistributedAttribute(tensor, self)
self._copy_distributed_attr_from_tensor_desc(
tensor.desc, tensor_dist_attr)
self.set_tensor_distributed_attr_for_program(
tensor, tensor_dist_attr)
tensor_dist_attr.set_shape(tensor.desc.shape())
if tensor_dist_attr.get_process_mesh() is not None:
tensor_dist_attr.mark_as_annotated("process_mesh")
if tensor_dist_attr.get_dims_mapping() is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
tensor_dist_attr.set_dims_mapping(tensor_dims_mapping)
else:
tensor_dist_attr.mark_as_annotated("dims_mapping")
if isinstance(tensor, framework.Parameter):
tensor_dist_attr.mark_as_parameter()
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is None:
op_dist_attr = OperatorDistributedAttribute(op, self)
self._copy_distributed_attr_from_op_desc(op.desc,
op_dist_attr)
self.set_op_distributed_attr_for_program(op, op_dist_attr)
# Default distributed implementation for all operators
# This will be updated during the completion prcess
op_dist_attr.set_impl_idx(-2)
if op_dist_attr.get_process_mesh() is not None:
op_dist_attr.mark_as_annotated("process_mesh")
for tensor_name in op.input_arg_names:
# There may be a better way to find the tensor by name
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_input_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_input_dims_mapping(tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_input_dims_mapping(tensor_name,
tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_input_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
for tensor_name in op.output_arg_names:
tensor = op.block._var_recursive(tensor_name)
op_dist_attr.set_output_shape(tensor_name,
tensor.desc.shape())
if op_dist_attr.get_output_dims_mapping(
tensor_name) is None:
tensor_dims_mapping = [
-1 for _ in range(len(tensor.desc.shape()))
]
op_dist_attr.set_output_dims_mapping(
tensor_name, tensor_dims_mapping)
else:
op_dist_attr.mark_as_annotated_output_dims_mapping(
tensor_name)
if isinstance(tensor, framework.Parameter):
op_dist_attr.mark_as_parameter(tensor_name)
self._is_initialized_for_program = True
def finalize_distributed_attr_for_program(self, program):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before finalization."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = self.get_tensor_distributed_attr_for_program(
tensor)
if tensor_dist_attr is not None:
self._store_distributed_attr_to_tensor_desc(
tensor.desc, tensor_dist_attr)
for op in block.ops:
op_dist_attr = self.get_op_distributed_attr_for_program(op)
if op_dist_attr is not None:
self._store_distributed_attr_to_op_desc(op.desc,
op_dist_attr)
def _copy_distributed_attr_from_tensor_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
attr_name = append_distributed_attr_suffix("dim_mapping")
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_dims_mapping(copied_dims_mapping)
attr_name = append_distributed_attr_suffix("mask")
if desc.has_attr(attr_name):
shard_mask = desc.attr(attr_name)
copied_shard_mask = copy.deepcopy(shard_mask)
dist_attr.set_shard_mask(copied_shard_mask)
attr_name = append_distributed_attr_suffix("offload_device")
if desc.has_attr(attr_name):
offload_device = desc.attr(attr_name)
copied_offload_device = copy.deepcopy(offload_device)
dist_attr.set_offload_device(copied_offload_device)
def _copy_distributed_attr_from_op_desc(self, desc, dist_attr):
from paddle.distributed.auto_parallel.interface import _g_process_mesh_map
attr_name = append_distributed_attr_suffix("mesh_id")
if desc.has_attr(attr_name):
mesh_id = desc.attr(attr_name)
process_mesh = _g_process_mesh_map[mesh_id]
copied_process_mesh = copy.deepcopy(process_mesh)
dist_attr.set_process_mesh(copied_process_mesh)
for tensor_name in desc.input_arg_names():
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
for tensor_name in desc.output_arg_names():
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
if desc.has_attr(attr_name):
dims_mapping = desc.attr(attr_name)
copied_dims_mapping = copy.deepcopy(dims_mapping)
dist_attr.set_input_dims_mapping(tensor_name,
copied_dims_mapping)
attr_name = append_distributed_attr_suffix("pipeline_stage")
if desc.has_attr(attr_name):
pipeline_stage = desc.attr(attr_name)
copied_pipeline_stage = copy.deepcopy(pipeline_stage)
dist_attr.set_pipeline_stage(copied_pipeline_stage)
def _store_distributed_attr_to_tensor_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
dims_mapping = dist_attr.get_dims_mapping()
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("dim_mapping")
desc._set_attr(attr_name, dims_mapping)
shard_mask = dist_attr.get_shard_mask()
if shard_mask is not None:
attr_name = append_distributed_attr_suffix("mask")
desc._set_attr(attr_name, shard_mask)
offload_device = dist_attr.get_offload_device()
if offload_device is not None:
attr_name = append_distributed_attr_suffix("offload_device")
desc._set_attr(attr_name, offload_device)
def _store_distributed_attr_to_op_desc(self, desc, dist_attr):
process_mesh = dist_attr.get_process_mesh()
if process_mesh is not None:
attr_name = append_distributed_attr_suffix("mesh_id")
desc._set_attr(attr_name, process_mesh._id)
for tensor_name in desc.input_arg_names():
dims_mapping = dist_attr.get_input_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("IN_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
for tensor_name in desc.output_arg_names():
dims_mapping = dist_attr.get_output_dims_mapping(tensor_name)
if dims_mapping is not None:
attr_name = append_distributed_attr_suffix("OUT_" + tensor_name)
desc._set_attr(attr_name, dims_mapping)
pipeline_stage = dist_attr.get_pipeline_stage()
if pipeline_stage is not None:
attr_name = append_distributed_attr_suffix("pipeline_stage")
desc._set_attr(attr_name, pipeline_stage)
def initialize_distributed_attr_for_graph(self, graph):
assert self._is_initialized_for_program, \
"The program must initialize its distributed attribute before its graph."
if self._is_initialized_for_graph:
return
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
tensor_dist_attr = self._tensor_distributed_attr_map_for_program[
tensor_id]
assert tensor_dist_attr is not None, \
"Tensor must have a distributed attribute after the initialization for program."
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self.set_tensor_distributed_attr_for_graph(node,
new_tensor_dist_attr)
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self._op_distributed_attr_map_for_program[op_id]
assert op_dist_attr is not None, \
"Operator must have a distributed attribute after the initialization for program."
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self.set_op_distributed_attr_for_graph(node, new_op_dist_attr)
self._is_initialized_for_graph = True
def clear_distributed_attr_for_program(self):
self._tensor_distributed_attr_map_for_program.clear()
self._op_distributed_attr_map_for_program.clear()
def clear_distributed_attr_for_graph(self):
self._tensor_distributed_attr_map_for_graph.clear()
self._op_distributed_attr_map_for_graph.clear()
def copy_distribute_attr_from_graph_to_program(self, graph, program):
assert self._is_initialized_for_program and self._is_initialized_for_graph, \
"The distribute attributes must be initialized both in its program and graph"
updated_tensors = {}
all_nodes = graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
tensor_desc = node.var()
tensor_id = tensor_desc.id()
updated = updated_tensors.get(tensor_desc.name(), False)
# If a var has multiples var nodes in graph, only use the first one for now
if not updated:
tensor_dist_attr = self.get_tensor_distributed_attr_for_graph(
node)
new_tensor_dist_attr = copy.deepcopy(tensor_dist_attr)
self._tensor_distributed_attr_map_for_program[
tensor_id] = new_tensor_dist_attr
updated_tensors[tensor_desc.name()] = True
if node.is_op() and node.op() is not None:
op_desc = node.op()
op_id = op_desc.id()
op_dist_attr = self.get_op_distributed_attr_for_graph(node)
new_op_dist_attr = copy.deepcopy(op_dist_attr)
self._op_distributed_attr_map_for_program[
op_id] = new_op_dist_attr
def amend_distributed_attr_for_program(self):
for attr in self._tensor_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Tensor's distributed attribute {} is not valid".format(attr)
tensor_shape = attr.get_shape()
dims_mapping = attr.get_dims_mapping()
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[dims_mapping[
i]] > tensor_shape[i]:
dims_mapping[i] = -1
for attr in self._op_distributed_attr_map_for_program.values():
assert attr.is_valid(), \
"Operator's distributed attribute {} is not valid".format(attr)
for arg_name in attr.get_owner_op().desc.input_arg_names():
tensor_shape = attr.get_input_shape(arg_name)
dims_mapping = attr.get_input_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
for arg_name in attr.get_owner_op().desc.output_arg_names():
tensor_shape = attr.get_output_shape(arg_name)
dims_mapping = attr.get_output_dims_mapping(arg_name)
process_mesh_shape = attr.get_process_mesh().topology
# If the dimension of tensor is less than the sharding dimension of process mesh,
# we just amend the dimension mapping to -1. (Is this really OK?)
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and process_mesh_shape[
dims_mapping[i]] > tensor_shape[i]:
dims_mapping[i] = -1
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
import numpy import numpy
import paddle.fluid.core as core import copy
import paddle import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable from paddle.fluid.framework import Variable
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
...@@ -237,6 +238,23 @@ class ProcessMesh(object): ...@@ -237,6 +238,23 @@ class ProcessMesh(object):
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.process_group)
return str
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
# No need to copy the owner tensor and context
if k == "_desc":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
return result
def _dim_mapping_checker(tensor, mesh, dim_mapping): def _dim_mapping_checker(tensor, mesh, dim_mapping):
assert len(tensor.shape) == len(dim_mapping) assert len(tensor.shape) == len(dim_mapping)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from .common import find_best_compatible_distributed_operator_impl
from . import dist_embedding
from . import dist_matmul
from . import dist_reshape
from . import dist_softmax
from . import dist_transpose
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
DISTRIBUTED_OPERATORS = {}
class DistributedOperator:
def __init__(self):
self._impls = []
self._name = None
def register_impl(self, dist_impl):
self._impls.append(dist_impl)
def get_impl(self, impl_idx):
return self._impls[impl_idx]
def get_impls(self):
return self._impls
class DistributedOperatorImpl:
def __init__(self):
self._name = None
def forward(self, dist_ctx, *args, **kwargs):
raise NotImplementedError("Please Implement this method in Subclass.")
def backward(self, dist_ctx, *grad_outputs):
raise NotImplementedError("Please Implement this method in Subclass.")
def get_name(self):
return self._name
def is_process_mesh_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_input_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_output_compatible(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def is_compatible(self, op_dist_attr):
return self.is_process_mesh_compatible(op_dist_attr) \
and self.is_input_compatible(op_dist_attr) \
and self.is_output_compatible(op_dist_attr)
def update_dims_mapping(self, op_dist_attr):
raise NotImplementedError("Please Implement this method in Subclass.")
def register_distributed_operator(name, dist_op):
global DISTRIBUTED_OPERATORS
DISTRIBUTED_OPERATORS[name] = dist_op
def get_distributed_operator(name):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS.get(name, None)
def register_distributed_operator_impl(name, dist_impl):
dist_op = get_distributed_operator(name)
if dist_op is not None:
dist_op.register_impl(dist_impl)
else:
assert False, "Must register distributed operator first."
def get_distributed_operator_impl(name, impl_idx):
global DISTRIBUTED_OPERATORS
return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx)
def find_best_compatible_distributed_operator_impl(name, op_dist_attr,
fwd=True):
"""
Here just return the first compatible implemention.
This will be improved by cost model in the future.
"""
dist_op = get_distributed_operator(name)
if dist_op is None:
return None, -1
compatible_impls = []
impls = dist_op.get_impls()
if fwd:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_input_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
else:
for idx, impl in enumerate(impls):
if impl.is_process_mesh_compatible(op_dist_attr) \
and impl.is_output_compatible(op_dist_attr):
compatible_impls.append((impl, idx))
if compatible_impls:
best_compatible_impl, idx = compatible_impls[0]
else:
best_compatible_impl, idx = None, -1
return best_compatible_impl, idx
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedEmbedding(DistributedOperator):
def __init__(self, name):
super(DistributedEmbedding, self).__init__()
self._name = name
register_distributed_operator("lookup_table_v2",
DistributedEmbedding("embedding"))
# RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
ids_name = op_desc.input('Ids')[0]
w_name = op_desc.input('W')[0]
out_name = op_desc.output('Out')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(ids_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[ids_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[w_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("lookup_table_v2",
DistributedEmbeddingImpl("row_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
def _update_dims_mapping_for_matmul(op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_dims_mapping_len = len(x_dims_mapping)
y_dims_mapping_len = len(y_dims_mapping)
out_dims_mapping_len = len(out_dims_mapping)
# print("before", x_dims_mapping, y_dims_mapping, out_dims_mapping)
# Add dim mapping to Make sure the length dims_mapping be at least 2
if x_dims_mapping_len == 1:
x_dims_mapping.insert(0, -1)
if y_dims_mapping_len == 1:
y_dims_mapping.insert(1, -1)
# Deal with dim > 2 and take care of broadcasting
if out_dims_mapping_len > 2:
broadcast_x_dims_mapping = []
broadcast_y_dims_mapping = []
broadcast_out_dims_mapping = []
for i in range(out_dims_mapping_len - x_dims_mapping_len):
broadcast_x_dims_mapping.append(out_dims_mapping[i])
for i in range(x_dims_mapping_len - 2):
broadcast_x_dims_mapping.append(x_dims_mapping[i])
for i in range(out_dims_mapping_len - y_dims_mapping_len):
broadcast_y_dims_mapping.append(out_dims_mapping[i])
for i in range(y_dims_mapping_len - 2):
broadcast_y_dims_mapping.append(y_dims_mapping[i])
for i in range(out_dims_mapping_len - 2):
broadcast_out_dims_mapping.append(out_dims_mapping[i])
compatible_dims_mapping = compute_compatible_dims_mapping([
broadcast_x_dims_mapping, broadcast_y_dims_mapping,
broadcast_out_dims_mapping
])
assert compatible_dims_mapping is not None, "There is no compatible dim mapping."
for i in range(x_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - x_dims_mapping_len)
if x_dims_mapping[i] != compatible_dims_mapping[new_idx]:
x_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(y_dims_mapping_len - 2):
new_idx = i + (out_dims_mapping_len - y_dims_mapping_len)
if y_dims_mapping[i] != compatible_dims_mapping[new_idx]:
y_dims_mapping[i] = compatible_dims_mapping[new_idx]
changed = True
for i in range(out_dims_mapping_len - 2):
if out_dims_mapping[i] != compatible_dims_mapping[i]:
out_dims_mapping[i] = compatible_dims_mapping[i]
changed = True
# The following which uses negative index can be work
# when len(out_dims_mapping) > 2 and len(out_dims_mapping) <=2
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, y_dims_mapping], [-1, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [-2, -2])
if dim_changed:
changed = True
dim_changed = compute_compatible_and_update_dim_mapping(
[y_dims_mapping, out_dims_mapping], [-1, -1])
if dim_changed:
changed = True
# Remove unnecessary dim mapping to make sure the lenght of dims_mapping is same as its tensor
if x_dims_mapping_len == 1:
x_dims_mapping.pop(0)
if y_dims_mapping_len == 1:
y_dims_mapping.pop(1)
# print("after", x_dims_mapping, y_dims_mapping, out_dims_mapping)
assert len(x_dims_mapping) == x_dims_mapping_len
assert len(y_dims_mapping) == y_dims_mapping_len
assert len(out_dims_mapping) == out_dims_mapping_len
return changed
class DistributedMatmul(DistributedOperator):
def __init__(self, name):
super(DistributedMatmul, self).__init__()
self._name = name
register_distributed_operator("matmul", DistributedMatmul("matmul"))
# ColumnParallel
class DistributedMatmulImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_dim_shard(y_dims_mapping[0]) or is_dim_replicate(y_dims_mapping[
1]):
return False
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_replicate(out_dims_mapping[-1]):
return False
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# RowParallel
class DistributedMatmulImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_replicate(x_dims_mapping[-1]):
return False
if is_dim_replicate(y_dims_mapping[-2]) or is_dim_shard(y_dims_mapping[
-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
# Other dimensions must be replicate except the batch dimension
for mapping in out_dims_mapping[1:-1]:
if is_dim_shard(mapping):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
# ReplicateParallel
class DistributedMatmulImpl2(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulImpl2, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl("matmul",
DistributedMatmulImpl0("column_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl1("row_parallel"))
register_distributed_operator_impl("matmul",
DistributedMatmulImpl2("replicate_parallel"))
class DistributedMatmulV2(DistributedOperator):
def __init__(self, name):
super(DistributedMatmulV2, self).__init__()
self._name = name
register_distributed_operator("matmul_v2", DistributedMatmulV2("matmul_v2"))
# ReplicateParallel
class DistributedMatmulV2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedMatmulV2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
y_name = op_desc.input('Y')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
y_dims_mapping = op_dist_attr.get_input_dims_mapping(y_name)
if is_dim_shard(x_dims_mapping[-1]):
return False
if is_valid_list_index(x_dims_mapping,
-2) and is_dim_shard(x_dims_mapping[-2]):
return False
if is_dim_shard(y_dims_mapping[-1]):
return False
if is_valid_list_index(y_dims_mapping,
-2) and is_dim_shard(y_dims_mapping[-2]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if is_dim_shard(out_dims_mapping[-1]):
return False
if is_valid_list_index(out_dims_mapping,
-2) and is_dim_shard(out_dims_mapping[-2]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
dim_changed = _update_dims_mapping_for_matmul(op_dist_attr)
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"matmul_v2", DistributedMatmulV2Impl("replicate_parallel"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedReshape2(DistributedOperator):
def __init__(self, name):
super(DistributedReshape2, self).__init__()
self._name = name
register_distributed_operator("reshape2", DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[-1]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
if is_dim_shard(x_dims_mapping[-1]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if len(x_dims_mapping) != len(out_dims_mapping) + 1:
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl("reshape2",
DistributedReshapeImpl0("add_one_dim_back"))
register_distributed_operator_impl(
"reshape2", DistributedReshapeImpl1("remove_one_dim_back"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedSoftmax(DistributedOperator):
def __init__(self, name):
super(DistributedSoftmax, self).__init__()
self._name = name
register_distributed_operator("softmax", DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
axis = op_desc.attr('axis')
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
# print("softmax axis", axis)
if axis != -1 and axis != len(x_dims_mapping) - 1:
return False
if is_dim_shard(x_dims_mapping[axis]):
return False
return True
def is_output_compatible(self, op_dist_attr):
op_desc = op_dist_attr.get_owner_op().desc
out_name = op_desc.output('Out')[0]
axis = op_desc.attr('axis')
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if axis != -1 and axis != len(out_dims_mapping) - 1:
return False
if is_dim_shard(out_dims_mapping[axis]):
return False
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
register_distributed_operator_impl(
"softmax", DistributedSoftmaxImpl("replicate_last_axis"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from .common import DistributedOperator
from .common import DistributedOperatorImpl
from .common import register_distributed_operator
from .common import register_distributed_operator_impl
from ..utils import is_dim_shard
from ..utils import is_dim_replicate
from ..utils import is_valid_list_index
from ..utils import compute_compatible_dim_mapping
from ..utils import compute_compatible_dims_mapping
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedTranspose2(DistributedOperator):
def __init__(self, name):
super(DistributedTranspose2, self).__init__()
self._name = name
register_distributed_operator("transpose2", DistributedTranspose2("transpose2"))
class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__()
self._name = name
def is_process_mesh_compatible(self, op_dist_attr):
""" No restriction for now. """
return True
def is_input_compatible(self, op_dist_attr):
return True
def is_output_compatible(self, op_dist_attr):
return True
def update_dims_mapping(self, op_dist_attr):
changed = False
op_desc = op_dist_attr.get_owner_op().desc
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_shape_name = op_desc.output('XShape')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
x_shape_dims_mapping = op_dist_attr.get_output_dims_mapping(
x_shape_name)
perm = op_desc.attr('axis')
assert len(x_dims_mapping) == len(perm)
new_dims_mapping = [-1 for i in range(len(x_dims_mapping))]
for i in range(len(x_dims_mapping)):
new_dims_mapping[i] = x_dims_mapping[perm[i]]
for i in range(len(out_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[new_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
for i in range(len(x_dims_mapping)):
if x_dims_mapping[perm[i]] != new_dims_mapping[i]:
x_dims_mapping[perm[i]] = new_dims_mapping[i]
changed = True
for i in range(len(x_dims_mapping)):
x_shape_dims_mapping[i + 1] = x_dims_mapping[i]
return changed
register_distributed_operator_impl(
"transpose2", DistributedTranspose2Impl("same_mapping_transpose"))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import threading
import paddle.fluid.core as core
def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list):
return True
else:
return False
def is_dim_shard(mapping):
if mapping != -1:
return True
else:
return False
def is_dim_replicate(mapping):
if mapping == -1:
return True
else:
return False
def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings:
return None
compatible_mapping = dim_mappings[0]
for mapping in dim_mappings:
if compatible_mapping == -1:
compatible_mapping = mapping
elif mapping == -1:
continue
elif compatible_mapping == mapping:
continue
else:
return None
return compatible_mapping
def compute_compatible_dims_mapping(dims_mapping_list):
if not dims_mapping_list:
return None
length = len(dims_mapping_list[0])
for dims_mapping in dims_mapping_list:
assert dims_mapping is not None, \
"Dims mapping must not be None for compatible computation"
assert len(dims_mapping) == length, \
"The length of dims_mapping in list must be same for compatible computation."
compatible_result = []
for dim_mappings in zip(*dims_mapping_list):
compatible_dim_mapping = compute_compatible_dim_mapping(
list(dim_mappings))
if compatible_dim_mapping is None:
return None
compatible_result.append(compatible_dim_mapping)
return compatible_result
def compute_compatible_process_mesh(process_mesh_list):
compatible_process_mesh = None
if not process_mesh_list:
return compatible_process_mesh
for process_mesh in process_mesh_list:
if process_mesh is not None:
if compatible_process_mesh is None:
compatible_process_mesh = process_mesh
else:
assert process_mesh == compatible_process_mesh, \
"There is no compatible process mesh."
return compatible_process_mesh
def compute_compatible_and_update_dim_mapping(dims_mapping_list, index_list):
assert len(dims_mapping_list) == len(index_list)
changed = False
dim_mappings = []
for i in range(len(dims_mapping_list)):
assert is_valid_list_index(dims_mapping_list[i], index_list[i])
dim_mappings.append(dims_mapping_list[i][index_list[i]])
compatible_dim_mapping = compute_compatible_dim_mapping(dim_mappings)
if compatible_dim_mapping is None:
return False
for i in range(len(dims_mapping_list)):
if compatible_dim_mapping != dims_mapping_list[i][index_list[i]]:
dims_mapping_list[i][index_list[i]] = compatible_dim_mapping
changed = True
return changed
def append_distributed_attr_suffix(name):
"""
Append auto parallel suffix for distributed attribute name.
"""
return name + core.kAutoParallelSuffix()
def remove_distributed_attr_suffix(name):
"""
Remove auto parallel suffix from distributed attribute name.
"""
return name.strip(core.kAutoParallelSuffix())
def check_distributed_attr_for_program(program, dist_context=None):
from .context import get_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
assert dist_context.is_initialized_for_program(), \
"Distributed attributes must be initialized before check."
for block in program.blocks:
for tensor in block.vars.values():
tensor_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
tensor)
if (tensor_dist_attr is not None) and (
not tensor_dist_attr.is_valid()):
return False
for op in block.ops:
op_dist_attr = dist_context.get_op_distributed_attr_for_program(op)
if (op_dist_attr is not None) and (not op_dist_attr.is_valid()):
return False
return True
def print_program_with_distributed_attr(program, dist_context=None):
"""
This function reuses the original program output ability with a distributed context.
Using lock can avoid multiple threads change the default distributed context simultaneously.
"""
lock = threading.Lock()
lock.acquire()
from .context import get_default_distributed_context
from .context import set_default_distributed_context
if dist_context is None:
dist_context = get_default_distributed_context()
print(program)
else:
original_default_context = get_default_distributed_context()
set_default_distributed_context(dist_context)
print(program)
set_default_distributed_context(original_default_context)
lock.release()
...@@ -1224,6 +1224,14 @@ class Variable(object): ...@@ -1224,6 +1224,14 @@ class Variable(object):
if self.persistable: if self.persistable:
var_str = "persist " + var_str var_str = "persist " + var_str
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
var_dist_attr = dist_context.get_tensor_distributed_attr_for_program(
self)
if var_dist_attr is not None:
var_str += ", {name} = {value}".format(
name="dist_attr", value=var_dist_attr)
return var_str return var_str
def to_string(self, throw_on_error, with_details=False): def to_string(self, throw_on_error, with_details=False):
...@@ -2384,6 +2392,13 @@ class Operator(object): ...@@ -2384,6 +2392,13 @@ class Operator(object):
if i != len(attr_names) - 1: if i != len(attr_names) - 1:
attrs_str += ", " attrs_str += ", "
from paddle.distributed.auto_parallel.context import get_default_distributed_context
dist_context = get_default_distributed_context()
op_dist_attr = dist_context.get_op_distributed_attr_for_program(self)
if op_dist_attr is not None:
attrs_str += ", {name} = {value}".format(
name="dist_attr", value=op_dist_attr)
if outputs_str != "{}": if outputs_str != "{}":
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\ op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".\
format(outputs=outputs_str, op_type=self.type, format(outputs=outputs_str, op_type=self.type,
......
...@@ -165,6 +165,7 @@ packages=['paddle', ...@@ -165,6 +165,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.pp_utils', 'paddle.distributed.fleet.meta_parallel.pp_utils',
'paddle.distributed.fleet.meta_parallel.parallel_layers', 'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel', 'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.framework', 'paddle.framework',
'paddle.jit', 'paddle.jit',
'paddle.jit.dy2static', 'paddle.jit.dy2static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册