From 911c859337aef85f2b9efae5bd0e767ec17e36f1 Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 5 Aug 2021 15:33:31 +0800 Subject: [PATCH] optimize pipeline performance with recompute and amp, test=allcase (#34519) --- python/paddle/fluid/backward.py | 7 ++ .../contrib/mixed_precision/fp16_lists.py | 2 + .../contrib/mixed_precision/fp16_utils.py | 23 ++++++- .../test_fleet_pipeline_meta_optimizer.py | 66 ++++++++++++++++--- 4 files changed, 87 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index ee1bc73c61..8bf27f6d2f 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -945,6 +945,13 @@ def _append_backward_ops_with_checkpoints_( for op_desc in reversed(added_descs): grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op_desc, cpt.to_text(no_grad_dict[block.idx]), []) + + # Set device for grad_op according to forward Op + if op_desc.has_attr(device_attr_name): + op_device = op_desc.attr(device_attr_name) + for g_op_desc in grad_op_desc: + g_op_desc._set_attr(device_attr_name, op_device) + for key in var_name_dict: _rename_arg_(grad_op_desc, key, var_name_dict[key]) grad_op_descs.extend(grad_op_desc) diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 37fe1e505f..703146736e 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -150,6 +150,8 @@ gray_list = { 'c_identity', 'c_concat', 'c_allreduce_sum', + 'concat', + 'split', } # The set of ops that don't support fp16 calculation diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 16dfb2bd50..5978d3829a 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -110,6 +110,27 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) out_var = block.vars.get(cast_name) if out_var is None or out_var.dtype != dest_dtype: + op_device = op.attr('op_device') + # NOTE(wangxi): optimize for pipeline, reduce one send. + # if in_var is stop_gradient and prev_op device is `all`, + # set cast_op device to `all`, can reduce send cast_var. + # TODO: need remove this after we unified the dynamic + # and static pipeline interface. + if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient: + prev_op = None + if in_var.op is op: + prev_op = find_true_prev_op(block.ops, op, + in_var_name) + elif in_var.op is not None: + prev_op = in_var.op + + prev_op_device = None + if prev_op is not None: + prev_op_device = prev_op.attr('op_device') + + if prev_op_device is not None and 'all' in prev_op_device: + op_device = prev_op_device + out_var = block.create_var( name=cast_name, dtype=dest_dtype, @@ -124,7 +145,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): attrs={ "in_dtype": in_var.dtype, "out_dtype": out_var.dtype, - "op_device": op.attr("op_device") + "op_device": op_device }) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py index a9c37d7853..3f8d994ad1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py @@ -14,6 +14,10 @@ import unittest import paddle +import paddle.fluid as fluid +import paddle.static as static +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker import os paddle.enable_static() @@ -25,26 +29,34 @@ class TestFleetMetaOptimizer(unittest.TestCase): os.environ[ "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" - def test_pipeline_optimizer(self): - import paddle.distributed.fleet as fleet - import paddle.distributed.fleet.base.role_maker as role_maker - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - with paddle.fluid.device_guard("gpu:0"): + def net(self): + with static.device_guard("gpu:0"): input_x = paddle.fluid.layers.data( name="x", shape=[32], dtype='float32') input_y = paddle.fluid.layers.data( name="y", shape=[1], dtype='int64') + input_z = paddle.fluid.layers.data( + name="z", shape=[1], dtype="float32") + with static.device_guard("gpu:all"): + input_z = input_z * 1.0 + input_z.stop_gradient = True fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_1 = fc_1 * input_z - with paddle.fluid.device_guard("gpu:1"): + with static.device_guard("gpu:1"): fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + fc_2 = fc_2 * input_z prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') cost = paddle.fluid.layers.cross_entropy( input=prediction, label=input_y) avg_cost = paddle.fluid.layers.mean(x=cost) + return avg_cost + + def test_pipeline_optimizer(self): + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) strategy = paddle.distributed.fleet.DistributedStrategy() strategy.pipeline = True @@ -53,9 +65,43 @@ class TestFleetMetaOptimizer(unittest.TestCase): 'accumulate_steps': 2 } - optimizer = paddle.fluid.optimizer.Adam(0.01) - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) - optimizer.minimize(avg_cost) + train_prog, startup_prog = static.Program(), static.Program() + with static.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + avg_cost = self.net() + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + def test_pipeline_amp_optimizer(self): + """ test pipeline& with device:all """ + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.amp = True + strategy.pipeline = True + strategy.pipeline_configs = { + 'micro_batch_size': 1, + 'accumulate_steps': 2 + } + + train_prog, startup_prog = static.Program(), static.Program() + with static.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + avg_cost = self.net() + + optimizer = paddle.fluid.optimizer.Adam(0.01) + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + ops = train_prog._pipeline_opt['section_program'].global_block().ops + ops = [op.type for op in ops] + self.assertEqual(ops.count('send_v2'), 1) + self.assertEqual(ops.count('recv_v2'), 1) if __name__ == "__main__": -- GitLab