From e0eb5cf3aba1e34aac758a7b1fad38cbac0ecb89 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Tue, 21 Mar 2023 01:09:52 -0500 Subject: [PATCH] [OPT]Set order for hybridparallel setting (#51781) * set order for hybridparallel * fix bug * fix -> * fix -> * fix -> * add topology * fix utest --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 5 +++- python/paddle/distributed/fleet/fleet.py | 28 ++++++++++++++----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 5f5e5a3fac5..1b329a0d65b 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -55,6 +55,7 @@ message HybridConfig { optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; + repeated string order = 5 ; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 3614d703922..57e4378041e 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1673,6 +1673,8 @@ class DistributedStrategy: **pp_degree(int)**: set number of GPUs in a pipeline parallel group. Default 1 + **order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','mp'] + Examples: .. code-block:: python @@ -1681,7 +1683,8 @@ class DistributedStrategy: strategy.hybrid_configs = { "dp_degree": 1, "mp_degree": 2, - "pp_degree": 1} + "pp_degree": 1, + "order":['dp','pp','sharding','mp']} """ return get_msg_dict(self.strategy.hybrid_configs) diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index 74ee30e349f..6ce0b8f9189 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -405,14 +405,28 @@ class Fleet: self.dp_degree = max(self.dp_degree, 1) + d_hybrid_degree = { + "dp": ["data", self.dp_degree], + "pp": ['pipe', self.pp_degree], + "sharding": ['sharding', self.sharding_degree], + "mp": ['model', self.mp_degree], + } + + order = self.hybrid_configs["order"] + if not order: + order = ['dp', 'pp', 'sharding', 'mp'] + if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): + assert False, "The order of hybrid_config setting is incorrect." + + hybrid_group_names = [] + dims = [] + for h_name in order: + name, degree = d_hybrid_degree[h_name] + hybrid_group_names.append(name) + dims.append(degree) + self._topology = tp.CommunicateTopology( - hybrid_group_names=["data", "pipe", "sharding", "model"], - dims=[ - self.dp_degree, - self.pp_degree, - self.sharding_degree, - self.mp_degree, - ], + hybrid_group_names=hybrid_group_names, dims=dims ) self._hcg = tp.HybridCommunicateGroup(self._topology) -- GitLab