diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index da34916f3c2c0452bc11ce206f9ea12bd164da26..557465df9dc60869f7fd4673a56faad4a01df317 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -617,12 +617,11 @@ class OpSet9(): elif axis == 0 and len(indices_shape) > 1: if val_x.out_shapes[0] is not None and isinstance( val_x, ONNXGraphDataNode): - if indices.dtype != 'int64': - node.fluid_code.add_layer( - 'cast', - inputs=indices, - output=indices, - param_attr={'dtype': string('int64')}) + node.fluid_code.add_layer( + 'cast', + inputs=indices, + output=indices, + param_attr={'dtype': string('int64')}) node.fluid_code.add_layer( 'embedding', inputs=indices, @@ -634,7 +633,6 @@ class OpSet9(): }) else: from functools import reduce - #indices_shape = [1,7] reshape_shape = reduce(lambda x, y: x * y, indices_shape) indices_reshape = indices.layer_name + '_shape' node.fluid_code.add_layer(