未验证 提交 178d7e5e 编写于 作者: L LiYuRio 提交者: GitHub

add strategy group (#47021)

上级 d68c38ef
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import collections
import functools
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase
class OrthogonalStrategy():
"""
A hybrid of multiple distributed strategies. Strategies need to be orthogonal, means the ranks are organized like
a square if there are two strategies, a cube if there aree three strategies, etc.
Args:
list_of_strategy(list): Stategy in the list should be represented as tuple, format as (strategy_name, degree, strategy_class).
fused_strategy_dict(dict, optional): Exist strategies can be fused to new strategy. Use the name of new strategy as key, a list of
strategy names you want to fuse as value.
Returns:
The instance of strategy.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import DPGroup, MPGroup, PPGroup
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
dist.init_parallel_env()
strategy = OrthogonalStrategy([("dp", 2, DPGroup), ("mp", 2, MPGroup), ("pp", 2, PPGroup)], fused_strategy_dict={"check": ["mp", "pp"]})
"""
def __init__(self, list_of_strategy, fused_strategy_dict={}):
self._list_of_strategy = list_of_strategy
self._fused_strategy_dict = fused_strategy_dict
self._rank = dist.get_rank()
self._rank_list_dict = {}
self._name_to_group_dict = {}
self._name_to_degree_dict = {}
self._list_of_strategy_name = [
strategy[0] for strategy in list_of_strategy
]
self._list_of_degree = [strategy[1] for strategy in list_of_strategy]
self._coordinate = collections.namedtuple('Coordinate',
self._list_of_strategy_name)
self._check_valid_strategy()
ranges = [range(degree) for degree in self._list_of_degree]
list_of_coord = [
self._coordinate(*coord) for coord in itertools.product(*ranges)
]
self._coord_to_rank_dict = dict(
zip(list_of_coord, range(len(list_of_coord))))
for idx, strategy in enumerate(list_of_strategy):
strategy_name = strategy[0]
self._name_to_degree_dict[strategy_name] = strategy[1]
self._rank_list_dict[strategy_name] = self._calc_rank_list(idx)
self._name_to_group_dict[strategy_name] = strategy[2](
self._rank_list_dict[strategy_name])
self._name_to_fused_group_dict = {}
self._create_fused_group()
def strategy_group(self, name):
"""
Get strategy group with specific name.
Args:
name: The name of strategy group
Returns:
An instance of specific strategy group.
"""
assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format(
name)
return self._name_to_group_dict[name]
def fused_strategy_group(self, name):
"""
Get fused strategy group with specific name.
Args:
name: The name of fused strategy group
Returns:
(StrategyGroupBase): An instance of strategy group.
"""
assert name in self._name_to_fused_group_dict, "Fused strategy group {} is not created.".format(
name)
return self._name_to_fused_group_dict[name]
def rank_in_strategy(self, name):
"""
Get local rank in strategy group with specific name.
Args:
name: The name of strategy group
Returns:
(Integer): Local rank in specific strategy.
"""
assert name in self._list_of_strategy_name, "Strategy group {} is not created.".format(
name)
return self._name_to_group_dict[name].group.rank
def _check_valid_strategy(self):
assert len(self._list_of_strategy_name) == len(
set(self._list_of_strategy_name)
), "Defined duplicated strategies: {}".format(list_of_strategy)
num_of_ranks = functools.reduce(lambda x, y: x * y,
self._list_of_degree)
assert num_of_ranks == dist.get_world_size(
), "There are total {} ranks, but need {} ranks in this strategy.".format(
dist.get_world_size(), num_of_ranks)
for fused_strategy in self._fused_strategy_dict.values():
for strategy in fused_strategy:
assert strategy in self._list_of_strategy_name, "Can not fuse strategy {} without defined previous.".format(
strategy)
def _create_fused_group(self):
for name in self._fused_strategy_dict:
fused_strategy = self._fused_strategy_dict[name]
non_fused_strategy = list(
set(self._list_of_strategy_name).difference(fused_strategy))
non_fused_ranges = []
for strategy in non_fused_strategy:
non_fused_ranges.append(
range(self._name_to_degree_dict[strategy]))
fused_ranges = []
for strategy in fused_strategy:
fused_ranges.append(range(self._name_to_degree_dict[strategy]))
rank_list = []
for non_fused_ranks in itertools.product(*non_fused_ranges):
coord_dict = {}
ranks = []
for i, non_fused_rank in enumerate(non_fused_ranks):
coord_dict[non_fused_strategy[i]] = non_fused_rank
for fused_ranks in itertools.product(*fused_ranges):
for i, fused_rank in enumerate(fused_ranks):
coord_dict[fused_strategy[i]] = fused_rank
ranks.append(self._coord_to_rank_dict[self._coordinate(
**coord_dict)])
rank_list.append(ranks)
self._name_to_fused_group_dict[name] = StrategyGroupBase(rank_list)
def _calc_rank_list(self, strategy_axis):
ranges = []
for idx, degree in enumerate(self._list_of_degree):
if idx == strategy_axis:
continue
ranges.append(range(degree))
rank_list = []
for coord in itertools.product(*ranges):
ranks = []
for val in range(self._list_of_degree[strategy_axis]):
coord_list = list(coord)
coord_list.insert(strategy_axis, val)
ranks.append(
self._coord_to_rank_dict[self._coordinate(*coord_list)])
rank_list.append(ranks)
return rank_list
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed as dist
class StrategyGroupBase():
"""
The base class of communication group with distributed strategy.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of strategy group.
Examples:
.. code-block:: python
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase
dist.init_parallel_env()
strategy_group = dist.fleet.base.strategy_group.StrategyGroupBase([[0, 1], [2, 3]])
print(strategy_group.world_size) # 2
"""
def __init__(self, list_of_ranks):
assert dist.is_initialized(
), "The global communication group need to be initialized."
assert len(list_of_ranks), "The list_of_ranks can not be empty."
self._rank = dist.get_rank()
self._list_of_ranks = list_of_ranks
self._group = self._create_group()
@property
def world_size(self):
"""
The world size of communication group.
Returns:
Integer if the world_size of each group are equal, or a list of world_size if they are not equal.
"""
world_size_list = []
for ranks in self._list_of_ranks:
world_size_list.append(len(ranks))
is_value = all(world_size == world_size_list[0]
for world_size in world_size_list)
return world_size_list[0] if is_value else world_size_list
@property
def group(self):
"""
The communication group which current rank belongs to.
Returns:
Group if current rank only belong to single communication group, or a list of Group if it belongs many.
"""
return self._group
def _create_group(self):
list_of_group = []
for ranks in self._list_of_ranks:
group = dist.new_group(ranks=ranks)
if self._rank in ranks:
list_of_group.append(group)
assert len(
list_of_group
) > 0, "Rank {} does not belong to the list_of_ranks {}.".format(
self._rank, self._list_of_ranks)
return list_of_group if len(list_of_group) > 1 else list_of_group[0]
class DPGroup(StrategyGroupBase):
"""
The communication group strategy for data parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of data parallel strategy group.
"""
def __init__(self, list_of_ranks):
super(DPGroup, self).__init__(list_of_ranks)
assert not isinstance(
self.group, list), "Rank {} belongs to multi dp groups".format(
self._rank)
class MPGroup(StrategyGroupBase):
"""
The communication group strategy for model parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of model parallel strategy group.
"""
def __init__(self, list_of_ranks):
super(MPGroup, self).__init__(list_of_ranks)
assert not isinstance(
self.group, list), "Rank {} belongs to multi mp groups".format(
self._rank)
class ShardingGroup(StrategyGroupBase):
"""
The communication group strategy for sharding parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of sharding parallel strategy group.
"""
def __init__(self, list_of_ranks):
super(ShardingGroup, self).__init__(list_of_ranks)
assert not isinstance(
self.group,
list), "Rank {} belongs to multi sharding groups".format(self._rank)
class PPGroup(StrategyGroupBase):
"""
The communication group strategy for pipeline parallel.
Args:
list_of_ranks: A 2D-array, such as `[[0, 1, 2, 3], [4, 5, 6, 7]]`. Ranks in sublist represents
they are in the same communication group.
Returns:
The instance of pipeline parallel strategy group.
"""
def __init__(self, list_of_ranks):
super(PPGroup, self).__init__(list_of_ranks)
assert not isinstance(
self.group, list), "Rank {} belongs to multi pp groups".format(
self._rank)
self._send_next_group = None
self._send_prev_group = None
self._recv_next_group = None
self._recv_prev_group = None
self._rank_of_next_stage = None
self._rank_of_prev_stage = None
if self.world_size > 1:
self._create_p2p_group()
@property
def rank_of_prev_stage(self):
"""
Rank of the previous pp stage.
Returns:
The global rank of previous pp stage. `None` if without previous.
"""
return self._rank_of_prev_stage
@property
def rank_of_next_stage(self):
"""
Rank of the next pp stage.
Returns:
The global rank of next pp stage. `None` if without next.
"""
return self._rank_of_next_stage
@property
def p2p_groups(self):
"""
Communication subgroup in order to switch data with previous and next stage.
Returns:
Four subgroups including send/recv to/from prev/next.
"""
return self._send_next_group, self._send_prev_group, self._recv_next_group, self._recv_prev_group
def _create_p2p_group(self):
degree = self.world_size
for ranks in self._list_of_ranks:
for idx, rank in enumerate(ranks):
next_rank = ranks[(idx + 1) % degree]
prev_rank = ranks[(idx - 1) % degree]
if self._rank == rank:
self._rank_of_next_stage = next_rank
self._rank_of_prev_stage = prev_rank
next_group = dist.new_group(ranks=[rank, next_rank])
if self._rank == rank:
self._send_next_group = next_group
elif self._rank == next_rank:
self._recv_prev_group = next_group
prev_group = dist.new_group(ranks=[prev_rank, rank])
if self._rank == rank:
self._send_prev_group = prev_group
elif self._rank == prev_rank:
self._recv_next_group = prev_group
assert self._send_next_group and self._send_prev_group and self._recv_next_group and self._recv_prev_group,\
"Error occurs while creating p2p group for rank {}.".format(self._rank)
...@@ -391,5 +391,27 @@ if(WITH_MPI) ...@@ -391,5 +391,27 @@ if(WITH_MPI)
"PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy=") "PADDLE_DIST_UT_PORT=21672;http_proxy=;https_proxy=")
endif() endif()
endif() endif()
if((WITH_ROCM OR WITH_GPU) AND (LINUX))
bash_test_modules(
test_strategy_group
START_BASH
test_strategy_group.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21814;http_proxy=;https_proxy=")
set_tests_properties(test_strategy_group PROPERTIES TIMEOUT "120")
endif()
if((WITH_ROCM OR WITH_GPU) AND (LINUX))
bash_test_modules(
test_orthogonal_strategy
START_BASH
test_orthogonal_strategy.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21958;http_proxy=;https_proxy=")
set_tests_properties(test_orthogonal_strategy PROPERTIES TIMEOUT "120")
endif()
add_subdirectory(fleet) add_subdirectory(fleet)
add_subdirectory(multinode) add_subdirectory(multinode)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import DPGroup, ShardingGroup, MPGroup, PPGroup
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
class TestOrthogonalStrategyAPI(unittest.TestCase):
def setUp(self):
self._num_of_ranks = 2
dist.init_parallel_env()
self._global_rank = dist.get_rank()
self._strategy = OrthogonalStrategy(
[("dp", 2, DPGroup), ("mp", 1, MPGroup),
("sharding", 1, ShardingGroup), ("pp", 1, PPGroup)],
fused_strategy_dict={"checkness": ["mp", "sharding", "pp"]})
def test_orthogonal_strategy(self):
dp_group = self._strategy.strategy_group("dp")
self.assertEqual(dp_group.world_size, self._num_of_ranks)
self.assertEqual(dp_group.group.nranks, self._num_of_ranks)
self.assertEqual(self._strategy.rank_in_strategy("dp"),
self._global_rank)
fused_group = self._strategy.fused_strategy_group("checkness")
self.assertEqual(fused_group.world_size, 1)
self.assertEqual(fused_group.group.nranks, 1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed.fleet.base.strategy_group import StrategyGroupBase, DPGroup, MPGroup, PPGroup, ShardingGroup
def _check_using_all_reduce(group):
data = paddle.to_tensor([1, 2, 3])
result = paddle.to_tensor([2, 4, 6])
dist.all_reduce(data, group=group)
assert np.array_equal(data, result)
def _check_using_send(group, dst):
data = paddle.to_tensor([1, 2, 3])
dist.send(data, dst=dst, group=group)
def _check_using_recv(group, src):
result = paddle.to_tensor([1, 2, 3])
data = paddle.to_tensor([0, 0, 0])
dist.recv(data, src=src, group=group)
assert np.array_equal(data, result)
class TestStrategyGroupAPI(unittest.TestCase):
def setUp(self):
self._num_of_ranks = 2
self._list_of_rank = [[0, 1]]
self._list_of_ranks = [[0, 1], [0, 1]]
dist.init_parallel_env()
self._global_rank = dist.get_rank()
self._peer_rank = 0 if self._global_rank == 1 else 1
def test_strategy_group_base(self):
strategy_group = StrategyGroupBase(self._list_of_rank)
self.assertEqual(strategy_group.world_size, self._num_of_ranks)
self.assertEqual(strategy_group.group.nranks, self._num_of_ranks)
_check_using_all_reduce(strategy_group.group)
def test_data_parallel_group(self):
dp_group = DPGroup(self._list_of_rank)
self.assertEqual(dp_group.world_size, self._num_of_ranks)
self.assertEqual(dp_group.group.nranks, self._num_of_ranks)
_check_using_all_reduce(dp_group.group)
def test_model_parallel_group(self):
mp_group = MPGroup(self._list_of_rank)
self.assertEqual(mp_group.world_size, self._num_of_ranks)
self.assertEqual(mp_group.group.nranks, self._num_of_ranks)
_check_using_all_reduce(mp_group.group)
def test_sharding_parallel_group(self):
sharding_group = ShardingGroup(self._list_of_rank)
self.assertEqual(sharding_group.world_size, self._num_of_ranks)
self.assertEqual(sharding_group.group.nranks, self._num_of_ranks)
_check_using_all_reduce(sharding_group.group)
def test_pipeline_parallel_group(self):
pp_group = PPGroup(self._list_of_rank)
send_next_group, send_prev_group, recv_next_group, recv_prev_group = pp_group.p2p_groups
if self._global_rank == 0:
self.assertEqual(pp_group.rank_of_next_stage, 1)
self.assertEqual(pp_group.rank_of_prev_stage, 1)
_check_using_send(send_next_group, self._peer_rank)
_check_using_send(send_prev_group, self._peer_rank)
_check_using_recv(recv_prev_group, self._peer_rank)
_check_using_recv(recv_next_group, self._peer_rank)
else:
self.assertEqual(pp_group.rank_of_next_stage, 0)
self.assertEqual(pp_group.rank_of_prev_stage, 0)
_check_using_recv(recv_prev_group, self._peer_rank)
_check_using_recv(recv_next_group, self._peer_rank)
_check_using_send(send_next_group, self._peer_rank)
_check_using_send(send_prev_group, self._peer_rank)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 orthogonal_strategy.py
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 strategy_group.py
...@@ -46,3 +46,5 @@ test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_pro ...@@ -46,3 +46,5 @@ test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_pro
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=, test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
test_mpi_comm,linux,,,DIST,test_mpi_comm.sh,2,,http_proxy=;https_proxy=,WITH_MPI test_mpi_comm,linux,,,DIST,test_mpi_comm.sh,2,,http_proxy=;https_proxy=,WITH_MPI
test_strategy_group,linux,rocm;gpu,120,DIST,test_strategy_group.sh,2,,http_proxy=;https_proxy=,
test_orthogonal_strategy,linux,rocm;gpu,120,DIST,test_orthogonal_strategy.sh,2,,http_proxy=;https_proxy=,
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册