未验证 提交 615d931c 编写于 作者: J Jiabin Yang 提交者: GitHub

Support to onnx test (#42698)

* support to onnx test

* add comments

* remove log

* remove log

* update paddle2onnx version
上级 e3ee2ad8
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import paddle import paddle
from paddle.static import InputSpec from paddle.static import InputSpec
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode, _test_eager_guard
class LinearNet(paddle.nn.Layer): class LinearNet(paddle.nn.Layer):
...@@ -45,43 +45,46 @@ class Logic(paddle.nn.Layer): ...@@ -45,43 +45,46 @@ class Logic(paddle.nn.Layer):
class TestExportWithTensor(unittest.TestCase): class TestExportWithTensor(unittest.TestCase):
def setUp(self): def func_with_tensor(self):
self.x_spec = paddle.static.InputSpec( self.x_spec = paddle.static.InputSpec(
shape=[None, 128], dtype='float32') shape=[None, 128], dtype='float32')
def test_with_tensor(self):
if in_dygraph_mode():
return
model = LinearNet() model = LinearNet()
paddle.onnx.export(model, 'linear_net', input_spec=[self.x_spec]) paddle.onnx.export(model, 'linear_net', input_spec=[self.x_spec])
def test_with_tensor(self):
with _test_eager_guard():
self.func_with_tensor()
self.func_with_tensor()
class TestExportWithTensor1(unittest.TestCase): class TestExportWithTensor1(unittest.TestCase):
def setUp(self): def func_with_tensor(self):
self.x = paddle.to_tensor(np.random.random((1, 128))) self.x = paddle.to_tensor(np.random.random((1, 128)))
def test_with_tensor(self):
if in_dygraph_mode():
return
model = LinearNet() model = LinearNet()
paddle.onnx.export(model, 'linear_net', input_spec=[self.x]) paddle.onnx.export(model, 'linear_net', input_spec=[self.x])
def test_with_tensor(self):
with _test_eager_guard():
self.func_with_tensor()
self.func_with_tensor()
class TestExportPrunedGraph(unittest.TestCase): class TestExportPrunedGraph(unittest.TestCase):
def setUp(self): def func_prune_graph(self):
model = Logic()
self.x = paddle.to_tensor(np.array([1])) self.x = paddle.to_tensor(np.array([1]))
self.y = paddle.to_tensor(np.array([-1])) self.y = paddle.to_tensor(np.array([-1]))
def test_prune_graph(self):
if in_dygraph_mode():
return
model = Logic()
paddle.jit.to_static(model) paddle.jit.to_static(model)
out = model(self.x, self.y, z=True) out = model(self.x, self.y, z=True)
paddle.onnx.export( paddle.onnx.export(
model, 'pruned', input_spec=[self.x], output_spec=[out]) model, 'pruned', input_spec=[self.x], output_spec=[out])
def test_prune_graph(self):
# test eager
with _test_eager_guard():
self.func_prune_graph()
self.func_prune_graph()
if __name__ == '__main__': if __name__ == '__main__':
if not in_dygraph_mode(): unittest.main()
unittest.main()
...@@ -8,7 +8,7 @@ pygame==2.1.0 ...@@ -8,7 +8,7 @@ pygame==2.1.0
hypothesis hypothesis
opencv-python<=4.2.0.32 opencv-python<=4.2.0.32
visualdl visualdl
paddle2onnx>=0.8.2 paddle2onnx>=0.9.6
scipy>=1.6; python_version >= "3.7" scipy>=1.6; python_version >= "3.7"
scipy>=1.5; python_version == "3.6" scipy>=1.5; python_version == "3.6"
prettytable prettytable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册