diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 1de6d26d05b9e410b6e81f07c2f21829d8b8c54c..546b9d2601df57322c22d751c94ced76331c1eb6 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -42,6 +42,7 @@ message ShardingConfig { optional bool optimize_offload = 9 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional int32 pp_degree = 11 [ default = 1 ]; + optional bool optimize_cast = 12 [ default = false ]; } message HybridConfig { diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 051f6b11c2609a4ebebe77a4bdd3dbfec6ce3f51..d43292ddbd32e91a1df61f5b8f1fc93a7e33c31f 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -888,6 +888,9 @@ class DistributedStrategy(object): pp_allreduce_in_optimize(bool, optional): [Hybrid parallelism ONLY] move the allreduce operations from backward stage to update(optimize) stage when pipeline parallelsim is on. This configuration will affect the communication speed of Hybrid parallelism training depeneded on network topology. this strategy is experimental by now.. Default is False. + optimize_cast(bool, optional): [Hybrid parallelism ONLY] Move the cast op of AMP which cast fp32 param to fp16 param to optimizer. optimize_cast will persist fp16 param, it + will take more memory, but will be faster, trade space for time. Recommend to turn on only when using pipeline or gradient_merge_acc_step large. + Examples: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index f6741b165ce07280aea7a95acebc13ac0a291704..a96705b09e835eb0e91a94a669bf3b6e18de2e61 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole +from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op from paddle.fluid import core, unique_name __all__ = [] @@ -84,7 +84,7 @@ class OffloadHelper(object): dtype=var.dtype, persistable=True) - def offload_fp32param(self, block, startup_block): + def offload_fp32param(self, block, startup_block, offload=True): """ (p_fp16) = cast(p) (p_fp16_recompute) = cast(p) @@ -113,11 +113,12 @@ class OffloadHelper(object): # step1: record param for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] param_to_idx[param] = idx - # step2: remove param which can't offload + # step2: remove param which can't offload and + # record param->fp16param, fp16param->recompute_var for idx, op in enumerate(block.ops): if is_optimizer_op(op): break @@ -125,7 +126,7 @@ class OffloadHelper(object): if input_name not in param_to_idx: continue - # param is real used by fp32 op + # param which will be used by fp32 op if op.type != 'cast': remove_param(input_name) continue @@ -154,17 +155,19 @@ class OffloadHelper(object): # step3: main_block add offload, cast op # change recompute to fp16, remove cast(param) to fp16 for idx, op in reversed(list(enumerate(block.ops))): - if op.type in ('adam', 'momentum', 'lars', 'lamb'): + if is_update_op(op): param = op.desc.input("Param")[0] if param not in param_to_idx: continue # step3.1: create offload_var offload_var_name = self._get_offload_var_name(param) param_name_to_offload_name[param] = offload_var_name - self._create_offload_var(param, offload_var_name, - [block, startup_block]) + if offload: + self._create_offload_var(param, offload_var_name, + [block, startup_block]) - # step3.2: insert cast op and offload op - self._insert_offload_op(block, idx + 1, param, offload_var_name) + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, + offload_var_name) assert param in param_to_fp16 fp16_param_name = param_to_fp16[param] @@ -173,8 +176,9 @@ class OffloadHelper(object): self._insert_cast_op(block, idx + 1, param, param_to_fp16[param]) - # step3.3: insert fetch op - self._insert_fetch_op(block, idx, offload_var_name, param) + if offload: + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) continue # step3.4: remove cast op @@ -206,9 +210,10 @@ class OffloadHelper(object): if out_name in param_name_to_offload_name: var_name = out_name - offload_var_name = param_name_to_offload_name[var_name] - self._insert_offload_op(startup_block, idx + 1, var_name, - offload_var_name) + if offload: + offload_var_name = param_name_to_offload_name[var_name] + self._insert_offload_op(startup_block, idx + 1, + var_name, offload_var_name) self._insert_cast_op(startup_block, idx + 1, var_name, param_to_fp16[var_name]) @@ -217,6 +222,19 @@ class OffloadHelper(object): block._sync_with_cpp() startup_block._sync_with_cpp() + def cast_fp32param_in_optimize(self, block, startup_block): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (pout,) = adam(p) + (p_fp16) = cast(p) + """ + self.offload_fp32param(block, startup_block, offload=False) + def offload(self, block, startup_block): """ (m1, m2) = prefetch(m1@offload, m2@offload) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 93901b38873b95691696f08cf3ef30e935142c25..5c2f24054f835cda631d50c278fee9ab7d657777 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -400,7 +400,14 @@ class ShardingOptimizer(MetaOptimizerBase): logger.info("Sharding with optimize offload !") offload_helper = OffloadHelper() offload_helper.offload(main_block, startup_block) + # The optimize_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) + elif sharding_configs['optimize_cast']: + logger.info("Sharding with optimize cast !") + # NOTE(wangxi): optimize_cast will persist fp16 param, it + # will take more memory, but will be faster. Trade space for time. + offload_helper = OffloadHelper() + offload_helper.cast_fp32param_in_optimize(main_block, startup_block) def _dump_program_for_debug(self): main_block = self._main_program.global_block() @@ -444,6 +451,7 @@ class ShardingOptimizer(MetaOptimizerBase): # loss div dp_degree self._insert_loss_grad_scale_op() + # apply optimize offload or optimize cast self._apply_optimize_offload_pass() # step6: (optional) sharding gradient merge diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index d70a58c7d8ab41b0b1907c8543841aaf343911d0..5a981a470cb4ef6f742577950190cf375b0f1c6d 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -859,6 +859,197 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36002']) + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "optimize_cast": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random', + 'cast', 'fill_constant', 'cast', 'uniform_random', 'cast', + 'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'cast', + 'momentum', 'cast', 'momentum', 'cast', 'momentum', 'momentum', + 'cast' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + + def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + strategy.amp = True + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "mp_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "optimize_offload": True, + } + strategy.pipeline = True + strategy.pipeline_configs = { + "schedule_mode": "1F1B", + "micro_batch_size": 2, + "accumulate_steps": 4, + } + strategy.fp16_allreduce = True + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # ring: mp, pp_group, pp_pair, pp_pair + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast', + 'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant', + 'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy', + 'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast', + 'memcpy', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', + 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'mul', 'elementwise_add', 'tanh', 'mul', + 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', 'mul', + 'cast', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'elementwise_mul', 'fill_constant', 'scale', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'cast', 'mul_grad', + 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'send_v2', + 'fill_constant', 'cast', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'fill_constant', 'sum', 'fill_constant', + 'sum', 'fill_constant', 'sum', 'fill_constant', 'sum', + 'fill_constant', 'sum', 'coalesce_tensor', 'c_allreduce_sum', + 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', 'cast', + 'c_sync_comm_stream', 'check_finite_and_unscale', 'cast', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'memcpy', + 'momentum', 'cast', 'memcpy', 'memcpy', 'momentum', 'cast', + 'memcpy', 'memcpy', 'momentum', 'cast', 'memcpy', 'momentum', + 'memcpy', 'momentum', 'cast', 'memcpy' + ]) + + # amp check_finite_and_unscale, allreduce(pp) + self.assertEqual(main_prog_op_types.count('c_allreduce_max'), 1) + + # should has ring id for pp + created_ring_ids = [ + op.desc.attr("ring_id") for op in startup_prog_ops + if op.type == "c_comm_init" + ] + self.assertIn(self.pp_pair_ring_id, created_ring_ids) + self.assertIn(self.dp_ring_id, created_ring_ids) + + # check correctness of pp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_0": + pp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(pp_group_waiting_ports, ['127.0.0.1:36003']) + + # check correctness of dp group + for op in startup_prog_ops: + if op.type == "c_gen_nccl_id" and op.desc.output_arg_names()[ + 0] == "comm_id_3": + dp_group_waiting_ports = op.desc.attr("other_endpoints") + + self.assertEqual(dp_group_waiting_ports, ['127.0.0.1:36002']) + if __name__ == "__main__": unittest.main()