From 83cd1859475446f47b981ff1df0f8ae7d8f6cd50 Mon Sep 17 00:00:00 2001 From: Dong Daxiang <35550832+guru4elephant@users.noreply.github.com> Date: Fri, 21 Aug 2020 13:27:08 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91Meta=20from=20opt?= =?UTF-8?q?imizer=20(#26392)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * consider the combination of different strategies to work together --- .../distributed/fleet/base/fleet_base.py | 1 + .../fleet/base/strategy_compiler.py | 14 +++++ .../fleet/meta_optimizers/amp_optimizer.py | 7 ++- .../fleet/meta_optimizers/dgc_optimizer.py | 1 + .../gradient_merge_optimizer.py | 9 ++- .../graph_execution_optimizer.py | 1 + .../fleet/meta_optimizers/lamb_optimizer.py | 3 +- .../fleet/meta_optimizers/lars_optimizer.py | 3 +- .../meta_optimizers/localsgd_optimizer.py | 1 + .../meta_optimizers/meta_optimizer_base.py | 39 +++++++++++-- .../meta_optimizers/pipeline_optimizer.py | 1 + .../meta_optimizers/recompute_optimizer.py | 8 ++- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../tests/unittests/launch_function_helper.py | 15 +++++ ...st_fleet_graph_execution_meta_optimizer.py | 17 ++++-- .../test_fleet_meta_optimizer_base.py | 58 +++++++++++++++++++ 16 files changed, 166 insertions(+), 14 deletions(-) create mode 100755 python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 8093ad504c1..6e090d90655 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -78,6 +78,7 @@ class Fleet(object): def init(self, role_maker): self._role_maker = role_maker self.strategy_compiler = StrategyCompiler() + return None def is_first_worker(self): """ diff --git a/python/paddle/distributed/fleet/base/strategy_compiler.py b/python/paddle/distributed/fleet/base/strategy_compiler.py index a5ff247a21f..4097fc1237f 100644 --- a/python/paddle/distributed/fleet/base/strategy_compiler.py +++ b/python/paddle/distributed/fleet/base/strategy_compiler.py @@ -114,4 +114,18 @@ class StrategyCompiler(StrategyCompilerBase): 0] return_graph = None if graph_optimizers == None else graph_optimizers[ 0] + + if meta_optimizers == None or graph_optimizers == None: + return return_meta, return_graph + + # do heuristic filter here, if any meta optimizer in graph optimizers is in + # any meta optimizers' black list, set return_graph to None + need_graph_opt = True + for graph_opt in graph_optimizers: + for program_opt in meta_optimizers: + if graph_opt.__class__.__name__ in program_opt.meta_optimizers_black_list: + need_graph_opt = False + if not need_graph_opt: + return_graph = None + return return_meta, return_graph diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index 6b1756c3695..66db14209b4 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -23,7 +23,12 @@ class AMPOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.amp_opt = None # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "LarsOptimizer", "LambOptimizer", "RecomputeOptimizer", + "LocalSGDOptimizer", "GradientMergeOptimizer", + "GraphExecutionOptimizer" + ] + self.meta_optimizers_black_list = ["DGCOptimizer"] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index 361175a11c5..f34786f9dc3 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -25,6 +25,7 @@ class DGCOptimizer(MetaOptimizerBase): self.dgc_opt = None # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py index 28cbce317a9..bd52179a358 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py @@ -16,13 +16,20 @@ from .meta_optimizer_base import MetaOptimizerBase __all__ = ["GradientMergeOptimizer"] +# amp + gradient merge + lamb + class GradientMergeOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(GradientMergeOptimizer, self).__init__(optimizer) self.inner_opt = optimizer self.wrapped_opt = GM(optimizer) - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "LarsOptimizer", + "LambOptimizer", + "GraphExecutionOptimizer", + ] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index b9ff31a068e..ace31687338 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -25,6 +25,7 @@ class GraphExecutionOptimizer(MetaOptimizerBase): self.inner_opt = optimizer # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] + self.meta_optimizers_black_list = [] def _is_graph_out(self): return True diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index d9a31c17e0d..7e08a02eb1d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -25,7 +25,8 @@ class LambOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.lamb_opt = None # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = ["GraphExecutionOptimizer"] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index a54a4fc5599..09c418fa791 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -24,7 +24,8 @@ class LarsOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.lars_opt = None # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = ["GraphExecutionOptimizer"] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index c807815ff46..e22127c1399 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -25,6 +25,7 @@ class LocalSGDOptimizer(MetaOptimizerBase): super(LocalSGDOptimizer, self).__init__(optimizer) self.inner_opt = optimizer self.meta_optimizers_white_list = [] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer"] self.snapshot_key = '@SNAPSHOT' def _can_apply(self): diff --git a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py index 04800cefdda..12a4d904340 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py +++ b/python/paddle/distributed/fleet/meta_optimizers/meta_optimizer_base.py @@ -14,10 +14,16 @@ __all__ = ["MetaOptimizerBase"] +from paddle.fluid.optimizer import Optimizer -class MetaOptimizerBase(object): + +class MetaOptimizerBase(Optimizer): def __init__(self, optimizer): - pass + self.inner_opt = optimizer + self._learning_rate = self.inner_opt._learning_rate + self._learning_rate_map = self.inner_opt._learning_rate_map + self.meta_optimizers_white_list = [] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): @@ -26,7 +32,7 @@ class MetaOptimizerBase(object): self.user_defined_optimizer = user_defined_optimizer self.user_defined_strategy = user_defined_strategy - def _update_inner_optimier(self, optimizer): + def _update_inner_optimizer(self, optimizer): self.inner_opt = optimizer def _can_apply(self): @@ -44,12 +50,37 @@ class MetaOptimizerBase(object): raise NotImplementedError("you should implement disable strategy in {}". format(type(self).__name__)) + def apply_gradients(self, params_grads): + return self.inner_opt.apply_gradients(params_grads=params_grads) + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self.inner_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def apply_optimize(self, loss, startup_program, params_grads): + return self.inner_opt.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + def minimize_impl(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): - raise NotImplementedError("meta optimizer not implemented") + params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + + optimize_ops = self.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) + + return optimize_ops, params_grads def minimize(self, loss, diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index 8a0c48aa544..fe9221307cb 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -95,6 +95,7 @@ class PipelineOptimizer(MetaOptimizerBase): self.inner_opt = optimizer # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index 07b69f19e7e..45130b44712 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -24,7 +24,13 @@ class RecomputeOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.wrapped_opt = RO(optimizer) # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "LarsOptimizer", + "LambOptimizer", + "GradientMergeOptimizer", + "GraphExecutionOptimizer", + ] + self.meta_optimizers_black_list = [] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c565f8da4fa..33d9326681d 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -46,6 +46,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_dgc_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) list(APPEND MIXED_DIST_TEST_OPS test_fleet_graph_executor) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer_base) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -399,6 +400,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) + py_test_modules(test_fleet_meta_optimizer_base MODULES test_fleet_meta_optimizer_base ENVS ${dist_ENVS}) if(NOT WIN32) py_test_modules(test_fleet_localsgd_meta_optimizer MODULES test_fleet_localsgd_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_lars_meta_optimizer MODULES test_fleet_lars_meta_optimizer ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/launch_function_helper.py b/python/paddle/fluid/tests/unittests/launch_function_helper.py index 64fee35710a..ecfe39b80e9 100644 --- a/python/paddle/fluid/tests/unittests/launch_function_helper.py +++ b/python/paddle/fluid/tests/unittests/launch_function_helper.py @@ -13,6 +13,8 @@ # limitations under the License. from multiprocessing import Pool, Process import os +import socket +from contextlib import closing def launch_func(func, env_dict): @@ -20,3 +22,16 @@ def launch_func(func, env_dict): os.environ[key] = env_dict[key] proc = Process(target=func) return proc + + +def _find_free_port(port_set): + def __free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in port_set: + port_set.add(port) + return port diff --git a/python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py index 3e97ab3bfc6..8b2c32c2f02 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_graph_execution_meta_optimizer.py @@ -15,7 +15,7 @@ import unittest import paddle import os -from launch_function_helper import launch_func +from launch_function_helper import launch_func, _find_free_port class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): @@ -71,20 +71,27 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): proc_b.join() def test_graph_execution_optimizer(self): + + port_set = set() + port_a = _find_free_port(port_set) + port_b = _find_free_port(port_set) + node_a = { "PADDLE_TRAINER_ID": "0", - "PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36001", + "PADDLE_CURRENT_ENDPOINT": "127.0.0.1:{}".format(port_a), "PADDLE_TRAINERS_NUM": "2", - "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", + "PADDLE_TRAINER_ENDPOINTS": + "127.0.0.1:{},127.0.0.1:{}".format(port_a, port_b), "http_proxy": "", "https_proxy": "" } node_b = { "PADDLE_TRAINER_ID": "1", - "PADDLE_CURRENT_ENDPOINT": "127.0.0.1:36002", + "PADDLE_CURRENT_ENDPOINT": "127.0.0.1:{}".format(port_b), "PADDLE_TRAINERS_NUM": "2", - "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", + "PADDLE_TRAINER_ENDPOINTS": + "127.0.0.1:{},127.0.0.1:{}".format(port_a, port_b), "http_proxy": "", "https_proxy": "" } diff --git a/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py new file mode 100755 index 00000000000..81bb3d36d72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer_base.py @@ -0,0 +1,58 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +from paddle import fluid +import os +import paddle.distributed.fleet as fleet +import paddle.fluid.incubate.fleet.base.role_maker as role_maker +from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase + + +class TestFleetMetaOptimizerBase(unittest.TestCase): + def net(main_prog, startup_prog): + with fluid.program_guard(main_prog, startup_prog): + with fluid.unique_name.guard(): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data( + name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, + size=64, + act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=256, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], + size=2, + act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01) + opt = MetaOptimizerBase(optimizer) + opt_ops, params_grads = opt.minimize(avg_cost) + opt.apply_optimize(avg_cost, + paddle.static.default_startup_program(), + params_grads) + return None + + net(fluid.default_startup_program(), fluid.default_main_program()) + + +if __name__ == "__main__": + unittest.main() -- GitLab