diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index f4b9d3fe332af785894da5e241a82f0fcdc6b1f8..5946ff668887f185f9abedb504f2915c16576a57 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1150,9 +1150,7 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) starts, ends, axes, steps = None, None, None, None layer_attrs = {} - if val_x.dtype not in [ - "float16", "float32", "float64", "int32", "int64" - ]: + if val_x.dtype == 'uint8': self.paddle_graph.add_layer( 'paddle.cast', inputs={"x": val_x.name}, @@ -1878,6 +1876,7 @@ class OpSet9(): if axes is None: axes_node = self.graph.get_input_node(node, idx=1, copy=True) axes = _const_weight_or_none(axes_node, necessary=True) + # deal with scalar(0D) tensor if len(val_x.out_shapes[0]) <= 1 and len(axes) == 1 and axes[0] == 0: self.paddle_graph.add_layer( "paddle.cast",