未验证 提交 615d931c 编写于 作者: J Jiabin Yang 提交者: GitHub

Support to onnx test (#42698)

* support to onnx test

* add comments

* remove log

* remove log

* update paddle2onnx version
上级 e3ee2ad8
......@@ -21,7 +21,7 @@ import numpy as np
import paddle
from paddle.static import InputSpec
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode, _test_eager_guard
class LinearNet(paddle.nn.Layer):
......@@ -45,43 +45,46 @@ class Logic(paddle.nn.Layer):
class TestExportWithTensor(unittest.TestCase):
def setUp(self):
def func_with_tensor(self):
self.x_spec = paddle.static.InputSpec(
shape=[None, 128], dtype='float32')
def test_with_tensor(self):
if in_dygraph_mode():
return
model = LinearNet()
paddle.onnx.export(model, 'linear_net', input_spec=[self.x_spec])
def test_with_tensor(self):
with _test_eager_guard():
self.func_with_tensor()
self.func_with_tensor()
class TestExportWithTensor1(unittest.TestCase):
def setUp(self):
def func_with_tensor(self):
self.x = paddle.to_tensor(np.random.random((1, 128)))
def test_with_tensor(self):
if in_dygraph_mode():
return
model = LinearNet()
paddle.onnx.export(model, 'linear_net', input_spec=[self.x])
def test_with_tensor(self):
with _test_eager_guard():
self.func_with_tensor()
self.func_with_tensor()
class TestExportPrunedGraph(unittest.TestCase):
def setUp(self):
def func_prune_graph(self):
model = Logic()
self.x = paddle.to_tensor(np.array([1]))
self.y = paddle.to_tensor(np.array([-1]))
def test_prune_graph(self):
if in_dygraph_mode():
return
model = Logic()
paddle.jit.to_static(model)
out = model(self.x, self.y, z=True)
paddle.onnx.export(
model, 'pruned', input_spec=[self.x], output_spec=[out])
def test_prune_graph(self):
# test eager
with _test_eager_guard():
self.func_prune_graph()
self.func_prune_graph()
if __name__ == '__main__':
if not in_dygraph_mode():
unittest.main()
unittest.main()
......@@ -8,7 +8,7 @@ pygame==2.1.0
hypothesis
opencv-python<=4.2.0.32
visualdl
paddle2onnx>=0.8.2
paddle2onnx>=0.9.6
scipy>=1.6; python_version >= "3.7"
scipy>=1.5; python_version == "3.6"
prettytable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册