未验证 提交 3d6bd6a4 编写于 作者: C Charles-hit 提交者: GitHub

add check ops for prim (#52302)

* add check ops for prim

* fix pow and concat composite registration

* modify log

* add note and remove useless code

* remove useless code

* modify program to check

* remove useless note
上级 3d4d7c19
......@@ -605,6 +605,11 @@ class PrimForwardChecker:
)
ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
......@@ -675,6 +680,16 @@ class PrimForwardChecker:
)
net = PrimNet(self.public_python_api)
net = apply_to_static(net, False)
# ensure the operator not in program if check_prim is True
forward_ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.forward_program.block(0)
.ops
]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
ret = flatten(_as_list(net(args)))
ret = paddle.utils.map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
......@@ -761,6 +776,16 @@ class PrimForwardChecker:
net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn
)
# check the operator not in program if check prim is True
forward_ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.forward_program.block(0)
.ops
]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
ret = flatten(_as_list(net(args)))
ret = paddle.utils.map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
......@@ -1055,6 +1080,12 @@ class PrimGradChecker(PrimForwardChecker):
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
......@@ -1140,6 +1171,17 @@ class PrimGradChecker(PrimForwardChecker):
)
net = PrimNet(self.public_python_api)
net = apply_to_static(net, False)
# check the backward operator not in program when check_prim is True
ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.backward_program.block(0)
.ops
]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
......@@ -1259,6 +1301,18 @@ class PrimGradChecker(PrimForwardChecker):
net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn
)
# check the backward operator not in program when check_prim is True
ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.backward_program.block(0)
.ops
]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册