From 9e7e37395ecc9a24aca150bdcd81cd103fc783ef Mon Sep 17 00:00:00 2001 From: Channingss Date: Fri, 31 Jul 2020 08:35:58 +0000 Subject: [PATCH] fix bug: Reshape attr:shape may come from Gather by a scalar indices --- .../op_mapper/onnx2paddle/opset9/opset.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e1ebdf2..ec98c88 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -814,11 +814,13 @@ class OpSet9(): inputs=val_shape, output=val_shape_cast, param_attr={'dtype': string('int32')}) - node.fluid_code.add_layer( - 'reshape', - inputs=val_shape_cast, - output=val_shape_cast, - param_attr={'shape': val_shape.out_shapes[0]}) + # shape may be [], come form Gather by scalar indices + if len(val_shape.out_shapes[0]) > 0: + node.fluid_code.add_layer( + 'reshape', + inputs=val_shape_cast, + output=val_shape_cast, + param_attr={'shape': val_shape.out_shapes[0]}) node.fluid_code.add_layer( 'reshape', inputs={'x': val_x, @@ -826,11 +828,13 @@ class OpSet9(): output=node, param_attr=attr) else: - node.fluid_code.add_layer( - 'reshape', - inputs=val_shape, - output=val_shape, - param_attr={'shape': val_shape.out_shapes[0]}) + # shape may be [], come form Gather by scalar indices + if len(val_shape.out_shapes[0]) > 0: + node.fluid_code.add_layer( + 'reshape', + inputs=val_shape, + output=val_shape, + param_attr={'shape': val_shape.out_shapes[0]}) node.fluid_code.add_layer( 'reshape', inputs={'x': val_x, -- GitLab