未验证 提交 2cd05d5d 编写于 作者: W WangXi 提交者: GitHub

[hybrid] refine pipeline stage and mp send/recv check (#34870)

上级 8c8667f0
...@@ -4397,6 +4397,10 @@ class PipelineOptimizer(object): ...@@ -4397,6 +4397,10 @@ class PipelineOptimizer(object):
return op_role & int(self._op_role.Backward) and op_role & int( return op_role & int(self._op_role.Backward) and op_role & int(
self._op_role.Loss) self._op_role.Loss)
def _is_forward_op(self, op):
return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) == int(self._op_role.Forward))
def _is_backward_op(self, op): def _is_backward_op(self, op):
return self._op_role_key in op.attr_names and ( return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) & int(self._op_role.Backward)) int(op.attr(self._op_role_key)) & int(self._op_role.Backward))
...@@ -4705,10 +4709,6 @@ class PipelineOptimizer(object): ...@@ -4705,10 +4709,6 @@ class PipelineOptimizer(object):
int(self._op_role.Optimize), int(self._op_role.Optimize),
int(self._op_role.Backward) | int(self._op_role.Loss), int(self._op_role.Backward) | int(self._op_role.Loss),
] ]
pre_stage_id = None
decrease_flag = False
in_optimize = False
in_forward = True
for op in block.ops: for op in block.ops:
if not op._has_kernel(op.type): if not op._has_kernel(op.type):
assert op.type == "conditional_block" and ( assert op.type == "conditional_block" and (
...@@ -4724,10 +4724,6 @@ class PipelineOptimizer(object): ...@@ -4724,10 +4724,6 @@ class PipelineOptimizer(object):
op_role, op_role,
op.type, op.type,
valid_op_role_value) valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
if int(op_role) == int(self._op_role.Backward):
in_forward = False
assert op.has_attr(self._op_device_key), ( assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type, "op ({}) has no {} attribute.".format(op.type,
...@@ -4739,7 +4735,6 @@ class PipelineOptimizer(object): ...@@ -4739,7 +4735,6 @@ class PipelineOptimizer(object):
if device == f"{self._device}:all": continue if device == f"{self._device}:all": continue
dev_type = device.split(':')[0] dev_type = device.split(':')[0]
stage_id = int(device.split(':')[1])
assert dev_type == "gpu" or dev_type == 'npu', ( assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported " "Now only gpu and npu devices are supported "
"for pipeline parallelism.") "for pipeline parallelism.")
...@@ -4747,26 +4742,6 @@ class PipelineOptimizer(object): ...@@ -4747,26 +4742,6 @@ class PipelineOptimizer(object):
if device not in device_list: if device not in device_list:
device_list.append(device) device_list.append(device)
if not in_optimize:
if pre_stage_id is not None:
interval = stage_id - pre_stage_id
assert abs(interval) <= 1, \
"The stage interval of two consecutive ops in the pipeline must be < = 1," \
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if in_forward:
assert interval >= 0, \
"Pipeline stage must be sequential increment in Forward, prev_stage={}, " \
"please check the stage of op={}".format(pre_stage_id, op)
else:
# FIXME(wangxi): recompute check failed
pass
#assert interval <=0, \
# "Pipeline stage must be sequential decrement in Backward, prev_stage={}, " \
# "please check the stage of op={}".format(pre_stage_id, op)
pre_stage_id = stage_id
return device_list return device_list
def _insert_sendrecv_ops_for_boundaries(self, block): def _insert_sendrecv_ops_for_boundaries(self, block):
...@@ -4820,6 +4795,25 @@ class PipelineOptimizer(object): ...@@ -4820,6 +4795,25 @@ class PipelineOptimizer(object):
device_type = cur_device.split(':')[0] + ':' device_type = cur_device.split(':')[0] + ':'
def _check_stage(cur_id, prev_id):
# check send/recv stage valid
is_forward = self._is_forward_op(op)
is_backward = self._is_backward_op(op)
assert is_forward or is_backward, \
'send/recv in pipeline should only be inserted in forward or backward,' \
'please check the op_role of op={}'.format(op)
if is_forward:
assert prev_id < cur_id, \
"In forward, send/recv can only be passed forward, but now " \
"prev_stage={} great than cur_stage={}, please check op_device of op={}".format(
prev_id, cur_id, op)
elif is_backward:
assert prev_id > cur_id, \
"In backward, send/recv can only be passed backward, but now " \
"prev_stage={} less than cur_stage={}, please check op_device of op={}".format(
prev_id, cur_id, op)
def _insert_send_recv(cur_id, prev_id): def _insert_send_recv(cur_id, prev_id):
cur_dev = device_type + str(cur_id) cur_dev = device_type + str(cur_id)
prev_dev = device_type + str(prev_id) prev_dev = device_type + str(prev_id)
...@@ -4890,9 +4884,9 @@ class PipelineOptimizer(object): ...@@ -4890,9 +4884,9 @@ class PipelineOptimizer(object):
var_shape[0] = self.micro_batch_size if var_shape[ var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0] 0] < 0 else var_shape[0]
numel = np.prod(var.shape) numel = np.prod(var_shape)
assert numel % self.mp_degree == 0, \ use_mp = (self.mp_degree > 1) and (
"The numel={} must be divisible by mp_degree={}".format(numel, self.mp_degree) numel % self.mp_degree == 0)
if 'subprog' in var.name: if 'subprog' in var.name:
# For recompute, if the checkpoints var is layer_norm_6.tmp_2 # For recompute, if the checkpoints var is layer_norm_6.tmp_2
...@@ -4919,6 +4913,8 @@ class PipelineOptimizer(object): ...@@ -4919,6 +4913,8 @@ class PipelineOptimizer(object):
extra_index_info['index'] += 1 extra_index_info['index'] += 1
return return
_check_stage(cur_id, prev_id)
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='c_sync_calc_stream', type='c_sync_calc_stream',
...@@ -4931,8 +4927,7 @@ class PipelineOptimizer(object): ...@@ -4931,8 +4927,7 @@ class PipelineOptimizer(object):
extra_index_info['index'] += 1 extra_index_info['index'] += 1
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='send_v2' type='send_v2' if not use_mp else 'partial_send',
if self.mp_degree == 1 else 'partial_send',
inputs={'X': var}, inputs={'X': var},
attrs={ attrs={
self._op_device_key: prev_dev, self._op_device_key: prev_dev,
...@@ -4968,8 +4963,7 @@ class PipelineOptimizer(object): ...@@ -4968,8 +4963,7 @@ class PipelineOptimizer(object):
extra_index_info['index'] += 1 extra_index_info['index'] += 1
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='recv_v2' type='recv_v2' if not use_mp else 'partial_recv',
if self.mp_degree == 1 else 'partial_recv',
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var_shape, 'out_shape': var_shape,
...@@ -4984,7 +4978,7 @@ class PipelineOptimizer(object): ...@@ -4984,7 +4978,7 @@ class PipelineOptimizer(object):
'id': self.mp_rank, 'id': self.mp_rank,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
if self.mp_degree > 1: if use_mp:
block._insert_op_without_sync( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='partial_allgather', type='partial_allgather',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册