未验证 提交 3a5f1f22 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybird optim] reduce rend/recv times for recompute, test=develop (#34248)

上级 7f2b5be3
......@@ -4867,6 +4867,39 @@ class PipelineOptimizer(object):
})
extra_index_info['index'] += 1
elif self.schedule_mode == '1F1B': # 1F1B
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
numel = np.prod(var.shape)
assert numel % self.mp_degree == 0, \
"The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree)
if 'subprog' in var.name:
# For recompute, if the checkpoints var is layer_norm_6.tmp_2
# this var will be sent twice, layer_norm_6.tmp_2 for forward pass,
# layer_norm_6.tmp_2.subprog_* for recompute pass.
# We can store the first sent var and copy the value to the
# second one to reduce one send/recv op.
# The origin_ckpt_name is layer_norm_6.tmp_2, which will be used
# to find the stored var for the forward pass.
origin_name = var.name.split('subprog')[0][0:-1]
associate_var = block.var(origin_name)
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='assign',
inputs={'X': [associate_var]},
outputs={'Out': [var]},
attrs={
'out_shape': var_shape,
'dtype': var.dtype,
self._op_device_key: cur_dev,
self._op_role_key: op_role,
'use_calc_stream': True,
})
extra_index_info['index'] += 1
return
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='c_sync_calc_stream',
......@@ -4894,7 +4927,6 @@ class PipelineOptimizer(object):
})
extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
......@@ -4902,7 +4934,6 @@ class PipelineOptimizer(object):
else:
insert_index = index
new_op_role = self._op_role.Backward
sync_comm_op = block._insert_op_without_sync(
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
......@@ -4913,18 +4944,9 @@ class PipelineOptimizer(object):
self._op_role_key: new_op_role,
'ring_id': ring_id,
})
if int(op_role) == int(self._op_role.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
extra_index_info['index'] += 1
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
numel = np.prod(var.shape)
assert numel % self.mp_degree == 0, \
"The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree)
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='recv_v2'
......
......@@ -17,6 +17,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute)
list(APPEND DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND DIST_TEST_OPS test_gen_nccl_id_op)
......@@ -56,6 +57,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_2)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_base_3)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer_with_recompute)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_raw_program_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
paddle.enable_static()
class TestFleetMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
def test_pipeline_optimizer(self):
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
with paddle.fluid.device_guard("gpu:0"):
input_x = paddle.fluid.layers.data(
name="x", shape=[32], dtype='float32')
input_y = paddle.fluid.layers.data(
name="y", shape=[1], dtype='int64')
fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh')
fc_3 = paddle.fluid.layers.fc(input=fc_2, size=64, act='tanh')
fc_4 = paddle.fluid.layers.fc(input=fc_3, size=64, act='tanh')
fc_5 = paddle.fluid.layers.fc(input=fc_4, size=64, act='tanh')
fc_6 = paddle.fluid.layers.fc(input=fc_5, size=64, act='tanh')
with paddle.fluid.device_guard("gpu:1"):
fc_7 = paddle.fluid.layers.fc(input=fc_6, size=64, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_7],
size=2,
act='softmax')
cost = paddle.fluid.layers.cross_entropy(
input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.pipeline = True
strategy.pipeline_configs = {
'micro_batch_size': 1,
'accumulate_steps': 2,
'schedule_mode': '1F1B'
}
checkpoints = ['fc_5.tmp_0', 'fc_7.tmp_0']
strategy.recompute = True
strategy.recompute_configs = {
"checkpoints": checkpoints,
"enable_offload": False,
"checkpoint_shape": []
}
optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册