未验证 提交 bec9fc9a 编写于 作者: W WangXi 提交者: GitHub

[hybrid] Fix model parallel non-distributed param broadcast (#36186)

上级 f703558d
......@@ -25,8 +25,9 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self, ring_id=None):
self.ring_id = ring_id
def __init__(self, mp_ring_id=None, dp_ring_id=None):
self.mp_ring_id = mp_ring_id
self.dp_ring_id = dp_ring_id
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
......@@ -49,20 +50,31 @@ class OffloadHelper(object):
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None:
return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})
def _insert_broadcast_op(self, block, idx, param_name):
rings = []
if self.dp_ring_id is not None:
rings.append(self.dp_ring_id)
# need sync non distributed param in mp group
if self.mp_ring_id is not None:
param = block.var(param_name)
if not hasattr(param, 'is_distributed') or not param.is_distributed:
rings.append(self.mp_ring_id)
# the insert op order is: mp, dp
for ring in rings:
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param_name},
outputs={'Out': param_name},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
......@@ -236,7 +248,7 @@ class OffloadHelper(object):
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
# the insert op order is: {mp, dp}broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)
......@@ -489,6 +501,8 @@ class OffloadHelper(object):
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: {mp, dp}broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)
......
......@@ -467,14 +467,16 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
offload_helper = OffloadHelper(
mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id)
# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper(ring_id=dp_ring_id)
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
......@@ -482,7 +484,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper(ring_id=dp_ring_id)
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block,
......@@ -554,6 +555,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# init param broadcast should be called after startup pruning
self._initialization_broadcast()
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, recreate param as a var
self._recreate_not_persist_param_as_var()
self._dump_program_for_debug()
# GPU need to wait server ready, GPU and NPU is Layered connection
......@@ -1385,23 +1390,14 @@ class ShardingOptimizer(MetaOptimizerBase):
return
def _initialization_broadcast(self):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
if not self.hybrid_dp:
return
startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()
params_name = []
def _recreate_not_persist_param_as_var(self):
def recreate_not_persist_param_as_var(program):
block = program.global_block()
params = block.all_parameters()
for param in params:
if param.persistable:
continue
# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for param in params:
params_name.append(param.name)
if not param.persistable:
name = param.name
shape = param.shape
dtype = param.dtype
......@@ -1411,15 +1407,14 @@ class ShardingOptimizer(MetaOptimizerBase):
trainable = param.trainable
optimize_attr = param.optimize_attr
regularizer = param.regularizer
have_dist_attr = False
is_distributed = False
if hasattr(param, 'is_distributed'):
have_dist_attr = True
is_distributed = param.is_distributed
startup_block._remove_var(name, sync=False)
var = startup_block.create_var(
block._remove_var(name, sync=False)
var = block.create_var(
name=name,
shape=shape,
dtype=dtype,
......@@ -1431,6 +1426,31 @@ class ShardingOptimizer(MetaOptimizerBase):
if have_dist_attr:
var.is_distributed = is_distributed
block._sync_with_cpp()
recreate_not_persist_param_as_var(self._startup_program)
recreate_not_persist_param_as_var(self._main_program)
def _initialization_broadcast(self):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical.
"""
if self.dp_degree <= 1 and self.mp_degree <= 1:
return
startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()
params_name = []
not_dist_param_name = set()
for param in params:
params_name.append(param.name)
if not hasattr(param, 'is_distributed') or not param.is_distributed:
not_dist_param_name.add(param.name)
# offload and optimize_cast will insert broadcast op
broadcast_params = set()
for op in startup_block.ops:
......@@ -1439,23 +1459,25 @@ class ShardingOptimizer(MetaOptimizerBase):
for param in params_name:
if param in broadcast_params: continue
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': params_name},
outputs={'Out': params_name},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
rings = []
# need sync not distributed param in mp group
if self.mp_degree > 1 and param in not_dist_param_name:
rings.append(self.mp_ring_id)
if self.dp_degree > 1:
rings.append(self.dp_ring_id)
for ring in rings:
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
......
......@@ -72,8 +72,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -155,8 +154,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -218,7 +216,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -292,7 +290,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -371,7 +369,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'c_sync_comm_stream'
'cast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -460,7 +458,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
'c_comm_init', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -511,7 +509,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......
......@@ -655,7 +655,9 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -764,7 +766,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -932,7 +934,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -1029,7 +1031,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -1129,7 +1131,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......@@ -1221,7 +1223,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast'
])
self.assertEqual(main_prog_op_types, [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册