提交 d8bd5a09 编写于 作者: Z zhaojichen

add global batch normalization

上级 d2b04664
......@@ -21,6 +21,7 @@ class Hccl():
_instance = None
_rank_id = 0
_rank_size = 1
_group_size = 4
def __init__(self):
pass
......@@ -47,6 +48,10 @@ class Hccl():
def rank_size(self):
return self._rank_size
@property
def group_size(self):
return self._group_size
@rank_size.setter
def rank_size(self, size):
self._rank_size = size
......@@ -68,7 +73,7 @@ def get_rank_size(group=None):
def get_group_size(group=None):
hccl = Hccl()
if group is None:
return hccl.rank_size
return hccl.group_size
if isinstance(group, str):
return int(group.split("-")[0])
raise ValueError
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册