From 178d7e5ed1d3a44231f5c73a1f7f68602d6fcd26 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Tue, 18 Oct 2022 20:05:34 +0800 Subject: [PATCH] add strategy group (#47021) --- .../fleet/base/orthogonal_strategy.py | 181 ++++++++++++++ .../distributed/fleet/base/strategy_group.py | 227 ++++++++++++++++++ .../tests/unittests/collective/CMakeLists.txt | 22 ++ .../collective/orthogonal_strategy.py | 45 ++++ .../unittests/collective/strategy_group.py | 95 ++++++++ .../collective/test_orthogonal_strategy.sh | 17 ++ .../collective/test_strategy_group.sh | 17 ++ .../tests/unittests/collective/testslist.csv | 2 + 8 files changed, 606 insertions(+) create mode 100644 python/paddle/distributed/fleet/base/orthogonal_strategy.py create mode 100644 python/paddle/distributed/fleet/base/strategy_group.py create mode 100644 python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py create mode 100644 python/paddle/fluid/tests/unittests/collective/strategy_group.py create mode 100644 python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh create mode 100644 python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh diff --git a/python/paddle/distributed/fleet/base/orthogonal_strategy.py b/python/paddle/distributed/fleet/base/orthogonal_strategy.py new file mode 100644 index 0000000000..d2ba6e4461 --- /dev/null +++ b/python/paddle/distributed/fleet/base/orthogonal_strategy.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import collections +import functools +import paddle.distributed as dist +from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase + + +class OrthogonalStrategy(): + """ + A hybrid of multiple distributed strategies. Strategies need to be orthogonal, means the ranks are organized like + a square if there are two strategies, a cube if there aree three strategies, etc. + + Args: + list_of_strategy(list): Stategy in the list should be represented as tuple, format as (strategy_name, degree, strategy_class). + fused_strategy_dict(dict, optional): Exist strategies can be fused to new strategy. Use the name of new strategy as key, a list of + strategy names you want to fuse as value. + + Returns: + The instance of strategy. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + from paddle.distributed.fleet.base.strategy_group import DPGroup, MPGroup, PPGroup + from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy + + dist.init_parallel_env() + strategy = OrthogonalStrategy([("dp", 2, DPGroup), ("mp", 2, MPGroup), ("pp", 2, PPGroup)], fused_strategy_dict={"check": ["mp", "pp"]}) + + """ + + def __init__(self, list_of_strategy, fused_strategy_dict={}): + self._list_of_strategy = list_of_strategy + self._fused_strategy_dict = fused_strategy_dict + self._rank = dist.get_rank() + self._rank_list_dict = {} + self._name_to_group_dict = {} + self._name_to_degree_dict = {} + self._list_of_strategy_name = [ + strategy[0] for strategy in list_of_strategy + ] + self._list_of_degree = [strategy[1] for strategy in list_of_strategy] + self._coordinate = collections.namedtuple('Coordinate', + self._list_of_strategy_name) + self._check_valid_strategy() + + ranges = [range(degree) for degree in self._list_of_degree] + list_of_coord = [ + self._coordinate(*coord) for coord in itertools.product(*ranges) + ] + self._coord_to_rank_dict = dict( + zip(list_of_coord, range(len(list_of_coord)))) + + for idx, strategy in enumerate(list_of_strategy): + strategy_name = strategy[0] + self._name_to_degree_dict[strategy_name] = strategy[1] + self._rank_list_dict[strategy_name] = self._calc_rank_list(idx) + self._name_to_group_dict[strategy_name] = strategy[2]( + self._rank_list_dict[strategy_name]) + + self._name_to_fused_group_dict = {} + self._create_fused_group() + + def strategy_group(self, name): + """ + Get strategy group with specific name. + + Args: + name: The name of strategy group + + Returns: + An instance of specific strategy group. + """ + assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format( + name) + return self._name_to_group_dict[name] + + def fused_strategy_group(self, name): + """ + Get fused strategy group with specific name. + + Args: + name: The name of fused strategy group + + Returns: + (StrategyGroupBase): An instance of strategy group. + """ + assert name in self._name_to_fused_group_dict, "Fused strategy group {} is not created.".format( + name) + return self._name_to_fused_group_dict[name] + + def rank_in_strategy(self, name): + """ + Get local rank in strategy group with specific name. + + Args: + name: The name of strategy group + + Returns: + (Integer): Local rank in specific strategy. + """ + assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format( + name) + return self._name_to_group_dict[name].group.rank + + def _check_valid_strategy(self): + assert len(self._list_of_strategy_name) == len( + set(self._list_of_strategy_name) + ), "Defined duplicated strategies: {}".format(list_of_strategy) + num_of_ranks = functools.reduce(lambda x, y: x * y, + self._list_of_degree) + assert num_of_ranks == dist.get_world_size( + ), "There are total {} ranks, but need {} ranks in this strategy.".format( + dist.get_world_size(), num_of_ranks) + for fused_strategy in self._fused_strategy_dict.values(): + for strategy in fused_strategy: + assert strategy in self._list_of_strategy_name, "Can not fuse strategy {} without defined previous.".format( + strategy) + + def _create_fused_group(self): + for name in self._fused_strategy_dict: + fused_strategy = self._fused_strategy_dict[name] + non_fused_strategy = list( + set(self._list_of_strategy_name).difference(fused_strategy)) + non_fused_ranges = [] + for strategy in non_fused_strategy: + non_fused_ranges.append( + range(self._name_to_degree_dict[strategy])) + fused_ranges = [] + for strategy in fused_strategy: + fused_ranges.append(range(self._name_to_degree_dict[strategy])) + + rank_list = [] + for non_fused_ranks in itertools.product(*non_fused_ranges): + coord_dict = {} + ranks = [] + for i, non_fused_rank in enumerate(non_fused_ranks): + coord_dict[non_fused_strategy[i]] = non_fused_rank + for fused_ranks in itertools.product(*fused_ranges): + for i, fused_rank in enumerate(fused_ranks): + coord_dict[fused_strategy[i]] = fused_rank + ranks.append(self._coord_to_rank_dict[self._coordinate( + **coord_dict)]) + rank_list.append(ranks) + self._name_to_fused_group_dict[name] = StrategyGroupBase(rank_list) + + def _calc_rank_list(self, strategy_axis): + ranges = [] + for idx, degree in enumerate(self._list_of_degree): + if idx == strategy_axis: + continue + ranges.append(range(degree)) + + rank_list = [] + for coord in itertools.product(*ranges): + ranks = [] + for val in range(self._list_of_degree[strategy_axis]): + coord_list = list(coord) + coord_list.insert(strategy_axis, val) + ranks.append( + self._coord_to_rank_dict[self._coordinate(*coord_list)]) + rank_list.append(ranks) + + return rank_list diff --git a/python/paddle/distributed/fleet/base/strategy_group.py b/python/paddle/distributed/fleet/base/strategy_group.py new file mode 100644 index 0000000000..94ab04b6c4 --- /dev/null +++ b/python/paddle/distributed/fleet/base/strategy_group.py @@ -0,0 +1,227 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist + + +class StrategyGroupBase(): + """ + The base class of communication group with distributed strategy. + + Args: + list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents + they are in the same communication group. + + Returns: + The instance of strategy group. + + Examples: + .. code-block:: python + + import paddle.distributed as dist + from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase + + dist.init_parallel_env() + strategy_group = dist.fleet.base.strategy_group.StrategyGroupBase([[0, 1], [2, 3]]) + print(strategy_group.world_size) # 2 + + """ + + def __init__(self, list_of_ranks): + assert dist.is_initialized( + ), "The global communication group need to be initialized." + assert len(list_of_ranks), "The list_of_ranks can not be empty." + self._rank = dist.get_rank() + self._list_of_ranks = list_of_ranks + self._group = self._create_group() + + @property + def world_size(self): + """ + The world size of communication group. + + Returns: + Integer if the world_size of each group are equal, or a list of world_size if they are not equal. + """ + world_size_list = [] + for ranks in self._list_of_ranks: + world_size_list.append(len(ranks)) + is_value = all(world_size == world_size_list[0] + for world_size in world_size_list) + return world_size_list[0] if is_value else world_size_list + + @property + def group(self): + """ + The communication group which current rank belongs to. + + Returns: + Group if current rank only belong to single communication group, or a list of Group if it belongs many. + """ + return self._group + + def _create_group(self): + list_of_group = [] + for ranks in self._list_of_ranks: + group = dist.new_group(ranks=ranks) + if self._rank in ranks: + list_of_group.append(group) + assert len( + list_of_group + ) > 0, "Rank {} does not belong to the list_of_ranks {}.".format( + self._rank, self._list_of_ranks) + return list_of_group if len(list_of_group) > 1 else list_of_group[0] + + +class DPGroup(StrategyGroupBase): + """ + The communication group strategy for data parallel. + + Args: + list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents + they are in the same communication group. + + Returns: + The instance of data parallel strategy group. + """ + + def __init__(self, list_of_ranks): + super(DPGroup, self).__init__(list_of_ranks) + assert not isinstance( + self.group, list), "Rank {} belongs to multi dp groups".format( + self._rank) + + +class MPGroup(StrategyGroupBase): + """ + The communication group strategy for model parallel. + + Args: + list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents + they are in the same communication group. + + Returns: + The instance of model parallel strategy group. + """ + + def __init__(self, list_of_ranks): + super(MPGroup, self).__init__(list_of_ranks) + assert not isinstance( + self.group, list), "Rank {} belongs to multi mp groups".format( + self._rank) + + +class ShardingGroup(StrategyGroupBase): + """ + The communication group strategy for sharding parallel. + + Args: + list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents + they are in the same communication group. + + Returns: + The instance of sharding parallel strategy group. + """ + + def __init__(self, list_of_ranks): + super(ShardingGroup, self).__init__(list_of_ranks) + assert not isinstance( + self.group, + list), "Rank {} belongs to multi sharding groups".format(self._rank) + + +class PPGroup(StrategyGroupBase): + """ + The communication group strategy for pipeline parallel. + + Args: + list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents + they are in the same communication group. + + Returns: + The instance of pipeline parallel strategy group. + """ + + def __init__(self, list_of_ranks): + super(PPGroup, self).__init__(list_of_ranks) + assert not isinstance( + self.group, list), "Rank {} belongs to multi pp groups".format( + self._rank) + + self._send_next_group = None + self._send_prev_group = None + self._recv_next_group = None + self._recv_prev_group = None + self._rank_of_next_stage = None + self._rank_of_prev_stage = None + + if self.world_size > 1: + self._create_p2p_group() + + @property + def rank_of_prev_stage(self): + """ + Rank of the previous pp stage. + + Returns: + The global rank of previous pp stage. `None` if without previous. + """ + return self._rank_of_prev_stage + + @property + def rank_of_next_stage(self): + """ + Rank of the next pp stage. + + Returns: + The global rank of next pp stage. `None` if without next. + """ + return self._rank_of_next_stage + + @property + def p2p_groups(self): + """ + Communication subgroup in order to switch data with previous and next stage. + + Returns: + Four subgroups including send/recv to/from prev/next. + """ + return self._send_next_group, self._send_prev_group, self._recv_next_group, self._recv_prev_group + + def _create_p2p_group(self): + degree = self.world_size + for ranks in self._list_of_ranks: + for idx, rank in enumerate(ranks): + next_rank = ranks[(idx + 1) % degree] + prev_rank = ranks[(idx - 1) % degree] + + if self._rank == rank: + self._rank_of_next_stage = next_rank + self._rank_of_prev_stage = prev_rank + + next_group = dist.new_group(ranks=[rank, next_rank]) + + if self._rank == rank: + self._send_next_group = next_group + elif self._rank == next_rank: + self._recv_prev_group = next_group + + prev_group = dist.new_group(ranks=[prev_rank, rank]) + if self._rank == rank: + self._send_prev_group = prev_group + elif self._rank == prev_rank: + self._recv_next_group = prev_group + + assert self._send_next_group and self._send_prev_group and self._recv_next_group and self._recv_prev_group,\ + "Error occurs while creating p2p group for rank {}.".format(self._rank) diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index a4f42cdb6e..4e19583be6 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -391,5 +391,27 @@ if(WITH_MPI) "PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy=") endif() endif() +if((WITH_ROCM OR WITH_GPU) AND (LINUX)) + bash_test_modules( + test_strategy_group + START_BASH + test_strategy_group.sh + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21814;http_proxy=;https_proxy=") + set_tests_properties(test_strategy_group PROPERTIES TIMEOUT "120") +endif() +if((WITH_ROCM OR WITH_GPU) AND (LINUX)) + bash_test_modules( + test_orthogonal_strategy + START_BASH + test_orthogonal_strategy.sh + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21958;http_proxy=;https_proxy=") + set_tests_properties(test_orthogonal_strategy PROPERTIES TIMEOUT "120") +endif() add_subdirectory(fleet) add_subdirectory(multinode) diff --git a/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py b/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py new file mode 100644 index 0000000000..5dbc624b79 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/orthogonal_strategy.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.distributed as dist +from paddle.distributed.fleet.base.strategy_group import DPGroup, ShardingGroup, MPGroup, PPGroup +from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy + + +class TestOrthogonalStrategyAPI(unittest.TestCase): + + def setUp(self): + self._num_of_ranks = 2 + dist.init_parallel_env() + self._global_rank = dist.get_rank() + self._strategy = OrthogonalStrategy( + [("dp", 2, DPGroup), ("mp", 1, MPGroup), + ("sharding", 1, ShardingGroup), ("pp", 1, PPGroup)], + fused_strategy_dict={"checkness": ["mp", "sharding", "pp"]}) + + def test_orthogonal_strategy(self): + dp_group = self._strategy.strategy_group("dp") + self.assertEqual(dp_group.world_size, self._num_of_ranks) + self.assertEqual(dp_group.group.nranks, self._num_of_ranks) + self.assertEqual(self._strategy.rank_in_strategy("dp"), + self._global_rank) + + fused_group = self._strategy.fused_strategy_group("checkness") + self.assertEqual(fused_group.world_size, 1) + self.assertEqual(fused_group.group.nranks, 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/strategy_group.py b/python/paddle/fluid/tests/unittests/collective/strategy_group.py new file mode 100644 index 0000000000..247a232aec --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/strategy_group.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase, DPGroup, MPGroup, PPGroup, ShardingGroup + + +def _check_using_all_reduce(group): + data = paddle.to_tensor([1, 2, 3]) + result = paddle.to_tensor([2, 4, 6]) + dist.all_reduce(data, group=group) + assert np.array_equal(data, result) + + +def _check_using_send(group, dst): + data = paddle.to_tensor([1, 2, 3]) + dist.send(data, dst=dst, group=group) + + +def _check_using_recv(group, src): + result = paddle.to_tensor([1, 2, 3]) + data = paddle.to_tensor([0, 0, 0]) + dist.recv(data, src=src, group=group) + assert np.array_equal(data, result) + + +class TestStrategyGroupAPI(unittest.TestCase): + + def setUp(self): + self._num_of_ranks = 2 + self._list_of_rank = [[0, 1]] + self._list_of_ranks = [[0, 1], [0, 1]] + dist.init_parallel_env() + self._global_rank = dist.get_rank() + self._peer_rank = 0 if self._global_rank == 1 else 1 + + def test_strategy_group_base(self): + strategy_group = StrategyGroupBase(self._list_of_rank) + self.assertEqual(strategy_group.world_size, self._num_of_ranks) + self.assertEqual(strategy_group.group.nranks, self._num_of_ranks) + _check_using_all_reduce(strategy_group.group) + + def test_data_parallel_group(self): + dp_group = DPGroup(self._list_of_rank) + self.assertEqual(dp_group.world_size, self._num_of_ranks) + self.assertEqual(dp_group.group.nranks, self._num_of_ranks) + _check_using_all_reduce(dp_group.group) + + def test_model_parallel_group(self): + mp_group = MPGroup(self._list_of_rank) + self.assertEqual(mp_group.world_size, self._num_of_ranks) + self.assertEqual(mp_group.group.nranks, self._num_of_ranks) + _check_using_all_reduce(mp_group.group) + + def test_sharding_parallel_group(self): + sharding_group = ShardingGroup(self._list_of_rank) + self.assertEqual(sharding_group.world_size, self._num_of_ranks) + self.assertEqual(sharding_group.group.nranks, self._num_of_ranks) + _check_using_all_reduce(sharding_group.group) + + def test_pipeline_parallel_group(self): + pp_group = PPGroup(self._list_of_rank) + send_next_group, send_prev_group, recv_next_group, recv_prev_group = pp_group.p2p_groups + if self._global_rank == 0: + self.assertEqual(pp_group.rank_of_next_stage, 1) + self.assertEqual(pp_group.rank_of_prev_stage, 1) + _check_using_send(send_next_group, self._peer_rank) + _check_using_send(send_prev_group, self._peer_rank) + _check_using_recv(recv_prev_group, self._peer_rank) + _check_using_recv(recv_next_group, self._peer_rank) + else: + self.assertEqual(pp_group.rank_of_next_stage, 0) + self.assertEqual(pp_group.rank_of_prev_stage, 0) + _check_using_recv(recv_prev_group, self._peer_rank) + _check_using_recv(recv_next_group, self._peer_rank) + _check_using_send(send_next_group, self._peer_rank) + _check_using_send(send_prev_group, self._peer_rank) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh b/python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh new file mode 100644 index 0000000000..6b4df2b124 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_orthogonal_strategy.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 orthogonal_strategy.py diff --git a/python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh b/python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh new file mode 100644 index 0000000000..d6c3a0e79f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_strategy_group.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 strategy_group.py diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index 1f6584f7b9..5d554aeee8 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -46,3 +46,5 @@ test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_pro test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=, test_mpi_comm,linux,,,DIST,test_mpi_comm.sh,2,,http_proxy=;https_proxy=,WITH_MPI +test_strategy_group,linux,rocm;gpu,120,DIST,test_strategy_group.sh,2,,http_proxy=;https_proxy=, +test_orthogonal_strategy,linux,rocm;gpu,120,DIST,test_orthogonal_strategy.sh,2,,http_proxy=;https_proxy=, -- GitLab