未验证 提交 b04dcff4 编写于 作者: J Jason 提交者: GitHub

Merge pull request #148 from Channingss/develop

improve conversion for gather op
......@@ -567,9 +567,9 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_shape = indices.out_shapes[0]
axis = node.get_attr('axis')
axis = node.get_attr('axis', 0)
assert len(
indices_shape) <= 1, "Gather op don't support dim of indice >1 "
indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer('gather',
inputs={
......@@ -598,6 +598,45 @@ class ONNXOpMapper(OpMapper):
inputs=node,
output=node,
param_attr=attr_trans)
elif len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
node.fluid_code.add_layer('reshape',
inputs=indices,
output=indices,
param_attr={'shape': [
reshape_shape,
]})
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
node.fluid_code.add_layer('transpose',
inputs=val_x,
output=name_trans,
param_attr=attr_trans)
node.fluid_code.add_layer('gather',
inputs={
'input': name_trans,
'index': indices
},
output=node,
param_attr=None)
node.fluid_code.add_layer('transpose',
inputs=node,
output=node,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
node.fluid_code.add_layer('reshape',
inputs=node,
output=node,
param_attr={'shape': reshaped_shape})
def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册