From 7dbab1031de8a65686a79a6e2fa9cd0a2724e6a9 Mon Sep 17 00:00:00 2001 From: sandyhouse Date: Sun, 7 Feb 2021 16:39:26 +0800 Subject: [PATCH] update, test=develop --- paddle/fluid/framework/section_worker.cc | 1 - .../fleet/base/distributed_strategy.py | 54 +++++++++++++++++++ .../fleet/meta_optimizers/amp_optimizer.py | 11 +++- .../meta_optimizers/sharding_optimizer.py | 3 +- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 5cb0497ece7..fb75a26e5b5 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -13,7 +13,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/executor_gc_helper.h" - #include "paddle/fluid/platform/device_context.h" namespace paddle { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f7a28f15e9b..3bad28bbd14 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -736,6 +736,60 @@ class DistributedStrategy(object): "sharding_configs") assign_configs_value(self.strategy.sharding_configs, configs) + @property + def model_parallel(self): + """ + Indicating whether we are using model parallel parallelism for distributed training. + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.model_parallel = True + + """ + return self.strategy.model_parallel + + @model_parallel.setter + @is_strict_auto + def model_parallel(self, flag): + if isinstance(flag, bool): + self.strategy.model_parallel = flag + else: + print("WARNING: model_parallel should have value of bool type") + + @property + def model_parallel_configs(self): + """ + Set model_parallel parallelism configurations. + + **Notes**: + **Detailed arguments for model_parallel_configs** + + **parallelism**: degree of model parallel + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.model_parallel = True + strategy.model_parallel_configs = {"parallelism": 12} + + """ + + return get_msg_dict(self.strategy.model_parallel_configs) + + @model_parallel_configs.setter + @is_strict_auto + def model_parallel_configs(self, configs): + check_configs_key(self.strategy.model_parallel_configs, configs, + "model_parallel_configs") + assign_configs_value(self.strategy.model_parallel_configs, configs) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index eefa206c8d8..cf6962357cb 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -50,7 +50,8 @@ class AMPOptimizer(MetaOptimizerBase): self.inner_opt, amp_lists, config['init_loss_scaling'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_ratio'], config['decr_ratio'], - config['use_dynamic_loss_scaling']) + config['use_dynamic_loss_scaling'], config['use_pure_fp16'], + config['use_fp16_guard']) # if worker_num > 1, all cards will communication with each other, # add is_distributed to optimize amp, overlap communication and @@ -113,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase): self.wrapped_opt.minimize(loss, startup_program, parameter_list, no_grad_set) return optimize_ops, params_grads + + def amp_init(self, + place, + scope=None, + test_program=None, + use_fp16_test=False): + return self.wrapped_opt.amp_init(place, scope, test_program, + use_fp16_test) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 35bfd6a1b0c..2e544ceb718 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -87,8 +87,7 @@ class ShardingOptimizer(MetaOptimizerBase): self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ "as_outer_parallelism"] self._inner_parallelism_size = int( - self.user_defined_strategy.sharding_configs[ - "inner_parallelism_size"]) + self.user_defined_strategy.sharding_configs["parallelism"]) self.use_pipeline = self.user_defined_strategy.sharding_configs[ "use_pipeline"] -- GitLab