提交 d4cfe55c 编写于 作者: Y yao_yf

rename mirror_mean to gradients_mean

上级 bc4c5afc
...@@ -45,7 +45,7 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() { ...@@ -45,7 +45,7 @@ std::shared_ptr<ParallelContext> ParallelContext::GetInstance() {
ParallelContext::ParallelContext() { Reset(); } ParallelContext::ParallelContext() { Reset(); }
void ParallelContext::Reset() { void ParallelContext::Reset() {
mirror_mean_ = false; gradients_mean_ = false;
full_batch_ = false; full_batch_ = false;
gradient_fp32_sync_ = true; gradient_fp32_sync_ = true;
loss_repeated_mean_ = true; loss_repeated_mean_ = true;
...@@ -74,7 +74,7 @@ void ParallelContext::set_global_rank(int32_t global_rank) { ...@@ -74,7 +74,7 @@ void ParallelContext::set_global_rank(int32_t global_rank) {
global_rank_is_set_ = true; global_rank_is_set_ = true;
} }
void ParallelContext::set_mirror_mean(bool mirror_mean) { mirror_mean_ = mirror_mean; } void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_ = gradients_mean; }
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; } void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
......
...@@ -52,8 +52,8 @@ class ParallelContext { ...@@ -52,8 +52,8 @@ class ParallelContext {
static std::shared_ptr<ParallelContext> GetInstance(); static std::shared_ptr<ParallelContext> GetInstance();
void set_mirror_mean(bool mirror_mean); void set_gradients_mean(bool gradients_mean);
bool mirror_mean() const { return mirror_mean_; } bool gradients_mean() const { return gradients_mean_; }
void set_full_batch(bool full_batch); void set_full_batch(bool full_batch);
bool full_batch() const { return full_batch_; } bool full_batch() const { return full_batch_; }
...@@ -107,7 +107,7 @@ class ParallelContext { ...@@ -107,7 +107,7 @@ class ParallelContext {
private: private:
ParallelContext(); ParallelContext();
static std::shared_ptr<ParallelContext> inst_context_; static std::shared_ptr<ParallelContext> inst_context_;
bool mirror_mean_; bool gradients_mean_;
bool full_batch_; bool full_batch_;
bool gradient_fp32_sync_; bool gradient_fp32_sync_;
bool loss_repeated_mean_; bool loss_repeated_mean_;
......
...@@ -251,7 +251,7 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { ...@@ -251,7 +251,7 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num;
} }
OperatorVector op_for_weight; OperatorVector op_for_weight;
bool mean_flag = ParallelContext::GetInstance()->mirror_mean(); bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
OperatorName operator_name = MIRROR_OPERATOR; OperatorName operator_name = MIRROR_OPERATOR;
ValuePtr attr0_value = MakeValue(group_name); ValuePtr attr0_value = MakeValue(group_name);
......
...@@ -2488,7 +2488,7 @@ Status ParallelInit() { ...@@ -2488,7 +2488,7 @@ Status ParallelInit() {
} }
MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank MS_LOG(INFO) << "The parallel context: dev num: " << device_num << ", global rank: " << global_rank
<< ", backend: " << backend << ", mirror_mean: " << ParallelContext::GetInstance()->mirror_mean() << ", backend: " << backend << ", gradients_mean: " << ParallelContext::GetInstance()->gradients_mean()
<< ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync(); << ", gradient_fp32_sync: " << ParallelContext::GetInstance()->gradient_fp32_sync();
return SUCCESS; return SUCCESS;
} }
......
...@@ -113,8 +113,8 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -113,8 +113,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("get_global_rank", &ParallelContext::global_rank, "Get global rank.") .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
.def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.") .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
.def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.") .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
.def("get_mirror_mean", &ParallelContext::mirror_mean, "Get mirror mean.") .def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
.def("set_mirror_mean", &ParallelContext::set_mirror_mean, "Set mirror mean.") .def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")
.def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.") .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
.def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.") .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
.def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.") .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
......
...@@ -323,7 +323,7 @@ def _context(): ...@@ -323,7 +323,7 @@ def _context():
return _k_context return _k_context
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def set_auto_parallel_context(**kwargs): def set_auto_parallel_context(**kwargs):
...@@ -341,8 +341,8 @@ def set_auto_parallel_context(**kwargs): ...@@ -341,8 +341,8 @@ def set_auto_parallel_context(**kwargs):
Args: Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support mirror_mean. Default: False. "stand_alone" do not support gradients_mean. Default: False.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True..
"stand_alone", "data_parallel" and "hybrid_parallel" do not support "stand_alone", "data_parallel" and "hybrid_parallel" do not support
gradient_fp32_sync. Default: True. gradient_fp32_sync. Default: True.
...@@ -380,7 +380,7 @@ def set_auto_parallel_context(**kwargs): ...@@ -380,7 +380,7 @@ def set_auto_parallel_context(**kwargs):
Examples: Examples:
>>> context.set_auto_parallel_context(device_num=8) >>> context.set_auto_parallel_context(device_num=8)
>>> context.set_auto_parallel_context(global_rank=0) >>> context.set_auto_parallel_context(global_rank=0)
>>> context.set_auto_parallel_context(mirror_mean=True) >>> context.set_auto_parallel_context(gradients_mean=True)
>>> context.set_auto_parallel_context(gradient_fp32_sync=False) >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
>>> context.set_auto_parallel_context(parallel_mode="auto_parallel") >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
>>> context.set_auto_parallel_context(parameter_broadcast=False) >>> context.set_auto_parallel_context(parameter_broadcast=False)
...@@ -412,7 +412,7 @@ def reset_auto_parallel_context(): ...@@ -412,7 +412,7 @@ def reset_auto_parallel_context():
- device_num: 1. - device_num: 1.
- global_rank: 0. - global_rank: 0.
- mirror_mean: False. - gradients_mean: False.
- gradient_fp32_sync: True. - gradient_fp32_sync: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Cell_wrapper.""" """Cell_wrapper."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from ...common import dtype as mstype from ...common import dtype as mstype
...@@ -190,7 +190,7 @@ class TrainOneStepCell(Cell): ...@@ -190,7 +190,7 @@ class TrainOneStepCell(Cell):
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -279,7 +279,7 @@ class DistributedGradReducer(Cell): ...@@ -279,7 +279,7 @@ class DistributedGradReducer(Cell):
>>> ParallelMode.HYBRID_PARALLEL]: >>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True >>> self.reducer_flag = True
>>> if self.reducer_flag: >>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean") >>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set(): >>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num") >>> degree = context.get_auto_parallel_context("device_num")
>>> else: >>> else:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import mindspore.context as context import mindspore.context as context
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from ..cell import Cell from ..cell import Cell
from ...common import Tensor, RowTensor from ...common import Tensor, RowTensor
from ...common.parameter import Parameter from ...common.parameter import Parameter
...@@ -231,7 +231,7 @@ class TrainOneStepWithLossScaleCell(Cell): ...@@ -231,7 +231,7 @@ class TrainOneStepWithLossScaleCell(Cell):
self.grad_reducer = F.identity self.grad_reducer = F.identity
self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL] self.reducer_flag = self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE
......
...@@ -95,23 +95,23 @@ class _AutoParallelContext: ...@@ -95,23 +95,23 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_global_rank() return self._context_handle.get_global_rank()
def set_mirror_mean(self, mirror_mean): def set_gradients_mean(self, gradients_mean):
""" """
Set mirror_mean flag. Set gradients_mean flag.
Note: Note:
If mirror_mean is true, it will insert a div operator after parameter gradients allreduce. If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
Args: Args:
mirror_mean (bool): The mirror_mean flag. gradients_mean (bool): The gradients_mean flag.
""" """
self.check_context_handle() self.check_context_handle()
self._context_handle.set_mirror_mean(mirror_mean) self._context_handle.set_gradients_mean(gradients_mean)
def get_mirror_mean(self): def get_gradients_mean(self):
"""Get mirror_mean flag.""" """Get gradients_mean flag."""
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_mirror_mean() return self._context_handle.get_gradients_mean()
def set_gradient_fp32_sync(self, gradient_fp32_sync): def set_gradient_fp32_sync(self, gradient_fp32_sync):
""" """
...@@ -453,7 +453,7 @@ def auto_parallel_context(): ...@@ -453,7 +453,7 @@ def auto_parallel_context():
_set_auto_parallel_context_func_map = { _set_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().set_device_num, "device_num": auto_parallel_context().set_device_num,
"global_rank": auto_parallel_context().set_global_rank, "global_rank": auto_parallel_context().set_global_rank,
"mirror_mean": auto_parallel_context().set_mirror_mean, "gradients_mean": auto_parallel_context().set_gradients_mean,
"gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
"parallel_mode": auto_parallel_context().set_parallel_mode, "parallel_mode": auto_parallel_context().set_parallel_mode,
...@@ -468,7 +468,7 @@ _set_auto_parallel_context_func_map = { ...@@ -468,7 +468,7 @@ _set_auto_parallel_context_func_map = {
_get_auto_parallel_context_func_map = { _get_auto_parallel_context_func_map = {
"device_num": auto_parallel_context().get_device_num, "device_num": auto_parallel_context().get_device_num,
"global_rank": auto_parallel_context().get_global_rank, "global_rank": auto_parallel_context().get_global_rank,
"mirror_mean": auto_parallel_context().get_mirror_mean, "gradients_mean": auto_parallel_context().get_gradients_mean,
"gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
"loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
"parallel_mode": auto_parallel_context().get_parallel_mode, "parallel_mode": auto_parallel_context().get_parallel_mode,
...@@ -480,7 +480,7 @@ _get_auto_parallel_context_func_map = { ...@@ -480,7 +480,7 @@ _get_auto_parallel_context_func_map = {
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer}
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, gradient_fp32_sync=bool, @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
...@@ -495,7 +495,7 @@ def _set_auto_parallel_context(**kwargs): ...@@ -495,7 +495,7 @@ def _set_auto_parallel_context(**kwargs):
Args: Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
mirror_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
calculations. Default: True. calculations. Default: True.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
...@@ -562,7 +562,7 @@ def _reset_auto_parallel_context(): ...@@ -562,7 +562,7 @@ def _reset_auto_parallel_context():
- device_num: 1. - device_num: 1.
- global_rank: 0. - global_rank: 0.
- mirror_mean: False. - gradients_mean: False.
- gradient_fp32_sync: True. - gradient_fp32_sync: True.
- parallel_mode: "stand_alone". - parallel_mode: "stand_alone".
- parameter_broadcast: False. - parameter_broadcast: False.
......
...@@ -88,9 +88,9 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): ...@@ -88,9 +88,9 @@ def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
lst.append(Tensor(scaling_sens, mstype.float32)) lst.append(Tensor(scaling_sens, mstype.float32))
return tuple(lst) return tuple(lst)
def _get_mirror_mean(): def _get_gradients_mean():
"""Get if using mirror_mean.""" """Get if using gradients_mean."""
return auto_parallel_context().get_mirror_mean() return auto_parallel_context().get_gradients_mean()
def _get_device_num(): def _get_device_num():
......
...@@ -66,7 +66,7 @@ def model_fine_tune(flags, train_net, fix_weight_layer): ...@@ -66,7 +66,7 @@ def model_fine_tune(flags, train_net, fix_weight_layer):
para.requires_grad = False para.requires_grad = False
if __name__ == "__main__": if __name__ == "__main__":
if args_opt.distribute == "true": if args_opt.distribute == "true":
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init() init()
args_opt.base_size = config.crop_size args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size args_opt.crop_size = config.crop_size
......
...@@ -54,7 +54,7 @@ if __name__ == '__main__': ...@@ -54,7 +54,7 @@ if __name__ == '__main__':
rank = args_opt.rank_id rank = args_opt.rank_id
device_num = args_opt.device_num device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True)
init() init()
else: else:
rank = 0 rank = 0
......
...@@ -78,7 +78,7 @@ if __name__ == '__main__': ...@@ -78,7 +78,7 @@ if __name__ == '__main__':
if device_num > 1: if device_num > 1:
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
init() init()
elif device_target == "GPU": elif device_target == "GPU":
init() init()
...@@ -86,7 +86,7 @@ if __name__ == '__main__': ...@@ -86,7 +86,7 @@ if __name__ == '__main__':
if device_num > 1: if device_num > 1:
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
else: else:
raise ValueError("Unsupported platform.") raise ValueError("Unsupported platform.")
......
...@@ -58,7 +58,7 @@ if __name__ == '__main__': ...@@ -58,7 +58,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size() cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
cfg.rank = 0 cfg.rank = 0
cfg.group_size = 1 cfg.group_size = 1
......
...@@ -58,7 +58,7 @@ if __name__ == '__main__': ...@@ -58,7 +58,7 @@ if __name__ == '__main__':
rank = args_opt.rank_id rank = args_opt.rank_id
device_num = args_opt.device_num device_num = args_opt.device_num
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True)
init() init()
else: else:
rank = 0 rank = 0
......
...@@ -39,7 +39,7 @@ def context_device_init(config): ...@@ -39,7 +39,7 @@ def context_device_init(config):
init("nccl") init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
elif config.platform == "Ascend": elif config.platform == "Ascend":
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id,
...@@ -47,7 +47,7 @@ def context_device_init(config): ...@@ -47,7 +47,7 @@ def context_device_init(config):
if config.run_distribute: if config.run_distribute:
context.set_auto_parallel_context(device_num=config.rank_size, context.set_auto_parallel_context(device_num=config.rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init() init()
else: else:
......
...@@ -57,7 +57,7 @@ elif args_opt.device_target == "GPU": ...@@ -57,7 +57,7 @@ elif args_opt.device_target == "GPU":
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", device_target="GPU",
save_graphs=False) save_graphs=False)
...@@ -77,7 +77,7 @@ def train_on_ascend(): ...@@ -77,7 +77,7 @@ def train_on_ascend():
context.set_auto_parallel_context(device_num=rank_size, context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, parameter_broadcast=True,
mirror_mean=True) gradients_mean=True)
init() init()
# define network # define network
......
...@@ -55,7 +55,7 @@ if args_opt.device_target == "GPU": ...@@ -55,7 +55,7 @@ if args_opt.device_target == "GPU":
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
else: else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")
......
...@@ -24,7 +24,7 @@ import mindspore.ops.composite as C ...@@ -24,7 +24,7 @@ import mindspore.ops.composite as C
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode from mindspore.train.parallel_utils import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
GRADIENT_CLIP_TYPE = 1 GRADIENT_CLIP_TYPE = 1
...@@ -921,7 +921,7 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell): ...@@ -921,7 +921,7 @@ class NASNetAMobileTrainOneStepWithClipGradient(nn.Cell):
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -58,7 +58,7 @@ if __name__ == '__main__': ...@@ -58,7 +58,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size() cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
cfg.rank = 0 cfg.rank = 0
cfg.group_size = 1 cfg.group_size = 1
......
...@@ -76,7 +76,7 @@ if __name__ == '__main__': ...@@ -76,7 +76,7 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
else: else:
...@@ -86,7 +86,7 @@ if __name__ == '__main__': ...@@ -86,7 +86,7 @@ if __name__ == '__main__':
else: else:
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
if args_opt.net == "resnet50": if args_opt.net == "resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
......
...@@ -76,11 +76,11 @@ if __name__ == '__main__': ...@@ -76,11 +76,11 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=rank_size, context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, parameter_broadcast=True,
mirror_mean=True) gradients_mean=True)
init() init()
context.set_auto_parallel_context(device_num=args_opt.device_num, context.set_auto_parallel_context(device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
# define network # define network
......
...@@ -129,7 +129,7 @@ class DistributedGradReducerThor(Cell): ...@@ -129,7 +129,7 @@ class DistributedGradReducerThor(Cell):
>>> ParallelMode.HYBRID_PARALLEL]: >>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True >>> self.reducer_flag = True
>>> if self.reducer_flag: >>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean") >>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set(): >>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num") >>> degree = context.get_auto_parallel_context("device_num")
>>> else: >>> else:
......
...@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype ...@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor
_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
...@@ -85,7 +85,7 @@ class THOR_GPU(Optimizer): ...@@ -85,7 +85,7 @@ class THOR_GPU(Optimizer):
self.assign = P.Assign() self.assign = P.Assign()
self.mul = P.Mul() self.mul = P.Mul()
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree) self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree) self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)
...@@ -191,7 +191,7 @@ class THOR(Optimizer): ...@@ -191,7 +191,7 @@ class THOR(Optimizer):
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0] 1.0]
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
......
...@@ -94,7 +94,7 @@ if __name__ == '__main__': ...@@ -94,7 +94,7 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
...@@ -105,7 +105,7 @@ if __name__ == '__main__': ...@@ -105,7 +105,7 @@ if __name__ == '__main__':
else: else:
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107]) auto_parallel_context().set_all_reduce_fusion_split_indices([107])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
......
...@@ -117,7 +117,7 @@ def test(cloud_args=None): ...@@ -117,7 +117,7 @@ def test(cloud_args=None):
args.group_size = get_group_size() args.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
args.rank = 0 args.rank = 0
args.group_size = 1 args.group_size = 1
......
...@@ -179,7 +179,7 @@ def train(cloud_args=None): ...@@ -179,7 +179,7 @@ def train(cloud_args=None):
args.group_size = get_group_size() args.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
args.rank = 0 args.rank = 0
args.group_size = 1 args.group_size = 1
......
...@@ -60,7 +60,7 @@ if __name__ == '__main__': ...@@ -60,7 +60,7 @@ if __name__ == '__main__':
cfg.group_size = get_group_size() cfg.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
cfg.rank = 0 cfg.rank = 0
cfg.group_size = 1 cfg.group_size = 1
......
...@@ -392,7 +392,7 @@ class TrainingWrapper(nn.Cell): ...@@ -392,7 +392,7 @@ class TrainingWrapper(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:
......
...@@ -60,7 +60,7 @@ def main(): ...@@ -60,7 +60,7 @@ def main():
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
init() init()
rank = args_opt.device_id % device_num rank = args_opt.device_id % device_num
......
...@@ -140,7 +140,7 @@ if __name__ == '__main__': ...@@ -140,7 +140,7 @@ if __name__ == '__main__':
device_num = args.group_size device_num = args.group_size
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, mirror_mean=True) parameter_broadcast=True, gradients_mean=True)
else: else:
context.set_context(device_id=args.device_id) context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""Automatic differentiation with grad clip.""" """Automatic differentiation with grad clip."""
import numpy as np import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode) _get_parallel_mode)
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
...@@ -93,7 +93,7 @@ class TrainOneStepCellWithGradClip(Cell): ...@@ -93,7 +93,7 @@ class TrainOneStepCellWithGradClip(Cell):
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -64,7 +64,7 @@ if __name__ == '__main__': ...@@ -64,7 +64,7 @@ if __name__ == '__main__':
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
else: else:
device_num = 1 device_num = 1
rank = 0 rank = 0
......
...@@ -255,7 +255,7 @@ def test(): ...@@ -255,7 +255,7 @@ def test():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
args.logger.info('Creating Network....') args.logger.info('Creating Network....')
network = YOLOV3DarkNet53(is_training=False) network = YOLOV3DarkNet53(is_training=False)
......
...@@ -421,7 +421,7 @@ class TrainingWrapper(nn.Cell): ...@@ -421,7 +421,7 @@ class TrainingWrapper(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:
......
...@@ -178,7 +178,7 @@ def train(): ...@@ -178,7 +178,7 @@ def train():
else: else:
parallel_mode = ParallelMode.STAND_ALONE parallel_mode = ParallelMode.STAND_ALONE
degree = 1 degree = 1
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
network = YOLOV3DarkNet53(is_training=True) network = YOLOV3DarkNet53(is_training=True)
# default is kaiming-normal # default is kaiming-normal
......
...@@ -254,7 +254,7 @@ def test(): ...@@ -254,7 +254,7 @@ def test():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
parallel_mode = ParallelMode.STAND_ALONE parallel_mode = ParallelMode.STAND_ALONE
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=1) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=1)
args.logger.info('Creating Network....') args.logger.info('Creating Network....')
network = YOLOV3DarkNet53(is_training=False) network = YOLOV3DarkNet53(is_training=False)
......
...@@ -421,7 +421,7 @@ class TrainingWrapper(nn.Cell): ...@@ -421,7 +421,7 @@ class TrainingWrapper(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:
......
...@@ -162,7 +162,7 @@ def train(): ...@@ -162,7 +162,7 @@ def train():
else: else:
parallel_mode = ParallelMode.STAND_ALONE parallel_mode = ParallelMode.STAND_ALONE
degree = 1 degree = 1
context.set_auto_parallel_context(parallel_mode=parallel_mode, mirror_mean=True, device_num=degree) context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
network = YOLOV3DarkNet53(is_training=True) network = YOLOV3DarkNet53(is_training=True)
# default is kaiming-normal # default is kaiming-normal
......
...@@ -656,7 +656,7 @@ class TrainingWrapper(nn.Cell): ...@@ -656,7 +656,7 @@ class TrainingWrapper(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:
......
...@@ -92,7 +92,7 @@ def main(): ...@@ -92,7 +92,7 @@ def main():
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
init() init()
rank = args_opt.device_id % device_num rank = args_opt.device_id % device_num
......
...@@ -85,7 +85,7 @@ def run_pretrain(): ...@@ -85,7 +85,7 @@ def run_pretrain():
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.num_hidden_layers == 12:
......
...@@ -66,7 +66,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -66,7 +66,7 @@ class BertFinetuneCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
...@@ -167,7 +167,7 @@ class BertSquadCell(nn.Cell): ...@@ -167,7 +167,7 @@ class BertSquadCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
......
...@@ -283,7 +283,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -283,7 +283,7 @@ class BertTrainOneStepCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -87,7 +87,7 @@ def run_pretrain(): ...@@ -87,7 +87,7 @@ def run_pretrain():
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.num_hidden_layers == 12:
......
...@@ -301,7 +301,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -301,7 +301,7 @@ class BertTrainOneStepCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -129,7 +129,7 @@ class DistributedGradReducerThor(Cell): ...@@ -129,7 +129,7 @@ class DistributedGradReducerThor(Cell):
>>> ParallelMode.HYBRID_PARALLEL]: >>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True >>> self.reducer_flag = True
>>> if self.reducer_flag: >>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean") >>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set(): >>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num") >>> degree = context.get_auto_parallel_context("device_num")
>>> else: >>> else:
......
...@@ -20,7 +20,7 @@ from mindspore.common.parameter import ParameterTuple ...@@ -20,7 +20,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from .grad_reducer_thor import DistributedGradReducerThor from .grad_reducer_thor import DistributedGradReducerThor
momentum_opt = C.MultitypeFuncGraph("momentum_opt") momentum_opt = C.MultitypeFuncGraph("momentum_opt")
...@@ -83,7 +83,7 @@ class THOR(Optimizer): ...@@ -83,7 +83,7 @@ class THOR(Optimizer):
self.damping = damping self.damping = damping
self.one = Tensor(1, mstype.int32) self.one = Tensor(1, mstype.int32)
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False) self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree) self.grad_reducer_g = DistributedGradReducerThor(self.parameters, 3, mean, degree)
......
...@@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter ...@@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from .transformer import Transformer from .transformer import Transformer
from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients from .grad_clip import GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE, ClipGradients
...@@ -251,7 +251,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -251,7 +251,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
......
...@@ -234,7 +234,7 @@ def _setup_parallel_env(platform): ...@@ -234,7 +234,7 @@ def _setup_parallel_env(platform):
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=MultiAscend.get_group_size(), device_num=MultiAscend.get_group_size(),
parameter_broadcast=True, parameter_broadcast=True,
mirror_mean=True gradients_mean=True
) )
......
...@@ -81,7 +81,7 @@ def run_general_distill(): ...@@ -81,7 +81,7 @@ def run_general_distill():
rank = D.get_rank() rank = D.get_rank()
save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num) device_num=device_num)
else: else:
rank = 0 rank = 0
......
...@@ -318,7 +318,7 @@ class BertTrainCell(nn.Cell): ...@@ -318,7 +318,7 @@ class BertTrainCell(nn.Cell):
self.grad_reducer = F.identity self.grad_reducer = F.identity
self.degree = 1 self.degree = 1
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
self.degree = get_group_size() self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
self.cast = P.Cast() self.cast = P.Cast()
...@@ -568,7 +568,7 @@ class BertEvaluationCell(nn.Cell): ...@@ -568,7 +568,7 @@ class BertEvaluationCell(nn.Cell):
self.grad_reducer = F.identity self.grad_reducer = F.identity
self.degree = 1 self.degree = 1
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
self.degree = get_group_size() self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
......
...@@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple ...@@ -23,7 +23,7 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
from mindspore import context from mindspore import context
from .transformer_model import TransformerModel from .transformer_model import TransformerModel
...@@ -168,7 +168,7 @@ class TransformerTrainOneStepCell(nn.Cell): ...@@ -168,7 +168,7 @@ class TransformerTrainOneStepCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
...@@ -256,7 +256,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): ...@@ -256,7 +256,7 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
......
...@@ -118,7 +118,7 @@ def run_transformer_train(): ...@@ -118,7 +118,7 @@ def run_transformer_train():
if args.distribute == "true": if args.distribute == "true":
device_num = args.device_num device_num = args.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
parameter_broadcast=True, device_num=device_num) parameter_broadcast=True, device_num=device_num)
D.init() D.init()
rank_id = args.device_id % device_num rank_id = args.device_id % device_num
......
...@@ -56,7 +56,7 @@ if __name__ == '__main__': ...@@ -56,7 +56,7 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init() init()
rank_id = int(os.environ.get('RANK_ID')) rank_id = int(os.environ.get('RANK_ID'))
elif args_opt.device_target == "GPU": elif args_opt.device_target == "GPU":
...@@ -65,7 +65,7 @@ if __name__ == '__main__': ...@@ -65,7 +65,7 @@ if __name__ == '__main__':
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) gradients_mean=True)
rank_id = get_rank() rank_id = get_rank()
else: else:
print("Unsupported device_target ", args_opt.device_target) print("Unsupported device_target ", args_opt.device_target)
......
...@@ -367,7 +367,7 @@ class TrainStepWrap(nn.Cell): ...@@ -367,7 +367,7 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL) ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
......
...@@ -147,8 +147,8 @@ if __name__ == "__main__": ...@@ -147,8 +147,8 @@ if __name__ == "__main__":
init() init()
if wide_deep_config.host_device_mix == 1: if wide_deep_config.host_device_mix == 1:
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
else: else:
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
train_and_eval(wide_deep_config) train_and_eval(wide_deep_config)
...@@ -119,7 +119,7 @@ if __name__ == "__main__": ...@@ -119,7 +119,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True)
init() init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size()) device_num=get_group_size())
train_and_eval(wide_deep_config) train_and_eval(wide_deep_config)
...@@ -119,7 +119,7 @@ if __name__ == "__main__": ...@@ -119,7 +119,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target)
init() init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size()) device_num=get_group_size())
train_and_eval(wide_deep_config) train_and_eval(wide_deep_config)
...@@ -554,7 +554,7 @@ class TrainStepWrap(nn.Cell): ...@@ -554,7 +554,7 @@ class TrainStepWrap(nn.Cell):
ParallelMode.HYBRID_PARALLEL): ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
self.grad_reducer_w = DistributedGradReducer( self.grad_reducer_w = DistributedGradReducer(
self.optimizer_w.parameters, mean, degree) self.optimizer_w.parameters, mean, degree)
......
...@@ -113,6 +113,6 @@ if __name__ == "__main__": ...@@ -113,6 +113,6 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True) save_graphs=True)
init() init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=get_group_size()) device_num=get_group_size())
train_and_eval(wide_and_deep_config) train_and_eval(wide_and_deep_config)
...@@ -34,7 +34,7 @@ from mindspore.context import ParallelMode ...@@ -34,7 +34,7 @@ from mindspore.context import ParallelMode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) context.set_context(device_id=int(os.getenv('DEVICE_ID')))
init() init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
np.random.seed(10) np.random.seed(10)
......
...@@ -31,7 +31,7 @@ from src.config import WideDeepConfig ...@@ -31,7 +31,7 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
init() init()
......
...@@ -24,7 +24,7 @@ from mindspore.nn.optim import Adam, FTRL ...@@ -24,7 +24,7 @@ from mindspore.nn.optim import Adam, FTRL
# from mindspore.nn.metrics import Metric # from mindspore.nn.metrics import Metric
from mindspore.common.initializer import Uniform, initializer from mindspore.common.initializer import Uniform, initializer
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size from mindspore.communication.management import get_group_size
...@@ -299,7 +299,7 @@ class TrainStepWrap(nn.Cell): ...@@ -299,7 +299,7 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL, self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
ParallelMode.HYBRID_PARALLEL) ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag: if self.reducer_flag:
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree) self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree) self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
......
...@@ -30,7 +30,7 @@ from src.config import WideDeepConfig ...@@ -30,7 +30,7 @@ from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init() init()
......
...@@ -656,7 +656,7 @@ class TrainingWrapper(nn.Cell): ...@@ -656,7 +656,7 @@ class TrainingWrapper(nn.Cell):
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True self.reducer_flag = True
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set(): if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num") degree = context.get_auto_parallel_context("device_num")
else: else:
......
...@@ -78,7 +78,7 @@ def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32): ...@@ -78,7 +78,7 @@ def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32):
def test_lenet_nccl(): def test_lenet_nccl():
context.set_auto_parallel_context(parallel_mode="data_parallel", mirror_mean=True, device_num=get_group_size()) context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True, device_num=get_group_size())
net = LeNet() net = LeNet()
net.set_train() net.set_train()
......
...@@ -279,7 +279,7 @@ class BertTrainOneStepCell(nn.Cell): ...@@ -279,7 +279,7 @@ class BertTrainOneStepCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
...@@ -61,7 +61,7 @@ class BertFinetuneCell(nn.Cell): ...@@ -61,7 +61,7 @@ class BertFinetuneCell(nn.Cell):
self.reducer_flag = True self.reducer_flag = True
self.grad_reducer = None self.grad_reducer = None
if self.reducer_flag: if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean") mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size() degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
......
...@@ -130,7 +130,7 @@ class DistributedGradReducerThor(Cell): ...@@ -130,7 +130,7 @@ class DistributedGradReducerThor(Cell):
>>> ParallelMode.HYBRID_PARALLEL]: >>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True >>> self.reducer_flag = True
>>> if self.reducer_flag: >>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("mirror_mean") >>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set(): >>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num") >>> degree = context.get_auto_parallel_context("device_num")
>>> else: >>> else:
......
...@@ -20,7 +20,7 @@ from mindspore.common.parameter import ParameterTuple ...@@ -20,7 +20,7 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from .grad_reducer_thor import DistributedGradReducerThor from .grad_reducer_thor import DistributedGradReducerThor
...@@ -87,7 +87,7 @@ class THOR(Optimizer): ...@@ -87,7 +87,7 @@ class THOR(Optimizer):
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0] 1.0]
mean = _get_mirror_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
......
...@@ -137,7 +137,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): ...@@ -137,7 +137,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
init() init()
...@@ -240,7 +240,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): ...@@ -240,7 +240,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
......
...@@ -97,7 +97,8 @@ if __name__ == "__main__": ...@@ -97,7 +97,8 @@ if __name__ == "__main__":
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
if device_target == "GPU": if device_target == "GPU":
context.set_auto_parallel_context(parallel_mode="data_parallel", mirror_mean=True, device_num=get_group_size()) context.set_auto_parallel_context(parallel_mode="data_parallel", gradients_mean=True,
device_num=get_group_size())
net_with_criterion = WithLossCell(network, criterion) net_with_criterion = WithLossCell(network, criterion)
train_network = TrainOneStepCell(net_with_criterion, net_opt) train_network = TrainOneStepCell(net_with_criterion, net_opt)
train_network.set_train() train_network.set_train()
......
...@@ -58,7 +58,7 @@ def test_data_parallel_dense(): ...@@ -58,7 +58,7 @@ def test_data_parallel_dense():
"""test_data_parallel_dense""" """test_data_parallel_dense"""
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=8) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8)
inp = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01) inp = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([32, 768]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = DenseMMNet() net = DenseMMNet()
......
...@@ -80,7 +80,7 @@ def test_lenet5_train_step_training_pynative(): ...@@ -80,7 +80,7 @@ def test_lenet5_train_step_training_pynative():
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
device_num=8, mirror_mean=True) device_num=8, gradients_mean=True)
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
DatasetLenet(predict, label, 2) DatasetLenet(predict, label, 2)
......
...@@ -97,7 +97,7 @@ def test_on_momentum(): ...@@ -97,7 +97,7 @@ def test_on_momentum():
def test_data_parallel_with_cast(): def test_data_parallel_with_cast():
"""test_data_parallel_with_cast""" """test_data_parallel_with_cast"""
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, device_num=8) context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=8)
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = LeNet5() net = LeNet5()
......
...@@ -46,7 +46,7 @@ class Net(nn.Cell): ...@@ -46,7 +46,7 @@ class Net(nn.Cell):
def test_dense_gen_graph(): def test_dense_gen_graph():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, mirror_mean=True, device_num=8) context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
init() init()
network = Net(512, 128) network = Net(512, 128)
......
...@@ -20,17 +20,17 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context ...@@ -20,17 +20,17 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
def test_set_auto_parallel_context(): def test_set_auto_parallel_context():
context.set_auto_parallel_context(device_num=4, global_rank=3, mirror_mean=True, gradient_fp32_sync=False, context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False,
parallel_mode="auto_parallel", parameter_broadcast=False) parallel_mode="auto_parallel", parameter_broadcast=False)
device_num = context.get_auto_parallel_context("device_num") device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank") global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean") gradients_mean = context.get_auto_parallel_context("gradients_mean")
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
assert device_num == 4 assert device_num == 4
assert global_rank == 3 assert global_rank == 3
assert mirror_mean assert gradients_mean
assert not gradient_fp32_sync assert not gradient_fp32_sync
assert parallel_mode == "auto_parallel" assert parallel_mode == "auto_parallel"
assert not parameter_broadcast assert not parameter_broadcast
...@@ -45,9 +45,9 @@ def test_set_auto_parallel_context(): ...@@ -45,9 +45,9 @@ def test_set_auto_parallel_context():
global_rank = auto_parallel_context().get_global_rank() global_rank = auto_parallel_context().get_global_rank()
assert global_rank == 4 assert global_rank == 4
auto_parallel_context().set_mirror_mean(True) auto_parallel_context().set_gradients_mean(True)
mirror_mean = auto_parallel_context().get_mirror_mean() gradients_mean = auto_parallel_context().get_gradients_mean()
assert mirror_mean assert gradients_mean
auto_parallel_context().set_gradient_fp32_sync(False) auto_parallel_context().set_gradient_fp32_sync(False)
gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync() gradient_fp32_sync = auto_parallel_context().get_gradient_fp32_sync()
...@@ -86,7 +86,7 @@ def test_reset_auto_parallel_context(): ...@@ -86,7 +86,7 @@ def test_reset_auto_parallel_context():
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
device_num = context.get_auto_parallel_context("device_num") device_num = context.get_auto_parallel_context("device_num")
global_rank = context.get_auto_parallel_context("global_rank") global_rank = context.get_auto_parallel_context("global_rank")
mirror_mean = context.get_auto_parallel_context("mirror_mean") gradients_mean = context.get_auto_parallel_context("gradients_mean")
gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync")
parallel_mode = context.get_auto_parallel_context("parallel_mode") parallel_mode = context.get_auto_parallel_context("parallel_mode")
parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast")
...@@ -94,7 +94,7 @@ def test_reset_auto_parallel_context(): ...@@ -94,7 +94,7 @@ def test_reset_auto_parallel_context():
parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set()
assert device_num == 1 assert device_num == 1
assert global_rank == 0 assert global_rank == 0
assert not mirror_mean assert not gradients_mean
assert gradient_fp32_sync assert gradient_fp32_sync
assert parallel_mode == "stand_alone" assert parallel_mode == "stand_alone"
assert not parameter_broadcast assert not parameter_broadcast
......
...@@ -65,7 +65,7 @@ def test_two_matmul(): ...@@ -65,7 +65,7 @@ def test_two_matmul():
out = self.matmul2(out, b) out = self.matmul2(out, b)
return out return out
context.set_auto_parallel_context(device_num=8, global_rank=0, mirror_mean=True) context.set_auto_parallel_context(device_num=8, global_rank=0, gradients_mean=True)
strategy1 = ((4, 2), (2, 1)) strategy1 = ((4, 2), (2, 1))
strategy2 = ((2, 4), (4, 1)) strategy2 = ((2, 4), (4, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
...@@ -90,7 +90,7 @@ def test_two_matmul_repeated_calculation1(): ...@@ -90,7 +90,7 @@ def test_two_matmul_repeated_calculation1():
out = self.matmul2(out, b) out = self.matmul2(out, b)
return out return out
context.set_auto_parallel_context(device_num=64, global_rank=5, mirror_mean=True) context.set_auto_parallel_context(device_num=64, global_rank=5, gradients_mean=True)
strategy1 = ((2, 4), (4, 8)) strategy1 = ((2, 4), (4, 8))
strategy2 = ((1, 1), (1, 1)) strategy2 = ((1, 1), (1, 1))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
......
...@@ -148,7 +148,7 @@ def test_compile_model_train_O2_parallel(): ...@@ -148,7 +148,7 @@ def test_compile_model_train_O2_parallel():
dataset_shapes = ((16, 16), (16, 16)) dataset_shapes = ((16, 16), (16, 16))
context.set_auto_parallel_context( context.set_auto_parallel_context(
global_rank=0, device_num=8, global_rank=0, device_num=8,
mirror_mean=True, parameter_broadcast=True, gradients_mean=True, parameter_broadcast=True,
parallel_mode=ParallelMode.DATA_PARALLEL) parallel_mode=ParallelMode.DATA_PARALLEL)
dataset = MindDataSet(dataset_types, dataset_shapes) dataset = MindDataSet(dataset_types, dataset_shapes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册