“8f7b020ba8925d3709b9712024195ce94cf518e2”上不存在“paddle/legacy/gserver/layers/SliceProjection.cpp”
未验证 提交 f48611f3 编写于 作者: S ShenLiang 提交者: GitHub

[Cherry-Pick]Add identity hcg for hybridparallel (#53787)

* add utest

* rm hack code
上级 cc6dcc7d
......@@ -47,11 +47,16 @@ class OrthogonalStrategy:
"""
def __init__(self, list_of_strategy, fused_strategy_dict={}):
def __init__(
self, list_of_strategy, fused_strategy_dict={}, strategy_rank_list=None
):
self._list_of_strategy = list_of_strategy
self._fused_strategy_dict = fused_strategy_dict
self._rank = dist.get_rank()
self._rank_list_dict = {}
self._strategy_rank_list = (
strategy_rank_list
if strategy_rank_list is not None
else list(range(dist.get_world_size()))
)
self._name_to_group_dict = {}
self._name_to_degree_dict = {}
self._list_of_strategy_name = [
......@@ -67,16 +72,17 @@ class OrthogonalStrategy:
list_of_coord = [
self._coordinate(*coord) for coord in itertools.product(*ranges)
]
self._coord_to_rank_dict = dict(
zip(list_of_coord, range(len(list_of_coord)))
zip(list_of_coord, self._strategy_rank_list)
)
for idx, strategy in enumerate(list_of_strategy):
strategy_name = strategy[0]
self._name_to_degree_dict[strategy_name] = strategy[1]
self._rank_list_dict[strategy_name] = self._calc_rank_list(idx)
rank_list = self._calc_rank_list(idx)
self._name_to_group_dict[strategy_name] = strategy[2](
self._rank_list_dict[strategy_name]
rank_list,
)
self._name_to_fused_group_dict = {}
......@@ -136,11 +142,13 @@ class OrthogonalStrategy:
num_of_ranks = functools.reduce(
lambda x, y: x * y, self._list_of_degree
)
assert (
num_of_ranks == dist.get_world_size()
assert num_of_ranks == len(
self._strategy_rank_list
), "There are total {} ranks, but need {} ranks in this strategy.".format(
dist.get_world_size(), num_of_ranks
len(self._strategy_rank_list), num_of_ranks
)
for fused_strategy in self._fused_strategy_dict.values():
for strategy in fused_strategy:
assert (
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.distributed as dist
from paddle.distributed.fleet.layers.mpu import RNGStatesTracker
class StrategyGroupBase:
......@@ -39,6 +41,9 @@ class StrategyGroupBase:
"""
def __init__(self, list_of_ranks):
"""
Initialize the communication group.
"""
assert (
dist.is_initialized()
), "The global communication group need to be initialized."
......@@ -46,6 +51,19 @@ class StrategyGroupBase:
self._rank = dist.get_rank()
self._list_of_ranks = list_of_ranks
self._group = self._create_group()
self.random_states_tracker = RNGStatesTracker()
def add_random_seed(self, name, seed):
"""
Add random seed for current rank.
"""
self.random_states_tracker.add(name, seed)
def get_random_states_tracker(self):
"""
Get the random states tracker.
"""
return self.random_states_tracker
@property
def world_size(self):
......@@ -74,17 +92,28 @@ class StrategyGroupBase:
return self._group
def _create_group(self):
list_of_group = []
self.list_of_group = []
for ranks in self._list_of_ranks:
group = dist.new_group(ranks=ranks)
if self._rank in ranks:
list_of_group.append(group)
assert (
len(list_of_group) > 0
), "Rank {} does not belong to the list_of_ranks {}.".format(
self._rank, self._list_of_ranks
self.list_of_group.append(group)
if not self.list_of_group:
return None
else:
return (
self.list_of_group[0]
if len(self.list_of_group) == 1
else self.list_of_group
)
return list_of_group if len(list_of_group) > 1 else list_of_group[0]
def __repr__(self):
debug_str = f"seed: {self._seed}; "
if not self.list_of_group:
return debug_str + "No group."
for i in range(len(self.list_of_group)):
debug_str += f"Group[{i}]: {str(self.list_of_group[i])}; "
return debug_str
class DPGroup(StrategyGroupBase):
......
......@@ -14,6 +14,7 @@
import unittest
import paddle
import paddle.distributed as dist
from paddle.distributed.fleet.base.orthogonal_strategy import OrthogonalStrategy
from paddle.distributed.fleet.base.strategy_group import (
......@@ -52,5 +53,40 @@ class TestOrthogonalStrategyAPI(unittest.TestCase):
self.assertEqual(fused_group.group.nranks, 1)
class TestOrthogonalStrategyCustomAPI(unittest.TestCase):
def setUp(self):
self._num_of_ranks = 2
dist.init_parallel_env()
self._global_rank = dist.get_rank()
self._strategy = OrthogonalStrategy(
[
("dp", 1, DPGroup),
("mp", 2, MPGroup),
("sharding", 1, ShardingGroup),
("pp", 1, PPGroup),
],
fused_strategy_dict={"checkness": ["mp", "sharding", "pp"]},
strategy_rank_list=[1, 0],
)
self._strategy.strategy_group("mp").add_random_seed("local_seed", 123)
self._strategy.strategy_group("mp").add_random_seed("global_seed", 321)
def test_orthogonal_strategy(self):
mp_group = self._strategy.strategy_group("mp")
self.assertEqual(mp_group.world_size, self._num_of_ranks)
self.assertEqual(mp_group.group.nranks, self._num_of_ranks)
self.assertEqual(
self._strategy.rank_in_strategy("mp"), self._global_rank
)
fused_group = self._strategy.fused_strategy_group("checkness")
self.assertEqual(fused_group.world_size, 2)
self.assertEqual(fused_group.group.nranks, 2)
with mp_group.random_states_tracker.rng_state("local_seed"):
a = paddle.randint(0, 100, [10]).numpy()[0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册