diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 5f5e5a3fac5fad2038f0182558dbbdc0c6838384..1b329a0d65b84453eee478d4d7c98f8eeb9f0bb3 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -55,6 +55,7 @@ message HybridConfig { optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; + repeated string order = 5 ; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 3614d703922832f6d8b1212bc9e734e7007793fb..57e4378041eaefcf66b7cd0ce557cdf6357a26eb 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1673,6 +1673,8 @@ class DistributedStrategy: **pp_degree(int)**: set number of GPUs in a pipeline parallel group. Default 1 + **order(list(string))**: set hybrid parallel dimensions, the order is from outside to inside. Default ['dp','pp','sharding','mp'] + Examples: .. code-block:: python @@ -1681,7 +1683,8 @@ class DistributedStrategy: strategy.hybrid_configs = { "dp_degree": 1, "mp_degree": 2, - "pp_degree": 1} + "pp_degree": 1, + "order":['dp','pp','sharding','mp']} """ return get_msg_dict(self.strategy.hybrid_configs) diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index 74ee30e349fa9b74b6df48701664f4fd24aff999..6ce0b8f9189db2f03d5e5198fb50c066e05b7a65 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -405,14 +405,28 @@ class Fleet: self.dp_degree = max(self.dp_degree, 1) + d_hybrid_degree = { + "dp": ["data", self.dp_degree], + "pp": ['pipe', self.pp_degree], + "sharding": ['sharding', self.sharding_degree], + "mp": ['model', self.mp_degree], + } + + order = self.hybrid_configs["order"] + if not order: + order = ['dp', 'pp', 'sharding', 'mp'] + if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): + assert False, "The order of hybrid_config setting is incorrect." + + hybrid_group_names = [] + dims = [] + for h_name in order: + name, degree = d_hybrid_degree[h_name] + hybrid_group_names.append(name) + dims.append(degree) + self._topology = tp.CommunicateTopology( - hybrid_group_names=["data", "pipe", "sharding", "model"], - dims=[ - self.dp_degree, - self.pp_degree, - self.sharding_degree, - self.mp_degree, - ], + hybrid_group_names=hybrid_group_names, dims=dims ) self._hcg = tp.HybridCommunicateGroup(self._topology)