don't support onnx::gather?
Created by: chenjun2hao
the onnx model used onnx::gather, detail:
%position_enc.weight : Float(264, 1024)):
%2 : Float(1, 30, 1024) = onnx::Gather(%position_enc.weight, %x)
but it report an error:
Traceback (most recent call last):
File "/root/miniconda3/envs/torch120/bin/x2paddle", line 11, in <module>
load_entry_point('x2paddle==0.5.0', 'console_scripts', 'x2paddle')()
File "/root/miniconda3/envs/torch120/lib/python3.6/site-packages/x2paddle-0.5.0-py3.6.egg/x2paddle/convert.py", line 211, in main
onnx2paddle(args.model, args.save_dir)
File "/root/miniconda3/envs/torch120/lib/python3.6/site-packages/x2paddle-0.5.0-py3.6.egg/x2paddle/convert.py", line 157, in onnx2paddle
mapper = ONNXOpMapper(model, save_dir)
File "/root/miniconda3/envs/torch120/lib/python3.6/site-packages/x2paddle-0.5.0-py3.6.egg/x2paddle/op_mapper/onnx_op_mapper.py", line 81, in __init__
func(node)
File "/root/miniconda3/envs/torch120/lib/python3.6/site-packages/x2paddle-0.5.0-py3.6.egg/x2paddle/op_mapper/onnx_op_mapper.py", line 500, in Gather
indices_shape) <= 1, "Gather op don't support dim of indice >1 "
AssertionError: Gather op don't support dim of indice >1
the pytorch code is:
class test_model(nn.Module):
def __init__(self, src_n_position=264, d_word_vec=1024):
super().__init__()
self.position_enc = nn.Embedding(src_n_position, d_word_vec)
def forward(self, x):
x = self.position_enc(x)
return x
net2 = test_model()
x = torch.randint(30, (1, 30))
torch.onnx.export(net2, x, './Fonnx/tranformertwo.onnx', verbose=True, input_names=['x'])