From ae4db404b9f039b102e615fdef8e2370dd816c3d Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Fri, 13 May 2022 17:20:39 +0800 Subject: [PATCH] fixed Where and ConstantOfShape --- .../op_mapper/onnx2paddle/opset9/opset.py | 42 ++++--------------- 1 file changed, 8 insertions(+), 34 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 104bd67..23cc125 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1239,7 +1239,6 @@ class OpSet9(): @print_mapping_info def ConstantOfShape(self, node): val_shape = self.graph.get_input_node(node, idx=0, copy=True) - val_y = self.graph.get_node(node.layer.output[0], copy=True) value = node.get_attr('value') dtype = value.dtype @@ -1248,6 +1247,8 @@ class OpSet9(): 'this is not supported') if len(value) == 1: value = value[0] + if value == float('inf') or value == float('-inf'): + value = string(value) layer_attrs = {'dtype': string(dtype), 'fill_value': value} self.paddle_graph.add_layer( "paddle.full", @@ -1951,40 +1952,13 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=2, copy=True) - not_condition = condition.name + '_not' - self.paddle_graph.add_layer( - "paddle.logical_not", - inputs={"x": condition.name}, - outputs=[not_condition]) - cast_not_condition = not_condition + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": not_condition}, - outputs=[cast_not_condition], - dtype=string(val_x.dtype)) - cast_condition = condition.name + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": condition.name}, - outputs=[cast_condition], - dtype=string(val_x.dtype)) - mul_val_x = val_x.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_x.name, - 'y': cast_condition}, - outputs=[mul_val_x]) - mul_val_y = val_y.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_y.name, - 'y': cast_not_condition}, - outputs=[mul_val_y]) - self.paddle_graph.add_layer( - "paddle.add", - inputs={'x': mul_val_x, - 'y': mul_val_y}, + "paddle.where", + inputs={ + 'condition': condition.name, + 'x': val_x.name, + 'y': val_y.name + }, outputs=[node.name]) @print_mapping_info -- GitLab