diff --git a/op_list.md b/op_list.md index aff73baf9a9f9447e215972c44479709b8a4c2d0..8b47b62b0300d9532a4fba974f923b758a2154e2 100644 --- a/op_list.md +++ b/op_list.md @@ -1,5 +1,5 @@ # X2Paddle支持OP列表 -> 目前X2Paddle支持50+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。 +> 目前X2Paddle支持70+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。 **注:** 目前,部分OP暂未支持,如您在转换过程中出现OP不支持的情况,可自行添加或反馈给我们。欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:) @@ -7,20 +7,24 @@ | 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP | |------|------|------|------|------|------|------|------| -| 1 | Relu | 2 | Relu6 | 3 | Shape | 4 | Abs | -| 5 | Sigmoid | 6 | Exp | 7 | Rsqrt | 8 | swish_f32 | -| 9 | Tanh | 10 | LeakyRelu | 11 | Add | 12 | RealDiv | -| 13 | Sub | 14 | Maximum | 15 | Mul | 16 | FloorDiv | -| 17 | Placeholder | 18 | Const | 19 | Transpose | 20 | FusedBatchNorm | -| 21 | Conv2D | 22 | BiasAdd | 23 | MaxPool | 24 | DepthwiseConv2dNative | -| 25 | Reshape | 26 | AvgPool | 27 | SplitV | 28 | SquaredDifference | -| 29 | Tile | 30 | Pack | 31 | Pad | 32 | ResizeBilinear | -| 33 | Mean | 34 | MatMul | 35 | ArgMax | 36 | StridedSlice | -| 37 | Slice | 38 | Sum | 39 | Max | 40 | Conv2DBackpropInput | -| 41 | Cast | 42 | Split | 43 | Squeeze | 44 | ResizeNearestNeighbor | -| 45 | Softmax | 46 | Range | 47 | ConcatV2 | 48 | MirrorPad | -| 49 | Identity | 50 | GreaterEqual | 51 | StopGradient | 52 | Minimum | -| 53 | RadnomUniform | 54 | Fill | 55 | Floor | 56 | DepthToSpace | +| 1 | Relu | 2 | Relu6 | 3 | Shape | 4 | Abs | +| 5 | Sigmoid | 6 | Exp | 7 | Rsqrt | 8 | swish_f32 | +| 9 | Tanh | 10 | LeakyRelu | 11 | Add | 12 | RealDiv | +| 13 | Sub | 14 | Maximum | 15 | Mul | 16 | FloorDiv | +| 17 | Placeholder | 18 | Const | 19 | Transpose | 20 | FusedBatchNorm | +| 21 | Conv2D | 22 | BiasAdd | 23 | MaxPool | 24 | DepthwiseConv2dNative | +| 25 | Reshape | 26 | AvgPool | 27 | SplitV | 28 | SquaredDifference | +| 29 | Tile | 30 | Pack | 31 | Pad | 32 | ResizeBilinear | +| 33 | Mean | 34 | MatMul | 35 | ArgMax | 36 | StridedSlice | +| 37 | Slice | 38 | Sum | 39 | Max | 40 | Conv2DBackpropInput | +| 41 | Cast | 42 | Split | 43 | Squeeze | 44 | ResizeNearestNeighbor | +| 45 | Softmax | 46 | Range | 47 | ConcatV2 | 48 | MirrorPad | +| 49 | Identity | 50 | GreaterEqual | 51 | StopGradient | 52 | Minimum | +| 53 | RadnomUniform | 54 | Fill | 55 | Floor | 56 | DepthToSpace | +| 57 | Sqrt | 58 | Softplus | 59 | Erf | 60 | AddV2 | +| 61 | LessEqual | 62 | BatchMatMul | 63 | BatchMatMulV2 | 64 | ExpandDims | +| 65 | BatchToSpaceND | 66 | SpaceToBatchND | 67 | OneHot | 68 | Pow | +| 69 | All | 70 | GatherV2 | 71 | IteratorV2 | | | ## Caffe diff --git a/x2paddle/decoder/onnx_shape_inference.py b/x2paddle/decoder/onnx_shape_inference.py index dae2d0268a00dcac0b1d4a6c796f55431a985ad3..910bf2dbfead6f5ec292af1302926fff02315cf3 100644 --- a/x2paddle/decoder/onnx_shape_inference.py +++ b/x2paddle/decoder/onnx_shape_inference.py @@ -267,9 +267,8 @@ class SymbolicShapeInference: if pending_nodes and self.verbose_ > 0: print('SymbolicShapeInference: orphaned nodes discarded: ') - print( - * [n.op_type + ': ' + n.output[0] for n in pending_nodes], - sep='\n') + print('\n'.join( + [n.op_type + ': ' + n.output[0] for n in pending_nodes])) if input_shapes is not None: for input_name, shape in input_shapes.items(): for idx in range(len(self.out_mp_.graph.input)): diff --git a/x2paddle/op_mapper/caffe_custom_layer/normalize.py b/x2paddle/op_mapper/caffe_custom_layer/normalize.py index 15bbb2043a1e385d987c577d099c8db902de3f26..19c583acc7ab73cd63d0bc7f373488e9437aec4c 100644 --- a/x2paddle/op_mapper/caffe_custom_layer/normalize.py +++ b/x2paddle/op_mapper/caffe_custom_layer/normalize.py @@ -17,7 +17,7 @@ def normalize_layer(inputs, scale_param = fluid.layers.create_parameter( shape=[1] if channel_shared else [1, 1, 1, input_shape[0][1]], dtype=input.dtype, - attr=name + '_scale') + attr=fluid.ParamAttr(name=name + '_scale')) scale_param = fluid.layers.reshape(x=scale_param, \ shape=[1] if channel_shared else [input_shape[0][1]]) out = fluid.layers.elementwise_mul( diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 3cb60ffa4145a455f0552c092ae31a324b3702ad..534d264a1f0e47e20de976e4bdf23dfc440b7606 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -32,15 +32,33 @@ import shutil _logger = _logging.getLogger(__name__) -def _const_weight_or_none(node): +def _const_weight_or_none(node, necessary=False): if 'Constant' in node.layer_type: return node.value if isinstance(node, ONNXGraphDataNode): return node.weight + if necessary: + assert '{} should be an initializer or Constant operator.'.format( + node.layer_name) return None -def get_same_padding(in_size, kernel_size, stride): +def _is_static_shape(shape): + negtive_dims = 0 + error_dims = 0 + for dim in shape: + if dim < 0: + negtive_dims += 1 + if dim < -1: + error_dims += 1 + if negtive_dims > 1: + return False + if error_dims > 0: + return False + return True + + +def _get_same_padding(in_size, kernel_size, stride): new_size = int(math.ceil(in_size * 1.0 / stride)) pad_size = (new_size - 1) * stride + kernel_size - in_size pad0 = int(pad_size / 2) @@ -228,42 +246,9 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) - val_y_shape = val_y.out_shapes[0] - val_x_shape = val_x.out_shapes[0] - - if len(val_x_shape) < len(val_y_shape): - val_x, val_y = val_y, val_x - val_y_shape, val_x_shape = val_x_shape, val_y_shape - - str_y_shape = ','.join(str(e) for e in val_y_shape) - str_x_shape = ','.join(str(e) for e in val_x_shape) - slice_idx = 0 - if str_y_shape not in str_x_shape: - for dim in val_y_shape: - if dim == 1: - slice_idx += 1 - else: - break - attr = {"name": string(node.layer_name)} - if slice_idx < len(val_y_shape) and slice_idx > 0: - val_y_reshaped = val_y_shape[slice_idx:] - var_y_reshaped = val_y.layer_name + '_reshaped' - attr_reshaped = { - 'shape': val_y_reshaped, - 'name': string(var_y_reshaped) - } - node.fluid_code.add_layer( - 'reshape', - inputs=val_y, - output=var_y_reshaped, - param_attr=attr_reshaped) - inputs = {'x': val_x, 'y': var_y_reshaped} - node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=attr) - else: - inputs = {'x': val_x, 'y': val_y} - node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=attr) + inputs = {'x': val_x, 'y': val_y} + node.fluid_code.add_layer( + op_type, inputs=inputs, output=node, param_attr=None) @print_mapping_info def place_holder(self, node): @@ -476,8 +461,21 @@ class OpSet9(): output=node, param_attr={'shape': [1]}) else: - node.fluid_code.add_layer( - 'unsqueeze', inputs=val_x, output=node, param_attr=attr) + if str(val_x.dtype) == 'bool': + val_x_cast = val_x.layer_name + '_cast' + node.fluid_code.add_layer( + 'cast', + inputs=val_x, + output=val_x_cast, + param_attr={'dtype': string('int64')}) + node.fluid_code.add_layer( + 'unsqueeze', + inputs=val_x_cast, + output=node, + param_attr=attr) + else: + node.fluid_code.add_layer( + 'unsqueeze', inputs=val_x, output=node, param_attr=attr) @print_mapping_info def Shrink(self, node): @@ -597,12 +595,35 @@ class OpSet9(): #assert len( # indices_shape) <= 2, "Gather op don't support dim of indice >2 " if axis == 0 and len(indices_shape) <= 1: - node.fluid_code.add_layer( - 'gather', - inputs={'input': val_x, - 'index': indices}, - output=node, - param_attr=None) + if len(val_x.out_shapes[0]) <= 1: + node.fluid_code.add_layer( + 'gather', + inputs={'input': val_x, + 'index': indices}, + output=node, + param_attr=None) + elif len(val_x.out_shapes[0]) > 1: + if len(indices_shape) == 0: + gather_ = node.layer_name + '_1' + node.fluid_code.add_layer( + 'gather', + inputs={'input': val_x, + 'index': indices}, + output=gather_, + param_attr=None) + node.fluid_code.add_layer( + 'squeeze', + inputs={'input': gather_, + 'axes': [0]}, + output=node, + param_attr=None) + else: + node.fluid_code.add_layer( + 'gather', + inputs={'input': val_x, + 'index': indices}, + output=node, + param_attr=None) elif axis > 0 and len(indices_shape) <= 1: perm = list(range(len(val_x.out_shapes[0]))) perm = [axis] + perm[:axis] + perm[axis + 1:] @@ -621,6 +642,13 @@ class OpSet9(): param_attr=None) node.fluid_code.add_layer( 'transpose', inputs=node, output=node, param_attr=attr_trans) + if len(indices_shape) < 1: + node.fluid_code.add_layer( + 'squeeze', + inputs={'input': node, + 'axes': [0]}, + output=node, + param_attr=None) elif axis == 0 and len(indices_shape) > 1: if val_x.out_shapes[0] is not None and isinstance( val_x, ONNXGraphDataNode): @@ -701,6 +729,86 @@ class OpSet9(): output=node, param_attr={'shape': reshaped_shape}) + @print_mapping_info + def ScatterND(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + indices = self.graph.get_input_node(node, idx=1, copy=True) + updates = self.graph.get_input_node(node, idx=2, copy=True) + if len(indices.out_shapes[0]) == 1: + node.fluid_code.add_layer( + 'scatter', + inputs={'input': val_x, + 'index': indices, + 'updates': updates}, + output=node, + param_attr=None) + else: + input_inner_indices = node.layer_name + '_input_inner_indices' + node.fluid_code.add_layer( + 'scatter_nd', + inputs={ + 'shape': val_x.out_shapes[0], + 'index': indices, + 'updates': updates + }, + output=input_inner_indices, + param_attr=None) + + constant_minus_one = node.layer_name + '_constant_minus_one' + node.fluid_code.add_layer( + 'fill_constant', + inputs=None, + output=constant_minus_one, + param_attr={ + 'shape': updates.out_shapes[0], + 'dtype': string(updates.dtype), + 'value': -1 + }) + + indices_mask = node.layer_name + '_indices_mask' + node.fluid_code.add_layer( + 'scatter_nd', + inputs={ + 'shape': val_x.out_shapes[0], + 'index': indices, + 'updates': constant_minus_one + }, + output=indices_mask, + param_attr=None) + + constant_1 = node.layer_name + '_constant_1' + node.fluid_code.add_layer( + 'fill_constant', + inputs=None, + output=constant_1, + param_attr={ + 'shape': val_x.out_shapes[0], + 'dtype': string(val_x.dtype), + 'value': 1 + }) + input_out_indices_mask = node.layer_name + '_input_out_indices_mask' + node.fluid_code.add_layer( + "elementwise_add", + inputs={"x": indices_mask, + "y": constant_1}, + output=input_out_indices_mask, + param_attr=None) + + input_out_indices = node.layer_name + '_input_out_indices' + node.fluid_code.add_layer( + "elementwise_mul", + inputs={"x": val_x, + "y": input_out_indices_mask}, + output=input_out_indices, + param_attr=None) + + node.fluid_code.add_layer( + "elementwise_add", + inputs={"x": input_inner_indices, + "y": input_out_indices}, + output=node, + param_attr=None) + @print_mapping_info def Range(self, node): val_start = self.graph.get_input_node(node, idx=0, copy=True) @@ -724,7 +832,7 @@ class OpSet9(): ends = self.graph.get_input_node(node, idx=2, copy=True) if len(node.inputs) > 3: axes = self.graph.get_input_node(node, idx=3, copy=True) - axes = _const_weight_or_none(axes) + axes = _const_weight_or_none(axes, necessary=True) if len(node.inputs) > 4: steps = self.graph.get_input_node(node, idx=4, copy=True) steps = _const_weight_or_none(steps) @@ -828,6 +936,14 @@ class OpSet9(): inputs={'x': val_x}, output=node, param_attr={'shape': shape_value.tolist()}) + elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[ + 0]): + node.fluid_code.add_layer( + 'reshape', + inputs={'x': val_x, + 'shape': node.out_shapes[0]}, + output=node, + param_attr=attr) elif val_shape.dtype == 'int64': val_shape_cast = val_shape.layer_name + '_cast' node.fluid_code.add_layer( @@ -879,6 +995,11 @@ class OpSet9(): node.fluid_code.add_layer( 'cast', inputs=val_input, output=node, param_attr=attr) + @print_mapping_info + def Not(self, node): + val_input = self.graph.get_input_node(node, idx=0, copy=True) + node.fluid_code.add_layer('logical_not', inputs=val_input, output=node) + @print_mapping_info def AveragePool(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) @@ -897,11 +1018,11 @@ class OpSet9(): if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": input_shape = val_x.out_shapes[0] - pad_h = get_same_padding(input_shape[2], kernel_shape[0], - strides[0]) - pad_w = get_same_padding(input_shape[3], kernel_shape[1], - strides[1]) - attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + pad_h = _get_same_padding(input_shape[2], kernel_shape[0], + strides[0]) + pad_w = _get_same_padding(input_shape[3], kernel_shape[1], + strides[1]) + paddings = pad_h + pad_w attr = { "pool_size": kernel_shape, @@ -1171,7 +1292,6 @@ class OpSet9(): def NonZero(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x_dim = len(val_x.out_shapes[0]) - print(val_x.layer_name, val_x.out_shapes[0]) if val_x_dim == 1: node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x) node.fluid_code.add_layer( @@ -1232,11 +1352,11 @@ class OpSet9(): if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": input_shape = val_x.out_shapes[0] - pad_h = get_same_padding(input_shape[2], kernel_shape[0], - strides[0]) - pad_w = get_same_padding(input_shape[3], kernel_shape[1], - strides[1]) - attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + pad_h = _get_same_padding(input_shape[2], kernel_shape[0], + strides[0]) + pad_w = _get_same_padding(input_shape[3], kernel_shape[1], + strides[1]) + paddings = pad_h + pad_w attr = { "pool_size": kernel_shape, @@ -1293,23 +1413,23 @@ class OpSet9(): kernel_shape = node.get_attr('kernel_shape') convnd = len(kernel_shape) assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported' - num_out_channels = val_w.out_shapes[0][0] # OI... + num_out_channels = val_w.out_shapes[0][0] fluid_op = 'conv{}d'.format(convnd) num_groups = node.get_attr('group', 1) - strides = node.get_attr('strides', [1] * convnd) # optional - dilations = node.get_attr('dilations', [1] * convnd) # optional - pads = node.get_attr('pads', [0] * (convnd * 2)) # optional + strides = node.get_attr('strides', [1] * convnd) + dilations = node.get_attr('dilations', [1] * convnd) + pads = node.get_attr('pads', [0] * (convnd * 2)) input_shape = val_x.out_shapes[0] paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": - pad_h = get_same_padding(input_shape[2], kernel_shape[0], - strides[0]) - pad_w = get_same_padding(input_shape[3], kernel_shape[1], - strides[1]) - attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + pad_h = _get_same_padding(input_shape[2], kernel_shape[0], + strides[0]) + pad_w = _get_same_padding(input_shape[3], kernel_shape[1], + strides[1]) + paddings = pad_h + pad_w attr = { "num_filters": num_out_channels, @@ -1379,183 +1499,3 @@ class OpSet9(): } node.fluid_code.add_layer( fluid_op, inputs=val_x, output=node, param_attr=attr) - - @print_mapping_info - def GRU(self, node): - val_x = self.graph.get_input_node(node, idx=0, copy=True) - val_w = self.graph.get_input_node(node, idx=1, copy=True) - val_r = self.graph.get_input_node(node, idx=2, copy=True) - - val_b = None - val_len = None - val_xh = None - miss_arg_num = 0 - num_ipt = len(node.layer.input) - if num_ipt > 3 and node.layer.input[3] != '': - val_b = self.graph.get_input_node(node, idx=3, copy=True) - else: - miss_arg_num += 1 - if num_ipt > 4 and node.layer.input[4] != '': - val_len = self.graph.get_input_node( - node, idx=4 - miss_arg_num, copy=True) - else: - miss_arg_num += 1 - if num_ipt > 5 and node.layer.input[5] != '': - val_xh = self.graph.get_input_node( - node, idx=5 - miss_arg_num, copy=True) - - x_shape = val_x.out_shapes[0] - - assert x_shape[1] == 1, 'only X with batch_size = 1 supported' - assert node.get_attr('clip', None) is None, 'clipping not supported' - - hidden_size = node.get_attr('hidden_size', None) - if hidden_size is None: - r_shape = val_r.out_shapes[0] - if r_shape: - hidden_size = r_shape[-1] - if hidden_size is None: - w_shape = var_w.out_shapes[0] - if w_shape: - hidden_size = w_shape[-2] // 3 - if hidden_size is None and val_b: - b_shape = val_b.out_shapes[0] - if b_shape: - hidden_size = b_shape[-1] // 6 - if hidden_size is None and val_xh: - xh_shape = val_xh.out_shapes[0] - if xh_shape: - hidden_size = xh_shape[-1] - - direction = node.get_attr('direction', 'forward') - assert direction != 'bidirectional', 'direction = bidirectional not supported' - - activations = node.get_attr('activations', ['Sigmoid', 'Tanh']) - assert len(activations) == 2, 'bidirectional operation not supported' - - assert node.get_attr('linear_before_reset', - 0) == 0, 'only linear_before_reset = 0 supported' - - activations = [s.lower() for s in activations] - gate_activation, candidate_activation = activations - is_reverse = direction == 'reverse' - - var_x0 = node.layer_name + '_x0' - node.fluid_code.add_layer( - 'squeeze', - inputs=val_x, - output=var_x0, - param_attr={'axes': [1], - 'name': string(var_x0)}) - - var_w0 = node.layer_name + '_w0' - node.fluid_code.add_layer( - 'squeeze', - inputs=val_w, - output=var_w0, - param_attr={'axes': [0], - 'name': string(var_w0)}) - - var_fc = node.layer_name + '_fc' - var_mm = (node.layer_name + '_mm') if val_b else var_fc - node.fluid_code.add_layer( - 'matmul', - inputs={'x': var_x0, - 'y': var_w0}, - output=var_mm, - param_attr={ - 'transpose_x': 0, - 'transpose_y': 1, - 'name': string(var_mm) - }) - - var_r0 = node.layer_name + '_r0' - node.fluid_code.add_layer( - 'squeeze', - inputs=val_r, - output=var_r0, - param_attr={'axes': [0], - 'name': string(var_r0)}) - - var_r0t = node.layer_name + '_r0t' - - node.fluid_code.add_layer( - 'transpose', - inputs=var_r0, - output=var_r0t, - param_attr={'perm': [1, 0], - 'name': string(var_r0t)}) - if val_b: - var_bi = node.layer_name + '_bi' - var_bh = node.layer_name + '_bh' - node.fluid_code.add_layer( - 'split', - inputs=val_b, - output=var_bi + ',' + var_bh, - param_attr={ - 'dim': 1, - 'num_or_sections': [hidden_size * 3, hidden_size * 3], - 'name': string(node.layer_name + '.b/split') - }) - var_bi0 = node.layer_name + '_bi0' - node.fluid_code.add_layer( - 'squeeze', - inputs=var_bi, - output=var_bi0, - param_attr={'axes': [0], - 'name': string(var_bi0)}) - - node.fluid_code.add_layer( - 'elementwise_add', - inputs=[var_mm, var_bi0], - output=var_fc, - param_attr={ - 'axes': 1, - 'name': string(node.layer_name + '.i/bias') - }) - - if val_xh: - var_xh0 = node.layer_name + '_xh0' - node.fluid_code.add_layer( - 'squeeze', - inputs=val_xh, - output=var_xh0, - param_attr={'axes': [1], - 'name': string(var_xh0)}) - var_y00 = node.layer_name + '_y00' - - attr = { - 'origin_mode': True, - 'h_0': var_xh0 if val_xh else None, - 'is_reverse': is_reverse, - 'gate_activation': string(gate_activation), - 'candidate_activation': string(candidate_activation), - 'param_attr': string(var_r0t), - 'bias_attr': string(var_bh) if val_b else False, - } - node.fluid_code.add_layer( - 'dynamic_gru', - inputs=var_fc + ',' + str(hidden_size), - output=var_y00, - param_attr=attr) - - num_opt = len(node.layer.output) - - if num_opt > 0 and node.layer.output[0] != '': - node.fluid_code.add_layer( - 'unsqueeze', - inputs=var_y00, - output=node.layer.output[0], - param_attr={ - 'axes': [1, 1], - 'name': string(node.layer.output[0]) - }) - if num_opt > 1 and node.layer.output[1] != '': - node.fluid_code.add_layer( - 'unsqueeze', - inputs=var_y00, - output=node.layer.output[1], - param_attr={ - 'axes': [1, 1], - 'name': string(node.layer.output[1]) - }) diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py index 80bb56863ad1a149ccc216c6a087b1a249f94e4f..3ef8523a519569840c8883754b9ac3c4922ba37d 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py @@ -875,6 +875,14 @@ class OpSet9(object): axes=op.attr('axes')) return node + def cast(self, op, block): + node = helper.make_node( + 'Cast', + inputs=op.input('X'), + outputs=op.output('Out'), + to=self.paddle_onnx_dtype_map[op.attr('out_dtype')]) + return node + def arg_max(self, op, block): node = helper.make_node( 'ArgMax', diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 2bb0adeac0c7cfc11030be01d7366b21cc86d6bd..a5198cc780203722042b8ae043fb169b94eeb3be 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -299,6 +299,10 @@ class TFOpMapperNHWC(OpMapper): data_format = node.get_attr("data_format").decode() pad_mode = node.get_attr("padding").decode() channel_first = data_format == "NCHW" + if data_format == "NHWC": + n, h, w, c = input.out_shapes[0] + else: + n, c, h, w = input.out_shapes[0] if kernel.layer_type == 'Const': kernel_value = kernel.value @@ -329,10 +333,15 @@ class TFOpMapperNHWC(OpMapper): "dilation": dilations[2:4], "padding": string(pad_mode) } - if hasattr(node, 'dilation') and attr['dilation'] == [1, 1]: if len(node.dilation) == 1: attr['dilation'] = [1, node.dilation[0]] + + if c == -1: + reshape_attr = {"shape": [0, k_size[2], 0, 0]} + node.fluid_code.add_layer( + "reshape", inputs=input, output=input, param_attr=reshape_attr) + node.fluid_code.add_layer( "conv2d", inputs=input, output=node, param_attr=attr) if not channel_first: @@ -748,11 +757,12 @@ class TFOpMapperNHWC(OpMapper): self.add_omit_nodes(begin.layer_name, node.layer_name) begin = begin.value.tolist() else: - begin = begin - shape = begin.out_shapes[0] - attr = {"shape": shape} - node.fluid_code.add_layer( - "reshape", inputs=begin, output=begin, param_attr=attr) + begin = self.decoder.infer_tensor(begin).tolist() + +# shape = begin.out_shapes[0] +# attr = {"shape": shape} +# node.fluid_code.add_layer( +# "reshape", inputs=begin, output=begin, param_attr=attr) if size.layer_type == "Const": self.add_omit_nodes(size.layer_name, node.layer_name) size = size.value.tolist() diff --git a/x2paddle/optimizer/tf_optimizer.py b/x2paddle/optimizer/tf_optimizer.py index 6d3c0cdd017c6d046451e5837d2b75ef649cd6a8..daeda64dceb8cdfa31cf10fa31edb0aae176170e 100644 --- a/x2paddle/optimizer/tf_optimizer.py +++ b/x2paddle/optimizer/tf_optimizer.py @@ -863,6 +863,9 @@ class TFOptimizer(object): weight = numpy.expand_dims(weight, 2) weight = numpy.expand_dims(weight, 3) self.op_mapper.weights[in_nodes3[0].layer_name] = weight + # fix bug in Paddle1.8.3 and may change in next version. + self.op_mapper.weights[in_nodes3[0].layer_name + + '_1'] = weight.reshape(1, -1) in_nodes3[0].fluid_code.layers[0].param_attr["shape"] = [ 1, in_shape[-1], 1, 1 ] @@ -885,7 +888,7 @@ class TFOptimizer(object): node.fluid_code.clear() attr = { "mode": string(mode), - "param_attr": string(in_nodes3[0].layer_name) + "param_attr": string(in_nodes3[0].layer_name + "_1") } node.fluid_code.add_layer( diff --git a/x2paddle_model_zoo.md b/x2paddle_model_zoo.md index 73c65511de064115d958483172a38949db93fe75..464abdf547b1c9c7a3698c05579131426c44b59e 100644 --- a/x2paddle_model_zoo.md +++ b/x2paddle_model_zoo.md @@ -1,5 +1,5 @@ # X2Paddle模型测试库 -> 目前X2Paddle支持50+的TensorFlow OP,40+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。 +> 目前X2Paddle支持70+的TensorFlow OP,40+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下模型列表中测试了X2Paddle的转换。 **注:** 受限于不同框架的差异,部分模型可能会存在目前无法转换的情况,如TensorFlow中包含控制流的模型,NLP模型等。对于CV常见的模型,如若您发现无法转换或转换失败,存在较大diff等问题,欢迎通过[ISSUE反馈](https://github.com/PaddlePaddle/X2Paddle/issues/new)的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:) @@ -20,10 +20,13 @@ | ResNet_V1_101 | [code](https://github.com/tensorflow/models/tree/master/research/slim/nets) |-| | ResNet_V2_101 | [code](https://github.com/tensorflow/models/tree/master/research/slim/nets) |-| | UNet | [code1](https://github.com/jakeret/tf_unet )/[code2](https://github.com/lyatdawn/Unet-Tensorflow) |-| -|MTCNN | [code](https://github.com/AITTSMD/MTCNN-Tensorflow) |-| -|YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) | 转换需要关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) | -| FALSR | [code](https://github.com/xiaomi-automl/FALSR) | - | -| DCSCN | [code](https://modelzoo.co/model/dcscn-super-resolution) | - | +| MTCNN | [code](https://github.com/AITTSMD/MTCNN-Tensorflow) |-| +| YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) | 转换需要关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) | +| FALSR | [code](https://github.com/xiaomi-automl/FALSR) | 需使用参数without_data_format_optimization | +| DCSCN | [code](https://modelzoo.co/model/dcscn-super-resolution) | 需使用参数without_data_format_optimization | +| Bert(albert) | [code](https://github.com/google-research/albert#pre-trained-models) | 需使用参数without_data_format_optimization | +| Bert(chinese_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) | 需使用参数without_data_format_optimization | +| Bert(multi_cased_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) | 需使用参数without_data_format_optimization | ## Caffe