diff --git a/python/paddle/fleet/base/meta_optimizer_factory.py b/python/paddle/fleet/base/meta_optimizer_factory.py index 9b94ac513399a6582c3ff64102b4bacd04c264bf..3bcee843a587f94ac6fd85192f92aab33ee7e08f 100644 --- a/python/paddle/fleet/base/meta_optimizer_factory.py +++ b/python/paddle/fleet/base/meta_optimizer_factory.py @@ -15,6 +15,7 @@ from ..meta_optimizers import RecomputeOptimizer from ..meta_optimizers import GradientMergeOptimizer from ..meta_optimizers import GraphExecutionOptimizer +from ..meta_optimizers import PipelineOptimizer __all__ = ["MetaOptimizerFactory"] @@ -22,6 +23,7 @@ meta_optimizer_names = [ "RecomputeOptimizer", "GradientMergeOptimizer", "GraphExecutionOptimizer", + "PipelineOptimizer", ] diff --git a/python/paddle/fleet/meta_optimizers/__init__.py b/python/paddle/fleet/meta_optimizers/__init__.py index 2133eba0810cf15a532c57d125ea655aaaea3367..cb22c45bf9c0f819319f31e128e74f69c61daa49 100644 --- a/python/paddle/fleet/meta_optimizers/__init__.py +++ b/python/paddle/fleet/meta_optimizers/__init__.py @@ -14,8 +14,10 @@ from .recompute_optimizer import RecomputeOptimizer from .gradient_merge_optimizer import GradientMergeOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer +from .pipeline_optimizer import PipelineOptimizer __all__ = [ 'RecomputeOptimizer', 'GradientMergeOptimizer', + 'PipelineOptimizer', ] diff --git a/python/paddle/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/fleet/meta_optimizers/pipeline_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd919f30f688d1b12fac258c2d6c9dc47fbf049 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/pipeline_optimizer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from paddle.fluid.optimizer import PipelineOptimizer as PO +from .meta_optimizer_base import MetaOptimizerBase + +__all__ = ["PipelineOptimizer"] + + +class PipelineOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(PipelineOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(PipelineOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + num_microbatches = user_defined_strategy.pipeline_configs['micro_batch'] + self.wrapped_opt = PO(self.inner_opt, num_microbatches=num_microbatches) + + def _can_apply(self): + if self.user_defined_strategy.pipeline == True: + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.pipeline = False + dist_strategy.pipeline_configs = {"micro_batch": 1} + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self.wrapped_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + optimize_ops, params_grads, prog_list = \ + self.wrapped_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + return optimize_ops, params_grads diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 9d11bbc607a58ef3f8730dc9400ad01775148593..b74c1a8eda131a5a109c3d11061d1af592166d7e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -32,6 +32,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_checkpoint) list(APPEND MIXED_DIST_TEST_OPS test_collective_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_base) list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) @@ -364,6 +365,7 @@ if(WITH_DISTRIBUTE) if(NOT APPLE) py_test_modules(test_fleet_base MODULES test_fleet_base ENVS ${dist_ENVS}) py_test_modules(test_fleet_meta_optimizer MODULES test_fleet_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) endif(NOT APPLE) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0005a4a8dbebff04cd9b11d0af082b01c718ca48 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + + def test_pipeline_optimizer(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + with paddle.fluid.device_guard("cpu"): + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data( + name="y", shape=[1], dtype='int64') + data_loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=[input_x, input_y], + capacity=64, + use_double_buffer=True, + iterable=False) + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + + with paddle.fluid.device_guard("gpu:0"): + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], + size=2, + act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.fleet.DistributedStrategy() + strategy.pipeline = True + strategy.pipeline_configs = {'micro_batch': 2} + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main()