未验证 提交 82fc5ede 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

fix convert bug (#46)

* fix convert bug

* refine

* update ci test.sh
上级 200a1fbe
...@@ -4,4 +4,5 @@ python3 -m pip install --user --upgrade pip ...@@ -4,4 +4,5 @@ python3 -m pip install --user --upgrade pip
if [ -f requirements.txt ]; then python3 -m pip install -r requirements.txt --user; fi if [ -f requirements.txt ]; then python3 -m pip install -r requirements.txt --user; fi
python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu110 python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu110
python3 setup.py install python3 setup.py install
python3 -m pytest examples/oneflow2onnx python3 -m pytest examples/oneflow2onnx/models
...@@ -470,6 +470,6 @@ def test_inceptionv3(): ...@@ -470,6 +470,6 @@ def test_inceptionv3():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(inceptionv3.state_dict(), tmpdirname) flow.save(inceptionv3.state_dict(), tmpdirname)
convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp") convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=True)
test_inceptionv3() test_inceptionv3()
...@@ -44,6 +44,7 @@ class LeNet(nn.Module): ...@@ -44,6 +44,7 @@ class LeNet(nn.Module):
return logits return logits
lenet = LeNet() lenet = LeNet()
lenet = lenet.to("cuda")
lenet.eval() lenet.eval()
class lenetGraph(flow.nn.Graph): class lenetGraph(flow.nn.Graph):
...@@ -58,7 +59,7 @@ class lenetGraph(flow.nn.Graph): ...@@ -58,7 +59,7 @@ class lenetGraph(flow.nn.Graph):
def test_lenet(): def test_lenet():
lenet_graph = lenetGraph() lenet_graph = lenetGraph()
lenet_graph._compile(flow.randn(1, 3, 32, 32)) lenet_graph._compile(flow.randn(1, 3, 32, 32).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(lenet.state_dict(), tmpdirname) flow.save(lenet.state_dict(), tmpdirname)
......
...@@ -309,7 +309,7 @@ def test_resnet(): ...@@ -309,7 +309,7 @@ def test_resnet():
resnet_graph = ResNetGraph() resnet_graph = ResNetGraph()
resnet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda")) resnet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda"))
print(resnet_graph._full_graph_proto) # print(resnet_graph._full_graph_proto)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(resnet.state_dict(), tmpdirname) flow.save(resnet.state_dict(), tmpdirname)
convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=False) convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=False)
......
...@@ -24,11 +24,7 @@ long_description += "Email: zhangxiaoyu@oneflow.org" ...@@ -24,11 +24,7 @@ long_description += "Email: zhangxiaoyu@oneflow.org"
setuptools.setup( setuptools.setup(
name="oneflow_onnx", name="oneflow_onnx",
<<<<<<< HEAD
version="0.5.1", version="0.5.1",
=======
version="0.5.0.rc",
>>>>>>> parent of cca89ba... release v0.5.0 (#42)
author="zhangxiaoyu", author="zhangxiaoyu",
author_email="zhangxiaoyu@oneflow.org", author_email="zhangxiaoyu@oneflow.org",
description="a toolkit for converting trained model of OneFlow to ONNX and ONNX to OneFlow.", description="a toolkit for converting trained model of OneFlow to ONNX and ONNX to OneFlow.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册