Use onnx error
Created by: zhangguangzhi
I added onnx configuration in the code, but the runtime has the following error
Traceback (most recent call last): File "main.py", line 401, in <module> main() File "main.py", line 191, in main torch.onnx.export(net, dummy_input, "main.onnx") File "/usr/local/lib/python2.7/dist-packages/torch/onnx/__init__.py", line 75, in export _export(model, args, f, export_params, verbose, training) File "/usr/local/lib/python2.7/dist-packages/torch/onnx/__init__.py", line 116, in _export trace, torch_out = torch.jit.trace(model, args) File "/usr/local/lib/python2.7/dist-packages/torch/jit/__init__.py", line 217, in trace return TracedModule(f, nderivs=nderivs)(*args, **kwargs) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "/usr/local/lib/python2.7/dist-packages/torch/jit/__init__.py", line 241, in forward trace, (out_vars, out_struct) = traced_inner(in_vars, in_struct) File "/usr/local/lib/python2.7/dist-packages/torch/jit/__init__.py", line 259, in wrapper out_vars, out_struct = f(in_vars, in_struct) File "/usr/local/lib/python2.7/dist-packages/torch/jit/__init__.py", line 236, in traced_inner return _flatten(self.inner(*args, **kwargs)) File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 325, in __call__ result = self.forward(*input, **kwargs) File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 68, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/data_parallel.py", line 78, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/usr/local/lib/python2.7/dist-packages/torch/nn/parallel/parallel_apply.py", line 67, in parallel_apply raise output TypeError: forward() takes exactly 3 arguments (2 given)
The following is the code added in the program:
` dummy_input = Variable(torch.randn(4, 3, 32, 32))
torch.onnx.export(net, dummy_input, "main.onnx")`
The following is a reference to my program:
import argparse import os import time import numpy as np import data from importlib import import_module import shutil from utils import * import sys sys.path.append('../') from split_combine import SplitComb import torch from torch.nn import DataParallel from torch.backends import cudnn from torch.utils.data import DataLoader from torch import optim from torch.autograd import Variable from config_training import config as config_training import torch.onnx import torchvision import torchvision.transforms as transforms import onnx
Can you help me see, how to check it?