未验证 提交 0c037d2d 编写于 作者: F fangshuixun007 提交者: GitHub

fix test sync_with_cpp (#32212)

fix test sync_with_cpp (#32212)
上级 e6bc358d
...@@ -116,7 +116,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -116,7 +116,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
persistable=False, persistable=False,
stop_gradient=in_var.stop_gradient) stop_gradient=in_var.stop_gradient)
block._insert_op( block._insert_op_without_sync(
idx, idx,
type="cast", type="cast",
inputs={"X": in_var}, inputs={"X": in_var},
...@@ -490,6 +490,7 @@ def rewrite_program(main_prog, amp_lists): ...@@ -490,6 +490,7 @@ def rewrite_program(main_prog, amp_lists):
main_prog (Program): The main program for training. main_prog (Program): The main program for training.
""" """
block = main_prog.global_block() block = main_prog.global_block()
block._sync_with_cpp()
ops = block.ops ops = block.ops
white_op_set = set() white_op_set = set()
black_op_set = set() black_op_set = set()
...@@ -578,6 +579,7 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -578,6 +579,7 @@ def update_role_var_grad(main_prog, params_grads):
params_grads (list): A list of params and grads. params_grads (list): A list of params and grads.
""" """
block = main_prog.global_block() block = main_prog.global_block()
block._sync_with_cpp()
BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for p, g in params_grads: for p, g in params_grads:
...@@ -585,7 +587,7 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -585,7 +587,7 @@ def update_role_var_grad(main_prog, params_grads):
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
role = op.attr('op_role') role = op.attr('op_role')
if role & int(BACKWARD) and op.has_attr('op_role_var'): if role & int(BACKWARD) and op.has_attr('op_role_var'):
op.desc.remove_attr("op_role_var") op._remove_attr("op_role_var")
else: else:
raise ValueError("The cast op {0} must be in BACKWARD role " raise ValueError("The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op)) "and have op_role_var attr.".format(op))
...@@ -610,11 +612,19 @@ def update_role_var_grad(main_prog, params_grads): ...@@ -610,11 +612,19 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The cast op {0}'s output should not be" raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it" "used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])) "is used by {1}".format(op, post_ops[0]))
#add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op = framework.Operator(
block=block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None)
block.ops.append(new_op)
op_idx = find_op_index(block.desc, op.desc) op_idx = find_op_index(block.desc, op.desc)
if op_idx == -1: if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op)) raise ValueError("The op {0} is not in program".format(op))
block.desc._remove_op(op_idx, op_idx + 1) block._remove_op(op_idx, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -3239,10 +3239,7 @@ class Block(object): ...@@ -3239,10 +3239,7 @@ class Block(object):
Operator: the insert Operator. Operator: the insert Operator.
""" """
self._sync_with_cpp() self._sync_with_cpp()
op_desc = self.desc._insert_op(index) return self._insert_op_without_sync(index, *args, **kwargs)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
def _insert_op_without_sync(self, index, *args, **kwargs): def _insert_op_without_sync(self, index, *args, **kwargs):
""" """
......
...@@ -4352,7 +4352,7 @@ class PipelineOptimizer(object): ...@@ -4352,7 +4352,7 @@ class PipelineOptimizer(object):
ring_id = self._pp_ring_map[pair_key] ring_id = self._pp_ring_map[pair_key]
if self.schedule_mode == 'F-then-B': # F-then-B if self.schedule_mode == 'F-then-B': # F-then-B
block._insert_op( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='send_v2', type='send_v2',
inputs={'X': var}, inputs={'X': var},
...@@ -4364,7 +4364,7 @@ class PipelineOptimizer(object): ...@@ -4364,7 +4364,7 @@ class PipelineOptimizer(object):
'ring_id': ring_id 'ring_id': ring_id
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
block._insert_op( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='recv_v2', type='recv_v2',
outputs={'Out': [var]}, outputs={'Out': [var]},
...@@ -4379,7 +4379,7 @@ class PipelineOptimizer(object): ...@@ -4379,7 +4379,7 @@ class PipelineOptimizer(object):
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
elif self.schedule_mode == '1F1B': # 1F1B elif self.schedule_mode == '1F1B': # 1F1B
block._insert_op( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='c_sync_calc_stream', type='c_sync_calc_stream',
inputs={'X': [var]}, inputs={'X': [var]},
...@@ -4389,7 +4389,7 @@ class PipelineOptimizer(object): ...@@ -4389,7 +4389,7 @@ class PipelineOptimizer(object):
self._op_role_key: op_role, self._op_role_key: op_role,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
block._insert_op( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='send_v2', type='send_v2',
inputs={'X': var}, inputs={'X': var},
...@@ -4409,7 +4409,7 @@ class PipelineOptimizer(object): ...@@ -4409,7 +4409,7 @@ class PipelineOptimizer(object):
else: else:
insert_index = index insert_index = index
new_op_role = self._op_role.Backward new_op_role = self._op_role.Backward
block._insert_op( block._insert_op_without_sync(
index=insert_index + extra_index_info['index'], index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': [var]}, inputs={'X': [var]},
...@@ -4424,7 +4424,7 @@ class PipelineOptimizer(object): ...@@ -4424,7 +4424,7 @@ class PipelineOptimizer(object):
var_shape = list(var.shape) var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[ var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0] 0] < 0 else var_shape[0]
block._insert_op( block._insert_op_without_sync(
index=index + extra_index_info['index'], index=index + extra_index_info['index'],
type='recv_v2', type='recv_v2',
outputs={'Out': [var]}, outputs={'Out': [var]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册