提交 616b9ea3 编写于 作者: Z zhaojichen

add global batch normalization

上级 b5e98042
......@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_local_rank_size, get_rank
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell
......@@ -71,7 +71,7 @@ class _BatchNorm(Cell):
self.group = check_int_positive(group)
if self.group != 1:
self.rank_id = get_rank()
self.rank_size = get_local_rank_size()
self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)
......
......@@ -65,6 +65,14 @@ def get_rank_size(group=None):
return int(group.split("-")[0])
raise ValueError
def get_group_size(group=None):
hccl = Hccl()
if group is None:
return hccl.rank_size
if isinstance(group, str):
return int(group.split("-")[0])
raise ValueError
# pylint: disable=unused-argument
def get_world_rank_from_group_rank(group, group_rank_id):
return group_rank_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册