diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 38b7acd93c445fe2062169da84dfc314782d536b..64a2efcfdccda277a1fbdef4535e2fa9927f4061 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -83,9 +83,10 @@ message HybridConfig { optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; - optional MpConfig mp_configs = 5; - optional PpConfig pp_configs = 6; - optional DygraphShardingConfig sharding_configs = 7; + optional int32 sep_degree = 5 [ default = 1 ]; + optional MpConfig mp_configs = 6; + optional PpConfig pp_configs = 7; + optional DygraphShardingConfig sharding_configs = 8; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 23d9327623fbcc5c490758047fbda66c21ec3d03..5cb9480eca8b6cdc201a3586f617c5fc87f553f8 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -153,7 +153,7 @@ class DistributedStrategy: if _global_flags().is_public(key): self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) - self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp'] + self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'sep', 'mp'] self.sync_param_name = ["embedding", "layer_norm", ".b_"] self.__lock_attr = True @@ -1718,10 +1718,10 @@ class DistributedStrategy: def hybrid_configs(self): """ - Dynamic graph hybrid parallel strategy configuration. Three-way hybrid parallelism + Dynamic graph hybrid parallel strategy configuration. Five-way hybrid parallelism needs to meet the following relationships - total_number_GPUs = dp_degree * mp_degree * pp_degree + total_number_GPUs = dp_degree * mp_degree * pp_degree * sharding_degree * sep_degree **Note**: **dp_degree(int)**: set number of GPUs in a data parallel group. Default -1. @@ -1732,8 +1732,9 @@ class DistributedStrategy: **mp_degree(int)**: set number of GPUs in a model parallel group. Default 1 **pp_degree(int)**: set number of GPUs in a pipeline parallel group. Default 1 - - **order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','mp'] + **sep_degree(int)**: set number of GPUs in a sep parallel group. Default 1 + **sharding_degree(int)**: set number of GPUs in a sharding parallel group. Default 1 + **order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','sep', 'mp'] Examples: .. code-block:: python @@ -1744,7 +1745,7 @@ class DistributedStrategy: "dp_degree": 1, "mp_degree": 2, "pp_degree": 1, - "order":['dp','pp','sharding','mp']} + "order":['dp','pp','sharding', 'sep', 'mp']} """ return get_msg_dict(self.strategy.hybrid_configs) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 4c655a0295c4fe2e460e3cc4fad7ddebb0ab7deb..bec592e6bb534d504d7bbf5343bccd41923dda79 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -60,8 +60,8 @@ class ParallelMode: class CommunicateTopology: def __init__( self, - hybrid_group_names=["data", "pipe", "sharding", "model"], - dims=[1, 1, 1, 1], + hybrid_group_names=["data", "pipe", "sharding", "sep", "model"], + dims=[1, 1, 1, 1, 1], ): self._parallel_names = hybrid_group_names self._dims = dims @@ -112,6 +112,33 @@ class CommunicateTopology: assert axis_name in self._parallel_names return self._dims[self._parallel_names.index(axis_name)] + def get_fused_ranks(self, fused_axis): + non_fused_axis = list(set(self._parallel_names).difference(fused_axis)) + non_fused_ranges = [] + for axis_name in non_fused_axis: + non_fused_ranges.append( + range(self._dims[self._parallel_names.index(axis_name)]) + ) + fused_ranges = [] + for axis_name in fused_axis: + fused_ranges.append( + range(self._dims[self._parallel_names.index(axis_name)]) + ) + + rank_list = [] + for non_fused_ranks in product(*non_fused_ranges): + coord_dict = {} + ranks = [] + for i, non_fused_rank in enumerate(non_fused_ranks): + coord_dict[non_fused_axis[i]] = non_fused_rank + for fused_ranks in product(*fused_ranges): + for i, fused_rank in enumerate(fused_ranks): + coord_dict[fused_axis[i]] = fused_rank + ranks.append(self._coord2rank[self.coordinate(**coord_dict)]) + rank_list.append(ranks) + + return rank_list + def get_comm_list(self, axis_name): assert axis_name in self._parallel_names other_axis_names = [ @@ -153,20 +180,23 @@ class HybridCommunicateGroup: self._mp_degree = self._topo.get_dim('model') self._pp_degree = self._topo.get_dim('pipe') self._sharding_degree = self._topo.get_dim('sharding') + self._sep_degree = self._topo.get_dim('sep') self._data_parallel_id = self._get_data_parallel_id() self._model_parallel_id = self._get_model_parallel_id() self._sharding_parallel_id = self._get_sharding_parallel_id() + self._sep_parallel_id = self._get_sep_parallel_id() self.stage_id = self._get_pipe_parallel_id() assert self._check_vaild_topo(), ( "Here is an unreasonable topogy setting. world_size: {}, but" - "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format( + "mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format( self.nranks, self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree, + self._sep_degree, ) ) @@ -183,6 +213,10 @@ class HybridCommunicateGroup: self._sharding_group, self._sharding_comm_group = self._set_comm_group( "sharding" ) + self._sep_group = None + if self._sep_degree > 1: + # create comm group for sep parallel + self._sep_group, self._sep_comm_group = self._set_comm_group("sep") # create global group for check inf_nan / clip global norm self._check_group, self._check_comm_group = self._set_check_group( @@ -195,6 +229,16 @@ class HybridCommunicateGroup: self.sharding_check_comm_group, ) = self._set_check_group("sharding") + # create fused comm group + if self._sep_degree > 1: + ( + self._dp_sep_group, + self._dp_sep_comm_group, + ) = self.create_fuse_group(["data", "sep"]) + self._pp_mp_group, self._pp_mp_comm_group = self.create_fuse_group( + ["pipe", "model"] + ) + # create p2p group self.is_first_stage = self.stage_id == 0 self.is_last_stage = self.stage_id == (self._pp_degree - 1) @@ -209,20 +253,22 @@ class HybridCommunicateGroup: debug_str = ( "HybridParallelInfo: rank_id: %d, mp_degree: %d, " - "sharding_degree: %d, pp_degree: %d, dp_degree: %d" + "sharding_degree: %d, pp_degree: %d, dp_degree: %d, sep_degree: %d" % ( self.global_rank, self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree, + self._sep_degree, ) ) - debug_str += ", mp_group: {}, sharding_group: {}, pp_group: {}, dp_group: {}, check/clip group: {}".format( + debug_str += ", mp_group: {}, sharding_group: {}, pp_group: {}, dp_group: {}, sep:group: {}, check/clip group: {}".format( self._mp_group, self._sharding_group, self._pp_group, self._dp_group, + self._sep_group, self._check_group, ) logger.info(debug_str) @@ -257,9 +303,13 @@ class HybridCommunicateGroup: * self._mp_degree * self._pp_degree * self._sharding_degree + * self._sep_degree == self.nranks ) + def _check_sep_exist(self): + assert self._sep_degree > 1, "sep not exist" + def _set_comm_group(self, parallel_method="data"): parallel_group = [] parallel_comm_group = None @@ -404,6 +454,23 @@ class HybridCommunicateGroup: def get_pipe_parallel_world_size(self): return self._pp_degree + def _get_sep_parallel_id(self): + return self._topo.get_coord(self.global_rank).sep + + def get_sep_parallel_rank(self): + return self._sep_parallel_id + + def get_sep_parallel_world_size(self): + return self._sep_degree + + def get_sep_parallel_group(self): + self._check_sep_exist() + return self._sep_comm_group + + def get_sep_parallel_group_src_rank(self): + self._check_sep_exist() + return self._sep_comm_group.ranks[0] + def get_pipe_parallel_group(self): return self._pp_comm_group @@ -447,6 +514,44 @@ class HybridCommunicateGroup: self.global_rank, pipe=stage_id, **kwargs ) + # fuse comm group message + def get_dp_sep_parallel_group(self): + self._check_sep_exist() + return self._dp_sep_comm_group + + def get_pp_mp_parallel_group(self): + self._check_sep_exist() + return self._pp_mp_comm_group + + def create_fuse_group(self, fused_strategy_list): + assert ( + len(fused_strategy_list) > 0 + ), "the length of fused_strategy_list must be greater than 0." + + parallel_group = [] + parallel_comm_group = [] + parallel_groups = self._topo.get_fused_ranks(fused_strategy_list) + parallel_groups.sort() + + for group in parallel_groups: + comm_group = paddle.distributed.new_group(ranks=group) + if self.global_rank in group: + parallel_group.append(group) + parallel_comm_group.append(comm_group) + + assert len(parallel_group) > 0 + assert len(parallel_comm_group) > 0 + + logger.info( + "Total {} comm group(s) of fused {} create successfully!".format( + len(parallel_groups), fused_strategy_list + ) + ) + if len(parallel_group) > 1: + return parallel_group, parallel_comm_group + else: + return parallel_group[0], parallel_comm_group[0] + class _CommunicateGroup: """tmp for static""" diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index 2dab355264b4d1cc72e5a1197c2116c169224747..df0c39ee119aba3eb3536e6fbb20dedead9bd33e 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -370,21 +370,26 @@ class Fleet: return self def _init_hybrid_parallel_env(self): - """initialize the hybrid environment""" + """initialize the hybrid environment.""" self.hybrid_configs = self._user_defined_strategy.hybrid_configs self.dp_degree = self.hybrid_configs["dp_degree"] self.mp_degree = self.hybrid_configs["mp_degree"] self.pp_degree = self.hybrid_configs["pp_degree"] + self.sep_degree = self.hybrid_configs["sep_degree"] self.sharding_degree = self.hybrid_configs["sharding_degree"] assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0" assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0" + assert ( + self.sep_degree >= 0 + ), "sep_degree should be greater or equal to 0" assert ( self.sharding_degree >= 0 ), "sharding_degree should be greater or equal to 0" self.mp_degree = max(self.mp_degree, 1) self.pp_degree = max(self.pp_degree, 1) + self.sep_degree = max(self.sep_degree, 1) if self.dp_degree < 0: nranks = paddle.distributed.get_world_size() @@ -397,6 +402,7 @@ class Fleet: "pp": ['pipe', self.pp_degree], "sharding": ['sharding', self.sharding_degree], "mp": ['model', self.mp_degree], + "sep": ["sep", self.sep_degree], } order = self._user_defined_strategy.hybrid_parallel_order diff --git a/test/collective/fleet/CMakeLists.txt b/test/collective/fleet/CMakeLists.txt index 9d7f9bcdb96c95bc5044a10ab3026d85354376a1..74e4ab4ae64b2facf79b4ca44bc417025051463a 100644 --- a/test/collective/fleet/CMakeLists.txt +++ b/test/collective/fleet/CMakeLists.txt @@ -243,7 +243,7 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) ENVS "PADDLE_DIST_UT_PORT=21222;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" ) - set_tests_properties(test_pipeline PROPERTIES TIMEOUT "120") + set_tests_properties(test_pipeline PROPERTIES TIMEOUT "160") endif() if(LOCAL_ALL_ARCH AND (LINUX OR APPLE)) py_test_modules( @@ -332,7 +332,7 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) LABELS "RUN_TYPE=DIST" ENVS - "NVIDIA_TF32_OVERRIDE=0;PADDLE_DIST_UT_PORT=21234;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + "PADDLE_DIST_UT_PORT=21234;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" ) set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT "120") @@ -351,12 +351,8 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX OR WIN32)) endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) py_test_modules( - test_recv_save_op - MODULES - test_recv_save_op - ENVS - "NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" - ) + test_recv_save_op MODULES test_recv_save_op ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) py_test_modules( @@ -435,6 +431,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT "200") endif() +if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) + bash_test_modules( + test_parallel_dygraph_sep_parallel + START_BASH + ../../legacy_test/dist_test.sh + TIMEOUT + "120" + LABELS + "RUN_TYPE=DIST" + ENVS + "PADDLE_DIST_UT_PORT=21242;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + ) + set_tests_properties(test_parallel_dygraph_sep_parallel PROPERTIES TIMEOUT + "120") +endif() if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT) bash_test_modules( test_dygraph_group_sharded_api_for_eager @@ -705,7 +716,7 @@ if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT) LABELS "RUN_TYPE=DIST" ENVS - "NVIDIA_TF32_OVERRIDE=0;PADDLE_DIST_UT_PORT=21274;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" + "PADDLE_DIST_UT_PORT=21274;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python" ) set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT "200") endif() diff --git a/test/collective/fleet/hybrid_parallel_sep_model.py b/test/collective/fleet/hybrid_parallel_sep_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb83f27dc01d3427a162af2a74e0602d95ccc07 --- /dev/null +++ b/test/collective/fleet/hybrid_parallel_sep_model.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np + +import paddle +from paddle.distributed import fleet + + +class TestDistMPTraining(unittest.TestCase): + def setUp(self): + random.seed(2023) + np.random.seed(2023) + paddle.seed(2023) + + self.strategy = fleet.DistributedStrategy() + self.strategy.hybrid_configs = { + "sharding_degree": 1, + "dp_degree": 1, + "mp_degree": 1, + "pp_degree": 1, + "sep_degree": 2, + } + fleet.init(is_collective=True, strategy=self.strategy) + + def test_basic_hcg(self): + hcg = fleet.get_hybrid_communicate_group() + assert hcg.get_sep_parallel_rank() >= 0 + assert hcg.get_sep_parallel_world_size() == 2 + assert hcg.get_sep_parallel_group_src_rank() == 0 + assert hcg.get_sep_parallel_group() is not None + assert hcg.get_dp_sep_parallel_group() is not None + assert hcg.get_pp_mp_parallel_group() is not None + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/test_parallel_dygraph_sep_parallel.py b/test/collective/fleet/test_parallel_dygraph_sep_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3525d5792b84ed8ea2e8b155ab0775fff2ac08 --- /dev/null +++ b/test/collective/fleet/test_parallel_dygraph_sep_parallel.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestHybridParallel(TestMultipleGpus): + def test_hybrid_parallel_hcg(self): + self.run_mnist_2gpu('hybrid_parallel_sep_model.py') + + +if __name__ == "__main__": + unittest.main() diff --git a/test/collective/fleet/testslist.csv b/test/collective/fleet/testslist.csv index 751a67754cf556b2e0beaf2b17dffc0f4c78ad4e..c0369da06ac062060d95766bfcae512fef86b15c 100644 --- a/test/collective/fleet/testslist.csv +++ b/test/collective/fleet/testslist.csv @@ -3,81 +3,83 @@ test_fleet_sharding_meta_optimizer,,GPU;XPU,350,DIST,test_runner.py,2,,http_prox test_fleet_static_mp_layers,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dgc_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC test_dgc_optimizer,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC -test_parallel_margin_cross_entropy,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL -test_parallel_dygraph_transformer,,GPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 -test_parallel_dygraph_transformer,,ROCM,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_margin_cross_entropy,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_parallel_dygraph_transformer,,GPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 +test_parallel_dygraph_transformer,,ROCM,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_fp16_allreduce_meta_optimizer,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_rnn_dp,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_mp_layers,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL -test_tcp_store,LINUX;APPLE,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_rnn_dp,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_mp_layers,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_tcp_store,LINUX;APPLE,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL -test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_pipeline_parallel_sync_send,,GPU;XPU,300,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..;PADDLE_P2P_SYNC_SEND=1, -test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_pp_adaptor,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_pipeline_parallel_sync_send,,GPU;XPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..;PADDLE_P2P_SYNC_SEND=1, +test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_pp_adaptor,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL -test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_class_center_sample,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_pipeline,,,160,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_utils,LINUX;APPLE,,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_static_model_parallel,,,240,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_no_sync,,GPU,300,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL -test_dygraph_sharding_stage2,,,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_control_flow,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_hybrid_parallel_inference_helper,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_static_model_parallel,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_no_sync,,GPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL +test_dygraph_sharding_stage2,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_control_flow,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_hybrid_parallel_inference_helper,,,120,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_rolemaker_new,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dist_mnist_gradient_merge,LINUX;WIN32,GPU;ROCM,360,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_recv_save_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_communicator_sync,,,,,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_pipeline_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_pipeline_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_gradient_merge_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_amp_init,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_meta_optimizer_base,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_raw_program_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_sharding_parallel,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_tensor_parallel,,,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_dygraph_group_sharded_api_for_eager,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_raw_program_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_sharding_parallel,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_tensor_parallel,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_sep_parallel,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_group_sharded_api_for_eager,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_distributed_strategy,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_dgc_meta_optimizer,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC -test_parallel_dygraph_unused_variables,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_unused_variables,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_lamb_meta_optimizer,LINUX,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dgc_momentum_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC -test_parallel_dygraph_no_sync_gradient_check,,,60,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_fleet_pipeline_meta_optimizer_with_recompute,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_no_sync_gradient_check,,,60,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_pipeline_meta_optimizer_with_recompute,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_hybrid_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_qat,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_sparse_embedding,,GPU,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 -test_parallel_dygraph_sparse_embedding,,ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_qat,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_sparse_embedding,,GPU,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 +test_parallel_dygraph_sparse_embedding,,ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_amp_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_sparse_embedding_over_height,,GPU,150,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 -test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_fleet_qat_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_sparse_embedding_over_height,,GPU,150,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212 +test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_distributed_strategy,LINUX;APPLE,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_parallel_parallelizer,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_parallel_parallelizer,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_new_group,,GPU;XPU,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=, test_c_comm_init_op,LINUX,GPU;XPU,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=, test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=, -test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_ir_pass_pipeline,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_dygraph_sync_batch_norm,,GPU;ROCM,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_imperative_auto_mixed_precision_for_eager,,GPU;ROCM,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_mixed_precision,,GPU;ROCM,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_dygraph_recompute_for_eager,,GPU;ROCM,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_dist_mnist_dgc_nccl,,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC -test_dist_se_resnext_dgc,,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC -test_auto_checkpoint,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_checkpoint1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_checkpoint2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_checkpoint3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_checkpoint_multiple,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_auto_checkpoint_dist_basic,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dist_mnist_dgc_nccl,,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC +test_dist_se_resnext_dgc,,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC +test_auto_checkpoint,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_checkpoint1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_checkpoint2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_checkpoint3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_checkpoint_multiple,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_auto_checkpoint_dist_basic,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., -test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_dist_save_load,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../.., diff --git a/test/legacy_test/test_hybrid_parallel_topology.py b/test/legacy_test/test_hybrid_parallel_topology.py index 33d7bca3f4dcb8ad03536193614007c0b1fd8904..6614a59ed8e127a60f5a5d4df5cd4401508ad758 100644 --- a/test/legacy_test/test_hybrid_parallel_topology.py +++ b/test/legacy_test/test_hybrid_parallel_topology.py @@ -214,6 +214,277 @@ class TestCommunicateTopology(unittest.TestCase): self.assertEqual(topo.get_dim_size("pp"), 2) self.assertEqual(topo.get_dim_size("sharding"), 2) + def test_topology_5D(self): + topo = fleet.CommunicateTopology( + ["dp", "pp", "sharding", "sep", "mp"], [2, 2, 2, 2, 2] + ) + + # test get_comm_list + dp_comm_list = [ + [0, 16], + [1, 17], + [2, 18], + [3, 19], + [4, 20], + [5, 21], + [6, 22], + [7, 23], + [8, 24], + [9, 25], + [10, 26], + [11, 27], + [12, 28], + [13, 29], + [14, 30], + [15, 31], + ] + mp_comm_list = [ + [0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11], + [12, 13], + [14, 15], + [16, 17], + [18, 19], + [20, 21], + [22, 23], + [24, 25], + [26, 27], + [28, 29], + [30, 31], + ] + pp_comm_list = [ + [0, 8], + [1, 9], + [2, 10], + [3, 11], + [4, 12], + [5, 13], + [6, 14], + [7, 15], + [16, 24], + [17, 25], + [18, 26], + [19, 27], + [20, 28], + [21, 29], + [22, 30], + [23, 31], + ] + sharding_comm_list = [ + [0, 4], + [1, 5], + [2, 6], + [3, 7], + [8, 12], + [9, 13], + [10, 14], + [11, 15], + [16, 20], + [17, 21], + [18, 22], + [19, 23], + [24, 28], + [25, 29], + [26, 30], + [27, 31], + ] + sep_comm_list = [ + [0, 2], + [1, 3], + [4, 6], + [5, 7], + [8, 10], + [9, 11], + [12, 14], + [13, 15], + [16, 18], + [17, 19], + [20, 22], + [21, 23], + [24, 26], + [25, 27], + [28, 30], + [29, 31], + ] + + np.testing.assert_array_equal(dp_comm_list, topo.get_comm_list("dp")) + np.testing.assert_array_equal(mp_comm_list, topo.get_comm_list("mp")) + np.testing.assert_array_equal(pp_comm_list, topo.get_comm_list("pp")) + np.testing.assert_array_equal( + sharding_comm_list, topo.get_comm_list("sharding") + ) + np.testing.assert_array_equal(sep_comm_list, topo.get_comm_list("sep")) + + # test get_fused_ranks + dp_sep_fuse_comm_list = [ + [0, 2, 16, 18], + [1, 3, 17, 19], + [4, 6, 20, 22], + [5, 7, 21, 23], + [8, 10, 24, 26], + [9, 11, 25, 27], + [12, 14, 28, 30], + [13, 15, 29, 31], + ] + pp_mp_fuse_comm_list = [ + [0, 1, 8, 9], + [2, 3, 10, 11], + [4, 5, 12, 13], + [6, 7, 14, 15], + [16, 17, 24, 25], + [18, 19, 26, 27], + [20, 21, 28, 29], + [22, 23, 30, 31], + ] + + np.testing.assert_array_equal( + sorted(dp_sep_fuse_comm_list), + sorted(topo.get_fused_ranks(["dp", "sep"])), + ) + np.testing.assert_array_equal( + sorted(pp_mp_fuse_comm_list), + sorted(topo.get_fused_ranks(["pp", "mp"])), + ) + + # test get_hybrid_group_names + parallel_names = ["dp", "pp", "sharding", "sep", "mp"] + np.testing.assert_array_equal( + parallel_names, topo.get_hybrid_group_names() + ) + + # test get_dims + np.testing.assert_array_equal(2, topo.get_dim("dp")) + np.testing.assert_array_equal(2, topo.get_dim("mp")) + np.testing.assert_array_equal(2, topo.get_dim("pp")) + np.testing.assert_array_equal(2, topo.get_dim("sharding")) + np.testing.assert_array_equal(2, topo.get_dim("sep")) + + # test world size + self.assertEqual(topo.world_size(), 32) + + # test get_rank + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=0, mp=0), 0) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=0, mp=1), 1) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=1, mp=0), 2) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=1, mp=1), 3) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=0, mp=0), 4) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=0, mp=1), 5) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=1, mp=0), 6) + self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=1, mp=1), 7) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=0, mp=0), 8) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=0, mp=1), 9) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=1, mp=0), 10) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=1, mp=1), 11) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=0, mp=0), 12) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=0, mp=1), 13) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=1, mp=0), 14) + self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=1, mp=1), 15) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=0, mp=0), 16) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=0, mp=1), 17) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=1, mp=0), 18) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=1, mp=1), 19) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=0, mp=0), 20) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=0, mp=1), 21) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=1, mp=0), 22) + self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=1, mp=1), 23) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=0, mp=0), 24) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=0, mp=1), 25) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=1, mp=0), 26) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=1, mp=1), 27) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=0, mp=0), 28) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=0, mp=1), 29) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=1, mp=0), 30) + self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=1, mp=1), 31) + + # test get_coord + self.assertEqual(topo.get_coord(0), topo.coordinate(0, 0, 0, 0, 0)) + self.assertEqual(topo.get_coord(1), topo.coordinate(0, 0, 0, 0, 1)) + self.assertEqual(topo.get_coord(2), topo.coordinate(0, 0, 0, 1, 0)) + self.assertEqual(topo.get_coord(3), topo.coordinate(0, 0, 0, 1, 1)) + self.assertEqual(topo.get_coord(4), topo.coordinate(0, 0, 1, 0, 0)) + self.assertEqual(topo.get_coord(5), topo.coordinate(0, 0, 1, 0, 1)) + self.assertEqual(topo.get_coord(6), topo.coordinate(0, 0, 1, 1, 0)) + self.assertEqual(topo.get_coord(7), topo.coordinate(0, 0, 1, 1, 1)) + self.assertEqual(topo.get_coord(8), topo.coordinate(0, 1, 0, 0, 0)) + self.assertEqual(topo.get_coord(9), topo.coordinate(0, 1, 0, 0, 1)) + self.assertEqual(topo.get_coord(10), topo.coordinate(0, 1, 0, 1, 0)) + self.assertEqual(topo.get_coord(11), topo.coordinate(0, 1, 0, 1, 1)) + self.assertEqual(topo.get_coord(12), topo.coordinate(0, 1, 1, 0, 0)) + self.assertEqual(topo.get_coord(13), topo.coordinate(0, 1, 1, 0, 1)) + self.assertEqual(topo.get_coord(14), topo.coordinate(0, 1, 1, 1, 0)) + self.assertEqual(topo.get_coord(15), topo.coordinate(0, 1, 1, 1, 1)) + self.assertEqual(topo.get_coord(16), topo.coordinate(1, 0, 0, 0, 0)) + self.assertEqual(topo.get_coord(17), topo.coordinate(1, 0, 0, 0, 1)) + self.assertEqual(topo.get_coord(18), topo.coordinate(1, 0, 0, 1, 0)) + self.assertEqual(topo.get_coord(19), topo.coordinate(1, 0, 0, 1, 1)) + self.assertEqual(topo.get_coord(20), topo.coordinate(1, 0, 1, 0, 0)) + self.assertEqual(topo.get_coord(21), topo.coordinate(1, 0, 1, 0, 1)) + self.assertEqual(topo.get_coord(22), topo.coordinate(1, 0, 1, 1, 0)) + self.assertEqual(topo.get_coord(23), topo.coordinate(1, 0, 1, 1, 1)) + self.assertEqual(topo.get_coord(24), topo.coordinate(1, 1, 0, 0, 0)) + self.assertEqual(topo.get_coord(25), topo.coordinate(1, 1, 0, 0, 1)) + self.assertEqual(topo.get_coord(26), topo.coordinate(1, 1, 0, 1, 0)) + self.assertEqual(topo.get_coord(27), topo.coordinate(1, 1, 0, 1, 1)) + self.assertEqual(topo.get_coord(28), topo.coordinate(1, 1, 1, 0, 0)) + self.assertEqual(topo.get_coord(29), topo.coordinate(1, 1, 1, 0, 1)) + self.assertEqual(topo.get_coord(30), topo.coordinate(1, 1, 1, 1, 0)) + self.assertEqual(topo.get_coord(31), topo.coordinate(1, 1, 1, 1, 1)) + + # test get_axis_list + self.assertEqual( + topo.get_axis_list("dp", 0), + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + ) + self.assertEqual( + topo.get_axis_list("dp", 1), + [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], + ) + + self.assertEqual( + topo.get_axis_list("sep", 0), + [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29], + ) + self.assertEqual( + topo.get_axis_list("sep", 1), + [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31], + ) + + self.assertEqual( + topo.get_axis_list("mp", 0), + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30], + ) + self.assertEqual( + topo.get_axis_list("mp", 1), + [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31], + ) + self.assertEqual( + topo.get_axis_list("pp", 0), + [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23], + ) + self.assertEqual( + topo.get_axis_list("pp", 1), + [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31], + ) + self.assertEqual( + topo.get_axis_list("sharding", 0), + [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27], + ) + self.assertEqual( + topo.get_axis_list("sharding", 1), + [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31], + ) + + # test get_dim_size + self.assertEqual(topo.get_dim_size("dp"), 2) + self.assertEqual(topo.get_dim_size("mp"), 2) + self.assertEqual(topo.get_dim_size("pp"), 2) + self.assertEqual(topo.get_dim_size("sharding"), 2) + self.assertEqual(topo.get_dim_size("sep"), 2) + if __name__ == '__main__': unittest.main()