diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 104bd67025755298cd55d82285fb035aa9643533..23cc1258edcce779c814534ddf6a970a43ae39cf 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1239,7 +1239,6 @@ class OpSet9(): @print_mapping_info def ConstantOfShape(self, node): val_shape = self.graph.get_input_node(node, idx=0, copy=True) - val_y = self.graph.get_node(node.layer.output[0], copy=True) value = node.get_attr('value') dtype = value.dtype @@ -1248,6 +1247,8 @@ class OpSet9(): 'this is not supported') if len(value) == 1: value = value[0] + if value == float('inf') or value == float('-inf'): + value = string(value) layer_attrs = {'dtype': string(dtype), 'fill_value': value} self.paddle_graph.add_layer( "paddle.full", @@ -1951,40 +1952,13 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=2, copy=True) - not_condition = condition.name + '_not' - self.paddle_graph.add_layer( - "paddle.logical_not", - inputs={"x": condition.name}, - outputs=[not_condition]) - cast_not_condition = not_condition + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": not_condition}, - outputs=[cast_not_condition], - dtype=string(val_x.dtype)) - cast_condition = condition.name + '_cast' - self.paddle_graph.add_layer( - "paddle.cast", - inputs={"x": condition.name}, - outputs=[cast_condition], - dtype=string(val_x.dtype)) - mul_val_x = val_x.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_x.name, - 'y': cast_condition}, - outputs=[mul_val_x]) - mul_val_y = val_y.name + '_mul' - self.paddle_graph.add_layer( - "paddle.multiply", - inputs={'x': val_y.name, - 'y': cast_not_condition}, - outputs=[mul_val_y]) - self.paddle_graph.add_layer( - "paddle.add", - inputs={'x': mul_val_x, - 'y': mul_val_y}, + "paddle.where", + inputs={ + 'condition': condition.name, + 'x': val_x.name, + 'y': val_y.name + }, outputs=[node.name]) @print_mapping_info