提交 362bbacf 编写于 作者: K kswang

add group for allreduce fusion

上级 4ecc9389
......@@ -113,20 +113,28 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
strategy_ckpt_save_file_ = strategy_ckpt_save_file;
}
void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) {
all_reduce_fusion_split_indices_ = indices;
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices;
}
const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_indices() const {
return all_reduce_fusion_split_indices_;
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const {
auto iter = all_reduce_fusion_split_indices_.find(group);
if (iter != all_reduce_fusion_split_indices_.end()) {
return iter->second;
}
return {};
}
void ParallelContext::set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes) {
all_reduce_fusion_split_sizes_ = sizes;
void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) {
all_reduce_fusion_split_sizes_[group] = sizes;
}
const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_sizes() const {
return all_reduce_fusion_split_sizes_;
const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const {
auto iter = all_reduce_fusion_split_sizes_.find(group);
if (iter != all_reduce_fusion_split_sizes_.end()) {
return iter->second;
}
return {};
}
} // namespace parallel
} // namespace mindspore
......@@ -19,6 +19,7 @@
#include <cstdint>
#include <memory>
#include <map>
#include <string>
#include <vector>
......@@ -76,10 +77,10 @@ class ParallelContext {
bool global_rank_is_set() const { return global_rank_is_set_; }
bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; }
void set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices);
const std::vector<uint32_t> all_reduce_fusion_split_indices() const;
void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes);
const std::vector<uint32_t> all_reduce_fusion_split_sizes() const;
void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const;
void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group);
const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const;
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
}
......@@ -108,8 +109,8 @@ class ParallelContext {
bool global_rank_is_set_;
bool parameter_broadcast_is_set_;
bool enable_all_reduce_fusion_;
std::vector<uint32_t> all_reduce_fusion_split_indices_;
std::vector<uint32_t> all_reduce_fusion_split_sizes_;
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_;
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_;
};
......
......@@ -159,13 +159,13 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
.def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
.def("set_all_reduce_fusion_split_indices", &ParallelContext::set_all_reduce_fusion_split_indices,
.def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
"Set all reduce fusion split indices.")
.def("get_all_reduce_fusion_split_indices", &ParallelContext::all_reduce_fusion_split_indices,
.def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices,
"Get all reduce fusion split indices.")
.def("set_all_reduce_fusion_split_sizes", &ParallelContext::set_all_reduce_fusion_split_sizes,
.def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes,
"Set all reduce fusion split sizes.")
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes,
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes,
"Get all reduce fusion split sizes.")
.def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
"Set enable/disable all reduce fusion.")
......
......@@ -92,7 +92,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
} // namespace
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
std::vector<size_t> *segment_index) const {
std::vector<size_t> *segment_index, const std::string &group) const {
MS_EXCEPTION_IF_NULL(segment_num);
MS_EXCEPTION_IF_NULL(segment_index);
size_t communication_op_node_size = communication_op_info.communication_op_nodes.size();
......@@ -100,7 +100,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
const std::vector<uint32_t> split_indices = parallel_context->all_reduce_fusion_split_indices();
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
size_t segments = 0;
if (split_indices.size() != 0) {
......@@ -255,7 +255,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) {
}
size_t segment_num = 0;
std::vector<size_t> segment_index;
if (GetSplitSegments(it.second, &segment_num, &segment_index)) {
if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) {
if (DoFusion(func_graph, it.second, segment_num, segment_index)) {
changed = true;
}
......
......@@ -46,7 +46,7 @@ class CommunicationOpFusion : public Pass {
const CommunicationOpInfo &communication_op_info, size_t start_index,
size_t end_index) const;
bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
std::vector<size_t> *segment_index) const;
std::vector<size_t> *segment_index, const std::string &group) const;
std::string op_name_;
size_t groups_ = 1;
};
......
......@@ -19,6 +19,8 @@ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx,
from mindspore._c_expression import AutoParallelContext
from mindspore._checkparam import args_type_check
_MAX_GROUP_NAME_LEN = 127
class _AutoParallelContext:
"""
......@@ -243,51 +245,117 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_parameter_broadcast_is_set()
def set_all_reduce_fusion_split_indices(self, indices):
def set_all_reduce_fusion_split_indices(self, indices, group=""):
"""
Set allreduce fusion strategy by parameters indices.
Args:
indices (list): Indices list.
group (str): The hccl communication group.
Raises:
TypeError: If type of indices item is not int.
TypeError: If group is not a python str.
"""
self.check_context_handle()
for index in indices:
if not isinstance(index, int):
raise TypeError('indices has invalid value')
self._context_handle.set_all_reduce_fusion_split_indices(indices)
if isinstance(indices, (list)):
for index in indices:
if not isinstance(index, int):
raise TypeError('indices has invalid value')
else:
raise TypeError('indices must be a python list')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_idx(indices)
if group == "":
_set_fusion_strategy_by_idx(indices)
else:
_set_fusion_strategy_by_idx(indices, group)
def get_all_reduce_fusion_split_indices(self, group=""):
"""
Get allreduce fusion split indices.
Args:
group (str): The hccl communication group.
Returns:
Return split sizes list according to the group.
def get_all_reduce_fusion_split_indices(self):
"""Get allreduce fusion split indices."""
Raises:
TypeError: If group is not a python str.
"""
self.check_context_handle()
return self._context_handle.get_all_reduce_fusion_split_indices()
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
return self._context_handle.get_all_reduce_fusion_split_indices(group)
def set_all_reduce_fusion_split_sizes(self, sizes):
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
"""
Set allreduce fusion strategy by parameters data sizes.
Args:
sizes (list): Sizes list.
group (str): The hccl communication group.
Raises:
TypeError: If type of sizes item is not int.
TypeError: If group is not a python str.
"""
self.check_context_handle()
for size in sizes:
if not isinstance(size, int):
raise TypeError('sizes has invalid value')
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
if isinstance(sizes, (list)):
for size in sizes:
if not isinstance(size, int):
raise TypeError('sizes has invalid value')
else:
raise TypeError('sizes must be a python list')
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes)
if group == "":
_set_fusion_strategy_by_size(sizes)
else:
_set_fusion_strategy_by_size(sizes, group)
def get_all_reduce_fusion_split_sizes(self):
"""Get allreduce fusion split sizes."""
def get_all_reduce_fusion_split_sizes(self, group=""):
"""
Get allreduce fusion split sizes.
Args:
group (str): The hccl communication group.
Returns:
Return split sizes list according to the group.
Raises:
TypeError: If group is not a python str.
"""
self.check_context_handle()
return self._context_handle.get_all_reduce_fusion_split_sizes()
if isinstance(group, (str)):
group_len = len(group)
if group_len > _MAX_GROUP_NAME_LEN:
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
else:
raise TypeError('Group must be a python str')
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册