From 8ebffc78c9f999759a35921c71b83226200d8561 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Mon, 3 Aug 2020 11:33:29 +0800 Subject: [PATCH] add lars to fleet meta optimizer (#25884) --- .../framework/distributed_strategy.proto | 14 +++- .../fleet/base/meta_optimizer_factory.py | 2 + .../paddle/fleet/meta_optimizers/__init__.py | 2 + .../fleet/meta_optimizers/lars_optimizer.py | 83 +++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 2 + .../test_fleet_lars_meta_optimizer.py | 73 ++++++++++++++++ 6 files changed, 175 insertions(+), 1 deletion(-) mode change 100644 => 100755 paddle/fluid/framework/distributed_strategy.proto mode change 100644 => 100755 python/paddle/fleet/base/meta_optimizer_factory.py mode change 100644 => 100755 python/paddle/fleet/meta_optimizers/__init__.py create mode 100755 python/paddle/fleet/meta_optimizers/lars_optimizer.py mode change 100644 => 100755 python/paddle/fluid/tests/unittests/CMakeLists.txt create mode 100755 python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100644 new mode 100755 index aafb4b91095..96ddc82b1c9 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -40,6 +40,17 @@ message GradientMergeConfig { optional bool avg = 2 [ default = true ]; } +message LarsConfig { + optional float lars_coeff = 1 [ default = 0.001 ]; + optional float lars_weight_decay = 2 [ default = 0.0005 ]; +} + +message LambConfig { + optional float beta1 = 1 [ default = 0.001 ]; + optional float beta2 = 2 [ default = 0.999 ]; + optional float epsilon = 3 [ default = 0.000001 ]; +} + message BuildStrategy { optional bool enable_sequential_execution = 1 [ default = false ]; optional bool fuse_elewise_add_act_ops = 2 [ default = false ]; @@ -102,7 +113,8 @@ message DistributedStrategy { optional GradientMergeConfig gradient_merge_configs = 104; optional PipelineConfig pipeline_configs = 106; optional AsyncConfig a_sync_configs = 107; - + optional LarsConfig lars_configs = 108; + optional LambConfig lamb_configs = 109; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/python/paddle/fleet/base/meta_optimizer_factory.py b/python/paddle/fleet/base/meta_optimizer_factory.py old mode 100644 new mode 100755 index 31350e6934b..bbbd5fcacd6 --- a/python/paddle/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/fleet/base/meta_optimizer_factory.py @@ -17,6 +17,7 @@ from ..meta_optimizers import GradientMergeOptimizer from ..meta_optimizers import GraphExecutionOptimizer from ..meta_optimizers import PipelineOptimizer from ..meta_optimizers import LocalSGDOptimizer +from ..meta_optimizers import LarsOptimizer __all__ = ["MetaOptimizerFactory"] @@ -26,6 +27,7 @@ meta_optimizer_names = [ "GraphExecutionOptimizer", "PipelineOptimizer", "LocalSGDOptimizer", + "LarsOptimizer", ] diff --git a/python/paddle/fleet/meta_optimizers/__init__.py b/python/paddle/fleet/meta_optimizers/__init__.py old mode 100644 new mode 100755 index 95fbf4b7ddf..0beb06eacf8 --- a/python/paddle/fleet/meta_optimizers/__init__.py +++ b/python/paddle/fleet/meta_optimizers/__init__.py @@ -16,10 +16,12 @@ from .gradient_merge_optimizer import GradientMergeOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer from .pipeline_optimizer import PipelineOptimizer from .localsgd_optimizer import LocalSGDOptimizer +from .lars_optimizer import LarsOptimizer __all__ = [ 'RecomputeOptimizer', 'GradientMergeOptimizer', 'PipelineOptimizer', 'LocalSGDOptimizer', + 'LarsOptimizer', ] diff --git a/python/paddle/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/fleet/meta_optimizers/lars_optimizer.py new file mode 100755 index 00000000000..ff535e3ebf2 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/lars_optimizer.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from paddle.fluid.optimizer import Momentum, LarsMomentumOptimizer +from .meta_optimizer_base import MetaOptimizerBase +import logging + +__all__ = ["LarsOptimizer"] + + +class LarsOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(LarsOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.lars_opt = None + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(LarsOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + + opt = self.inner_opt + if not isinstance(opt, Momentum): + return + + configs = self.user_defined_strategy.lars_configs + + self.lars_opt = LarsMomentumOptimizer( + learning_rate=opt._learning_rate, + momentum=opt._momentum, + lars_coeff=configs['lars_coeff'], + lars_weight_decay=configs['lars_weight_decay'], + parameter_list=opt._parameter_list, + regularization=opt.regularization, + grad_clip=opt._grad_clip, + name=opt._name) + + def _can_apply(self): + if self.user_defined_strategy.lars: + if not isinstance(self.inner_opt, Momentum): + logging.warn( + "lars need the inner optimizer to be Momentum optimizer.") + return False + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.lars = False + dist_strategy.lars_configs = { + 'lars_coeff': 0.001, + 'lars_weight_decay': 0.0005, + } + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self.lars_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + optimize_ops, params_grads = \ + self.lars_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + return optimize_ops, params_grads diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt old mode 100644 new mode 100755 index fff2e8d7651..db3dc6b8594 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -36,6 +36,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) @@ -375,6 +376,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) if(NOT WIN32) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS}) endif(NOT WIN32) endif(NOT APPLE) if(WITH_DGC) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py new file mode 100755 index 00000000000..960ffbd4035 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.fleet as fleet +import paddle.fluid.incubate.fleet.base.role_maker as role_maker + + +class TestFleetLarsMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def net(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.fleet.DistributedStrategy() + strategy.lars = True + strategy.lars_configs = { + "lars_coeff": 0.001, + "lars_weight_decay": 0.0005, + } + + return avg_cost, strategy + + def test_lars_optimizer(self): + avg_cost, strategy = self.net() + optimizer = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertIn('lars_momentum', ops) + + def test_lars_not_apply_with_adam(self): + avg_cost, strategy = self.net() + optimizer = paddle.optimizer.Adam(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = [op.type for op in avg_cost.block.ops] + self.assertNotIn('lars_momentum', ops) + + +if __name__ == "__main__": + unittest.main() -- GitLab