From dfb5a46ec453902bb5564be417810f15ea0c770d Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Tue, 17 Nov 2020 20:43:04 +0800 Subject: [PATCH] fix --- .../dygraph/onnx2paddle/opset9/opset.py | 44 +++---------------- 1 file changed, 7 insertions(+), 37 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py index 6b48646..39bfd23 100644 --- a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py @@ -298,6 +298,8 @@ class OpSet9(): # TODO(syf): all use inputs['out_shape'] = var_hw + ipt = inputs.pop("x") + inputs["input"] = ipt mode = node.get_attr('mode', 'nearest') attrs = {"align_corners": False} self.paddle_graph.add_layer( @@ -455,24 +457,11 @@ class OpSet9(): outputs=[node.layer_name], shape=[1]) else: - if str(val_x.dtype) == 'bool': - val_x_cast = val_x.layer_name + '_cast' - self.paddle_graph.add_layer( - 'paddle.cast', - inputs={"x": self.get_node_name(val_x)}, - outputs=[val_x_cast], - dtype=string('int64')) - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": val_x_cast}, - outputs=[node.layer_name], - **layer_attrs) - else: - self.paddle_graph.add_layer( - 'paddle.unsqueeze', - inputs={"x": self.get_node_name(val_x)}, - outputs=[node.layer_name], - **layer_attrs) + self.paddle_graph.add_layer( + 'paddle.unsqueeze', + inputs={"x": self.get_node_name(val_x)}, + outputs=[node.layer_name], + **layer_attrs) @print_mapping_info def Shrink(self, node): @@ -1013,25 +1002,6 @@ class OpSet9(): inputs={'x': self.get_node_name(val_x)}, outputs=[node.layer_name], shape=node.out_shapes[0]) - elif val_shape.dtype == 'int64': - val_shape_cast = val_shape.layer_name + '_cast' - self.paddle_graph.add_layer( - 'paddle.cast', - inputs={'x': self.get_node_name(val_shape)}, - outputs=[val_shape_cast], - dtype=string('int32')) - # shape may be [], come form Gather by scalar indices - if len(val_shape.out_shapes[0]) > 0: - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={'x': self.get_node_name(val_shape_cast)}, - outputs=[val_shape_cast], - shape=val_shape.out_shapes[0]) - self.paddle_graph.add_layer( - 'paddle.reshape', - inputs={'x': elf.get_node_name(val_x), - 'shape': val_shape_cast}, - outputs=[node.layer_name]) else: # shape may be [], come form Gather by scalar indices if len(val_shape.out_shapes[0]) > 0: -- GitLab