From 4f403d3e3565d8c87c997f05f08d27b19c1109d1 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Sat, 17 Sep 2022 16:38:57 +0800 Subject: [PATCH] update strategy (#46138) --- .../distributed/auto_parallel/constants.py | 2 +- .../paddle/distributed/auto_parallel/engine.py | 18 ++++++++++++------ .../distributed/auto_parallel/strategy.py | 8 ++++---- .../passes/auto_parallel_sharding.py | 17 +++++++++++++---- .../auto_parallel/clip_grad_by_global_norm.py | 2 +- .../auto_parallel/optimization_tuner_api.py | 2 +- .../auto_parallel/sharding_pass_unittest.py | 2 +- .../unittests/auto_parallel/test_strategy.py | 8 ++++---- 8 files changed, 37 insertions(+), 22 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 82b3d4554b..f04d2994ab 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -81,7 +81,7 @@ set_field_default_config(AMP, "use_optimizer_fp16", False) SHARDING = "sharding" set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "stage", 1) -set_field_default_config(SHARDING, "sharding_degree", 8) +set_field_default_config(SHARDING, "degree", 8) set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "tuning_range", []) diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 07b505efbd..c4e4fbcfb9 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -564,9 +564,11 @@ class Engine: self._infer_sample_spec(train_data, batch_size, train_sample_split) if not self._mode_init_states[self.mode]: self._prepare_single_mode(self.mode) + else: + self._switch_mode("train") assert self.mode in self._dist_main_progs, \ - "train model is not ready, please call `engine.prepare()` first." + "train model is not ready, please call `engine._prepare_single_mode('train')` first." train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch, collate_fn) @@ -621,8 +623,8 @@ class Engine: self.evaluate(valid_data, valid_sample_split, batch_size, valid_steps, collate_fn, callbacks) self._switch_mode("train") - - self._reset_metrics() + else: + self._reset_metrics() return outputs def evaluate(self, @@ -682,9 +684,11 @@ class Engine: self._infer_sample_spec(valid_data, batch_size, valid_sample_split) if not self._mode_init_states[self.mode]: self._prepare_single_mode(self.mode) + else: + self._switch_mode("eval") assert self.mode in self._dist_main_progs, \ - "eval model is not ready, please call `engine.prepare()` first." + "eval model is not ready, please call `engine._prepare_single_mode('eval')` first." valid_dataloader = self._create_dataloader(valid_data, batch_size, steps_per_epoch=steps, @@ -785,9 +789,11 @@ class Engine: self._infer_sample_spec(test_data, batch_size, test_sample_split) if not self._mode_init_states[self.mode]: self._prepare_single_mode(self.mode) + else: + self._switch_mode("predict") assert self.mode in self._dist_main_progs, \ - "predict model is not ready, please call `engine.prepare()` first." + "predict model is not ready, please call `engine._prepare_single_mode('predict')` first." test_dataloader = self._create_dataloader(test_data, batch_size, steps_per_epoch=steps, @@ -1059,7 +1065,7 @@ class Engine: """ if training: assert 'train' in self._serial_main_progs, \ - "training model is not ready, please call `engine.prepare()` first." + "training model is not ready, please call `engine._prepare_single_mode('train')` first." serial_program = self._serial_main_progs["train"] dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_context = self._dist_contexts["train"] diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index c196b321ea..e40fde9664 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -117,7 +117,7 @@ class TuningConfig(BaseConfig): class Strategy(BaseConfig): """ - The `Strategy` object is used to configure the paralleization and optimization beheviors. + The `Strategy` object is used to configure the paralleization and optimization beheviors. Args: config (dict|string, optional): If this is None, the default configurations will used. @@ -136,13 +136,13 @@ class Strategy(BaseConfig): sharding = strategy.sharding self.assertEqual(sharding.enabled, False) self.assertEqual(sharding.stage, 1) - self.assertEqual(sharding.sharding_degree, 8) + self.assertEqual(sharding.degree, 8) sharding.enabled = True sharding.stage = 2 - sharding.sharding_degree = 2 + sharding.degree = 2 self.assertEqual(sharding.enabled, True) self.assertEqual(sharding.stage, 2) - self.assertEqual(sharding.sharding_degree, 2) + self.assertEqual(sharding.degree, 2) """ diff --git a/python/paddle/distributed/passes/auto_parallel_sharding.py b/python/paddle/distributed/passes/auto_parallel_sharding.py index 8b4d2288b7..dcc786f8ff 100644 --- a/python/paddle/distributed/passes/auto_parallel_sharding.py +++ b/python/paddle/distributed/passes/auto_parallel_sharding.py @@ -51,7 +51,8 @@ class ShardingPass(PassBase): super(ShardingPass, self).__init__() self.set_attr("dist_context", None) self.set_attr("stage", None) - self.set_attr("sharding_degree", None) + self.set_attr("sharding_degree", None) # for parallelizer + self.set_attr("degree", None) # for parallelizer_v2 self.set_attr("params_grads", []) self.set_attr("global_rank", -1) self.dp_groups = set() @@ -67,8 +68,15 @@ class ShardingPass(PassBase): if self.get_attr("stage") not in [1, 2, 3]: return False - if (not isinstance(self.get_attr("sharding_degree"), - int)) or self.get_attr("sharding_degree") <= 1: + if self.get_attr("sharding_degree") is not None: + if (not isinstance(self.get_attr("sharding_degree"), int)) \ + or self.get_attr("sharding_degree") <= 1: + return False + elif self.get_attr("degree") is not None: + if (not isinstance(self.get_attr("degree"), int)) \ + or self.get_attr("degree") <= 1: + return False + else: return False if len(self.get_attr("params_grads")) <= 0: return False @@ -83,7 +91,8 @@ class ShardingPass(PassBase): def _apply_single_impl(self, main_program, startup_program, context): self._dist_context = self.get_attr("dist_context") - self.sharding_world_size = int(self.get_attr("sharding_degree")) + self.sharding_world_size = int( + self.get_attr("sharding_degree") or self.get_attr("degree")) self.stage = int(self.get_attr("stage")) self.global_rank = int(self.get_attr("global_rank")) params_grads = self.get_attr("params_grads") diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py index 1a8c5e6072..5409f6919f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -31,7 +31,7 @@ def apply_pass(use_sharding=False): strategy.reinit = True if use_sharding: sharding = strategy.sharding - sharding.sharding_degree = 2 + sharding.degree = 2 sharding.stage = 2 return strategy diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index a245329a93..c8e553c486 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -96,7 +96,7 @@ def train(fetch): # sharding config sharding = dist_strategy.sharding sharding.enable = True - sharding.sharding_degree = 2 + sharding.degree = 2 sharding.stage = 3 sharding.enable_tuning = True sharding.tuning_range = [0, 1, 2, 3] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py index 70dfd5f87d..4613a726ce 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py @@ -32,7 +32,7 @@ def apply_pass(use_sharding=False, stage=None): if use_sharding: sharding = strategy.sharding sharding.enable = True - sharding.sharding_degree = 2 + sharding.degree = 2 sharding.stage = 1 return strategy diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py index 9fae8d970b..d5a660e3f2 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -45,7 +45,7 @@ class TestStrategy(unittest.TestCase): sharding = strategy.sharding self.assertEqual(sharding.enable, False) self.assertEqual(sharding.stage, 1) - self.assertEqual(sharding.sharding_degree, 8) + self.assertEqual(sharding.degree, 8) self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.tuning_range, []) @@ -112,13 +112,13 @@ class TestStrategy(unittest.TestCase): sharding = strategy.sharding sharding.enable = True sharding.stage = 2 - sharding.sharding_degree = 2 + sharding.degree = 2 sharding.segment_broadcast_MB = 64.0 sharding.enable_tuning = True sharding.tuning_range = [1, 2, 3] self.assertEqual(sharding.enable, True) self.assertEqual(sharding.stage, 2) - self.assertEqual(sharding.sharding_degree, 2) + self.assertEqual(sharding.degree, 2) self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0) self.assertEqual(sharding.enable_tuning, True) self.assertEqual(sharding.tuning_range, [1, 2, 3]) @@ -175,7 +175,7 @@ class TestStrategy(unittest.TestCase): # enable: false # enable_tuning: true # segment_broadcast_MB: 64.0 - # sharding_degree: 8 + # degree: 8 # stage: 2 # tuning_range: None # split_data: false -- GitLab