From ea254a8d490d3c05978816034d3aa749e299b764 Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Wed, 1 Jun 2022 10:20:57 +0800 Subject: [PATCH] Fixed Flatten op (#797) * fixed Flatten * fixed Where and ConstantOfShape * fixed inf bug * deal with comments --- .../op_mapper/onnx2paddle/opset9/opset.py | 87 +++++++++---------- 1 file changed, 41 insertions(+), 46 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 58bc0a5..caa9a46 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -262,6 +262,8 @@ class OpSet9(): shape = node.out_shapes[0] if hasattr(node.weight, "shape") and len(node.weight.shape) == 0: + if node.weight == float('inf') or node.weight == float('-inf'): + node.weight = string(node.weight) self.paddle_graph.add_layer( "paddle.full", inputs={}, @@ -792,6 +794,8 @@ class OpSet9(): if len(value) == 1: value = value.tolist() value = value[0] + if value == float('inf') or value == float('-inf'): + value = string(value) self.paddle_graph.add_layer( "paddle.full", inputs={}, @@ -1195,7 +1199,6 @@ class OpSet9(): @print_mapping_info def ConstantOfShape(self, node): val_shape = self.graph.get_input_node(node, idx=0, copy=True) - val_y = self.graph.get_node(node.layer.output[0], copy=True) value = node.get_attr('value') dtype = value.dtype @@ -1204,6 +1207,8 @@ class OpSet9(): 'this is not supported') if len(value) == 1: value = value[0] + if value == float('inf') or value == float('-inf'): + value = string(value) layer_attrs = {'dtype': string(dtype), 'fill_value': value} self.paddle_graph.add_layer( "paddle.full", @@ -1563,20 +1568,37 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) output_shape = val_x.out_shapes[0] axis = node.get_attr('axis', 1) - shape_list = [1, 1] if axis == 0: - for s in output_shape: - shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=[1, -1]) else: - for s in output_shape[:axis]: - shape_list[0] *= s - for s in output_shape[axis:]: - shape_list[1] *= s - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={"x": val_x.name}, - outputs=[node.name], - shape=shape_list) + if len(output_shape) != 0: + shape_list = [1, 1] + for s in output_shape[:axis]: + shape_list[0] *= s + for s in output_shape[axis:]: + shape_list[1] *= s + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={"x": val_x.name}, + outputs=[node.name], + shape=shape_list) + else: + # flatten + reshape + self.paddle_graph.add_layer( + "paddle.flatten", + inputs={"input": val_x.name}, + outputs=[val_x.name + "_flatten"], + start_axis=[0], + stop_axis=[axis]) + self.paddle_graph.add_layer( + 'paddle.reshape', + inputs={'x': val_x.name + "_flatten"}, + outputs=[node.name], + shape=[0, -1]) @print_mapping_info def Gemm(self, node): @@ -1846,40 +1868,13 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=2, copy=True) - not_condition = condition.name + '_not' - self.paddle_graph.add_layer( - "paddle.logical_not", - inputs={"x": condition.name}, - outputs=[not_condition]) - cast_not_condition = not_condition + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": not_condition}, - outputs=[cast_not_condition], - dtype=string(val_x.dtype)) - cast_condition = condition.name + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": condition.name}, - outputs=[cast_condition], - dtype=string(val_x.dtype)) - mul_val_x = val_x.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_x.name, - 'y': cast_condition}, - outputs=[mul_val_x]) - mul_val_y = val_y.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_y.name, - 'y': cast_not_condition}, - outputs=[mul_val_y]) - self.paddle_graph.add_layer( - "paddle.add", - inputs={'x': mul_val_x, - 'y': mul_val_y}, + "paddle.where", + inputs={ + 'condition': condition.name, + 'x': val_x.name, + 'y': val_y.name + }, outputs=[node.name]) @print_mapping_info -- GitLab