diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 0af1b6ec777a4f2352085fbb35cea7515d95c8e6..0de55e8cec3d1a31c8eec50df2ed5f840ceb6829 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -627,6 +627,12 @@ class OpSet9(): elif axis == 0 and len(indices_shape) > 1: if val_x.out_shapes[0] is not None and isinstance( val_x, ONNXGraphDataNode): + if indices.dtype != 'int64': + node.fluid_code.add_layer( + 'cast', + inputs=indices, + output=indices, + param_attr={'dtype': string('int64')}) node.fluid_code.add_layer( 'embedding', inputs=indices,