提交 4ff41808 编写于 作者: L lirongzhen 提交者: lirongzhen1

enable/disable allreduce_fusion

上级 9edc69af
......@@ -31,10 +31,11 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
MS_EXCEPTION_IF_NULL(optimizer);
MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance());
std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode();
bool enable_all_reduce_fusion = ParallelContext::GetInstance()->enable_all_reduce_fusion();
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) ||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
return changes;
}
......
......@@ -55,6 +55,7 @@ void ParallelContext::Reset() {
parallel_mode_ = STAND_ALONE;
parameter_broadcast_ = false;
parameter_broadcast_is_set_ = false;
enable_all_reduce_fusion_ = false;
}
void ParallelContext::set_device_num(int32_t device_num) {
......
......@@ -80,6 +80,10 @@ class ParallelContext {
const std::vector<uint32_t> all_reduce_fusion_split_indices() const;
void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes);
const std::vector<uint32_t> all_reduce_fusion_split_sizes() const;
void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) {
enable_all_reduce_fusion_ = enable_all_reduce_fusion;
}
bool enable_all_reduce_fusion() const { return enable_all_reduce_fusion_; }
void Reset();
......@@ -98,6 +102,7 @@ class ParallelContext {
bool device_num_is_set_;
bool global_rank_is_set_;
bool parameter_broadcast_is_set_;
bool enable_all_reduce_fusion_;
std::vector<uint32_t> all_reduce_fusion_split_indices_;
std::vector<uint32_t> all_reduce_fusion_split_sizes_;
};
......
......@@ -183,6 +183,10 @@ PYBIND11_MODULE(_c_expression, m) {
"Set all reduce fusion split sizes.")
.def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes,
"Get all reduce fusion split sizes.")
.def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
"Set enable/disable all reduce fusion.")
.def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
"Get enable/disable all reduce fusion.")
.def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
.def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
"Get parameter broadcast is set.")
......
......@@ -259,6 +259,23 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_all_reduce_fusion_split_sizes()
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
"""
Set enable/disable all reduce fusion.
Args:
enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
"""
self.check_context_handle()
if not isinstance(enable_all_reduce_fusion, bool):
raise TypeError('enable_all_reduce_fusion is invalid type')
self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
def get_enable_all_reduce_fusion(self):
"""Get all reduce fusion flag."""
self.check_context_handle()
return self._context_handle.get_enable_all_reduce_fusion()
def get_device_num_is_set(self):
"""Get device number is set or not."""
self.check_context_handle()
......
......@@ -117,6 +117,7 @@ _cast_before_mirror = None
_loss_repeated_mean = None
_communication_backend = None
_has_checkpointed = False
_enable_all_reduce_fusion = None
def _checkpoint_auto_parallel_context():
......@@ -133,6 +134,7 @@ def _checkpoint_auto_parallel_context():
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_parallel_mode = auto_parallel_context().get_parallel_mode()
_device_num = _get_device_num()
_global_rank = _get_global_rank()
......@@ -141,6 +143,7 @@ def _checkpoint_auto_parallel_context():
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
_communication_backend = auto_parallel_context().get_communication_backend()
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
_has_checkpointed = True
......@@ -154,10 +157,12 @@ def _restore_auto_parallel_context():
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
auto_parallel_context().set_communication_backend(_communication_backend)
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
def _reset_checkpoint_auto_parallel_context():
......
......@@ -13,10 +13,12 @@
# limitations under the License.
import mindspore.context as context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.parallel._utils import _reset_op_id
def setup_module(module):
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
_reset_op_id()
......
......@@ -23,7 +23,7 @@ from tests.dataset_mock import MindData
from mindspore import context
from mindspore.common.api import _executor
from mindspore.parallel import _cost_model_context as cost_model_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
class Dataset(MindData):
......@@ -105,6 +105,7 @@ def train_common(net):
epoch_size = 2
device_num=4
context.reset_auto_parallel_context()
auto_parallel_context().set_enable_all_reduce_fusion(enable_all_reduce_fusion=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_num, parameter_broadcast=False)
context.set_context(mode=context.GRAPH_MODE)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册