提交 fa0df3cf 编写于 作者: S SunAhong1993

fix the onnx gather

上级 a5641ca4
...@@ -730,11 +730,14 @@ class OpSet9(): ...@@ -730,11 +730,14 @@ class OpSet9():
inputs={'x': name_trans, inputs={'x': name_trans,
'index': indices.name}, 'index': indices.name},
outputs=[node.name]) outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', 'paddle.transpose',
inputs={"x": node.name}, inputs={"x": node.name},
outputs=[node.name], outputs=[node.name],
perm=perm) perm=new_perm)
if len(indices_shape) < 1: if len(indices_shape) < 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.squeeze', 'paddle.squeeze',
...@@ -811,11 +814,15 @@ class OpSet9(): ...@@ -811,11 +814,15 @@ class OpSet9():
'index': indices_reshape}, 'index': indices_reshape},
outputs=[node.name]) outputs=[node.name])
input_transpose = node.name + '_transpose' input_transpose = node.name + '_transpose'
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', 'paddle.transpose',
inputs={"x": node.name}, inputs={"x": node.name},
outputs=[input_transpose], outputs=[input_transpose],
perm=perm) perm=new_perm)
perm = new_perm
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
......
...@@ -696,11 +696,14 @@ class OpSet9(): ...@@ -696,11 +696,14 @@ class OpSet9():
inputs={'x': name_trans, inputs={'x': name_trans,
'index': indices.name}, 'index': indices.name},
outputs=[node.name]) outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', 'paddle.transpose',
inputs={"x": node.name}, inputs={"x": node.name},
outputs=[node.name], outputs=[node.name],
perm=perm) perm=new_perm)
if len(indices_shape) < 1: if len(indices_shape) < 1:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.squeeze', 'paddle.squeeze',
...@@ -772,11 +775,15 @@ class OpSet9(): ...@@ -772,11 +775,15 @@ class OpSet9():
'index': indices_reshape}, 'index': indices_reshape},
outputs=[node.name]) outputs=[node.name])
input_transpose = node.name + '_transpose' input_transpose = node.name + '_transpose'
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', 'paddle.transpose',
inputs={"x": node.name}, inputs={"x": node.name},
outputs=[input_transpose], outputs=[input_transpose],
perm=perm) perm=new_perm)
perm = new_perm
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册