未验证 提交 eef0a943 编写于 作者: W WangXi 提交者: GitHub

[hybrid] optimizer sharding support optimize cast (#35878)

上级 d5268a6e
......@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name
from .shard import Shard
__all__ = []
......@@ -23,11 +25,8 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def __init__(self, ring_id=None):
self.ring_id = ring_id
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
......@@ -50,6 +49,21 @@ class OffloadHelper(object):
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None:
return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
......@@ -206,6 +220,8 @@ class OffloadHelper(object):
# step5: startup_block add offload
visited_vars = set()
# FIXME(wangxi): should insert in idx, need move comm init to the head.
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
......@@ -213,13 +229,16 @@ class OffloadHelper(object):
if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
self._insert_offload_op(startup_block, insert_idx,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)
visited_vars.add(out_name)
......@@ -303,3 +322,181 @@ class OffloadHelper(object):
block._sync_with_cpp()
startup_block._sync_with_cpp()
def opt_sharding_cast_fp32param(self,
block,
startup_block,
params,
offload=False):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
broadcast(p_fp16)
"""
global_params = set()
local_params = set()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
global_params.remove(input_name)
if input_name in local_params:
local_params.remove(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
global_params = set(params)
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
local_params.add(param)
# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
if op.type == 'coalesce_tensor':
continue
for input_name in op.desc.input_arg_names():
if input_name not in global_params:
continue
# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in global_params:
continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in global_params:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: remove fp32 param which not need
for idx, op in enumerate(block.ops):
if op.type not in ['coalesce_tensor', 'c_broadcast']:
continue
for input_name in op.desc.input_arg_names():
if input_name in param_to_fp16:
op._rename_input(input_name, param_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in param_to_fp16:
op._rename_output(output_name, param_to_fp16[output_name])
for param in global_params:
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
if param not in local_params:
block._remove_var(param, sync=False)
# step6: startup_block add offload
visited_vars = set()
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars: continue
if out_name in param_to_fp16:
var_name = out_name
if offload:
self._insert_offload_op(
startup_block, idx + 1, var_name,
param_name_to_offload_name[var_name])
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
self._insert_broadcast_op(startup_block, insert_idx,
var_name)
if var_name not in local_params:
param = startup_block.var(out_name)
param.persistable = False
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -14,7 +14,7 @@
import paddle
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
import re
......@@ -366,6 +366,24 @@ def insert_allreduce_ops(block,
class FuseHelper(object):
@staticmethod
def sort_vars_by_dtype(block, vars_name):
fp32_vars = []
fp16_vars = []
other_vars = []
for var in vars_name:
dtype = block.var(var).dtype
if dtype == paddle.float32:
fp32_vars.append(var)
elif dtype == paddle.float16:
fp16_vars.append(var)
else:
other_vars.append(var)
assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse"
fp32_vars.extend(fp16_vars)
return fp32_vars
@staticmethod
def get_fused_groups(block, vars_name, fuse_size=32.):
""" coalesce tensor, get fused group """
......@@ -639,6 +657,54 @@ def insert_broadcast_param_ops(block,
return param_in_this_device
def fuse_opt_broadcast_param_ops(block,
ring_id,
shard,
op_role=OpRole.Optimize,
strategy=None):
"""
fuse optimizer sharding broadcast param ops
"""
if strategy is None or not strategy.fuse_all_reduce_ops:
return
fuse_size = strategy.fuse_grad_size_in_MB
nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op) or op.type != 'c_broadcast':
break
var = op.input_arg_names[0]
root_id = op.attr('root')
device_to_vars[root_id].insert(0, var)
block._remove_op(idx, sync=False)
insert_idx = idx + 1
for root_id, vars_name in enumerate(device_to_vars):
vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Param")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_broadcast',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': True,
OP_ROLE_KEY: op_role
})
block._sync_with_cpp()
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
......
......@@ -329,6 +329,7 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree == 1: return
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -399,6 +400,8 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
# NOTE(wangxi): we fused after optimize_cast
optimize_cast = sharding_configs['optimize_cast']
optimizer_param = utils.insert_broadcast_param_ops(
main_block,
len_of_ops,
......@@ -407,10 +410,10 @@ class ShardingOptimizer(MetaOptimizerBase):
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
strategy=None if optimize_cast else strategy)
logger.info("Optimizer param in this rank {}".format(
optimizer_param))
if not strategy.fuse_grad_merge:
if not strategy.fuse_grad_merge and not optimize_cast:
assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
insert_allreduce_ops(
......@@ -458,18 +461,20 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp()
def _apply_optimize_offload_pass(self):
def _apply_optimize_offload_pass(self, params_grads):
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper = OffloadHelper(ring_id=dp_ring_id)
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
......@@ -477,8 +482,17 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper()
offload_helper.cast_fp32param_in_optimize(main_block, startup_block)
offload_helper = OffloadHelper(ring_id=dp_ring_id)
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block,
[x[0].name for x in params_grads])
# NOTE(wangxi): fused after optimize_cast
utils.fuse_opt_broadcast_param_ops(
main_block, dp_ring_id, self._shard, strategy=strategy)
else:
offload_helper.cast_fp32param_in_optimize(main_block,
startup_block)
def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
......@@ -525,7 +539,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._insert_loss_grad_scale_op()
# apply optimize offload or optimize cast
self._apply_optimize_offload_pass()
self._apply_optimize_offload_pass(params_grads)
# step6: (optional) sharding gradient merge
self._sharding_gradient_merge()
......@@ -1381,17 +1395,50 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()
params_name = []
broadcast_params = []
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for param in params:
broadcast_params.append(param)
# optimize_cast need broadcast fp16 param
fp16_param_name = param.name + '.cast_fp16'
if startup_block.has_var(fp16_param_name):
fp16_param = startup_block.var(fp16_param_name)
broadcast_params.append(fp16_param)
for param in broadcast_params:
params_name.append(param.name)
if not param.persistable:
name = param.name
shape = param.shape
dtype = param.dtype
type = param.type
lod_level = param.lod_level
stop_gradient = param.stop_gradient
trainable = param.trainable
optimize_attr = param.optimize_attr
regularizer = param.regularizer
have_dist_attr = False
is_distributed = False
if hasattr(param, 'is_distributed'):
have_dist_attr = True
is_distributed = param.is_distributed
startup_block._remove_var(name, sync=False)
var = startup_block.create_var(
name=name,
shape=shape,
dtype=dtype,
type=type,
lod_level=lod_level,
stop_gradient=stop_gradient,
trainable=trainable,
persistable=False)
if have_dist_attr:
var.is_distributed = is_distributed
# offload and optimize_cast will insert broadcast op
broadcast_params = set()
for op in startup_block.ops:
if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names()[0])
for param in params_name:
if param in broadcast_params: continue
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
......@@ -1399,15 +1446,19 @@ class ShardingOptimizer(MetaOptimizerBase):
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': broadcast_params},
outputs={'Out': broadcast_params},
inputs={'X': params_name},
outputs={'Out': params_name},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
startup_block._sync_with_cpp()
# sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops(
self, main_block, startup_block, insert_idx, grad_names, shard):
......
......@@ -321,6 +321,82 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_broadcast'
])
def test_opt_sharding_with_pp_amp_ckp_fuse_gm_optcast(self):
train_prog, startup_prog = static.Program(), static.Program()
avg_cost, strategy = self.pp_net(train_prog, startup_prog)
self.set_strategy(strategy, 'pipeline')
self.set_strategy(strategy, 'amp')
strategy.amp_configs = {'custom_black_varnames': ['fc_6.b_0'], }
strategy.recompute = True
strategy.recompute_configs = {
"checkpoints":
["fc_0.tmp_2", "fc_1.tmp_2", "fc_2.tmp_2", "fc_3.tmp_2"]
}
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 1,
"pp_degree": 2,
"dp_degree": 2,
"_dp_as_optimizer_sharding": True,
'optimize_cast': True,
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
strategy.fuse_grad_merge = True
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
train_prog = train_prog._pipeline_opt['section_program']
startup_prog = startup_prog._pipeline_opt['startup_program']
# self._debug = True
self.debug_program(train_prog, startup_prog)
startup_prog_ops = startup_prog.global_block().ops
main_prog_ops = train_prog.global_block().ops
# check program
startup_prog_op_types = [op.type for op in startup_prog_ops]
main_prog_op_types = [op.type for op in main_prog_ops]
# global, sharding, pp_send, pp_recv
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
'recv_v2', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'coalesce_tensor', 'coalesce_tensor',
'coalesce_tensor', 'coalesce_tensor', 'coalesce_tensor',
'coalesce_tensor', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad', 'cast',
'elementwise_add_grad', 'cast', 'mul_grad', 'cast', 'tanh_grad',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
'send_v2', 'cast', 'sum', 'sum', 'cast', 'sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'check_finite_and_unscale', 'cast', 'c_allreduce_max',
'c_allreduce_max', 'cast', 'update_loss_scaling', 'momentum',
'cast', 'momentum', 'cast', 'momentum', 'cast', 'momentum',
'momentum', 'cast', 'coalesce_tensor', 'c_broadcast', 'c_broadcast',
'coalesce_tensor', 'c_broadcast'
])
class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
def setUp(self):
......
......@@ -922,18 +922,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random',
'cast', 'fill_constant', 'cast', 'uniform_random', 'cast',
'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -1019,19 +1018,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'memcpy', 'fill_constant', 'cast',
'memcpy', 'uniform_random', 'cast', 'memcpy', 'fill_constant',
'cast', 'memcpy', 'uniform_random', 'cast', 'memcpy',
'fill_constant', 'cast', 'memcpy', 'uniform_random', 'cast',
'memcpy', 'fill_constant', 'fill_constant', 'fill_constant',
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'c_sync_comm_stream'
])
......@@ -1122,18 +1119,17 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
# ring: mp, pp_group, pp_pair, pp_pair
self.assertEqual(startup_prog_op_types, [
'uniform_random', 'cast', 'fill_constant', 'cast', 'uniform_random',
'cast', 'fill_constant', 'cast', 'uniform_random', 'cast',
'fill_constant', 'cast', 'uniform_random', 'cast', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'uniform_random', 'fill_constant', 'uniform_random',
'fill_constant', 'uniform_random', 'fill_constant',
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册