diff --git a/tools/check_code_style.sh b/tools/check_code_style.sh index 235c291e6800fe2c668786c8d422a791aa08c0b2..fa5126cd537ae6c299dfa73cdff9743586c0d738 100644 --- a/tools/check_code_style.sh +++ b/tools/check_code_style.sh @@ -7,6 +7,7 @@ function abort(){ trap 'abort' 0 set -e +TRAVIS_BUILD_DIR=${PWD} cd $TRAVIS_BUILD_DIR export PATH=/usr/bin:$PATH pre-commit install diff --git a/x2paddle/op_mapper/onnx_op_mapper.py b/x2paddle/op_mapper/onnx_op_mapper.py index 161c44492b1cf92d3fb2702f2604ee24ca63e08b..48245cb8e04e4d4bd066450db87a7fb1bd2f1fcc 100644 --- a/x2paddle/op_mapper/onnx_op_mapper.py +++ b/x2paddle/op_mapper/onnx_op_mapper.py @@ -51,14 +51,15 @@ def get_same_padding(in_size, kernel_size, stride): return [pad0, pad1] -class ONNXOpMapper(OpMapper): +class ONNXOpMapper(OpMapper): elementwise_ops = { 'Add': 'elementwise_add', 'Div': 'elementwise_div', 'Sub': 'elementwise_sub', 'Mul': 'elementwise_mul', - 'Pow': 'elementwise_pow',} - + 'Pow': 'elementwise_pow', + } + def __init__(self, decoder, save_dir): super(ONNXOpMapper, self).__init__() self.decoder = decoder @@ -70,10 +71,10 @@ class ONNXOpMapper(OpMapper): self.is_inference = False self.tmp_data_dir = os.path.join(save_dir, 'tmp_data') self.get_output_shapes() - + if not self.op_checker(): raise Exception("Model are not supported yet.") - + #mapping op print("Total nodes: {}".format( sum([ @@ -160,8 +161,8 @@ class ONNXOpMapper(OpMapper): for opt in layer.output: if opt in value_infos: value_info = value_infos[opt] - if len(value_info['shape'] - ) == 0 or value_info['dtype'] is None or 0 in value_info['shape']: + if len(value_info['shape']) == 0 or value_info[ + 'dtype'] is None or 0 in value_info['shape']: if self.is_inference == False: self.get_results_of_inference( onnx_model, value_infos, @@ -258,25 +259,25 @@ class ONNXOpMapper(OpMapper): if child_func_code is not None: self.used_custom_layers[op + '_child_func'] = child_func_code + def elementwise_map(self, node): assert node.layer_type in self.elementwise_ops op_type = self.elementwise_ops[node.layer_type] val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) - - if len(val_x.out_shapes[0]) 0: val_y_reshaped = val_y_shape[slice_idx:] @@ -351,7 +352,7 @@ class ONNXOpMapper(OpMapper): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True) - + out_shape_ = val_y.out_shapes[0] if out_shape_ is not None: assert len(out_shape_) == 4, 'only 4-D Tensor as X and Y supported' @@ -375,17 +376,19 @@ class ONNXOpMapper(OpMapper): assert len( in_shape) == 4, 'only 4-D Tensor as X and Y supported' out_shape_ = [in_shape[2] * scale, in_shape[3] * scale] - + mode = node.get_attr('mode', 'nearest') - + fluid_op = 'resize_{}'.format(mode) if 'linear' in mode: - print('Warnning: paddle not support resize wiht mode: linear, we use bilinear replace linear') + print( + 'Warnning: paddle not support resize wiht mode: linear, we use bilinear replace linear' + ) fluid_op = 'resize_bilinear' - + if isinstance(val_scales, ONNXGraphNode): scale, _, _ = self.get_dynamic_shape(val_scales.layer_name) - + attr = { 'scale': scale, 'out_shape': out_shape, @@ -446,12 +449,12 @@ class ONNXOpMapper(OpMapper): def Unsqueeze(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) axes = node.get_attr('axes') - - if len(val_x.out_shapes[0])==0: + + if len(val_x.out_shapes[0]) == 0: node.fluid_code.add_layer('assign', - inputs=val_x, - output=node, - param_attr=None) + inputs=val_x, + output=node, + param_attr=None) else: attr = {'axes': axes, 'name': string(node.layer_name)} node.fluid_code.add_layer('unsqueeze', @@ -459,9 +462,6 @@ class ONNXOpMapper(OpMapper): output=node, param_attr=attr) - - - def Shrink(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) bias = node.get_attr('bias') @@ -845,7 +845,6 @@ class ONNXOpMapper(OpMapper): output=node, param_attr=attr) - def Sum(self, node): val_inps = node.layer.input inputs = {