diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index d17e68276cd1ce576029cf306a18469aef2ffdb0..05b7a16f1594f370cbf73ab7fdb4c98e3bb76024 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -43,6 +43,12 @@ message GradientMergeConfig { optional bool avg = 2 [ default = true ]; } +message DGCConfig { + optional int32 rampup_begin_step = 1 [ default = 0 ]; + optional int32 rampup_step = 2 [ default = 1 ]; + repeated float sparsity = 3; +} + message LarsConfig { optional float lars_coeff = 1 [ default = 0.001 ]; optional float lars_weight_decay = 2 [ default = 0.0005 ]; @@ -114,6 +120,7 @@ message DistributedStrategy { optional AMPConfig amp_configs = 102; optional LocalSGDConfig localsgd_configs = 103; optional GradientMergeConfig gradient_merge_configs = 104; + optional DGCConfig dgc_configs = 105; optional PipelineConfig pipeline_configs = 106; optional AsyncConfig a_sync_configs = 107; optional LarsConfig lars_configs = 108; diff --git a/python/paddle/fleet/base/distributed_strategy.py b/python/paddle/fleet/base/distributed_strategy.py index 4cc7beadd80a071f7b22bb46f0b157bdffbd74f2..43e50ca0bee6b324655f7dcfb5e5da2ebc0e85a8 100644 --- a/python/paddle/fleet/base/distributed_strategy.py +++ b/python/paddle/fleet/base/distributed_strategy.py @@ -604,6 +604,15 @@ class DistributedStrategy(object): else: print("WARNING: lars should have value of bool type") + @property + def lars_configs(self): + return get_msg_dict(self.strategy.lars_configs) + + @lars_configs.setter + def lars_configs(self, configs): + check_configs_key(self.strategy.lars_configs, configs, "lars_configs") + assign_configs_value(self.strategy.lars_configs, configs) + @property def lamb(self): return self.strategy.lamb diff --git a/python/paddle/fleet/base/meta_optimizer_factory.py b/python/paddle/fleet/base/meta_optimizer_factory.py index 89ebb0ec601e249c58fd43995df1530f44940af4..802f6c4dab7f3a98cc11d9bb1956db5ee33b2746 100755 --- a/python/paddle/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/fleet/base/meta_optimizer_factory.py @@ -19,6 +19,7 @@ from ..meta_optimizers import GraphExecutionOptimizer from ..meta_optimizers import PipelineOptimizer from ..meta_optimizers import LocalSGDOptimizer from ..meta_optimizers import LarsOptimizer +from ..meta_optimizers import DGCOptimizer __all__ = ["MetaOptimizerFactory"] @@ -30,6 +31,7 @@ meta_optimizer_names = [ "PipelineOptimizer", "LocalSGDOptimizer", "LarsOptimizer", + "DGCOptimizer", ] diff --git a/python/paddle/fleet/meta_optimizers/__init__.py b/python/paddle/fleet/meta_optimizers/__init__.py index aa6708e758a78cf2cb10f8ebda81d50ac796b548..718805c5aadaf3476fa1fc495a355395fec6396d 100755 --- a/python/paddle/fleet/meta_optimizers/__init__.py +++ b/python/paddle/fleet/meta_optimizers/__init__.py @@ -18,6 +18,7 @@ from .graph_execution_optimizer import GraphExecutionOptimizer from .pipeline_optimizer import PipelineOptimizer from .localsgd_optimizer import LocalSGDOptimizer from .lars_optimizer import LarsOptimizer +from .dgc_optimizer import DGCOptimizer __all__ = [ 'AMPOptimizer', @@ -26,4 +27,5 @@ __all__ = [ 'PipelineOptimizer', 'LocalSGDOptimizer', 'LarsOptimizer', + 'DGCOptimizer', ] diff --git a/python/paddle/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/fleet/meta_optimizers/dgc_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a28fdaf11dd0d4d45cfd3fb1904b80dc136711 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/dgc_optimizer.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from paddle.fluid.optimizer import Momentum, DGCMomentumOptimizer +from .meta_optimizer_base import MetaOptimizerBase +import logging + +__all__ = ["DGCOptimizer"] + + +class DGCOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(DGCOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.dgc_opt = None + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(DGCOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + + opt = self.inner_opt + if not isinstance(opt, Momentum): + return + + configs = self.user_defined_strategy.dgc_configs + if len(configs['sparsity']) == 0: + # default is [0.999] + configs['sparsity'] = [0.999] + + self.dgc_opt = DGCMomentumOptimizer( + learning_rate=opt._learning_rate, + momentum=opt._momentum, + rampup_begin_step=configs['rampup_begin_step'], + rampup_step=configs['rampup_step'], + sparsity=configs['sparsity'], + parameter_list=opt._parameter_list, + use_nesterov=opt._use_nesterov, + num_trainers=self.role_maker.worker_num(), + regularization=opt.regularization, + grad_clip=opt._grad_clip, + name=opt._name) + + def _can_apply(self): + if self.user_defined_strategy.dgc: + if not isinstance(self.inner_opt, Momentum): + logging.warn("dgc only works on Momentum optimizer") + return False + if self.role_maker.worker_num() <= 1: + logging.warn("dgc only works on multi cards") + return False + + return True + + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.dgc = False + dist_strategy.dgc_configs = { + 'rampup_begin_step': 0, + 'rampup_step': 1, + 'sparsity': [0.999] + } + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self.dgc_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + optimize_ops, params_grads = \ + self.dgc_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + return optimize_ops, params_grads diff --git a/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py index 1a3cfda94b98c9514208433dfcf5947caea8537c..9ba184fb0089589a86d6444d12cf402b9687b041 100644 --- a/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py @@ -40,7 +40,8 @@ class MetaOptimizerBase(object): return True def _disable_strategy(self, dist_strategy): - raise NotImplementedError("you should implement disable strategy") + raise NotImplementedError("you should implement disable strategy in {}". + format(type(self).__name__)) def minimize_impl(self, loss, diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index e66f640665e2ba9ca9aab51af3f65b50169de404..c84d2ac3796efe9d16641552f1be939a666aa4cf 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -47,9 +47,8 @@ __all__ = [ 'AdamOptimizer', 'AdamaxOptimizer', 'DpsgdOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'FtrlOptimizer', 'Adadelta', 'AdadeltaOptimizer', 'ModelAverage', 'LarsMomentum', - 'LarsMomentumOptimizer', 'DGCMomentumOptimizer', 'LambOptimizer', - 'ExponentialMovingAverage', 'PipelineOptimizer', 'LookaheadOptimizer', - 'RecomputeOptimizer' + 'LarsMomentumOptimizer', 'LambOptimizer', 'ExponentialMovingAverage', + 'PipelineOptimizer', 'LookaheadOptimizer', 'RecomputeOptimizer' ] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 7696839843b41d21eb6fdea0664ca69b917d8a0e..d73b9511b76ed6585c662264e99fe41f3354bc29 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -39,6 +39,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) @@ -388,6 +389,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_dgc_op MODULES test_dgc_op) py_test_modules(test_dgc_momentum_op MODULES test_dgc_momentum_op) py_test_modules(test_dgc_optimizer MODULES test_dgc_optimizer) + py_test_modules(test_fleet_dgc_meta_optimizer MODULES test_fleet_dgc_meta_optimizer) else() # if not with dgc, must close all dgc tests list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_dgc_nccl") diff --git a/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py b/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py index 07746dd9f6cff297feacfa2dac24d89b2af876ab..0b9b85d5d52c38f748679a92a99ec61c3dec7903 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py +++ b/python/paddle/fluid/tests/unittests/test_dist_mnist_dgc_nccl.py @@ -39,7 +39,6 @@ class TestDistMnistNCCL2DGC(TestDistBase): self._nccl2_mode = True self._use_dgc = True - @unittest.skip(reason="Skip unstable ut") def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): @@ -69,7 +68,6 @@ class TestDistMnistNCCL2DGCMultiCards(TestDistBase): self._nccl2_mode = True self._use_dgc = True - @unittest.skip(reason="Skip unstable ut") def test_dist_train(self): import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0590650bd02f5535b9c35bae187e77bc7274901c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.fleet as fleet +import paddle.fluid.incubate.fleet.base.role_maker as role_maker + + +class TestFleetDGCOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" + + def net(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.fleet.DistributedStrategy() + strategy.dgc = True + strategy.dgc_configs = { + "rampup_begin_step": 128, + "rampup_step": 100, + "sparsity": [0.996, 0.999] + } + return avg_cost, strategy + + def test_dgc_optimizer(self): + avg_cost, strategy = self.net() + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('dgc', ops) + self.assertIn('dgc_momentum', ops) + + def test_dgc_not_apply_with_adam(self): + avg_cost, strategy = self.net() + optimizer = paddle.optimizer.Adam(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertNotIn('dgc', ops) + self.assertNotIn('dgc_momentum', ops) + + def test_dgc_not_apply_with_one_worker(self): + os.environ["PADDLE_TRAINER_ID"] = "0" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + + avg_cost, strategy = self.net() + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertNotIn('dgc', ops) + self.assertNotIn('dgc_momentum', ops) + + +if __name__ == "__main__": + unittest.main()