diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 3816e9b3051abfc026dd426d6988537d64185de0..3ad6e320316c61ed1b74829b6074685874eb61fc 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op from paddle.fluid import core, unique_name +from .shard import Shard __all__ = [] @@ -23,11 +25,8 @@ class OffloadHelper(object): cuda_place_type = 1 cuda_pinned_place_type = 2 - def __init__(self): - pass - "0: dst is on CPUPlace. " - "1: dst is on CUDAPlace. " - "2: dst is on CUDAPinnedPlace. " + def __init__(self, ring_id=None): + self.ring_id = ring_id def _insert_cast_op(self, block, idx, src_name, dst_name): src_var = block.var(src_name) @@ -50,6 +49,21 @@ class OffloadHelper(object): OP_ROLE_KEY: OpRole.Optimize }) + def _insert_broadcast_op(self, block, idx, param): + if self.ring_id is None: + return + block._insert_op_without_sync( + idx, + type="c_broadcast", + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self.ring_id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }) + def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type): src_var = block.var(src_name) dst_var = block.var(dst_name) @@ -206,6 +220,8 @@ class OffloadHelper(object): # step5: startup_block add offload visited_vars = set() + # FIXME(wangxi): should insert in idx, need move comm init to the head. + insert_idx = len(startup_block.ops) for idx, op in reversed(list(enumerate(startup_block.ops))): for out_name in op.output_arg_names: if out_name in visited_vars: @@ -213,13 +229,16 @@ class OffloadHelper(object): if out_name in param_name_to_offload_name: var_name = out_name - # FIXME(wangxi): offload should insert after broadcast param if offload: offload_var_name = param_name_to_offload_name[var_name] - self._insert_offload_op(startup_block, idx + 1, + self._insert_offload_op(startup_block, insert_idx, var_name, offload_var_name) - self._insert_cast_op(startup_block, idx + 1, var_name, + self._insert_cast_op(startup_block, insert_idx, var_name, param_to_fp16[var_name]) + # NOTE(wangxi): cast and offload should insert after broadcast param. + # the insert op order is: broadcast, cast, offload + self._insert_broadcast_op(startup_block, insert_idx, + var_name) visited_vars.add(out_name) @@ -303,3 +322,181 @@ class OffloadHelper(object): block._sync_with_cpp() startup_block._sync_with_cpp() + + def opt_sharding_cast_fp32param(self, + block, + startup_block, + params, + offload=False): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (pout,) = adam(p) + (p_fp16) = cast(p) + broadcast(p_fp16) + """ + global_params = set() + local_params = set() + param_to_fp16 = dict() + # recompute_var which need rename to fp16_param + fp16_param_to_recompute = dict() + recompute_to_fp16 = dict() + + def remove_param(input_name): + global_params.remove(input_name) + if input_name in local_params: + local_params.remove(input_name) + if input_name in param_to_fp16: + fp16_param = param_to_fp16.pop(input_name) + if fp16_param in fp16_param_to_recompute: + recompute = fp16_param_to_recompute.pop(fp16_param) + recompute_to_fp16.pop(recompute) + + # step1: record param + global_params = set(params) + for idx, op in reversed(list(enumerate(block.ops))): + if is_update_op(op): + param = op.desc.input("Param")[0] + local_params.add(param) + + # step2: remove param which can't offload and + # record param->fp16param, fp16param->recompute_var + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + break + # TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast + if op.type == 'coalesce_tensor': + continue + for input_name in op.desc.input_arg_names(): + if input_name not in global_params: + continue + + # param which will be used by fp32 op + if op.type != 'cast': + remove_param(input_name) + continue + + # param is only used by cast op, + # which to cast fp32_param to fp16_param + output_name = op.output_arg_names[0] + if 'cast_fp16' not in output_name: + remove_param(input_name) + continue + + if 'subprog' not in output_name: + assert output_name == input_name + '.cast_fp16' + assert input_name not in param_to_fp16, \ + "There must be only one cast op from fp32 param to fp16 param." + param_to_fp16[input_name] = output_name + else: + # fp16-->recompute_var + assert input_name in param_to_fp16, \ + "param must first be cast to fp16" + fp16_param = param_to_fp16[input_name] + fp16_param_to_recompute[fp16_param] = output_name + recompute_to_fp16[output_name] = fp16_param + + param_name_to_offload_name = dict() + # step3: main_block add offload, cast op + # change recompute to fp16, remove cast(param) to fp16 + for idx, op in reversed(list(enumerate(block.ops))): + if is_update_op(op): + param = op.desc.input("Param")[0] + if param not in global_params: + continue + # step3.1: create offload_var + offload_var_name = self._get_offload_var_name(param) + param_name_to_offload_name[param] = offload_var_name + if offload: + self._create_offload_var(param, offload_var_name, + [block, startup_block]) + + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, + offload_var_name) + + assert param in param_to_fp16 + fp16_param_name = param_to_fp16[param] + fp16_param_var = block.var(fp16_param_name) + fp16_param_var.persistable = True + self._insert_cast_op(block, idx + 1, param, + param_to_fp16[param]) + + if offload: + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) + + continue + + # step3.4: remove cast op + if op.type == 'cast': + input_name = op.desc.input_arg_names()[0] + if input_name in global_params: + block._remove_op(idx, sync=False) + continue + + # step3.5: change recompute_param to fp16_param + for input_name in op.desc.input_arg_names(): + if input_name in recompute_to_fp16: + op._rename_input(input_name, recompute_to_fp16[input_name]) + for output_name in op.desc.output_arg_names(): + if output_name in recompute_to_fp16: + op._rename_output(output_name, + recompute_to_fp16[output_name]) + + # step4: remove recompute_param + for name in recompute_to_fp16.keys(): + block._remove_var(name, sync=False) + + # step5: remove fp32 param which not need + for idx, op in enumerate(block.ops): + if op.type not in ['coalesce_tensor', 'c_broadcast']: + continue + for input_name in op.desc.input_arg_names(): + if input_name in param_to_fp16: + op._rename_input(input_name, param_to_fp16[input_name]) + for output_name in op.desc.output_arg_names(): + if output_name in param_to_fp16: + op._rename_output(output_name, param_to_fp16[output_name]) + + for param in global_params: + assert param in param_to_fp16 + fp16_param_name = param_to_fp16[param] + fp16_param_var = block.var(fp16_param_name) + fp16_param_var.persistable = True + + if param not in local_params: + block._remove_var(param, sync=False) + + # step6: startup_block add offload + visited_vars = set() + insert_idx = len(startup_block.ops) + for idx, op in reversed(list(enumerate(startup_block.ops))): + for out_name in op.output_arg_names: + if out_name in visited_vars: continue + + if out_name in param_to_fp16: + var_name = out_name + if offload: + self._insert_offload_op( + startup_block, idx + 1, var_name, + param_name_to_offload_name[var_name]) + + self._insert_cast_op(startup_block, insert_idx, var_name, + param_to_fp16[var_name]) + + self._insert_broadcast_op(startup_block, insert_idx, + var_name) + + if var_name not in local_params: + param = startup_block.var(out_name) + param.persistable = False + + visited_vars.add(out_name) + + block._sync_with_cpp() + startup_block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index 0b8f67a0a7cd9f141fcbaf9c447d8f64e8451e69..447b52ace697878cf2a4d3425e5e2c99d79c073d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -14,7 +14,7 @@ import paddle from paddle.fluid import core, unique_name from functools import reduce -from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY import re @@ -366,6 +366,24 @@ def insert_allreduce_ops(block, class FuseHelper(object): + @staticmethod + def sort_vars_by_dtype(block, vars_name): + fp32_vars = [] + fp16_vars = [] + other_vars = [] + for var in vars_name: + dtype = block.var(var).dtype + if dtype == paddle.float32: + fp32_vars.append(var) + elif dtype == paddle.float16: + fp16_vars.append(var) + else: + other_vars.append(var) + assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse" + + fp32_vars.extend(fp16_vars) + return fp32_vars + @staticmethod def get_fused_groups(block, vars_name, fuse_size=32.): """ coalesce tensor, get fused group """ @@ -639,6 +657,54 @@ def insert_broadcast_param_ops(block, return param_in_this_device +def fuse_opt_broadcast_param_ops(block, + ring_id, + shard, + op_role=OpRole.Optimize, + strategy=None): + """ + fuse optimizer sharding broadcast param ops + """ + if strategy is None or not strategy.fuse_all_reduce_ops: + return + + fuse_size = strategy.fuse_grad_size_in_MB + + nranks = shard.worker_num + device_to_vars = [[] for _ in range(nranks)] + + for idx, op in reversed(list(enumerate(block.ops))): + if not is_optimizer_op(op) or op.type != 'c_broadcast': + break + var = op.input_arg_names[0] + root_id = op.attr('root') + device_to_vars[root_id].insert(0, var) + block._remove_op(idx, sync=False) + + insert_idx = idx + 1 + for root_id, vars_name in enumerate(device_to_vars): + vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name) + groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size) + + fused_vars, insert_num = FuseHelper.insert_coalesce_tensor( + block, insert_idx, groups, op_role, prefix="Param") + + for fused_var in fused_vars: + block._insert_op_without_sync( + insert_idx + insert_num, + type='c_broadcast', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={ + 'ring_id': ring_id, + 'root': root_id, + 'use_calc_stream': True, + OP_ROLE_KEY: op_role + }) + + block._sync_with_cpp() + + def get_grad_device(grad_name, shard): assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( grad_name) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 1af646b3959e014a99713ccb984e1fc12a320c7e..75a69e5527bc18e71eb9286ce1bda60c0aeaaf1d 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -329,6 +329,7 @@ class ShardingOptimizer(MetaOptimizerBase): if self.pp_degree == 1: return strategy = self.user_defined_strategy + sharding_configs = strategy.sharding_configs main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() @@ -399,6 +400,8 @@ class ShardingOptimizer(MetaOptimizerBase): first_optimize_op_index += (len(main_block.ops) - len_of_ops) len_of_ops = len(main_block.ops) + # NOTE(wangxi): we fused after optimize_cast + optimize_cast = sharding_configs['optimize_cast'] optimizer_param = utils.insert_broadcast_param_ops( main_block, len_of_ops, @@ -407,10 +410,10 @@ class ShardingOptimizer(MetaOptimizerBase): OpRole.Optimize, use_calc_stream=True, rank=self.dp_rank, - strategy=strategy) + strategy=None if optimize_cast else strategy) logger.info("Optimizer param in this rank {}".format( optimizer_param)) - if not strategy.fuse_grad_merge: + if not strategy.fuse_grad_merge and not optimize_cast: assert len(accumulated_grad_names) == len(optimizer_param) elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": insert_allreduce_ops( @@ -458,18 +461,20 @@ class ShardingOptimizer(MetaOptimizerBase): main_block._sync_with_cpp() - def _apply_optimize_offload_pass(self): + def _apply_optimize_offload_pass(self, params_grads): strategy = self.user_defined_strategy sharding_configs = strategy.sharding_configs main_block = self._main_program.global_block() startup_block = self._startup_program.global_block() + dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None + # optimize offload should be enable while gradient merge is enable and # acc_step is quite large (e.g. >> 100). Since its memcpy could not be # overlap with calc, otherwise it will slower down training severely. if sharding_configs["optimize_offload"]: logger.info("Sharding with optimize offload !") - offload_helper = OffloadHelper() + offload_helper = OffloadHelper(ring_id=dp_ring_id) offload_helper.offload(main_block, startup_block) # The optimize_cast is already included in offload_fp32param offload_helper.offload_fp32param(main_block, startup_block) @@ -477,8 +482,17 @@ class ShardingOptimizer(MetaOptimizerBase): logger.info("Sharding with optimize cast !") # NOTE(wangxi): optimize_cast will persist fp16 param, it # will take more memory, but will be faster. Trade space for time. - offload_helper = OffloadHelper() - offload_helper.cast_fp32param_in_optimize(main_block, startup_block) + offload_helper = OffloadHelper(ring_id=dp_ring_id) + if self._optimizer_sharding: + offload_helper.opt_sharding_cast_fp32param( + main_block, startup_block, + [x[0].name for x in params_grads]) + # NOTE(wangxi): fused after optimize_cast + utils.fuse_opt_broadcast_param_ops( + main_block, dp_ring_id, self._shard, strategy=strategy) + else: + offload_helper.cast_fp32param_in_optimize(main_block, + startup_block) def _dump_program_for_debug(self): main_block = self._main_program.global_block() @@ -525,7 +539,7 @@ class ShardingOptimizer(MetaOptimizerBase): self._insert_loss_grad_scale_op() # apply optimize offload or optimize cast - self._apply_optimize_offload_pass() + self._apply_optimize_offload_pass(params_grads) # step6: (optional) sharding gradient merge self._sharding_gradient_merge() @@ -1381,17 +1395,50 @@ class ShardingOptimizer(MetaOptimizerBase): startup_block = self._startup_program.global_block() params = startup_block.all_parameters() + params_name = [] - broadcast_params = [] + # NOTE(wangxi): if param is not persistable, program.clone will + # failed, so we remove no persistable param, re add param as a var for param in params: - broadcast_params.append(param) - # optimize_cast need broadcast fp16 param - fp16_param_name = param.name + '.cast_fp16' - if startup_block.has_var(fp16_param_name): - fp16_param = startup_block.var(fp16_param_name) - broadcast_params.append(fp16_param) - - for param in broadcast_params: + params_name.append(param.name) + if not param.persistable: + name = param.name + shape = param.shape + dtype = param.dtype + type = param.type + lod_level = param.lod_level + stop_gradient = param.stop_gradient + trainable = param.trainable + optimize_attr = param.optimize_attr + regularizer = param.regularizer + + have_dist_attr = False + is_distributed = False + if hasattr(param, 'is_distributed'): + have_dist_attr = True + is_distributed = param.is_distributed + + startup_block._remove_var(name, sync=False) + var = startup_block.create_var( + name=name, + shape=shape, + dtype=dtype, + type=type, + lod_level=lod_level, + stop_gradient=stop_gradient, + trainable=trainable, + persistable=False) + if have_dist_attr: + var.is_distributed = is_distributed + + # offload and optimize_cast will insert broadcast op + broadcast_params = set() + for op in startup_block.ops: + if op.type == 'c_broadcast': + broadcast_params.add(op.desc.output_arg_names()[0]) + + for param in params_name: + if param in broadcast_params: continue startup_block.append_op( type='c_broadcast', inputs={'X': param}, @@ -1399,15 +1446,19 @@ class ShardingOptimizer(MetaOptimizerBase): attrs={ 'ring_id': self.dp_ring_id, 'root': 0, + 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward }) + startup_block.append_op( type='c_sync_comm_stream', - inputs={'X': broadcast_params}, - outputs={'Out': broadcast_params}, + inputs={'X': params_name}, + outputs={'Out': params_name}, attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward}) + startup_block._sync_with_cpp() + # sharding gradient merge def create_persistable_gradients_and_insert_merge_ops( self, main_block, startup_block, insert_idx, grad_names, shard): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py index db8689c14c30f3785a1e38593acbb09756f7692f..6eb566935d9d52ced1444bad96dd16df94832fc0 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_hybrid_meta_optimizer.py @@ -321,6 +321,82 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): 'c_broadcast' ]) + def test_opt_sharding_with_pp_amp_ckp_fuse_gm_optcast(self): + train_prog, startup_prog = static.Program(), static.Program() + avg_cost, strategy = self.pp_net(train_prog, startup_prog) + + self.set_strategy(strategy, 'pipeline') + self.set_strategy(strategy, 'amp') + strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], } + strategy.recompute = True + strategy.recompute_configs = { + "checkpoints": + ["fc_0.tmp_2", "fc_1.tmp_2", "fc_2.tmp_2", "fc_3.tmp_2"] + } + + strategy.sharding = True + strategy.sharding_configs = { + "sharding_degree": 1, + "pp_degree": 2, + "dp_degree": 2, + "_dp_as_optimizer_sharding": True, + 'optimize_cast': True, + } + strategy.fuse_all_reduce_ops = True + strategy.fuse_grad_size_in_MB = 32 + strategy.fuse_grad_merge = True + + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + train_prog = train_prog._pipeline_opt['section_program'] + startup_prog = startup_prog._pipeline_opt['startup_program'] + + # self._debug = True + self.debug_program(train_prog, startup_prog) + + startup_prog_ops = startup_prog.global_block().ops + main_prog_ops = train_prog.global_block().ops + + # check program + startup_prog_op_types = [op.type for op in startup_prog_ops] + main_prog_op_types = [op.type for op in main_prog_ops] + + # global, sharding, pp_send, pp_recv + self.assertEqual(startup_prog_op_types, [ + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', + 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', + 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', + 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', + 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', + 'cast', 'c_broadcast', 'c_sync_comm_stream' + ]) + + self.assertEqual(main_prog_op_types, [ + 'recv_v2', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', + 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'softmax', 'cast', 'cross_entropy2', + 'mean', 'elementwise_mul', 'coalesce_tensor', 'coalesce_tensor', + 'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor', + 'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad', + 'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', 'cast', + 'elementwise_add_grad', 'cast', 'mul_grad', 'cast', 'tanh_grad', + 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', + 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream', + 'send_v2', 'cast', 'sum', 'sum', 'cast', 'sum', 'c_reduce_sum', + 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', + 'check_finite_and_unscale', 'cast', 'c_allreduce_max', + 'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum', + 'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum', + 'momentum', 'cast', 'coalesce_tensor', 'c_broadcast', 'c_broadcast', + 'coalesce_tensor', 'c_broadcast' + ]) + class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 61d98d32ec5fd7f2512a54ef998ae7b8ef392f2e..73eacd118ecad506aa993e65d76f70f3177b3d26 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -922,18 +922,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): # ring: mp, pp_group, pp_pair, pp_pair self.assertEqual(startup_prog_op_types, [ - 'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random', - 'cast', 'fill_constant', 'cast', 'uniform_random', 'cast', - 'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant', - 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', + 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', + 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', + 'c_broadcast', 'c_sync_comm_stream' ]) self.assertEqual(main_prog_op_types, [ @@ -1019,19 +1018,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): # ring: mp, pp_group, pp_pair, pp_pair self.assertEqual(startup_prog_op_types, [ - 'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast', - 'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant', - 'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy', - 'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast', - 'memcpy', 'fill_constant', 'fill_constant', 'fill_constant', + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', 'memcpy', + 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', + 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', + 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'c_sync_comm_stream' ]) @@ -1122,18 +1119,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): # ring: mp, pp_group, pp_pair, pp_pair self.assertEqual(startup_prog_op_types, [ - 'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random', - 'cast', 'fill_constant', 'cast', 'uniform_random', 'cast', - 'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant', - 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'uniform_random', 'fill_constant', 'uniform_random', + 'fill_constant', 'uniform_random', 'fill_constant', + 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream' + 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', + 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', + 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', + 'c_broadcast', 'c_sync_comm_stream' ]) self.assertEqual(main_prog_op_types, [