未验证 提交 f0166f39 编写于 作者: M mamingjie-China 提交者: GitHub

Merge pull request #3 from PaddlePaddle/develop

更新数据
...@@ -567,9 +567,9 @@ class ONNXOpMapper(OpMapper): ...@@ -567,9 +567,9 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
axis = node.get_attr('axis') axis = node.get_attr('axis', 0)
assert len( assert len(
indices_shape) <= 1, "Gather op don't support dim of indice >1 " indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1: if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer('gather', node.fluid_code.add_layer('gather',
inputs={ inputs={
...@@ -598,6 +598,45 @@ class ONNXOpMapper(OpMapper): ...@@ -598,6 +598,45 @@ class ONNXOpMapper(OpMapper):
inputs=node, inputs=node,
output=node, output=node,
param_attr=attr_trans) param_attr=attr_trans)
elif len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
node.fluid_code.add_layer('reshape',
inputs=indices,
output=indices,
param_attr={'shape': [
reshape_shape,
]})
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer('transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer('gather',
inputs={
'input': name_trans,
'index': indices
},
output=node,
param_attr=None)
node.fluid_code.add_layer('transpose',
inputs=node,
output=node,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer('reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
def Slice(self, node): def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册