未验证 提交 0c037d2d 编写于 作者: F fangshuixun007 提交者: GitHub

fix test sync_with_cpp (#32212)

fix test sync_with_cpp (#32212)
上级 e6bc358d
......@@ -116,7 +116,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
persistable=False,
stop_gradient=in_var.stop_gradient)
block._insert_op(
block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
......@@ -490,6 +490,7 @@ def rewrite_program(main_prog, amp_lists):
main_prog (Program): The main program for training.
"""
block = main_prog.global_block()
block._sync_with_cpp()
ops = block.ops
white_op_set = set()
black_op_set = set()
......@@ -578,6 +579,7 @@ def update_role_var_grad(main_prog, params_grads):
params_grads (list): A list of params and grads.
"""
block = main_prog.global_block()
block._sync_with_cpp()
BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for p, g in params_grads:
......@@ -585,7 +587,7 @@ def update_role_var_grad(main_prog, params_grads):
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
role = op.attr('op_role')
if role & int(BACKWARD) and op.has_attr('op_role_var'):
op.desc.remove_attr("op_role_var")
op._remove_attr("op_role_var")
else:
raise ValueError("The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op))
......@@ -610,11 +612,19 @@ def update_role_var_grad(main_prog, params_grads):
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
#add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = framework.Operator(
block=block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None)
block.ops.append(new_op)
op_idx = find_op_index(block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
block.desc._remove_op(op_idx, op_idx + 1)
block._sync_with_cpp()
block._remove_op(op_idx, sync=False)
block._sync_with_cpp()
......@@ -3239,10 +3239,7 @@ class Block(object):
Operator: the insert Operator.
"""
self._sync_with_cpp()
op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
return self._insert_op_without_sync(index, *args, **kwargs)
def _insert_op_without_sync(self, index, *args, **kwargs):
"""
......
......@@ -4352,7 +4352,7 @@ class PipelineOptimizer(object):
ring_id = self._pp_ring_map[pair_key]
if self.schedule_mode == 'F-then-B': # F-then-B
block._insert_op(
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='send_v2',
inputs={'X': var},
......@@ -4364,7 +4364,7 @@ class PipelineOptimizer(object):
'ring_id': ring_id
})
extra_index_info['index'] += 1
block._insert_op(
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='recv_v2',
outputs={'Out': [var]},
......@@ -4379,7 +4379,7 @@ class PipelineOptimizer(object):
})
extra_index_info['index'] += 1
elif self.schedule_mode == '1F1B': # 1F1B
block._insert_op(
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='c_sync_calc_stream',
inputs={'X': [var]},
......@@ -4389,7 +4389,7 @@ class PipelineOptimizer(object):
self._op_role_key: op_role,
})
extra_index_info['index'] += 1
block._insert_op(
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='send_v2',
inputs={'X': var},
......@@ -4409,7 +4409,7 @@ class PipelineOptimizer(object):
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op(
block._insert_op_without_sync(
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
inputs={'X': [var]},
......@@ -4424,7 +4424,7 @@ class PipelineOptimizer(object):
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
block._insert_op(
block._insert_op_without_sync(
index=index + extra_index_info['index'],
type='recv_v2',
outputs={'Out': [var]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册