diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py index 48245cb8e04e4d4bd066450db87a7fb1bd2f1fcc..2dcfcb98f9b1071604d02633eb1e9742c33701df 100644 --- a/x2paddle/op_mapper/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -567,9 +567,9 @@ class ONNXOpMapper(OpMapper): val_x = self.graph.get_input_node(node, idx=0, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True) indices_shape = indices.out_shapes[0] - axis = node.get_attr('axis') + axis = node.get_attr('axis', 0) assert len( - indices_shape) <= 1, "Gather op don't support dim of indice >1 " + indices_shape) <= 2, "Gather op don't support dim of indice >2 " if axis == 0 and len(indices_shape) <= 1: node.fluid_code.add_layer('gather', inputs={ @@ -598,6 +598,45 @@ class ONNXOpMapper(OpMapper): inputs=node, output=node, param_attr=attr_trans) + elif len(indices_shape) > 1: + from functools import reduce + reshape_shape = reduce(lambda x, y: x * y, indices_shape) + node.fluid_code.add_layer('reshape', + inputs=indices, + output=indices, + param_attr={'shape': [ + reshape_shape, + ]}) + + perm = list(range(len(val_x.out_shapes[0]))) + perm = [axis] + perm[:axis] + perm[axis + 1:] + attr_trans = {'perm': perm} + name_trans = val_x.layer_name + '_trans' + node.fluid_code.add_layer('transpose', + inputs=val_x, + output=name_trans, + param_attr=attr_trans) + node.fluid_code.add_layer('gather', + inputs={ + 'input': name_trans, + 'index': indices + }, + output=node, + param_attr=None) + node.fluid_code.add_layer('transpose', + inputs=node, + output=node, + param_attr=attr_trans) + val_x_shape = val_x.out_shapes[0] + reshaped_shape = [] + for i in perm: + reshaped_shape.append(indices_shape[i]) + for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: + reshaped_shape.append(i) + node.fluid_code.add_layer('reshape', + inputs=node, + output=node, + param_attr={'shape': reshaped_shape}) def Slice(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True)