From 32e91854bd11c3478992c5eeffa53b9a4839545e Mon Sep 17 00:00:00 2001 From: Channingss Date: Mon, 10 Aug 2020 13:20:27 +0000 Subject: [PATCH] update --- .../op_mapper/onnx2paddle/onnx_op_mapper.py | 7 ---- .../op_mapper/onnx2paddle/opset9/opset.py | 39 ++++++++----------- 2 files changed, 16 insertions(+), 30 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/onnx_op_mapper.py b/x2paddle/op_mapper/onnx2paddle/onnx_op_mapper.py index ae80b2d..bbb9608 100644 --- a/x2paddle/op_mapper/onnx2paddle/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx2paddle/onnx_op_mapper.py @@ -53,21 +53,14 @@ class ONNXOpMapper(OpMapper): def op_checker(self): unsupported_ops = set() - contain_ops = set() for node_name in self.graph.topo_sort: node = self.graph.get_node(node_name) op = node.layer_type - contain_ops.add(op) if not hasattr(self.opset, op) and \ op not in self.opset.default_op_mapping and \ op not in custom_layers and \ op not in self.opset.elementwise_ops: unsupported_ops.add(op) - - print("There are {} ops need converted , list as below".format( - len(contain_ops))) - for op in contain_ops: - print(op) if len(unsupported_ops) == 0: return True else: diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 58dfa77..74f9f5a 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -46,7 +46,7 @@ def _is_static_shape(shape): for dim in shape: if dim < 0: negtive_dims += 1 - if dim != -1: + if dim < -1: error_dims += 1 if negtive_dims > 1: return False @@ -513,8 +513,21 @@ class OpSet9(): output=node, param_attr={'shape': [1]}) else: - node.fluid_code.add_layer( - 'unsqueeze', inputs=val_x, output=node, param_attr=attr) + if str(val_x.dtype) == 'bool': + val_x_cast = val_x.layer_name + '_cast' + node.fluid_code.add_layer( + 'cast', + inputs=val_x, + output=val_x_cast, + param_attr={'dtype': string('int64')}) + node.fluid_code.add_layer( + 'unsqueeze', + inputs=val_x_cast, + output=node, + param_attr=attr) + else: + node.fluid_code.add_layer( + 'unsqueeze', inputs=val_x, output=node, param_attr=attr) @print_mapping_info def Shrink(self, node): @@ -783,9 +796,6 @@ class OpSet9(): param_attr=None) else: input_inner_indices = node.layer_name + '_input_inner_indices' - print('val_x shape:', val_x.out_shapes[0]) - print('indices shape:', indices.out_shapes[0]) - print('updates shape:', updates.out_shapes[0]) node.fluid_code.add_layer( 'scatter_nd', inputs={ @@ -1037,28 +1047,11 @@ class OpSet9(): node.fluid_code.add_layer( 'cast', inputs=val_input, output=node, param_attr=attr) - @print_mapping_info - def Cast(self, node): - val_input = self.graph.get_input_node(node, idx=0, copy=True) - val_output = self.graph.get_node(node.layer.output[0], copy=True) - - dtype = node.get_attr('to') - if not isinstance(dtype, np.dtype): - dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] - - output_dtype = val_output.dtype - if output_dtype: - assert dtype == output_dtype, 'dtype of to unmatches output' - attr = {'dtype': string(dtype)} - node.fluid_code.add_layer( - 'cast', inputs=val_input, output=node, param_attr=attr) - @print_mapping_info def Not(self, node): val_input = self.graph.get_input_node(node, idx=0, copy=True) node.fluid_code.add_layer('logical_not', inputs=val_input, output=node) - val_output = self.graph.get_node(node.layer.output[0], copy=True) node.fluid_code.add_layer( 'cast', inputs=node, -- GitLab