未验证 提交 3d6bd6a4 编写于 作者: C Charles-hit 提交者: GitHub

add check ops for prim (#52302)

* add check ops for prim

* fix pow and concat composite registration

* modify log

* add note and remove useless code

* remove useless code

* modify program to check

* remove useless note
上级 3d4d7c19
...@@ -605,6 +605,11 @@ class PrimForwardChecker: ...@@ -605,6 +605,11 @@ class PrimForwardChecker:
) )
ret = flatten(_as_list(self.public_python_api(*args))) ret = flatten(_as_list(self.public_python_api(*args)))
primapi.to_prim(main_program.blocks) primapi.to_prim(main_program.blocks)
# ensure the operator not in program if check_prim is True
forward_ops = [op.type for op in main_program.blocks[0].ops]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(startup_program) exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret) ret = exe.run(main_program, feed=feed, fetch_list=ret)
...@@ -675,6 +680,16 @@ class PrimForwardChecker: ...@@ -675,6 +680,16 @@ class PrimForwardChecker:
) )
net = PrimNet(self.public_python_api) net = PrimNet(self.public_python_api)
net = apply_to_static(net, False) net = apply_to_static(net, False)
# ensure the operator not in program if check_prim is True
forward_ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.forward_program.block(0)
.ops
]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
ret = flatten(_as_list(net(args))) ret = flatten(_as_list(net(args)))
ret = paddle.utils.map_structure(lambda x: x.numpy(), ret) ret = paddle.utils.map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype): if OpTestUtils.is_bfloat16_type(self.dtype):
...@@ -761,6 +776,16 @@ class PrimForwardChecker: ...@@ -761,6 +776,16 @@ class PrimForwardChecker:
net = apply_to_static( net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn net, core.is_compiled_with_cinn() and self.enable_cinn
) )
# check the operator not in program if check prim is True
forward_ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.forward_program.block(0)
.ops
]
assert self.op_type not in forward_ops, (
"%s shouldn't appear in program when check_prim is True"
) % (self.op_type)
ret = flatten(_as_list(net(args))) ret = flatten(_as_list(net(args)))
ret = paddle.utils.map_structure(lambda x: x.numpy(), ret) ret = paddle.utils.map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype): if OpTestUtils.is_bfloat16_type(self.dtype):
...@@ -1055,6 +1080,12 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1055,6 +1080,12 @@ class PrimGradChecker(PrimForwardChecker):
var_dict={**inputs_dict, **outputs_dict} var_dict={**inputs_dict, **outputs_dict}
) )
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars) ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
# check the backward operator not in program when check_prim is True
ops = [op.type for op in main_program.blocks[0].ops]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
exe = paddle.static.Executor(self.place) exe = paddle.static.Executor(self.place)
exe.run(startup_program) exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret) actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
...@@ -1140,6 +1171,17 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1140,6 +1171,17 @@ class PrimGradChecker(PrimForwardChecker):
) )
net = PrimNet(self.public_python_api) net = PrimNet(self.public_python_api)
net = apply_to_static(net, False) net = apply_to_static(net, False)
# check the backward operator not in program when check_prim is True
ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.backward_program.block(0)
.ops
]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
out = _as_list(net(args)) out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"): if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig outputs_sig = self.op_test.python_out_sig
...@@ -1259,6 +1301,18 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1259,6 +1301,18 @@ class PrimGradChecker(PrimForwardChecker):
net = apply_to_static( net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn net, core.is_compiled_with_cinn() and self.enable_cinn
) )
# check the backward operator not in program when check_prim is True
ops = [
op.type
for op in net.forward.get_concrete_program(args)[1]
.backward_program.block(0)
.ops
]
backward_op_type = self.op_type + "_grad"
assert backward_op_type not in ops, (
"%s shouldn't appear in program when check_prim is True"
) % (backward_op_type)
out = _as_list(net(args)) out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"): if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig outputs_sig = self.op_test.python_out_sig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册