From d2cdc7e3a633cb651082be3229f5eb1b609f3140 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Wed, 29 Mar 2023 21:39:07 -0500 Subject: [PATCH] [BugFix]Fix segment fault in order setting (#52293) * fix bug in proto * add utest --- paddle/fluid/framework/distributed_strategy.proto | 1 - .../fleet/base/distributed_strategy.py | 10 +++++++++- python/paddle/distributed/fleet/fleet.py | 4 +--- .../fleet/test_fleet_distributed_strategy.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index d9182c488f2..b9055d38d38 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -55,7 +55,6 @@ message HybridConfig { optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; optional int32 sharding_degree = 4 [ default = 1 ]; - repeated string order = 5 ; } message AMPConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 455f7bca375..950fddaf9db 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy + import google.protobuf import google.protobuf.text_format @@ -149,6 +151,7 @@ class DistributedStrategy: if _global_flags().is_public(key): self.strategy.sync_nccl_allreduce = bool(_global_flags()[key]) + self.hybrid_parallel_order = ['dp', 'pp', 'sharding', 'mp'] self.__lock_attr = True logger.info("distributed strategy initialized") @@ -1691,8 +1694,13 @@ class DistributedStrategy: @hybrid_configs.setter def hybrid_configs(self, configs): + hybrid_config = copy.deepcopy(configs) + if "order" in hybrid_config: + self.hybrid_parallel_order = hybrid_config["order"] + hybrid_config.pop('order') + check_configs_key( - self.strategy.hybrid_configs, configs, "hybrid_configs" + self.strategy.hybrid_configs, hybrid_config, "hybrid_configs" ) assign_configs_value(self.strategy.hybrid_configs, configs) diff --git a/python/paddle/distributed/fleet/fleet.py b/python/paddle/distributed/fleet/fleet.py index eda074100ee..9debd488d2e 100755 --- a/python/paddle/distributed/fleet/fleet.py +++ b/python/paddle/distributed/fleet/fleet.py @@ -412,9 +412,7 @@ class Fleet: "mp": ['model', self.mp_degree], } - order = self.hybrid_configs["order"] - if not order: - order = ['dp', 'pp', 'sharding', 'mp'] + order = self._user_defined_strategy.hybrid_parallel_order if order[:].sort() != list(d_hybrid_degree.keys())[:].sort(): raise AssertionError( 'The order of hybrid_config setting is incorrect.' diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py index e773014629d..99f235b5887 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_distributed_strategy.py @@ -84,6 +84,21 @@ class TestStrategyConfig(unittest.TestCase): self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + def test_hybrid_parallel_configs_order(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": 2, + "pp_degree": 4, + "order": ['sharding', 'mp', 'dp', 'pp'], + } + self.assertEqual(strategy.hybrid_configs["dp_degree"], 1) + self.assertEqual(strategy.hybrid_configs["mp_degree"], 2) + self.assertEqual(strategy.hybrid_configs["pp_degree"], 4) + self.assertEqual( + strategy.hybrid_parallel_order, ['sharding', 'mp', 'dp', 'pp'] + ) + def test_localsgd(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.localsgd = True -- GitLab