From d8bd5a09c4281c467e2576d140056c29b6826867 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Thu, 16 Apr 2020 21:13:45 -0400 Subject: [PATCH] add global batch normalization --- tests/ut/python/hccl_test/manage/api.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/ut/python/hccl_test/manage/api.py b/tests/ut/python/hccl_test/manage/api.py index b684df526..04ce7da6d 100644 --- a/tests/ut/python/hccl_test/manage/api.py +++ b/tests/ut/python/hccl_test/manage/api.py @@ -21,6 +21,7 @@ class Hccl(): _instance = None _rank_id = 0 _rank_size = 1 + _group_size = 4 def __init__(self): pass @@ -47,6 +48,10 @@ class Hccl(): def rank_size(self): return self._rank_size + @property + def group_size(self): + return self._group_size + @rank_size.setter def rank_size(self, size): self._rank_size = size @@ -68,7 +73,7 @@ def get_rank_size(group=None): def get_group_size(group=None): hccl = Hccl() if group is None: - return hccl.rank_size + return hccl.group_size if isinstance(group, str): return int(group.split("-")[0]) raise ValueError -- GitLab