提交 4ff41808 编写于 作者: L lirongzhen 提交者: lirongzhen1

enable/disable allreduce_fusion

上级 9edc69af
...@@ -31,10 +31,11 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti ...@@ -31,10 +31,11 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
// assume no change to graph // assume no change to graph
bool changes = false; bool changes = false;
// control whether use model_parallel mode // control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) ||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
return changes; return changes;
} }
......
...@@ -55,6 +55,7 @@ void ParallelContext::Reset() { ...@@ -55,6 +55,7 @@ void ParallelContext::Reset() {
parallel_mode_ = STAND_ALONE; parallel_mode_ = STAND_ALONE;
parameter_broadcast_ = false; parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false; parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = false;
} }
void ParallelContext::set_device_num(int32_t device_num) { void ParallelContext::set_device_num(int32_t device_num) {
......
...@@ -80,6 +80,10 @@ class ParallelContext { ...@@ -80,6 +80,10 @@ class ParallelContext {
const std::vector<uint32_t> all_reduce_fusion_split_indices() const; const std::vector<uint32_t> all_reduce_fusion_split_indices() const;
void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes); void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes);
const std::vector<uint32_t> all_reduce_fusion_split_sizes() const; const std::vector<uint32_t> all_reduce_fusion_split_sizes() const;
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
}
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
void Reset(); void Reset();
...@@ -98,6 +102,7 @@ class ParallelContext { ...@@ -98,6 +102,7 @@ class ParallelContext {
bool device_num_is_set_; bool device_num_is_set_;
bool global_rank_is_set_; bool global_rank_is_set_;
bool parameter_broadcast_is_set_; bool parameter_broadcast_is_set_;
bool enable_all_reduce_fusion_;
std::vector<uint32_t> all_reduce_fusion_split_indices_; std::vector<uint32_t> all_reduce_fusion_split_indices_;
std::vector<uint32_t> all_reduce_fusion_split_sizes_; std::vector<uint32_t> all_reduce_fusion_split_sizes_;
}; };
......
...@@ -183,6 +183,10 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -183,6 +183,10 @@ PYBIND11_MODULE(_c_expression, m) {
"Set all reduce fusion split sizes.") "Set all reduce fusion split sizes.")
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes, .def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes,
"Get all reduce fusion split sizes.") "Get all reduce fusion split sizes.")
.def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
"Set enable/disable all reduce fusion.")
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
"Get enable/disable all reduce fusion.")
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.") .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set, .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
"Get parameter broadcast is set.") "Get parameter broadcast is set.")
......
...@@ -259,6 +259,23 @@ class _AutoParallelContext: ...@@ -259,6 +259,23 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_all_reduce_fusion_split_sizes() return self._context_handle.get_all_reduce_fusion_split_sizes()
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
"""
Set enable/disable all reduce fusion.
Args:
enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
"""
self.check_context_handle()
if not isinstance(enable_all_reduce_fusion, bool):
raise TypeError('enable_all_reduce_fusion is invalid type')
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
def get_enable_all_reduce_fusion(self):
"""Get all reduce fusion flag."""
self.check_context_handle()
return self._context_handle.get_enable_all_reduce_fusion()
def get_device_num_is_set(self): def get_device_num_is_set(self):
"""Get device number is set or not.""" """Get device number is set or not."""
self.check_context_handle() self.check_context_handle()
......
...@@ -117,6 +117,7 @@ _cast_before_mirror = None ...@@ -117,6 +117,7 @@ _cast_before_mirror = None
_loss_repeated_mean = None _loss_repeated_mean = None
_communication_backend = None _communication_backend = None
_has_checkpointed = False _has_checkpointed = False
_enable_all_reduce_fusion = None
def _checkpoint_auto_parallel_context(): def _checkpoint_auto_parallel_context():
...@@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context(): ...@@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context():
global _cast_before_mirror global _cast_before_mirror
global _loss_repeated_mean global _loss_repeated_mean
global _communication_backend global _communication_backend
global _enable_all_reduce_fusion
_parallel_mode = auto_parallel_context().get_parallel_mode() _parallel_mode = auto_parallel_context().get_parallel_mode()
_device_num = _get_device_num() _device_num = _get_device_num()
_global_rank = _get_global_rank() _global_rank = _get_global_rank()
...@@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context(): ...@@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context():
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror() _cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean() _loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
_communication_backend = auto_parallel_context().get_communication_backend() _communication_backend = auto_parallel_context().get_communication_backend()
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
_has_checkpointed = True _has_checkpointed = True
...@@ -154,10 +157,12 @@ def _restore_auto_parallel_context(): ...@@ -154,10 +157,12 @@ def _restore_auto_parallel_context():
global _cast_before_mirror global _cast_before_mirror
global _loss_repeated_mean global _loss_repeated_mean
global _communication_backend global _communication_backend
global _enable_all_reduce_fusion
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank, _set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean, parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean) cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
auto_parallel_context().set_communication_backend(_communication_backend) auto_parallel_context().set_communication_backend(_communication_backend)
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
def _reset_checkpoint_auto_parallel_context(): def _reset_checkpoint_auto_parallel_context():
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
import mindspore.context as context import mindspore.context as context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._utils import _reset_op_id from mindspore.parallel._utils import _reset_op_id
def setup_module(module): def setup_module(module):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
_reset_op_id() _reset_op_id()
......
...@@ -23,7 +23,7 @@ from tests.dataset_mock import MindData ...@@ -23,7 +23,7 @@ from tests.dataset_mock import MindData
from mindspore import context from mindspore import context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
class Dataset(MindData): class Dataset(MindData):
...@@ -105,6 +105,7 @@ def train_common(net): ...@@ -105,6 +105,7 @@ def train_common(net):
epoch_size = 2 epoch_size = 2
device_num=4 device_num=4
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, parameter_broadcast=False) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, parameter_broadcast=False)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册