提交 fa0df3cf 编写于 作者: S SunAhong1993

fix the onnx gather

上级 a5641ca4
......@@ -730,11 +730,14 @@ class OpSet9():
inputs={'x': name_trans,
'index': indices.name},
outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name],
perm=perm)
perm=new_perm)
if len(indices_shape) < 1:
self.paddle_graph.add_layer(
'paddle.squeeze',
......@@ -811,11 +814,15 @@ class OpSet9():
'index': indices_reshape},
outputs=[node.name])
input_transpose = node.name + '_transpose'
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[input_transpose],
perm=perm)
perm=new_perm)
perm = new_perm
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
......
......@@ -696,11 +696,14 @@ class OpSet9():
inputs={'x': name_trans,
'index': indices.name},
outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name],
perm=perm)
perm=new_perm)
if len(indices_shape) < 1:
self.paddle_graph.add_layer(
'paddle.squeeze',
......@@ -772,11 +775,15 @@ class OpSet9():
'index': indices_reshape},
outputs=[node.name])
input_transpose = node.name + '_transpose'
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[input_transpose],
perm=perm)
perm=new_perm)
perm = new_perm
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册