From c1c18b089957b466faf122dc9490d5acd60a83ca Mon Sep 17 00:00:00 2001 From: lilong12 Date: Sat, 8 May 2021 10:42:25 +0800 Subject: [PATCH] Add raw program meta optimizer (#32597) * add raw program, test=develop --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 26 +++ .../fleet/meta_optimizers/__init__.py | 1 + .../meta_optimizers/raw_program_optimizer.py | 196 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 2 + .../test_fleet_raw_program_meta_optimizer.py | 53 +++++ 6 files changed, 279 insertions(+) create mode 100755 python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 654b88920ac..dbe9b8cb9aa 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -174,6 +174,7 @@ message DistributedStrategy { optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional bool find_unused_parameters = 28 [ default = true ]; optional bool tensor_parallel = 29 [ default = false ]; + optional bool without_graph_optimization = 30 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index a44d008fe9a..469b45d2006 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -827,6 +827,32 @@ class DistributedStrategy(object): "sharding_configs") assign_configs_value(self.strategy.sharding_configs, configs) + @property + def without_graph_optimization(self): + """ + Run program using Executor other than ParallelExecutor. + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.without_graph_optimization = True + + """ + return self.strategy.without_graph_optimization + + @without_graph_optimization.setter + @is_strict_auto + def without_graph_optimization(self, flag): + if isinstance(flag, bool): + self.strategy.without_graph_optimization = flag + else: + print( + "WARNING: without_graph_optimization should have value of bool type" + ) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 827835fde20..1788e044fe8 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -28,3 +28,4 @@ from .sharding_optimizer import ShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler from .tensor_parallel_optimizer import TensorParallelOptimizer +from .raw_program_optimizer import RawProgramOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py new file mode 100755 index 00000000000..243f6efe531 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -0,0 +1,196 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from __future__ import print_function +from __future__ import division +import os + +import paddle.fluid as fluid +from paddle.fluid import core, unique_name +from ..base.private_helper_function import wait_server_ready +from .meta_optimizer_base import MetaOptimizerBase +from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_loss_grad_op, is_backward_op, is_optimizer_op + + +class RawProgramOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(RawProgramOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + ] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self.global_ring_id = 0 + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(RawProgramOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.without_graph_optimization = user_defined_strategy.without_graph_optimization + + def _can_apply(self): + if not self.role_maker._is_collective: + return False + + if self.without_graph_optimization == True: + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.without_graph_optimization = False + + def _enable_strategy(self, dist_strategy, context): + dist_strategy.without_graph_optimization = True + + def _broadcast_params(self, ring_id): + block = self.startup_program.global_block() + param = None + for param in block.iter_parameters(): + if param.is_distributed: + continue + + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + + if not param: return # no parameter on this device + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + def _get_process_group_info(self): + # global ring info + self.global_endpoints = self.endpoints + self.global_rank = self.rank + self.global_nranks = self.nranks + + def _init_process_group(self): + self._get_process_group_info() + collective_helper = CollectiveHelper(self.role_maker, wait_port=False) + # Create global ring for all gpus (ring_id = 0) + collective_helper._init_communicator( + self.startup_program, self.current_endpoint, self.global_endpoints, + self.global_rank, self.global_ring_id, True, self.global_ring_id, + True) + self._broadcast_params(self.global_ring_id) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self.endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.endpoints[self.role_maker._worker_index()] + self.rank = self.role_maker._worker_index() + self.nranks = self.role_maker._worker_num() + if startup_program is None: + startup_program = fluid.default_startup_program() + self.startup_program = startup_program + + block = loss.block + program = block.program + self.main_program = program + + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) + + self._init_process_group() + + self.main_program = program + if self.nranks > 1: + self._transpile_main_program(loss) + return optimize_ops, params_grads + + def _transpile_main_program(self, loss): + self._insert_loss_grad_ops(loss) + self._insert_allreduce_ops() + + def _insert_loss_grad_ops(self, loss): + """ + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + """ + block = self.main_program.global_block() + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / self.nranks, + OP_ROLE_KEY: OpRole.Backward + }) + + def _insert_allreduce_ops(self): + block = self.main_program.global_block() + ring_id = self.global_ring_id + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.attr(OP_ROLE_VAR_KEY) + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + offset = 1 + for i in range(0, len(op_role_var), 2): + param_name = op_role_var[i] + param = block.var(param_name) + grad_name = op_role_var[i + 1] + grad = block.var(grad_name) + if param.is_distributed: + continue + + block._insert_op( + idx + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={OP_ROLE_KEY: OpRole.Backward, }) + offset += 1 + block._insert_op( + idx + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + break diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8e998459cd4..110665186c0 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) list(APPEND DIST_TEST_OPS test_parallel_dygraph_unused_variables) @@ -53,6 +54,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2) list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3) list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py new file mode 100644 index 00000000000..604109b262d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_raw_program_meta_optimizer.py @@ -0,0 +1,53 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + +paddle.enable_static() + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" + + def test_pipeline_optimizer(self): + import paddle.distributed.fleet as fleet + import paddle.distributed.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.without_graph_optimization = True + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main() -- GitLab