From 6fb58aef574d8336510cdbc1730016112ad699f7 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 31 Aug 2021 13:27:21 +0800 Subject: [PATCH] [cherry-pick][Hybrid Performance] Move the cast op of AMP which cast fp32 param to fp16 param to the optimizer (#34965) (#35296) Co-authored-by: WangXi --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 3 + .../sharding/offload_helper.py | 48 +++-- .../meta_optimizers/sharding_optimizer.py | 8 + .../test_fleet_sharding_meta_optimizer.py | 191 ++++++++++++++++++ 5 files changed, 236 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 1de6d26d05b..546b9d2601d 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -42,6 +42,7 @@ message ShardingConfig { optional bool optimize_offload = 9 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional int32 pp_degree = 11 [ default = 1 ]; + optional bool optimize_cast = 12 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 051f6b11c26..d43292ddbd3 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -888,6 +888,9 @@ class DistributedStrategy(object): pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. + optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it + will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large. + Examples: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index f6741b165ce..a96705b09e8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole +from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op from paddle.fluid import core, unique_name __all__ = [] @@ -84,7 +84,7 @@ class OffloadHelper(object): dtype=var.dtype, persistable=True) - def offload_fp32param(self, block, startup_block): + def offload_fp32param(self, block, startup_block, offload=True): """ (p_fp16) = cast(p) (p_fp16_recompute) = cast(p) @@ -113,11 +113,12 @@ class OffloadHelper(object): # step1: record param for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] param_to_idx[param] = idx - # step2: remove param which can't offload + # step2: remove param which can't offload and + # record param->fp16param, fp16param->recompute_var for idx, op in enumerate(block.ops): if is_optimizer_op(op): break @@ -125,7 +126,7 @@ class OffloadHelper(object): if input_name not in param_to_idx: continue - # param is real used by fp32 op + # param which will be used by fp32 op if op.type != 'cast': remove_param(input_name) continue @@ -154,17 +155,19 @@ class OffloadHelper(object): # step3: main_block add offload, cast op # change recompute to fp16, remove cast(param) to fp16 for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] if param not in param_to_idx: continue # step3.1: create offload_var offload_var_name = self._get_offload_var_name(param) param_name_to_offload_name[param] = offload_var_name - self._create_offload_var(param, offload_var_name, - [block, startup_block]) + if offload: + self._create_offload_var(param, offload_var_name, + [block, startup_block]) - # step3.2: insert cast op and offload op - self._insert_offload_op(block, idx + 1, param, offload_var_name) + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, + offload_var_name) assert param in param_to_fp16 fp16_param_name = param_to_fp16[param] @@ -173,8 +176,9 @@ class OffloadHelper(object): self._insert_cast_op(block, idx + 1, param, param_to_fp16[param]) - # step3.3: insert fetch op - self._insert_fetch_op(block, idx, offload_var_name, param) + if offload: + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) continue # step3.4: remove cast op @@ -206,9 +210,10 @@ class OffloadHelper(object): if out_name in param_name_to_offload_name: var_name = out_name - offload_var_name = param_name_to_offload_name[var_name] - self._insert_offload_op(startup_block, idx + 1, var_name, - offload_var_name) + if offload: + offload_var_name = param_name_to_offload_name[var_name] + self._insert_offload_op(startup_block, idx + 1, + var_name, offload_var_name) self._insert_cast_op(startup_block, idx + 1, var_name, param_to_fp16[var_name]) @@ -217,6 +222,19 @@ class OffloadHelper(object): block._sync_with_cpp() startup_block._sync_with_cpp() + def cast_fp32param_in_optimize(self, block, startup_block): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (pout,) = adam(p) + (p_fp16) = cast(p) + """ + self.offload_fp32param(block, startup_block, offload=False) + def offload(self, block, startup_block): """ (m1, m2) = prefetch(m1@offload, m2@offload) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 93901b38873..5c2f24054f8 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -400,7 +400,14 @@ class ShardingOptimizer(MetaOptimizerBase): logger.info("Sharding with optimize offload !") offload_helper = OffloadHelper() offload_helper.offload(main_block, startup_block) + # The optimize_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) + elif sharding_configs['optimize_cast']: + logger.info("Sharding with optimize cast !") + # NOTE(wangxi): optimize_cast will persist fp16 param, it + # will take more memory, but will be faster. Trade space for time. + offload_helper = OffloadHelper() + offload_helper.cast_fp32param_in_optimize(main_block, startup_block) def _dump_program_for_debug(self): main_block = self._main_program.global_block() @@ -444,6 +451,7 @@ class ShardingOptimizer(MetaOptimizerBase): # loss div dp_degree self._insert_loss_grad_scale_op() + # apply optimize offload or optimize cast self._apply_optimize_offload_pass() # step6: (optional) sharding gradient merge diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index d70a58c7d8a..5a981a470cb 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -859,6 +859,197 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "optimize_cast": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random', + 'cast', 'fill_constant', 'cast', 'uniform_random', 'cast', + 'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast', + 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum', + 'cast' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "optimize_offload": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast', + 'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant', + 'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy', + 'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast', + 'memcpy', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum', + 'memcpy', 'momentum', 'cast', 'memcpy' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main() -- GitLab