From f3e37fcfcfc010f276121edf9acf903d0047c748 Mon Sep 17 00:00:00 2001 From: Channingss Date: Wed, 5 Aug 2020 06:05:10 +0000 Subject: [PATCH] add cast(int64) for embedding --- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 0af1b6e..0de55e8 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -627,6 +627,12 @@ class OpSet9(): elif axis == 0 and len(indices_shape) > 1: if val_x.out_shapes[0] is not None and isinstance( val_x, ONNXGraphDataNode): + if indices.dtype != 'int64': + node.fluid_code.add_layer( + 'cast', + inputs=indices, + output=indices, + param_attr={'dtype': string('int64')}) node.fluid_code.add_layer( 'embedding', inputs=indices, -- GitLab