From 0443b480b80895d7b030ba3e4c9c64c659756e0f Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Mon, 7 Sep 2020 18:52:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91add=20auto=20para?= =?UTF-8?q?llel=20L1=20implementations=20(#27090)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add auto parallel L1 implementation test=develop --- .../fleet/base/distributed_strategy.py | 127 ++++++++++++++++++ .../distributed/fleet/base/fleet_base.py | 13 ++ .../fleet/meta_optimizers/amp_optimizer.py | 11 ++ .../fleet/meta_optimizers/dgc_optimizer.py | 4 + .../gradient_merge_optimizer.py | 4 + .../graph_execution_optimizer.py | 9 +- .../fleet/meta_optimizers/lamb_optimizer.py | 7 + .../fleet/meta_optimizers/lars_optimizer.py | 7 + .../meta_optimizers/localsgd_optimizer.py | 4 + .../meta_optimizers/meta_optimizer_base.py | 4 + .../parameter_server_graph_optimizer.py | 5 + .../parameter_server_optimizer.py | 5 + .../meta_optimizers/pipeline_optimizer.py | 4 + .../meta_optimizers/recompute_optimizer.py | 4 + .../fluid/tests/unittests/CMakeLists.txt | 2 + .../fluid/tests/unittests/test_fleet_auto.py | 51 +++++++ 16 files changed, 257 insertions(+), 4 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_auto.py diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 9c1793fd5b..62967a202a 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -15,10 +15,25 @@ import paddle from paddle.distributed.fleet.proto import distributed_strategy_pb2 from paddle.fluid.framework import Variable, set_flags, core +from paddle.fluid.wrapped_decorator import wrap_decorator import google.protobuf.text_format __all__ = ["DistributedStrategy"] +non_auto_func_called = True + + +def __non_auto_func_called__(func): + def __impl__(*args, **kwargs): + global non_auto_func_called + non_auto_func_called = False + return func(*args, **kwargs) + + return __impl__ + + +is_strict_auto = wrap_decorator(__non_auto_func_called__) + def get_msg_dict(msg): res_dict = {} @@ -164,6 +179,7 @@ class DistributedStrategy(object): return execution_strategy @execution_strategy.setter + @is_strict_auto def execution_strategy(self, strategy): fields = self.strategy.execution_strategy.DESCRIPTOR.fields for f in fields: @@ -203,6 +219,7 @@ class DistributedStrategy(object): return build_strategy @build_strategy.setter + @is_strict_auto def build_strategy(self, strategy): fields = self.strategy.build_strategy.DESCRIPTOR.fields for f in fields: @@ -237,6 +254,7 @@ class DistributedStrategy(object): return self.strategy.a_sync @a_sync.setter + @is_strict_auto def a_sync(self, flag): if isinstance(flag, bool): self.strategy.a_sync = flag @@ -287,6 +305,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.a_sync_configs) @a_sync_configs.setter + @is_strict_auto def a_sync_configs(self, configs): check_configs_key(self.strategy.a_sync_configs, configs, "a_sync_configs") @@ -309,6 +328,7 @@ class DistributedStrategy(object): return self.strategy.amp @amp.setter + @is_strict_auto def amp(self, flag): if isinstance(flag, bool): self.strategy.amp = flag @@ -351,6 +371,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.amp_configs) @amp_configs.setter + @is_strict_auto def amp_configs(self, configs): check_configs_key(self.strategy.amp_configs, configs, "amp_configs") assign_configs_value(self.strategy.amp_configs, configs) @@ -388,6 +409,7 @@ class DistributedStrategy(object): return self.strategy.sync_nccl_allreduce @sync_nccl_allreduce.setter + @is_strict_auto def sync_nccl_allreduce(self, flag): if isinstance(flag, bool): self.strategy.sync_nccl_allreduce = flag @@ -411,6 +433,7 @@ class DistributedStrategy(object): return self.strategy.use_hierarchical_allreduce @use_hierarchical_allreduce.setter + @is_strict_auto def use_hierarchical_allreduce(self, flag): if isinstance(flag, bool): self.strategy.use_hierarchical_allreduce = flag @@ -435,6 +458,7 @@ class DistributedStrategy(object): return self.strategy.hierarchical_allreduce_inter_nranks @hierarchical_allreduce_inter_nranks.setter + @is_strict_auto def hierarchical_allreduce_inter_nranks(self, value): if isinstance(value, int): self.strategy.hierarchical_allreduce_inter_nranks = value @@ -461,6 +485,7 @@ class DistributedStrategy(object): return self.strategy.sync_batch_norm @sync_batch_norm.setter + @is_strict_auto def sync_batch_norm(self, flag): if isinstance(flag, bool): self.strategy.sync_batch_norm = flag @@ -483,6 +508,7 @@ class DistributedStrategy(object): return self.strategy.fuse_all_reduce_ops @fuse_all_reduce_ops.setter + @is_strict_auto def fuse_all_reduce_ops(self, flag): if isinstance(flag, bool): self.strategy.fuse_all_reduce_ops = flag @@ -506,6 +532,7 @@ class DistributedStrategy(object): return self.strategy.fuse_grad_size_in_MB @fuse_grad_size_in_MB.setter + @is_strict_auto def fuse_grad_size_in_MB(self, value): if isinstance(value, int): self.strategy.fuse_grad_size_in_MB = value @@ -517,6 +544,7 @@ class DistributedStrategy(object): return self.strategy.fuse_grad_size_in_TFLOPS @_fuse_grad_size_in_TFLOPS.setter + @is_strict_auto def _fuse_grad_size_in_TFLOPS(self, value): if isinstance(value, float): self.strategy.fuse_grad_size_in_TFLOPS = value @@ -543,6 +571,7 @@ class DistributedStrategy(object): return self.strategy.nccl_comm_num @nccl_comm_num.setter + @is_strict_auto def nccl_comm_num(self, value): if isinstance(value, int): self.strategy.nccl_comm_num = value @@ -550,6 +579,7 @@ class DistributedStrategy(object): print("WARNING: nccl_comm_num should have value of int type") @recompute.setter + @is_strict_auto def recompute(self, flag): if isinstance(flag, bool): self.strategy.recompute = flag @@ -574,6 +604,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.recompute_configs) @recompute_configs.setter + @is_strict_auto def recompute_configs(self, configs): check_configs_key(self.strategy.recompute_configs, configs, "checkpoint_configs") @@ -598,6 +629,7 @@ class DistributedStrategy(object): return self.strategy.pipeline @pipeline.setter + @is_strict_auto def pipeline(self, flag): if isinstance(flag, bool): self.strategy.pipeline = flag @@ -634,6 +666,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.pipeline_configs) @pipeline_configs.setter + @is_strict_auto def pipeline_configs(self, configs): check_configs_key(self.strategy.pipeline_configs, configs, "pipeline_configs") @@ -658,6 +691,7 @@ class DistributedStrategy(object): return self.strategy.localsgd @localsgd.setter + @is_strict_auto def localsgd(self, flag): if isinstance(flag, bool): self.strategy.localsgd = flag @@ -690,6 +724,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.localsgd_configs) @localsgd_configs.setter + @is_strict_auto def localsgd_configs(self, configs): check_configs_key(self.strategy.localsgd_configs, configs, "localsgd_configs") @@ -714,6 +749,7 @@ class DistributedStrategy(object): return self.strategy.dgc @dgc.setter + @is_strict_auto def dgc(self, flag): if isinstance(flag, bool): self.strategy.dgc = flag @@ -749,6 +785,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.dgc_configs) @dgc_configs.setter + @is_strict_auto def dgc_configs(self, configs): check_configs_key(self.strategy.dgc_configs, configs, "dgc_configs") assign_configs_value(self.strategy.dgc_configs, configs) @@ -776,6 +813,7 @@ class DistributedStrategy(object): return self.strategy.gradient_merge @gradient_merge.setter + @is_strict_auto def gradient_merge(self, flag): if isinstance(flag, bool): self.strategy.gradient_merge = flag @@ -803,6 +841,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.gradient_merge_configs) @gradient_merge_configs.setter + @is_strict_auto def gradient_merge_configs(self, configs): check_configs_key(self.strategy.gradient_merge_configs, configs, "gradient_configs") @@ -827,6 +866,7 @@ class DistributedStrategy(object): return self.strategy.lars @lars.setter + @is_strict_auto def lars(self, flag): if isinstance(flag, bool): self.strategy.lars = flag @@ -862,6 +902,7 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.lars_configs) @lars_configs.setter + @is_strict_auto def lars_configs(self, configs): check_configs_key(self.strategy.lars_configs, configs, "lars_configs") assign_configs_value(self.strategy.lars_configs, configs) @@ -887,6 +928,7 @@ class DistributedStrategy(object): return self.strategy.lamb @lamb.setter + @is_strict_auto def lamb(self, flag): if isinstance(flag, bool): self.strategy.lamb = flag @@ -917,15 +959,21 @@ class DistributedStrategy(object): return get_msg_dict(self.strategy.lamb_configs) @lamb_configs.setter + @is_strict_auto def lamb_configs(self, configs): check_configs_key(self.strategy.lamb_configs, configs, "lamb_configs") assign_configs_value(self.strategy.lamb_configs, configs) @property def elastic(self): + """ + Indicating whether we want to do current distributed training on clusters with elastic resources. + Currently, this is configuration is not valid. + """ return self.strategy.elastic @elastic.setter + @is_strict_auto def elastic(self, flag): if isinstance(flag, bool): self.strategy.elastic = flag @@ -934,6 +982,25 @@ class DistributedStrategy(object): @property def auto(self): + """ + Indicating whether we are using auto-parallel configuration + This feature is currently an experimental feature. Currently, + auto-parallelism can be used only when a user does not set any other + strategy configs except auto. For details, please reference the following + code example + Default Value: False + + Examples: + .. code-block:: python + + import paddle + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.auto = True + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + """ return self.strategy.auto @auto.setter @@ -945,9 +1012,27 @@ class DistributedStrategy(object): @property def cudnn_exhaustive_search(self): + """ + Indicating whether to use exhaustive search method to choose convolution algorithms. + Exhaustive search attempts all cuDNN algorithms to choose the fastest algorithm. + This method is time-consuming, the choosed algorithm will be cached for the given layer specifications. + Once the layer specifications (like batch size, feature map size) are changed, it will search again. + Default Value: True + + Examples: + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.cudnn_exhaustive_search = False + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + """ return self.strategy.cudnn_exhaustive_search @cudnn_exhaustive_search.setter + @is_strict_auto def cudnn_exhaustive_search(self, flag): if isinstance(flag, bool): self.strategy.cudnn_exhaustive_search = flag @@ -958,9 +1043,28 @@ class DistributedStrategy(object): @property def conv_workspace_size_limit(self): + """ + The workspace limit size in MB unit for choosing cuDNN convolution algorithms. + The inner funciton of cuDNN obtain the fastest suited algorithm that fits within this memory limit. + Usually, large workspace size may lead to choose faster algorithms, + but significant increasing memory workspace. Users need to trade-off between memory and speed. + Default Value: 4000 + + Examples: + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.conv_workspace_size_limit = 1024 + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + + """ return self.strategy.conv_workspace_size_limit @conv_workspace_size_limit.setter + @is_strict_auto def conv_workspace_size_limit(self, value): if isinstance(value, int): self.strategy.conv_workspace_size_limit = value @@ -971,9 +1075,26 @@ class DistributedStrategy(object): @property def cudnn_batchnorm_spatial_persistent(self): + """ + Indicates whether to use the mode CUDNN_BATCHNORM_SPATIAL_PERSISTENT function in batchnorm. + This is only useful in cudnn. + Default Value: True + + Examples: + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.cudnn_batchnorm_spatial_persistent = True + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy) + + """ return self.strategy.cudnn_batchnorm_spatial_persistent @cudnn_batchnorm_spatial_persistent.setter + @is_strict_auto def cudnn_batchnorm_spatial_persistent(self, flag): if isinstance(flag, bool): self.strategy.cudnn_batchnorm_spatial_persistent = flag @@ -1005,6 +1126,12 @@ class DistributedStrategy(object): if core.globals().is_public(key): core.globals()[key] = values[i] + def _is_strict_auto(self): + global non_auto_func_called + if self.strategy.auto and non_auto_func_called: + return True + return False + def __repr__(self): fields = self.strategy.DESCRIPTOR.fields for f in fields: diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 8c748060e6..b918949269 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import print_function +import copy import warnings import paddle from paddle.fluid.framework import dygraph_only @@ -1008,6 +1009,18 @@ class Fleet(object): MetaOptimizerFactory()._get_valid_meta_optimizers( self.user_defined_optimizer) + context["user_defined_strategy"] = copy.copy(self.user_defined_strategy) + + # trigger the auto-parallel in very strict condition + # strategy = DistributedStrategy() + # strategy.auto = True + # optimizer = paddle.optimizer.SGD(learning_rate=0.1) + # optimizer = fleet.distributed_optimizer(optimizer, strategy) + if self.user_defined_strategy._is_strict_auto(): + # turn on all the strategy for each optimizer + for opt in distributed_optimizer_list: + opt._enable_strategy(self.user_defined_strategy) + valid_optimizer_list = [] valid_graph_optimizer_list = [] can_not_apply_optimizer_list = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index b1952276e4..938bd25884 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -42,6 +42,17 @@ class AMPOptimizer(MetaOptimizerBase): dist_strategy.amp = False dist_strategy.amp_configs = {} + def _enable_strategy(self, dist_strategy): + dist_strategy.amp = True + dist_strategy.amp_configs = { + "init_loss_scaling": 32768.0, + "incr_every_n_steps": 1000, + "decr_every_n_nan_or_inf": 2, + "incr_ratio": 2.0, + "decr_ratio": 8.0, + "use_dynamic_loss_scaling": True + } + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index f1c6defc5c..d292f58456 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -69,6 +69,10 @@ class DGCOptimizer(MetaOptimizerBase): dist_strategy.dgc = False dist_strategy.dgc_configs = {} + def _enable_strategy(self, dist_strategy): + dist_strategy.dgc = True + dist_strategy.dgc_configs = {"rampup_begin_step": 0, "rampup_step": 1} + def backward(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py index 7db79ad7b5..bb0c631e08 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py @@ -45,6 +45,10 @@ class GradientMergeOptimizer(MetaOptimizerBase): dist_strategy.gradient_merge = False dist_strategy.gradient_merge_configs = {} + def _enable_strategy(self, dist_strategy): + # we currently do not support auto-enable gradient merge + return + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index ace3168733..03304f1b68 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -148,9 +148,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase): sync_allreduce = dist_strategy.sync_nccl_allreduce if sync_allreduce: - paddle.fluid.framework.set_flags({ - "FLAGS_sync_nccl_allreduce": True - }) exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1 if local_build_strategy.use_hierarchical_allreduce: exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1 @@ -191,7 +188,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase): def _disable_strategy(self, dist_strategy): # TODO(guru4elephant): should close all PE related flags here - pass + return + + def _enable_strategy(self, dist_strategy): + # by default, graph execution strategy is enabled + return def minimize(self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index 9fa29c4078..3a9f2be533 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -75,6 +75,13 @@ class LambOptimizer(MetaOptimizerBase): dist_strategy.lamb = False dist_strategy.lamb_configs = {} + def _enable_strategy(self, dist_strategy): + dist_strategy.lamb = True + dist_strategy.lamb_configs = { + "lamb_weight_decay": 0.01, + "exclude_from_weight_decay": [] + } + def backward(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index a7b856ff5b..cb12154ddc 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -59,6 +59,13 @@ class LarsOptimizer(MetaOptimizerBase): dist_strategy.lars = False dist_strategy.lars_configs = {} + def _enable_strategy(self, dist_strategy): + dist_strategy.lars = True + dist_strategy.lars_configs = { + "lars_coeff": 0.01, + "lars_weight_decay": 0.0005, + } + def backward(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index e22127c139..3ac2fd374a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -42,6 +42,10 @@ class LocalSGDOptimizer(MetaOptimizerBase): dist_strategy.localsgd = False dist_strategy.localsgd_configs = {} + def _enable_strategy(self, dist_strategy): + dist_strategy.localsgd = True + dist_strategy.localsgd_configs = {"k_steps": 1} + def snapshot_name(self, param_name): return param_name + self.snapshot_key diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 073148e11a..b105c25b3a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -48,6 +48,10 @@ class MetaOptimizerBase(Optimizer): raise NotImplementedError("you should implement disable strategy in {}". format(type(self).__name__)) + def _enable_strategy(self, dist_strategy): + raise NotImplementedError("you should implement enable strategy in {}". + format(type(self).__name__)) + def apply_gradients(self, params_grads): return self.inner_opt.apply_gradients(params_grads=params_grads) diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py index 878ed7422d..c9260dd2f8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py @@ -39,6 +39,11 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): def _disable_strategy(self, dist_strategy): dist_strategy.a_sync_configs = {} + def _enable_strategy(self, dist_strategy): + # only open up the async mode for auto-parallel + dist_strategy.a_sync = True + dist_strategy.a_sync_configs = {} + def _is_graph_out(self): return True diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index ecb198bedf..f394a792e3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -157,4 +157,9 @@ class ParameterServerOptimizer(MetaOptimizerBase): return None, None def _disable_strategy(self, dist_strategy): + dist_strategy.a_sync_configs = {} self.user_defined_strategy.a_sync_configs = {} + + def _enable_strategy(self, dist_strategy): + dist_strategy.a_sync = True + dist_strategy.a_sync_configs = {} diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index d5a45e2b4e..32c54d4486 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -111,6 +111,10 @@ class PipelineOptimizer(MetaOptimizerBase): dist_strategy.pipeline = False dist_strategy.pipeline_configs = {} + def _enable_strategy(self, dist_strategy): + # we do not support enable pipeline automatically right now + return + def minimize_impl(self, loss, startup_program=None, diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index 3eb3ca6127..267656824c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -49,6 +49,10 @@ class RecomputeOptimizer(MetaOptimizerBase): dist_strategy.recompute = False dist_strategy.recompute_configs = {} + def _enable_strategy(self, dist_strategy): + # we do not support automatically recompute checkpoints currently + return + def backward(self, loss, startup_program=None, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index a25cba029d..9358132519 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -47,6 +47,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor) list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_auto) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -458,6 +459,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS}) + py_test_modules(test_fleet_auto MODULES test_fleet_auto ENVS ${dist_ENVS}) if(NOT WIN32) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_auto.py b/python/paddle/fluid/tests/unittests/test_fleet_auto.py new file mode 100644 index 0000000000..020f2f4db3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_auto.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker + + +class TestDistributedStrategyAuto(unittest.TestCase): + def setUp(self): + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def test_distributed_strategy_auto(self): + fleet.init(is_collective=True) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.auto = True + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main() -- GitLab