未验证 提交 d121cf29 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

add sep group (#56271)

* sep group

* add test

* test ok

* polish

* test cmake script generated

* add sep group

* format

* polish

* polish
上级 e2b05dcc
......@@ -83,9 +83,10 @@ message HybridConfig {
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
optional MpConfig mp_configs = 5;
optional PpConfig pp_configs = 6;
optional DygraphShardingConfig sharding_configs = 7;
optional int32 sep_degree = 5 [ default = 1 ];
optional MpConfig mp_configs = 6;
optional PpConfig pp_configs = 7;
optional DygraphShardingConfig sharding_configs = 8;
}
message AMPConfig {
......
......@@ -153,7 +153,7 @@ class DistributedStrategy:
if _global_flags().is_public(key):
self.strategy.sync_nccl_allreduce = bool(_global_flags()[key])
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp']
self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'sep', 'mp']
self.sync_param_name = ["embedding", "layer_norm", ".b_"]
self.__lock_attr = True
......@@ -1718,10 +1718,10 @@ class DistributedStrategy:
def hybrid_configs(self):
"""
Dynamic graph hybrid parallel strategy configuration. Three-way hybrid parallelism
Dynamic graph hybrid parallel strategy configuration. Five-way hybrid parallelism
needs to meet the following relationships
total_number_GPUs = dp_degree * mp_degree * pp_degree
total_number_GPUs = dp_degree * mp_degree * pp_degree * sharding_degree * sep_degree
**Note**:
**dp_degree(int)**: set number of GPUs in a data parallel group. Default -1.
......@@ -1732,8 +1732,9 @@ class DistributedStrategy:
**mp_degree(int)**: set number of GPUs in a model parallel group. Default 1
**pp_degree(int)**: set number of GPUs in a pipeline parallel group. Default 1
**order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','mp']
**sep_degree(int)**: set number of GPUs in a sep parallel group. Default 1
**sharding_degree(int)**: set number of GPUs in a sharding parallel group. Default 1
**order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','sep', 'mp']
Examples:
.. code-block:: python
......@@ -1744,7 +1745,7 @@ class DistributedStrategy:
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 1,
"order":['dp','pp','sharding','mp']}
"order":['dp','pp','sharding', 'sep', 'mp']}
"""
return get_msg_dict(self.strategy.hybrid_configs)
......
......@@ -60,8 +60,8 @@ class ParallelMode:
class CommunicateTopology:
def __init__(
self,
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[1, 1, 1, 1],
hybrid_group_names=["data", "pipe", "sharding", "sep", "model"],
dims=[1, 1, 1, 1, 1],
):
self._parallel_names = hybrid_group_names
self._dims = dims
......@@ -112,6 +112,33 @@ class CommunicateTopology:
assert axis_name in self._parallel_names
return self._dims[self._parallel_names.index(axis_name)]
def get_fused_ranks(self, fused_axis):
non_fused_axis = list(set(self._parallel_names).difference(fused_axis))
non_fused_ranges = []
for axis_name in non_fused_axis:
non_fused_ranges.append(
range(self._dims[self._parallel_names.index(axis_name)])
)
fused_ranges = []
for axis_name in fused_axis:
fused_ranges.append(
range(self._dims[self._parallel_names.index(axis_name)])
)
rank_list = []
for non_fused_ranks in product(*non_fused_ranges):
coord_dict = {}
ranks = []
for i, non_fused_rank in enumerate(non_fused_ranks):
coord_dict[non_fused_axis[i]] = non_fused_rank
for fused_ranks in product(*fused_ranges):
for i, fused_rank in enumerate(fused_ranks):
coord_dict[fused_axis[i]] = fused_rank
ranks.append(self._coord2rank[self.coordinate(**coord_dict)])
rank_list.append(ranks)
return rank_list
def get_comm_list(self, axis_name):
assert axis_name in self._parallel_names
other_axis_names = [
......@@ -153,20 +180,23 @@ class HybridCommunicateGroup:
self._mp_degree = self._topo.get_dim('model')
self._pp_degree = self._topo.get_dim('pipe')
self._sharding_degree = self._topo.get_dim('sharding')
self._sep_degree = self._topo.get_dim('sep')
self._data_parallel_id = self._get_data_parallel_id()
self._model_parallel_id = self._get_model_parallel_id()
self._sharding_parallel_id = self._get_sharding_parallel_id()
self._sep_parallel_id = self._get_sep_parallel_id()
self.stage_id = self._get_pipe_parallel_id()
assert self._check_vaild_topo(), (
"Here is an unreasonable topogy setting. world_size: {}, but"
"mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(
"mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}, sep_num: {}".format(
self.nranks,
self._mp_degree,
self._sharding_degree,
self._pp_degree,
self._dp_degree,
self._sep_degree,
)
)
......@@ -183,6 +213,10 @@ class HybridCommunicateGroup:
self._sharding_group, self._sharding_comm_group = self._set_comm_group(
"sharding"
)
self._sep_group = None
if self._sep_degree > 1:
# create comm group for sep parallel
self._sep_group, self._sep_comm_group = self._set_comm_group("sep")
# create global group for check inf_nan / clip global norm
self._check_group, self._check_comm_group = self._set_check_group(
......@@ -195,6 +229,16 @@ class HybridCommunicateGroup:
self.sharding_check_comm_group,
) = self._set_check_group("sharding")
# create fused comm group
if self._sep_degree > 1:
(
self._dp_sep_group,
self._dp_sep_comm_group,
) = self.create_fuse_group(["data", "sep"])
self._pp_mp_group, self._pp_mp_comm_group = self.create_fuse_group(
["pipe", "model"]
)
# create p2p group
self.is_first_stage = self.stage_id == 0
self.is_last_stage = self.stage_id == (self._pp_degree - 1)
......@@ -209,20 +253,22 @@ class HybridCommunicateGroup:
debug_str = (
"HybridParallelInfo: rank_id: %d, mp_degree: %d, "
"sharding_degree: %d, pp_degree: %d, dp_degree: %d"
"sharding_degree: %d, pp_degree: %d, dp_degree: %d, sep_degree: %d"
% (
self.global_rank,
self._mp_degree,
self._sharding_degree,
self._pp_degree,
self._dp_degree,
self._sep_degree,
)
)
debug_str += ", mp_group: {}, sharding_group: {}, pp_group: {}, dp_group: {}, check/clip group: {}".format(
debug_str += ", mp_group: {}, sharding_group: {}, pp_group: {}, dp_group: {}, sep:group: {}, check/clip group: {}".format(
self._mp_group,
self._sharding_group,
self._pp_group,
self._dp_group,
self._sep_group,
self._check_group,
)
logger.info(debug_str)
......@@ -257,9 +303,13 @@ class HybridCommunicateGroup:
* self._mp_degree
* self._pp_degree
* self._sharding_degree
* self._sep_degree
== self.nranks
)
def _check_sep_exist(self):
assert self._sep_degree > 1, "sep not exist"
def _set_comm_group(self, parallel_method="data"):
parallel_group = []
parallel_comm_group = None
......@@ -404,6 +454,23 @@ class HybridCommunicateGroup:
def get_pipe_parallel_world_size(self):
return self._pp_degree
def _get_sep_parallel_id(self):
return self._topo.get_coord(self.global_rank).sep
def get_sep_parallel_rank(self):
return self._sep_parallel_id
def get_sep_parallel_world_size(self):
return self._sep_degree
def get_sep_parallel_group(self):
self._check_sep_exist()
return self._sep_comm_group
def get_sep_parallel_group_src_rank(self):
self._check_sep_exist()
return self._sep_comm_group.ranks[0]
def get_pipe_parallel_group(self):
return self._pp_comm_group
......@@ -447,6 +514,44 @@ class HybridCommunicateGroup:
self.global_rank, pipe=stage_id, **kwargs
)
# fuse comm group message
def get_dp_sep_parallel_group(self):
self._check_sep_exist()
return self._dp_sep_comm_group
def get_pp_mp_parallel_group(self):
self._check_sep_exist()
return self._pp_mp_comm_group
def create_fuse_group(self, fused_strategy_list):
assert (
len(fused_strategy_list) > 0
), "the length of fused_strategy_list must be greater than 0."
parallel_group = []
parallel_comm_group = []
parallel_groups = self._topo.get_fused_ranks(fused_strategy_list)
parallel_groups.sort()
for group in parallel_groups:
comm_group = paddle.distributed.new_group(ranks=group)
if self.global_rank in group:
parallel_group.append(group)
parallel_comm_group.append(comm_group)
assert len(parallel_group) > 0
assert len(parallel_comm_group) > 0
logger.info(
"Total {} comm group(s) of fused {} create successfully!".format(
len(parallel_groups), fused_strategy_list
)
)
if len(parallel_group) > 1:
return parallel_group, parallel_comm_group
else:
return parallel_group[0], parallel_comm_group[0]
class _CommunicateGroup:
"""tmp for static"""
......
......@@ -370,21 +370,26 @@ class Fleet:
return self
def _init_hybrid_parallel_env(self):
"""initialize the hybrid environment"""
"""initialize the hybrid environment."""
self.hybrid_configs = self._user_defined_strategy.hybrid_configs
self.dp_degree = self.hybrid_configs["dp_degree"]
self.mp_degree = self.hybrid_configs["mp_degree"]
self.pp_degree = self.hybrid_configs["pp_degree"]
self.sep_degree = self.hybrid_configs["sep_degree"]
self.sharding_degree = self.hybrid_configs["sharding_degree"]
assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
assert (
self.sep_degree >= 0
), "sep_degree should be greater or equal to 0"
assert (
self.sharding_degree >= 0
), "sharding_degree should be greater or equal to 0"
self.mp_degree = max(self.mp_degree, 1)
self.pp_degree = max(self.pp_degree, 1)
self.sep_degree = max(self.sep_degree, 1)
if self.dp_degree < 0:
nranks = paddle.distributed.get_world_size()
......@@ -397,6 +402,7 @@ class Fleet:
"pp": ['pipe', self.pp_degree],
"sharding": ['sharding', self.sharding_degree],
"mp": ['model', self.mp_degree],
"sep": ["sep", self.sep_degree],
}
order = self._user_defined_strategy.hybrid_parallel_order
......
......@@ -243,7 +243,7 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
ENVS
"PADDLE_DIST_UT_PORT=21222;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_pipeline PROPERTIES TIMEOUT "120")
set_tests_properties(test_pipeline PROPERTIES TIMEOUT "160")
endif()
if(LOCAL_ALL_ARCH AND (LINUX OR APPLE))
py_test_modules(
......@@ -332,7 +332,7 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
LABELS
"RUN_TYPE=DIST"
ENVS
"NVIDIA_TF32_OVERRIDE=0;PADDLE_DIST_UT_PORT=21234;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
"PADDLE_DIST_UT_PORT=21234;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_hybrid_parallel_inference_helper PROPERTIES TIMEOUT
"120")
......@@ -351,12 +351,8 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX OR WIN32))
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
py_test_modules(
test_recv_save_op
MODULES
test_recv_save_op
ENVS
"NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
test_recv_save_op MODULES test_recv_save_op ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
py_test_modules(
......@@ -435,6 +431,21 @@ if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT
"200")
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
bash_test_modules(
test_parallel_dygraph_sep_parallel
START_BASH
../../legacy_test/dist_test.sh
TIMEOUT
"120"
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21242;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_parallel_dygraph_sep_parallel PROPERTIES TIMEOUT
"120")
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
bash_test_modules(
test_dygraph_group_sharded_api_for_eager
......@@ -705,7 +716,7 @@ if((WITH_GPU OR WITH_ROCM) AND LOCAL_ALL_PLAT)
LABELS
"RUN_TYPE=DIST"
ENVS
"NVIDIA_TF32_OVERRIDE=0;PADDLE_DIST_UT_PORT=21274;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
"PADDLE_DIST_UT_PORT=21274;NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python"
)
set_tests_properties(test_parallel_dygraph_mnist PROPERTIES TIMEOUT "200")
endif()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
import paddle
from paddle.distributed import fleet
class TestDistMPTraining(unittest.TestCase):
def setUp(self):
random.seed(2023)
np.random.seed(2023)
paddle.seed(2023)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 1,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
"sep_degree": 2,
}
fleet.init(is_collective=True, strategy=self.strategy)
def test_basic_hcg(self):
hcg = fleet.get_hybrid_communicate_group()
assert hcg.get_sep_parallel_rank() >= 0
assert hcg.get_sep_parallel_world_size() == 2
assert hcg.get_sep_parallel_group_src_rank() == 0
assert hcg.get_sep_parallel_group() is not None
assert hcg.get_dp_sep_parallel_group() is not None
assert hcg.get_pp_mp_parallel_group() is not None
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from legacy_test.test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridParallel(TestMultipleGpus):
def test_hybrid_parallel_hcg(self):
self.run_mnist_2gpu('hybrid_parallel_sep_model.py')
if __name__ == "__main__":
unittest.main()
......@@ -3,81 +3,83 @@ test_fleet_sharding_meta_optimizer,,GPU;XPU,350,DIST,test_runner.py,2,,http_prox
test_fleet_static_mp_layers,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dgc_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC
test_dgc_optimizer,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC
test_parallel_margin_cross_entropy,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_parallel_dygraph_transformer,,GPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_transformer,,ROCM,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_margin_cross_entropy,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_parallel_dygraph_transformer,,GPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_transformer,,ROCM,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_fp16_allreduce_meta_optimizer,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_rnn_dp,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mp_layers,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_tcp_store,LINUX;APPLE,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_rnn_dp,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mp_layers,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_tcp_store,LINUX;APPLE,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_sharding_stage3_for_eager,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_communicator_half_async,,,120,DIST,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_sync_send,,GPU;XPU,300,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..;PADDLE_P2P_SYNC_SEND=1,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pp_adaptor,,GPU,500,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pipeline_parallel_sync_send,,GPU;XPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..;PADDLE_P2P_SYNC_SEND=1,
test_parallel_dygraph_pipeline_parallel_with_virtual_stage,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_pp_adaptor,,GPU,500,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_localsgd_meta_optimizer,LINUX,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_class_center_sample,,GPU,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_pipeline,,,160,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_utils,LINUX;APPLE,,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_static_model_parallel,,,240,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_no_sync,,GPU,300,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_dygraph_sharding_stage2,,,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_control_flow,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hybrid_parallel_inference_helper,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_static_model_parallel,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_no_sync,,GPU,300,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL
test_dygraph_sharding_stage2,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_control_flow,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_lars_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hybrid_parallel_inference_helper,,,120,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_rolemaker_new,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dist_mnist_gradient_merge,LINUX;WIN32,GPU;ROCM,360,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_recv_save_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_communicator_sync,,,,,test_runner.py,2,,FLAGS_communicator_send_queue_size=1;FLAGS_communicator_max_merge_var_num=1;http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_pipeline_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_pipeline_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_gradient_merge_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_amp_init,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_meta_optimizer_base,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_raw_program_meta_optimizer,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sharding_parallel,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_tensor_parallel,,,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_group_sharded_api_for_eager,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_raw_program_meta_optimizer,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sharding_parallel,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_tensor_parallel,,,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sep_parallel,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_group_sharded_api_for_eager,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_distributed_strategy,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_dgc_meta_optimizer,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC
test_parallel_dygraph_unused_variables,,,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_unused_variables,,,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_lamb_meta_optimizer,LINUX,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dgc_momentum_op,,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_DGC
test_parallel_dygraph_no_sync_gradient_check,,,60,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_pipeline_meta_optimizer_with_recompute,,GPU;XPU,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_no_sync_gradient_check,,,60,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_pipeline_meta_optimizer_with_recompute,,GPU;XPU,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_hybrid_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_qat,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sparse_embedding,,GPU,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_sparse_embedding,,ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_qat,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sparse_embedding,,GPU,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_sparse_embedding,,ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_amp_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sparse_embedding_over_height,,GPU,150,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_qat_meta_optimizer,,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sparse_embedding_over_height,,GPU,150,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL;${NCCL_VERSION} VERSION_GREATER_EQUAL 2212
test_parallel_dygraph_sparse_embedding_over_height,,ROCM,350,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_distributed_strategy,LINUX;APPLE,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_parallel_parallelizer,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_parallel_parallelizer,,,120,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_new_group,,GPU;XPU,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=,
test_c_comm_init_op,LINUX,GPU;XPU,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=,
test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=,
test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_ir_pass_pipeline,,,240,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_sync_batch_norm,,GPU;ROCM,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_imperative_auto_mixed_precision_for_eager,,GPU;ROCM,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_mixed_precision,,GPU;ROCM,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_recompute_for_eager,,GPU;ROCM,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dist_mnist_dgc_nccl,,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC
test_dist_se_resnext_dgc,,,,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC
test_auto_checkpoint,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint_multiple,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint_dist_basic,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dist_mnist_dgc_nccl,,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC
test_dist_se_resnext_dgc,,,,DIST,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,WITH_NCCL OR WITH_RCCL;WITH_DGC
test_auto_checkpoint,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint_multiple,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_auto_checkpoint_dist_basic,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs1,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs2,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_hdfs3,LINUX,,200,EXCLUSIVE:NIGHTLY,../../legacy_test/dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_checkpoint,LINUX,GPU;ROCM,200,EXCLUSIVE:NIGHTLY,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_fleet_log,,,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,200,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_dist_save_load,LINUX,GPU,300,DIST,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_dygraph_save_for_auto_infer,LINUX,GPU,300,DIST,test_runner.py,,,NVIDIA_TF32_OVERRIDE=0;http_proxy=;https_proxy=;PYTHONPATH=../..,
......@@ -214,6 +214,277 @@ class TestCommunicateTopology(unittest.TestCase):
self.assertEqual(topo.get_dim_size("pp"), 2)
self.assertEqual(topo.get_dim_size("sharding"), 2)
def test_topology_5D(self):
topo = fleet.CommunicateTopology(
["dp", "pp", "sharding", "sep", "mp"], [2, 2, 2, 2, 2]
)
# test get_comm_list
dp_comm_list = [
[0, 16],
[1, 17],
[2, 18],
[3, 19],
[4, 20],
[5, 21],
[6, 22],
[7, 23],
[8, 24],
[9, 25],
[10, 26],
[11, 27],
[12, 28],
[13, 29],
[14, 30],
[15, 31],
]
mp_comm_list = [
[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9],
[10, 11],
[12, 13],
[14, 15],
[16, 17],
[18, 19],
[20, 21],
[22, 23],
[24, 25],
[26, 27],
[28, 29],
[30, 31],
]
pp_comm_list = [
[0, 8],
[1, 9],
[2, 10],
[3, 11],
[4, 12],
[5, 13],
[6, 14],
[7, 15],
[16, 24],
[17, 25],
[18, 26],
[19, 27],
[20, 28],
[21, 29],
[22, 30],
[23, 31],
]
sharding_comm_list = [
[0, 4],
[1, 5],
[2, 6],
[3, 7],
[8, 12],
[9, 13],
[10, 14],
[11, 15],
[16, 20],
[17, 21],
[18, 22],
[19, 23],
[24, 28],
[25, 29],
[26, 30],
[27, 31],
]
sep_comm_list = [
[0, 2],
[1, 3],
[4, 6],
[5, 7],
[8, 10],
[9, 11],
[12, 14],
[13, 15],
[16, 18],
[17, 19],
[20, 22],
[21, 23],
[24, 26],
[25, 27],
[28, 30],
[29, 31],
]
np.testing.assert_array_equal(dp_comm_list, topo.get_comm_list("dp"))
np.testing.assert_array_equal(mp_comm_list, topo.get_comm_list("mp"))
np.testing.assert_array_equal(pp_comm_list, topo.get_comm_list("pp"))
np.testing.assert_array_equal(
sharding_comm_list, topo.get_comm_list("sharding")
)
np.testing.assert_array_equal(sep_comm_list, topo.get_comm_list("sep"))
# test get_fused_ranks
dp_sep_fuse_comm_list = [
[0, 2, 16, 18],
[1, 3, 17, 19],
[4, 6, 20, 22],
[5, 7, 21, 23],
[8, 10, 24, 26],
[9, 11, 25, 27],
[12, 14, 28, 30],
[13, 15, 29, 31],
]
pp_mp_fuse_comm_list = [
[0, 1, 8, 9],
[2, 3, 10, 11],
[4, 5, 12, 13],
[6, 7, 14, 15],
[16, 17, 24, 25],
[18, 19, 26, 27],
[20, 21, 28, 29],
[22, 23, 30, 31],
]
np.testing.assert_array_equal(
sorted(dp_sep_fuse_comm_list),
sorted(topo.get_fused_ranks(["dp", "sep"])),
)
np.testing.assert_array_equal(
sorted(pp_mp_fuse_comm_list),
sorted(topo.get_fused_ranks(["pp", "mp"])),
)
# test get_hybrid_group_names
parallel_names = ["dp", "pp", "sharding", "sep", "mp"]
np.testing.assert_array_equal(
parallel_names, topo.get_hybrid_group_names()
)
# test get_dims
np.testing.assert_array_equal(2, topo.get_dim("dp"))
np.testing.assert_array_equal(2, topo.get_dim("mp"))
np.testing.assert_array_equal(2, topo.get_dim("pp"))
np.testing.assert_array_equal(2, topo.get_dim("sharding"))
np.testing.assert_array_equal(2, topo.get_dim("sep"))
# test world size
self.assertEqual(topo.world_size(), 32)
# test get_rank
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=0, mp=0), 0)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=0, mp=1), 1)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=1, mp=0), 2)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, sep=1, mp=1), 3)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=0, mp=0), 4)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=0, mp=1), 5)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=1, mp=0), 6)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, sep=1, mp=1), 7)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=0, mp=0), 8)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=0, mp=1), 9)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=1, mp=0), 10)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, sep=1, mp=1), 11)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=0, mp=0), 12)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=0, mp=1), 13)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=1, mp=0), 14)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, sep=1, mp=1), 15)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=0, mp=0), 16)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=0, mp=1), 17)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=1, mp=0), 18)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, sep=1, mp=1), 19)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=0, mp=0), 20)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=0, mp=1), 21)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=1, mp=0), 22)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, sep=1, mp=1), 23)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=0, mp=0), 24)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=0, mp=1), 25)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=1, mp=0), 26)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, sep=1, mp=1), 27)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=0, mp=0), 28)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=0, mp=1), 29)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=1, mp=0), 30)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, sep=1, mp=1), 31)
# test get_coord
self.assertEqual(topo.get_coord(0), topo.coordinate(0, 0, 0, 0, 0))
self.assertEqual(topo.get_coord(1), topo.coordinate(0, 0, 0, 0, 1))
self.assertEqual(topo.get_coord(2), topo.coordinate(0, 0, 0, 1, 0))
self.assertEqual(topo.get_coord(3), topo.coordinate(0, 0, 0, 1, 1))
self.assertEqual(topo.get_coord(4), topo.coordinate(0, 0, 1, 0, 0))
self.assertEqual(topo.get_coord(5), topo.coordinate(0, 0, 1, 0, 1))
self.assertEqual(topo.get_coord(6), topo.coordinate(0, 0, 1, 1, 0))
self.assertEqual(topo.get_coord(7), topo.coordinate(0, 0, 1, 1, 1))
self.assertEqual(topo.get_coord(8), topo.coordinate(0, 1, 0, 0, 0))
self.assertEqual(topo.get_coord(9), topo.coordinate(0, 1, 0, 0, 1))
self.assertEqual(topo.get_coord(10), topo.coordinate(0, 1, 0, 1, 0))
self.assertEqual(topo.get_coord(11), topo.coordinate(0, 1, 0, 1, 1))
self.assertEqual(topo.get_coord(12), topo.coordinate(0, 1, 1, 0, 0))
self.assertEqual(topo.get_coord(13), topo.coordinate(0, 1, 1, 0, 1))
self.assertEqual(topo.get_coord(14), topo.coordinate(0, 1, 1, 1, 0))
self.assertEqual(topo.get_coord(15), topo.coordinate(0, 1, 1, 1, 1))
self.assertEqual(topo.get_coord(16), topo.coordinate(1, 0, 0, 0, 0))
self.assertEqual(topo.get_coord(17), topo.coordinate(1, 0, 0, 0, 1))
self.assertEqual(topo.get_coord(18), topo.coordinate(1, 0, 0, 1, 0))
self.assertEqual(topo.get_coord(19), topo.coordinate(1, 0, 0, 1, 1))
self.assertEqual(topo.get_coord(20), topo.coordinate(1, 0, 1, 0, 0))
self.assertEqual(topo.get_coord(21), topo.coordinate(1, 0, 1, 0, 1))
self.assertEqual(topo.get_coord(22), topo.coordinate(1, 0, 1, 1, 0))
self.assertEqual(topo.get_coord(23), topo.coordinate(1, 0, 1, 1, 1))
self.assertEqual(topo.get_coord(24), topo.coordinate(1, 1, 0, 0, 0))
self.assertEqual(topo.get_coord(25), topo.coordinate(1, 1, 0, 0, 1))
self.assertEqual(topo.get_coord(26), topo.coordinate(1, 1, 0, 1, 0))
self.assertEqual(topo.get_coord(27), topo.coordinate(1, 1, 0, 1, 1))
self.assertEqual(topo.get_coord(28), topo.coordinate(1, 1, 1, 0, 0))
self.assertEqual(topo.get_coord(29), topo.coordinate(1, 1, 1, 0, 1))
self.assertEqual(topo.get_coord(30), topo.coordinate(1, 1, 1, 1, 0))
self.assertEqual(topo.get_coord(31), topo.coordinate(1, 1, 1, 1, 1))
# test get_axis_list
self.assertEqual(
topo.get_axis_list("dp", 0),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
)
self.assertEqual(
topo.get_axis_list("dp", 1),
[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
)
self.assertEqual(
topo.get_axis_list("sep", 0),
[0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29],
)
self.assertEqual(
topo.get_axis_list("sep", 1),
[2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31],
)
self.assertEqual(
topo.get_axis_list("mp", 0),
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30],
)
self.assertEqual(
topo.get_axis_list("mp", 1),
[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31],
)
self.assertEqual(
topo.get_axis_list("pp", 0),
[0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23],
)
self.assertEqual(
topo.get_axis_list("pp", 1),
[8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31],
)
self.assertEqual(
topo.get_axis_list("sharding", 0),
[0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27],
)
self.assertEqual(
topo.get_axis_list("sharding", 1),
[4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31],
)
# test get_dim_size
self.assertEqual(topo.get_dim_size("dp"), 2)
self.assertEqual(topo.get_dim_size("mp"), 2)
self.assertEqual(topo.get_dim_size("pp"), 2)
self.assertEqual(topo.get_dim_size("sharding"), 2)
self.assertEqual(topo.get_dim_size("sep"), 2)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册