未验证 提交 67c63639 编写于 作者: W WangXi 提交者: GitHub

[hybird] fix pipeline section program Parameter (#35847)

上级 5ba9fe6e
......@@ -213,6 +213,7 @@ class OffloadHelper(object):
if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
......
......@@ -1380,10 +1380,18 @@ class ShardingOptimizer(MetaOptimizerBase):
return
startup_block = self._startup_program.global_block()
params = []
for param in startup_block.iter_parameters():
params.append(param)
params = startup_block.all_parameters()
broadcast_params = []
for param in params:
broadcast_params.append(param)
# optimize_cast need broadcast fp16 param
fp16_param_name = param.name + '.cast_fp16'
if startup_block.has_var(fp16_param_name):
fp16_param = startup_block.var(fp16_param_name)
broadcast_params.append(fp16_param)
for param in broadcast_params:
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
......@@ -1395,8 +1403,8 @@ class ShardingOptimizer(MetaOptimizerBase):
})
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
inputs={'X': broadcast_params},
outputs={'Out': broadcast_params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
......
......@@ -4381,6 +4381,18 @@ class PipelineOptimizer(object):
name=var,
type=core.VarDesc.VarType.READER,
persistable=source_var.persistable)
elif isinstance(source_var, Parameter):
dest_var = block.create_parameter(
name=source_var.name,
shape=source_var.shape,
dtype=source_var.dtype,
type=source_var.type,
lod_level=source_var.lod_level,
stop_gradient=source_var.stop_gradient,
trainable=source_var.trainable,
optimize_attr=source_var.optimize_attr,
regularizer=source_var.regularizer,
error_clip=source_var.error_clip)
else:
dest_var = block._clone_variable(source_var, False)
self._clone_var_attr(dest_var, source_var)
......
......@@ -71,6 +71,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
])
......@@ -152,6 +154,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
])
......@@ -212,7 +216,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -284,7 +290,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -376,7 +384,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -427,7 +435,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......
......@@ -762,7 +762,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -928,7 +930,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -1023,7 +1028,11 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -1121,7 +1130,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......@@ -1211,7 +1223,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])
self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册