diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 486792093a35cf9c3078b6e3201fedad86a9662d..d60e07674edad0fa40d2fddebc45b0ae68c5df24 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4867,6 +4867,39 @@ class PipelineOptimizer(object): }) extra_index_info['index'] += 1 elif self.schedule_mode == '1F1B': # 1F1B + var_shape = list(var.shape) + var_shape[0] = self.micro_batch_size if var_shape[ + 0] < 0 else var_shape[0] + + numel = np.prod(var.shape) + assert numel % self.mp_degree == 0, \ + "The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree) + + if 'subprog' in var.name: + # For recompute, if the checkpoints var is layer_norm_6.tmp_2 + # this var will be sent twice, layer_norm_6.tmp_2 for forward pass, + # layer_norm_6.tmp_2.subprog_* for recompute pass. + # We can store the first sent var and copy the value to the + # second one to reduce one send/recv op. + # The origin_ckpt_name is layer_norm_6.tmp_2, which will be used + # to find the stored var for the forward pass. + origin_name = var.name.split('subprog')[0][0:-1] + associate_var = block.var(origin_name) + block._insert_op_without_sync( + index=index + extra_index_info['index'], + type='assign', + inputs={'X': [associate_var]}, + outputs={'Out': [var]}, + attrs={ + 'out_shape': var_shape, + 'dtype': var.dtype, + self._op_device_key: cur_dev, + self._op_role_key: op_role, + 'use_calc_stream': True, + }) + extra_index_info['index'] += 1 + return + block._insert_op_without_sync( index=index + extra_index_info['index'], type='c_sync_calc_stream', @@ -4894,7 +4927,6 @@ class PipelineOptimizer(object): }) extra_index_info['index'] += 1 insert_index = None - if int(op_role) == int(self._op_role.Backward): insert_index = extra_index_info[ 'first_optimize_index'] @@ -4902,7 +4934,6 @@ class PipelineOptimizer(object): else: insert_index = index new_op_role = self._op_role.Backward - sync_comm_op = block._insert_op_without_sync( index=insert_index + extra_index_info['index'], type='c_sync_comm_stream', @@ -4913,18 +4944,9 @@ class PipelineOptimizer(object): self._op_role_key: new_op_role, 'ring_id': ring_id, }) - if int(op_role) == int(self._op_role.Forward): sync_comm_op._set_attr('pipeline_flag', '') extra_index_info['index'] += 1 - - var_shape = list(var.shape) - var_shape[0] = self.micro_batch_size if var_shape[ - 0] < 0 else var_shape[0] - - numel = np.prod(var.shape) - assert numel % self.mp_degree == 0, \ - "The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree) block._insert_op_without_sync( index=index + extra_index_info['index'], type='recv_v2' diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index fcb2dbfa2ec0a22bfd31ebbfef581580d3b30618..3c5d4403e8b880f0bcd3c9365e3e6b9eea331248 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute) list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_gen_nccl_id_op) @@ -56,6 +57,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2) list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3) list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute) list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer_with_recompute.py b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer_with_recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..f67b26e0aef65a3e69002e99465bbca1281757fa --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer_with_recompute.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + +paddle.enable_static() + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["PADDLE_TRAINER_ID"] = "1" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" + + def test_pipeline_optimizer(self): + import paddle.distributed.fleet as fleet + import paddle.distributed.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + with paddle.fluid.device_guard("gpu:0"): + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data( + name="y", shape=[1], dtype='int64') + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + fc_3 = paddle.fluid.layers.fc(input=fc_2, size=64, act='tanh') + fc_4 = paddle.fluid.layers.fc(input=fc_3, size=64, act='tanh') + fc_5 = paddle.fluid.layers.fc(input=fc_4, size=64, act='tanh') + fc_6 = paddle.fluid.layers.fc(input=fc_5, size=64, act='tanh') + + with paddle.fluid.device_guard("gpu:1"): + fc_7 = paddle.fluid.layers.fc(input=fc_6, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_7], + size=2, + act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.pipeline = True + strategy.pipeline_configs = { + 'micro_batch_size': 1, + 'accumulate_steps': 2, + 'schedule_mode': '1F1B' + } + + checkpoints = ['fc_5.tmp_0', 'fc_7.tmp_0'] + strategy.recompute = True + strategy.recompute_configs = { + "checkpoints": checkpoints, + "enable_offload": False, + "checkpoint_shape": [] + } + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main()