diff --git a/README.md b/README.md index 96b1f83561f1b0087e7fda6fd98027b7b78c151c..05bd30e0af7cab9b3bedb079c513cc1d506cbcac 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ paddlepaddle >= 1.8.0 **按需安装以下依赖** tensorflow : tensorflow == 1.14.0 caffe : 无 -onnx : onnx == 1.6.0 +onnx : onnx >= 1.6.0 ## 安装 ### 安装方式一(推荐) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 8c1198c1aee8faf37938c6b67aa7df1b3806a1b2..c3ba7220ac8d7562705c27951bbff1098f5aee9a 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -170,8 +170,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False): try: import onnx version = onnx.version.version - if version != '1.6.0': - print("[ERROR] onnx==1.6.0 is required") + if version < '1.6.0': + print("[ERROR] onnx>=1.6.0 is required") return except: print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".") diff --git a/x2paddle/decoder/onnx_shape_inference.py b/x2paddle/decoder/onnx_shape_inference.py index 910bf2dbfead6f5ec292af1302926fff02315cf3..ff3fe71c32a6a435f171cc76321ce8acac3c37d3 100644 --- a/x2paddle/decoder/onnx_shape_inference.py +++ b/x2paddle/decoder/onnx_shape_inference.py @@ -267,8 +267,9 @@ class SymbolicShapeInference: if pending_nodes and self.verbose_ > 0: print('SymbolicShapeInference: orphaned nodes discarded: ') - print('\n'.join( - [n.op_type + ': ' + n.output[0] for n in pending_nodes])) + for n in pending_nodes: + print(n.op_type + ': ' + n.output[0]) + if input_shapes is not None: for input_name, shape in input_shapes.items(): for idx in range(len(self.out_mp_.graph.input)): diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index c08be94567bc9a7fd1caa9c206d2f6cf6bd30a92..20538cc7051b725abc09d728610e9caf3b13a0d2 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -487,16 +487,6 @@ class OpSet9(): node.fluid_code.add_layer( 'hard_shrink', inputs=val_x, output=node, param_attr=attr) - def Greater(self, node): - val_x = self.graph.get_input_node(node, idx=0, copy=True) - val_y = self.graph.get_input_node(node, idx=1, copy=True) - node.fluid_code.add_layer( - 'greater_than', - inputs={'x': val_x, - 'y': val_y}, - output=node, - param_attr=None) - @print_mapping_info def Constant(self, node): val_output = self.graph.get_node(node.layer.output[0], copy=True) @@ -566,25 +556,26 @@ class OpSet9(): def Expand(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True) - if len(val_shape.outputs) == 1: self.omit_nodes.append(val_shape.layer_name) - - val_y = self.graph.get_node(node.layer.output[0], copy=True) - out_shape = node.out_shapes[0] val_x_dtype = val_x.dtype - name_ones = node.layer_name + '_ones' - attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)} + attr_ones = { + 'shape': val_shape.layer_name, + 'dtype': string(val_x_dtype), + 'value': 1 + } node.fluid_code.add_layer( - 'ones', inputs=None, output=name_ones, param_attr=attr_ones) + 'fill_constant', + inputs=None, + output=name_ones, + param_attr=attr_ones) inputs = {'x': name_ones, 'y': val_x} - attr = {'name': string(node.layer_name)} node.fluid_code.add_layer( 'elementwise_mul', inputs=inputs, output=node.layer_name, - param_attr=attr) + param_attr=None) @print_mapping_info def Gather(self, node): @@ -652,9 +643,15 @@ class OpSet9(): elif axis == 0 and len(indices_shape) > 1: if val_x.out_shapes[0] is not None and isinstance( val_x, ONNXGraphDataNode): + indices_cast = indices.layer_name + '_cast' node.fluid_code.add_layer( - 'embedding', + 'cast', inputs=indices, + output=indices_cast, + param_attr={'dtype': string('int64')}) + node.fluid_code.add_layer( + 'embedding', + inputs=indices_cast, output=node, use_fluid=True, param_attr={ @@ -663,7 +660,6 @@ class OpSet9(): }) else: from functools import reduce - #indices_shape = [1,7] reshape_shape = reduce(lambda x, y: x * y, indices_shape) indices_reshape = indices.layer_name + '_shape' node.fluid_code.add_layer( @@ -703,7 +699,7 @@ class OpSet9(): perm = list(range(len(val_x.out_shapes[0]))) perm = [axis] + perm[:axis] + perm[axis + 1:] attr_trans = {'perm': perm} - name_trans = val_x.layer_name + '_trans' + name_trans = val_x.layer_name + '_transpose' node.fluid_code.add_layer( 'transpose', inputs=val_x, @@ -715,8 +711,12 @@ class OpSet9(): 'index': indices_reshape}, output=node, param_attr=None) + input_transpose = node.layer_name + '_transpose' node.fluid_code.add_layer( - 'transpose', inputs=node, output=node, param_attr=attr_trans) + 'transpose', + inputs=node, + output=input_transpose, + param_attr=attr_trans) val_x_shape = val_x.out_shapes[0] reshaped_shape = [] for i in perm: @@ -725,7 +725,7 @@ class OpSet9(): reshaped_shape.append(i) node.fluid_code.add_layer( 'reshape', - inputs=node, + inputs=input_transpose, output=node, param_attr={'shape': reshaped_shape}) @@ -859,17 +859,21 @@ class OpSet9(): } else: if starts.dtype != 'int32': + starts_cast = starts.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=starts, - output=starts, + output=starts_cast, param_attr={'dtype': string('int32')}) + attr['starts'] = starts_cast if ends.dtype != 'int32': + ends_cast = ends.layer_name + '_cast' node.fluid_code.add_layer( 'cast', inputs=ends, - output=ends, + output=ends_cast, param_attr={'dtype': string('int32')}) + attr['ends'] = ends_cast else: starts = node.get_attr('starts') ends = node.get_attr('ends') @@ -1138,7 +1142,7 @@ class OpSet9(): x_shape = val_x.out_shapes[0] y_shape = val_y.out_shapes[0] inputs = {"x": val_x, "y": val_y} - if y_shape[0] == 1 and x_shape[-1] != 1: + if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: y_squeeze = val_y.layer_name + '_squeeze' node.fluid_code.add_layer( "squeeze", @@ -1286,7 +1290,6 @@ class OpSet9(): 'y': cast_condition}, output=mul_val_x, param_attr=None) - mul_val_y = val_y.layer_name + '_mul' node.fluid_code.add_layer( "elementwise_mul", @@ -1339,7 +1342,8 @@ class OpSet9(): if val_repeats.dtype != 'int32': attr = {"dtype": string("int32")} node.fluid_code.add_layer( - "cast", inputs=repeats, + "cast", + inputs=repeats, output="{}.tmp".format(repeats), param_attr=attr) repeats = "{}.tmp".format(repeats)