From 02b0be083358b6ac5627b7fd324213b467cf007c Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 16 Sep 2021 10:51:58 +0800 Subject: [PATCH] [hybrid] remove scale op in insert_scale_loss_grad_ops (#35775) --- .../fleet/meta_optimizers/sharding/utils.py | 15 ++- .../meta_optimizers/sharding_optimizer.py | 2 +- python/paddle/fluid/optimizer.py | 17 ++- .../test_fleet_hybrid_meta_optimizer.py | 99 +++++++++------- .../test_fleet_sharding_meta_optimizer.py | 108 +++++++++--------- 5 files changed, 125 insertions(+), 116 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index d6acf541be5..0b8f67a0a7c 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -738,14 +738,13 @@ def insert_scale_loss_grad_ops(block, scale=1.0): ''' for idx, op in reversed(list(enumerate(block.ops))): if is_loss_grad_op(op): - loss_grad_var = block.vars[op.output_arg_names[0]] - block._insert_op_without_sync( - idx + 1, - type='scale', - inputs={'X': loss_grad_var}, - outputs={'Out': loss_grad_var}, - attrs={'scale': scale, - OP_ROLE_KEY: OpRole.Backward}) + assert op.type == 'fill_constant', \ + "loss_grad_op must be fill_constant op, " \ + "but this op is {}".format(op.type) + assert op.has_attr('value') + loss_scale = float(op.attr('value')) + loss_scale = loss_scale / scale + op._set_attr('value', loss_scale) break diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 7f5c780f1f5..1f96ab07d60 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -455,7 +455,7 @@ class ShardingOptimizer(MetaOptimizerBase): global_dp_degree = self.sharding_degree * self.dp_degree assert int(global_dp_degree) == global_dp_degree if global_dp_degree > 1: - insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree) + insert_scale_loss_grad_ops(main_block, scale=global_dp_degree) main_block._sync_with_cpp() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index bd5a6a26cc8..ed351dcbefd 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5019,16 +5019,13 @@ class PipelineOptimizer(object): if self._num_microbatches == 1: return for index, op in reversed(tuple(enumerate(list(block.ops)))): if self._is_loss_grad_op(op): - loss_grad_var = block.vars[op.output_arg_names[0]] - block._insert_op( - index=index + 1, - type='scale', - inputs={'X': loss_grad_var}, - outputs={'Out': loss_grad_var}, - attrs={ - 'scale': 1.0 / self._num_microbatches, - self._op_role_key: self._op_role.Backward - }) + assert op.type == 'fill_constant', \ + "loss_grad_op must be fill_constant op, " \ + "but this op is {}".format(op.type) + assert op.has_attr('value') + loss_scale = float(op.attr('value')) + loss_scale = loss_scale / self._num_microbatches + op._set_attr('value', loss_scale) break def _rename_gradient_var_name(self, block): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py index a832bf8adfc..6de2d2fb092 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py @@ -18,6 +18,7 @@ import paddle.static as static import unittest from fleet_meta_optimizer_base import TestFleetMetaOptimizer +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op paddle.enable_static() @@ -77,10 +78,10 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'scale', 'mean_grad', - 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', - 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', @@ -158,10 +159,10 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'scale', 'mean_grad', - 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', - 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', @@ -220,8 +221,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', - 'fill_constant', 'scale', 'scale', 'elementwise_mul_grad', - 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'fill_constant', 'elementwise_mul_grad', 'mean_grad', + 'cross_entropy_grad2', 'cast', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', @@ -293,23 +294,23 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'tanh', 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', 'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor', - 'coalesce_tensor', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', - 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', - 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', - 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', - 'c_sync_calc_stream', 'send_v2', 'cast', 'sum', 'cast', 'sum', - 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', - 'check_finite_and_unscale', 'cast', 'c_allreduce_max', - 'c_allreduce_max', 'cast', 'update_loss_scaling', 'squared_l2_norm', - 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', - 'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum', - 'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div', + 'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream', + 'send_v2', 'cast', 'sum', 'cast', 'sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale', + 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', + 'update_loss_scaling', 'squared_l2_norm', 'squared_l2_norm', + 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'sqrt', 'fill_constant', + 'elementwise_max', 'elementwise_div', 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'elementwise_mul', - 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', - 'momentum', 'momentum', 'momentum', 'coalesce_tensor', - 'c_broadcast', 'coalesce_tensor', 'c_broadcast' + 'elementwise_mul', 'momentum', 'momentum', 'momentum', 'momentum', + 'momentum', 'coalesce_tensor', 'c_broadcast', 'coalesce_tensor', + 'c_broadcast' ]) @@ -327,7 +328,10 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): self._debug = False def test_opt_sharding_with_pp_amp_gclip_boundary(self): - """ test optimizer sharding without parameter """ + """ + test optimizer sharding without parameter + test loss grad scale value + """ train_prog, startup_prog = static.Program(), static.Program() avg_cost, strategy = self.boundary_net(train_prog, startup_prog) @@ -357,6 +361,16 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): startup_prog_op_types = [op.type for op in startup_prog_ops] main_prog_op_types = [op.type for op in main_prog_ops] + # check loss scale for hybrid + for op in main_prog_ops: + if is_loss_grad_op(op): + self.assertEqual(op.type, 'fill_constant') + self.assertTrue(op.has_attr('value')) + scale = strategy.pipeline_configs[ + 'accumulate_steps'] * strategy.sharding_configs['dp_degree'] + loss_scale = 1.0 / scale + self.assertAlmostEqual(float(op.attr('value')), loss_scale) + # global, sharding, pp_send, pp_recv self.assertEqual(startup_prog_op_types, [ 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', @@ -367,14 +381,13 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): self.assertEqual(main_prog_op_types, [ 'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean', - 'elementwise_mul', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad', - 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum', - 'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale', - 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', - 'update_loss_scaling', 'fill_constant', 'c_allreduce_sum', - 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', - 'elementwise_div', 'c_broadcast' + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream', + 'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling', + 'fill_constant', 'c_allreduce_sum', 'c_allreduce_sum', 'sqrt', + 'fill_constant', 'elementwise_max', 'elementwise_div', 'c_broadcast' ]) def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self): @@ -419,14 +432,14 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): self.assertEqual(main_prog_op_types, [ 'recv_v2', 'cast', 'matmul', 'cast', 'reduce_mean', - 'elementwise_mul', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'reduce_mean_grad', 'cast', 'matmul_grad', - 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum', - 'c_reduce_sum', 'c_sync_comm_stream', 'check_finite_and_unscale', - 'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast', - 'update_loss_scaling', 'squared_l2_norm', 'sum', 'c_allreduce_sum', - 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max', - 'elementwise_div', 'elementwise_mul', 'momentum', 'c_broadcast' + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'reduce_mean_grad', 'cast', 'matmul_grad', 'c_sync_calc_stream', + 'send_v2', 'fill_constant', 'cast', 'sum', 'c_reduce_sum', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'c_allreduce_max', 'cast', 'update_loss_scaling', + 'squared_l2_norm', 'sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div', + 'elementwise_mul', 'momentum', 'c_broadcast' ]) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 96230f6d274..1dd368f0848 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -16,12 +16,11 @@ import unittest import paddle import os import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.base.role_maker as role_maker -import paddle.fluid.core as core import paddle.fluid as fluid from fleet_meta_optimizer_base import TestFleetMetaOptimizer import paddle.distributed.fleet.meta_optimizers.sharding as sharding +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op paddle.enable_static() @@ -52,8 +51,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -91,16 +90,16 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_sync_comm_stream', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', - 'mean', 'elementwise_mul', 'fill_constant', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', - 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', - 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', - 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'mean', 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', - 'c_sync_comm_stream', 'cast', 'cast', 'cast', - 'check_finite_and_unscale', 'cast', 'c_allreduce_max', 'cast', - 'update_loss_scaling', 'momentum', 'momentum', 'momentum' + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'cast', + 'cast', 'cast', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'momentum', 'momentum' ]) def test_sharding_recompute_optimizer(self): @@ -132,11 +131,11 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'mul', 'elementwise_add', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', - 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', - 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', + 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' ]) @@ -177,7 +176,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', - 'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad', + 'fill_constant', 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh_grad', 'cast', @@ -222,8 +221,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -259,8 +258,8 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -397,11 +396,14 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) # check loss scale for sharding hybrid dp - scale_ = -1 for op in main_prog_ops: - if op.type == "scale": - scale_ = float(op.desc.attr("scale")) - self.assertEqual(scale_, 0.25) + if is_loss_grad_op(op): + self.assertEqual(op.type, 'fill_constant') + self.assertTrue(op.has_attr('value')) + scale = strategy.sharding_configs[ + 'sharding_degree'] * strategy.sharding_configs['dp_degree'] + loss_scale = 1.0 / scale + self.assertAlmostEqual(float(op.attr('value')), loss_scale) # check program (allreudce) ops = [op.type for op in main_prog_ops] @@ -411,8 +413,8 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -474,8 +476,8 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -543,11 +545,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'c_broadcast', 'c_sync_comm_stream', 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', - 'cross_entropy2', 'mean', 'fill_constant', 'scale', 'scale', - 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', - 'elementwise_add_grad', 'mul_grad', 'tanh_grad', - 'elementwise_add_grad', 'mul_grad', 'tanh_grad', - 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'cross_entropy2', 'mean', 'fill_constant', 'mean_grad', + 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', + 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', @@ -742,11 +743,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'mul', 'cast', 'elementwise_add', 'tanh', 'cast', 'mul', 'cast', 'elementwise_add', 'tanh', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', 'elementwise_mul', - 'fill_constant', 'scale', 'scale', 'elementwise_mul_grad', - 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', - 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', - 'elementwise_add_grad', 'mul_grad', 'tanh_grad', - 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'fill_constant', 'elementwise_mul_grad', 'mean_grad', + 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', + 'cast', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', @@ -908,10 +908,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'elementwise_mul', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', @@ -1003,10 +1003,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', - 'elementwise_mul', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_mul', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', @@ -1102,8 +1102,8 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', 'elementwise_mul', 'coalesce_tensor', 'coalesce_tensor', - 'coalesce_tensor', 'coalesce_tensor', 'fill_constant', 'scale', - 'scale', 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'coalesce_tensor', 'coalesce_tensor', 'fill_constant', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', @@ -1193,10 +1193,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): 'elementwise_add', 'tanh', 'cast', 'mul', 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', 'elementwise_mul', 'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor', - 'coalesce_tensor', 'fill_constant', 'scale', 'scale', - 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', - 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', - 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'softmax_grad', + 'elementwise_add_grad', 'cast', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', 'cast', 'sum', 'sum', 'c_allreduce_sum', 'c_allreduce_sum', -- GitLab