未验证 提交 67c63639 编写于 作者: W WangXi 提交者: GitHub

[hybird] fix pipeline section program Parameter (#35847)

上级 5ba9fe6e
...@@ -213,6 +213,7 @@ class OffloadHelper(object): ...@@ -213,6 +213,7 @@ class OffloadHelper(object):
if out_name in param_name_to_offload_name: if out_name in param_name_to_offload_name:
var_name = out_name var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload: if offload:
offload_var_name = param_name_to_offload_name[var_name] offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, self._insert_offload_op(startup_block, idx + 1,
......
...@@ -1380,10 +1380,18 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1380,10 +1380,18 @@ class ShardingOptimizer(MetaOptimizerBase):
return return
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()
params = []
for param in startup_block.iter_parameters(): broadcast_params = []
params.append(param) for param in params:
broadcast_params.append(param)
# optimize_cast need broadcast fp16 param
fp16_param_name = param.name + '.cast_fp16'
if startup_block.has_var(fp16_param_name):
fp16_param = startup_block.var(fp16_param_name)
broadcast_params.append(fp16_param)
for param in broadcast_params:
startup_block.append_op( startup_block.append_op(
type='c_broadcast', type='c_broadcast',
inputs={'X': param}, inputs={'X': param},
...@@ -1395,8 +1403,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1395,8 +1403,8 @@ class ShardingOptimizer(MetaOptimizerBase):
}) })
startup_block.append_op( startup_block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': params}, inputs={'X': broadcast_params},
outputs={'Out': params}, outputs={'Out': broadcast_params},
attrs={'ring_id': self.dp_ring_id, attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward}) OP_ROLE_KEY: OpRole.Forward})
......
...@@ -4381,6 +4381,18 @@ class PipelineOptimizer(object): ...@@ -4381,6 +4381,18 @@ class PipelineOptimizer(object):
name=var, name=var,
type=core.VarDesc.VarType.READER, type=core.VarDesc.VarType.READER,
persistable=source_var.persistable) persistable=source_var.persistable)
elif isinstance(source_var, Parameter):
dest_var = block.create_parameter(
name=source_var.name,
shape=source_var.shape,
dtype=source_var.dtype,
type=source_var.type,
lod_level=source_var.lod_level,
stop_gradient=source_var.stop_gradient,
trainable=source_var.trainable,
optimize_attr=source_var.optimize_attr,
regularizer=source_var.regularizer,
error_clip=source_var.error_clip)
else: else:
dest_var = block._clone_variable(source_var, False) dest_var = block._clone_variable(source_var, False)
self._clone_var_attr(dest_var, source_var) self._clone_var_attr(dest_var, source_var)
......
...@@ -71,6 +71,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): ...@@ -71,6 +71,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream' 'c_sync_comm_stream'
]) ])
...@@ -152,6 +154,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): ...@@ -152,6 +154,8 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream' 'c_sync_comm_stream'
]) ])
...@@ -212,7 +216,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): ...@@ -212,7 +216,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -284,7 +290,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer): ...@@ -284,7 +290,9 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -376,7 +384,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): ...@@ -376,7 +384,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_comm_stream' 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -427,7 +435,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer): ...@@ -427,7 +435,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant', 'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
......
...@@ -762,7 +762,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -762,7 +762,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -928,7 +930,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -928,7 +930,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream' 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -1023,7 +1028,11 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -1023,7 +1028,11 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -1121,7 +1130,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -1121,7 +1130,10 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream' 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
...@@ -1211,7 +1223,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -1211,7 +1223,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream' 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
]) ])
self.assertEqual(main_prog_op_types, [ self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册