未验证 提交 f48611f3 编写于 作者: S ShenLiang 提交者: GitHub

[Cherry-Pick]Add identity hcg for hybridparallel (#53787)

* add utest

* rm hack code
上级 cc6dcc7d
...@@ -47,11 +47,16 @@ class OrthogonalStrategy: ...@@ -47,11 +47,16 @@ class OrthogonalStrategy:
""" """
def __init__(self, list_of_strategy, fused_strategy_dict={}): def __init__(
self, list_of_strategy, fused_strategy_dict={}, strategy_rank_list=None
):
self._list_of_strategy = list_of_strategy self._list_of_strategy = list_of_strategy
self._fused_strategy_dict = fused_strategy_dict self._fused_strategy_dict = fused_strategy_dict
self._rank = dist.get_rank() self._strategy_rank_list = (
self._rank_list_dict = {} strategy_rank_list
if strategy_rank_list is not None
else list(range(dist.get_world_size()))
)
self._name_to_group_dict = {} self._name_to_group_dict = {}
self._name_to_degree_dict = {} self._name_to_degree_dict = {}
self._list_of_strategy_name = [ self._list_of_strategy_name = [
...@@ -67,16 +72,17 @@ class OrthogonalStrategy: ...@@ -67,16 +72,17 @@ class OrthogonalStrategy:
list_of_coord = [ list_of_coord = [
self._coordinate(*coord) for coord in itertools.product(*ranges) self._coordinate(*coord) for coord in itertools.product(*ranges)
] ]
self._coord_to_rank_dict = dict( self._coord_to_rank_dict = dict(
zip(list_of_coord, range(len(list_of_coord))) zip(list_of_coord, self._strategy_rank_list)
) )
for idx, strategy in enumerate(list_of_strategy): for idx, strategy in enumerate(list_of_strategy):
strategy_name = strategy[0] strategy_name = strategy[0]
self._name_to_degree_dict[strategy_name] = strategy[1] self._name_to_degree_dict[strategy_name] = strategy[1]
self._rank_list_dict[strategy_name] = self._calc_rank_list(idx) rank_list = self._calc_rank_list(idx)
self._name_to_group_dict[strategy_name] = strategy[2]( self._name_to_group_dict[strategy_name] = strategy[2](
self._rank_list_dict[strategy_name] rank_list,
) )
self._name_to_fused_group_dict = {} self._name_to_fused_group_dict = {}
...@@ -136,11 +142,13 @@ class OrthogonalStrategy: ...@@ -136,11 +142,13 @@ class OrthogonalStrategy:
num_of_ranks = functools.reduce( num_of_ranks = functools.reduce(
lambda x, y: x * y, self._list_of_degree lambda x, y: x * y, self._list_of_degree
) )
assert (
num_of_ranks == dist.get_world_size() assert num_of_ranks == len(
self._strategy_rank_list
), "There are total {} ranks, but need {} ranks in this strategy.".format( ), "There are total {} ranks, but need {} ranks in this strategy.".format(
dist.get_world_size(), num_of_ranks len(self._strategy_rank_list), num_of_ranks
) )
for fused_strategy in self._fused_strategy_dict.values(): for fused_strategy in self._fused_strategy_dict.values():
for strategy in fused_strategy: for strategy in fused_strategy:
assert ( assert (
......
...@@ -12,7 +12,9 @@ ...@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.fleet.layers.mpu import RNGStatesTracker
class StrategyGroupBase: class StrategyGroupBase:
...@@ -39,6 +41,9 @@ class StrategyGroupBase: ...@@ -39,6 +41,9 @@ class StrategyGroupBase:
""" """
def __init__(self, list_of_ranks): def __init__(self, list_of_ranks):
"""
Initialize the communication group.
"""
assert ( assert (
dist.is_initialized() dist.is_initialized()
), "The global communication group need to be initialized." ), "The global communication group need to be initialized."
...@@ -46,6 +51,19 @@ class StrategyGroupBase: ...@@ -46,6 +51,19 @@ class StrategyGroupBase:
self._rank = dist.get_rank() self._rank = dist.get_rank()
self._list_of_ranks = list_of_ranks self._list_of_ranks = list_of_ranks
self._group = self._create_group() self._group = self._create_group()
self.random_states_tracker = RNGStatesTracker()
def add_random_seed(self, name, seed):
"""
Add random seed for current rank.
"""
self.random_states_tracker.add(name, seed)
def get_random_states_tracker(self):
"""
Get the random states tracker.
"""
return self.random_states_tracker
@property @property
def world_size(self): def world_size(self):
...@@ -74,17 +92,28 @@ class StrategyGroupBase: ...@@ -74,17 +92,28 @@ class StrategyGroupBase:
return self._group return self._group
def _create_group(self): def _create_group(self):
list_of_group = [] self.list_of_group = []
for ranks in self._list_of_ranks: for ranks in self._list_of_ranks:
group = dist.new_group(ranks=ranks) group = dist.new_group(ranks=ranks)
if self._rank in ranks: if self._rank in ranks:
list_of_group.append(group) self.list_of_group.append(group)
assert (
len(list_of_group) > 0 if not self.list_of_group:
), "Rank {} does not belong to the list_of_ranks {}.".format( return None
self._rank, self._list_of_ranks else:
) return (
return list_of_group if len(list_of_group) > 1 else list_of_group[0] self.list_of_group[0]
if len(self.list_of_group) == 1
else self.list_of_group
)
def __repr__(self):
debug_str = f"seed: {self._seed}; "
if not self.list_of_group:
return debug_str + "No group."
for i in range(len(self.list_of_group)):
debug_str += f"Group[{i}]: {str(self.list_of_group[i])}; "
return debug_str
class DPGroup(StrategyGroupBase): class DPGroup(StrategyGroupBase):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import unittest import unittest
import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
from paddle.distributed.fleet.base.strategy_group import ( from paddle.distributed.fleet.base.strategy_group import (
...@@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase): ...@@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase):
self.assertEqual(fused_group.group.nranks, 1) self.assertEqual(fused_group.group.nranks, 1)
class TestOrthogonalStrategyCustomAPI(unittest.TestCase):
def setUp(self):
self._num_of_ranks = 2
dist.init_parallel_env()
self._global_rank = dist.get_rank()
self._strategy = OrthogonalStrategy(
[
("dp", 1, DPGroup),
("mp", 2, MPGroup),
("sharding", 1, ShardingGroup),
("pp", 1, PPGroup),
],
fused_strategy_dict={"checkness": ["mp", "sharding", "pp"]},
strategy_rank_list=[1, 0],
)
self._strategy.strategy_group("mp").add_random_seed("local_seed", 123)
self._strategy.strategy_group("mp").add_random_seed("global_seed", 321)
def test_orthogonal_strategy(self):
mp_group = self._strategy.strategy_group("mp")
self.assertEqual(mp_group.world_size, self._num_of_ranks)
self.assertEqual(mp_group.group.nranks, self._num_of_ranks)
self.assertEqual(
self._strategy.rank_in_strategy("mp"), self._global_rank
)
fused_group = self._strategy.fused_strategy_group("checkness")
self.assertEqual(fused_group.world_size, 2)
self.assertEqual(fused_group.group.nranks, 2)
with mp_group.random_states_tracker.rng_state("local_seed"):
a = paddle.randint(0, 100, [10]).numpy()[0]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册