diff --git a/x2paddle/__init__.py b/x2paddle/__init__.py index a5f830a2c0bb0dba9efc6d960eac2adef22b6d74..bc8c296f6a56f243a9b7fbea63cd56eceeb2777d 100644 --- a/x2paddle/__init__.py +++ b/x2paddle/__init__.py @@ -1 +1 @@ -__version__ = "0.7.1" +__version__ = "0.7.2" diff --git a/x2paddle/convert.py b/x2paddle/convert.py index c0c8fb98131b5f063b79334775ae63e27ab16c85..92ec1cf502943ee6220aac887a60bbab34f0b0cd 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -19,32 +19,38 @@ import sys def arg_parser(): parser = argparse.ArgumentParser() - parser.add_argument("--model", - "-m", - type=_text_type, - default=None, - help="define model file path for tensorflow or onnx") - parser.add_argument("--prototxt", - "-p", - type=_text_type, - default=None, - help="prototxt file of caffe model") - parser.add_argument("--weight", - "-w", - type=_text_type, - default=None, - help="weight file of caffe model") - parser.add_argument("--save_dir", - "-s", - type=_text_type, - default=None, - help="path to save translated model") + parser.add_argument( + "--model", + "-m", + type=_text_type, + default=None, + help="define model file path for tensorflow or onnx") + parser.add_argument( + "--prototxt", + "-p", + type=_text_type, + default=None, + help="prototxt file of caffe model") + parser.add_argument( + "--weight", + "-w", + type=_text_type, + default=None, + help="weight file of caffe model") + parser.add_argument( + "--save_dir", + "-s", + type=_text_type, + default=None, + help="path to save translated model") parser.add_argument( "--framework", "-f", type=_text_type, default=None, - help="define which deeplearning framework(tensorflow/caffe/onnx)") + help= + "define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)" + ) parser.add_argument( "--caffe_proto", "-c", @@ -52,27 +58,30 @@ def arg_parser(): default=None, help="optional: the .py file compiled by caffe proto file of caffe model" ) - parser.add_argument("--version", - "-v", - action="store_true", - default=False, - help="get version of x2paddle") + parser.add_argument( + "--version", + "-v", + action="store_true", + default=False, + help="get version of x2paddle") parser.add_argument( "--without_data_format_optimization", "-wo", action="store_true", default=False, help="tf model conversion without data format optimization") - parser.add_argument("--define_input_shape", - "-d", - action="store_true", - default=False, - help="define input shape for tf model") - parser.add_argument("--params_merge", - "-pm", - action="store_true", - default=False, - help="define whether merge the params") + parser.add_argument( + "--define_input_shape", + "-d", + action="store_true", + default=False, + help="define input shape for tf model") + parser.add_argument( + "--params_merge", + "-pm", + action="store_true", + default=False, + help="define whether merge the params") return parser @@ -177,6 +186,14 @@ def onnx2paddle(model_path, save_dir, params_merge=False): mapper.save_inference_model(save_dir, params_merge) +def paddle2onnx(model_path, save_dir): + from x2paddle.decoder.paddle_decoder import PaddleDecoder + from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper + model = PaddleDecoder(model_path, '__model__', '__params__') + mapper = PaddleOpMapper() + mapper.convert(model.program, save_dir) + + def main(): if len(sys.argv) < 2: print("Use \"x2paddle -h\" to print the help information") @@ -249,8 +266,14 @@ def main(): if args.params_merge: params_merge = True onnx2paddle(args.model, args.save_dir, params_merge) + + elif args.framework == "paddle2onnx": + assert args.model is not None, "--model should be defined while translating paddle model to onnx" + paddle2onnx(args.model, args.save_dir) + else: - raise Exception("--framework only support tensorflow/caffe/onnx now") + raise Exception( + "--framework only support tensorflow/caffe/onnx/paddle2onnx now") if __name__ == "__main__": diff --git a/x2paddle/decoder/paddle_decoder.py b/x2paddle/decoder/paddle_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..77dae2d4eed9ed41912324d05dcf473cc7327a7e --- /dev/null +++ b/x2paddle/decoder/paddle_decoder.py @@ -0,0 +1,14 @@ +import paddle.fluid as fluid + + +class PaddleDecoder(object): + def __init__(self, + model_dir, + model_filename='__model__', + params_filename=None): + exe = fluid.Executor(fluid.CPUPlace()) + [self.program, feed, fetchs] = fluid.io.load_inference_model( + model_dir, + exe, + model_filename=model_filename, + params_filename=params_filename) diff --git a/x2paddle/op_mapper/paddle_op_mapper.py b/x2paddle/op_mapper/paddle_op_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..690237419f327ea23d19387dfbb5f9f636a65050 --- /dev/null +++ b/x2paddle/op_mapper/paddle_op_mapper.py @@ -0,0 +1,506 @@ +import math +import x2paddle +import os +import numpy as np +import paddle.fluid.core as core +import paddle.fluid as fluid +import onnx +from onnx import helper, onnx_pb + + +class PaddleOpMapper(object): + def __init__(self): + self.paddle_onnx_dtype_map = { + core.VarDesc.VarType.FP32: onnx_pb.TensorProto.FLOAT, + core.VarDesc.VarType.FP64: onnx_pb.TensorProto.DOUBLE, + core.VarDesc.VarType.INT32: onnx_pb.TensorProto.INT32, + core.VarDesc.VarType.INT16: onnx_pb.TensorProto.INT16, + core.VarDesc.VarType.INT16: onnx_pb.TensorProto.UINT16, + core.VarDesc.VarType.INT64: onnx_pb.TensorProto.INT64, + core.VarDesc.VarType.BOOL: onnx_pb.TensorProto.BOOL + } + + self.name_counter = dict() + + def get_name(self, op_name, var_name): + name = 'p2o.{}.{}'.format(op_name, var_name) + if name not in self.name_counter: + self.name_counter[name] = 0 + else: + self.name_counter[name] += 1 + return name + '.{}'.format(self.name_counter[name]) + + def make_constant_node(self, name, dtype, value=None): + if isinstance(value, list): + dims = (len(value), ) + elif value is None: + dims = () + value = [] + else: + dims = () + value = [value] + tensor = helper.make_tensor( + name=name, data_type=dtype, dims=dims, vals=value) + node = helper.make_node( + 'Constant', inputs=[], outputs=[name], value=tensor) + return node + + def conv2d(self, op, block): + kernel_shape = block.var(op.input('Filter')[0]).shape + node = helper.make_node( + 'Conv', + inputs=op.input('Input') + op.input('Filter'), + outputs=op.output('Output'), + dilations=op.attr('dilations'), + kernel_shape=kernel_shape[-2:], + strides=op.attr('strides'), + group=op.attr('groups'), + pads=op.attr('paddings') + op.attr('paddings')) + return node + + def relu(self, op, block): + node = helper.make_node( + 'Relu', inputs=op.input('X'), outputs=op.output('Out')) + return node + + def elementwise_add(self, op, block): + axis = op.attr('axis') + x_shape = block.var(op.input('X')[0]).shape + y_shape = block.var(op.input('Y')[0]).shape + if len(y_shape) == 1 and axis == 1: + shape_name = self.get_name(op.type, 'shape') + shape_value = [1] * len(x_shape) + shape_value[axis] = y_shape[0] + shape_node = self.make_constant_node( + shape_name, onnx_pb.TensorProto.INT64, shape_value) + temp_value = self.get_name(op.type, 'temp') + y_node = helper.make_node( + 'Reshape', + inputs=[op.input('Y')[0], shape_name], + outputs=[temp_value]) + node = helper.make_node( + 'Add', + inputs=[op.input('X')[0], temp_value], + outputs=op.output('Out')) + return [shape_node, y_node, node] + elif len(x_shape) == len(y_shape): + node = helper.make_node( + 'Add', + inputs=[op.input('X')[0], op.input('Y')[0]], + outputs=op.output('Out')) + return node + else: + raise Excpetion("Unexpected situation happend in elementwise_add") + + def pool2d(self, op, block): + pool_type = { + 'max': ('MaxPool', 'GlobalMaxPool'), + 'avg': ('AveragePool', 'GlobalAveragePool') + } + if op.attr('global_pooling'): + node = helper.make_node( + pool_type[op.attr('pooling_type')][1], + inputs=op.input('X'), + outputs=op.output('Out'), + ) + else: + node = helper.make_node( + pool_type[op.attr('pooling_type')][0], + inputs=op.input('X'), + outputs=op.output('Out'), + kernel_shape=op.attr('ksize'), + strides=op.attr('strides'), + pads=op.attr('paddings') + op.attr('paddings')) + return node + + def softmax(self, op, block): + node = helper.make_node( + 'Softmax', + inputs=op.input('X'), + outputs=op.output('Out'), + axis=op.attr('axis')) + return node + + def scale(self, op, block): + scale = op.attr('scale') + bias = op.attr('bias') + if math.fabs(scale - 1.0) < 1e-06 and math.fabs(bias - 0.0) < 1e-06: + node = helper.make_node( + 'Identity', inputs=op.input('X'), outputs=op.output('Out')) + return node + else: + scale_name = self.get_name(op.type, 'scale') + bias_name = self.get_name(op.type, 'bias') + scale_node = self.make_constant_node( + scale_name, onnx_pb.TensorProto.FLOAT, scale) + bias_node = self.make_constant_node(bias_name, + onnx_pb.TensorProto.FLOAT, bias) + temp_tensor_name = self.get_name(op.type, 'temporary') + if op.attr('bias_after_scale'): + node1 = helper.make_node( + 'Mul', + inputs=[scale_name, op.input('X')[0]], + outputs=[temp_tensor_name]) + node2 = helper.make_node( + 'Add', + inputs=[bias_name, temp_tensor_name], + outputs=op.output('Out')) + else: + node1 = helper.make_node( + 'Add', + inputs=[bias_name, op.input('X')[0]], + outputs=temp_tensor_name) + node2 = helper.make_node( + 'Mul', + inputs=[scale_name, temp_tensor_name], + outputs=[op.output('Out')]) + return [scale_node, bias_node, node1, node2] + + def mul(self, op, block): + x_shape = block.var(op.input('X')[0]).shape + y_shape = block.var(op.input('Y')[0]).shape + out_shape = list(block.var(op.output('Out')[0]).shape) + x_num_col_dims = op.attr('x_num_col_dims') + y_num_col_dims = op.attr('y_num_col_dims') + flatten_x_name = 'flatten_{}'.format(op.input('X')[0]) + flatten_y_name = 'flatten_{}'.format(op.input('Y')[0]) + shape_name = 'temp_shape_{}'.format(op.output('Out')[0]) + temp_out_name = 'temp_{}'.format(op.output('Out')[0]) + flatten_x = helper.make_node( + 'Flatten', + inputs=op.input('X'), + outputs=[flatten_x_name], + axis=x_num_col_dims) + flatten_y = helper.make_node( + 'Flatten', + inputs=op.input('Y'), + outputs=[flatten_y_name], + axis=y_num_col_dims) + shape_node = self.make_constant_node( + shape_name, onnx_pb.TensorProto.INT64, out_shape) + node = helper.make_node( + 'MatMul', + inputs=[flatten_x_name, flatten_y_name], + outputs=[temp_out_name]) + reshape_out = helper.make_node( + 'Reshape', + inputs=[temp_out_name, shape_name], + outputs=op.output('Out')) + return [flatten_x, flatten_y, shape_node, node, reshape_out] + + def batch_norm(self, op, block): + kwargs = { + 'epsilon': op.attr('epsilon'), + 'momentum': op.attr('momentum') + } + inputs = op.input('X') + op.input('Scale') + op.input( + 'Bias') + op.input('Mean') + op.input('Variance') + node = helper.make_node( + 'BatchNormalization', + inputs=inputs, + outputs=op.output('Y'), + **kwargs) + return node + + def concat(self, op, block): + node = helper.make_node( + 'Concat', + inputs=op.input('X'), + outputs=op.output('Out'), + axis=op.attr('axis')) + return node + + def depthwise_conv2d(self, op, block): + return self.conv2d(op, block) + + def relu6(self, op, block): + min_name = self.get_name(op.type, 'min') + max_name = self.get_name(op.type, 'max') + min_node = self.make_constant_node(min_name, onnx_pb.TensorProto.FLOAT, + 0) + max_node = self.make_constant_node(max_name, onnx_pb.TensorProto.FLOAT, + op.attr('threshold')) + node = helper.make_node( + 'Clip', + inputs=[op.input('X')[0], min_name, max_name], + outputs=op.output('Out'), + ) + return [min_node, max_node, node] + + def shape(self, op, block): + node = helper.make_node( + 'Shape', inputs=op.input('Input'), outputs=op.output('Out')) + return node + + def split(self, op, block): + sections = op.attr('sections') + if len(sections) > 0: + node = helper.make_node( + 'Split', + inputs=op.input('X'), + outputs=op.output('Out'), + axis=op.attr('axis'), + split=sections) + else: + node = helper.make_node( + 'Split', + inputs=op.input('X'), + outputs=op.output('Out'), + axis=op.attr('axis')) + + def slice(self, op, block): + axes = op.attr('axes') + starts = op.attr('starts') + ends = op.attr('ends') + axes_name = get_name(op.type, 'axes') + starts_name = get_name(op.type, 'starts') + ends_name = get_name(op.type, 'ends') + + axes_node = make_constant_node(axes_name, onnx_pb.TensorProto.INT64, + axes) + starts_node = make_constant_node(starts_name, onnx_pb.TensorProto.INT64, + starts) + ends_node = make_constant_node(ends_name, onnx_pb.TensorProto.INT64, + ends) + node = helper.make_node( + "Slice", + inputs=[op.input('Input')[0], starts_name, ends_name, axes_name], + outputs=op.output('Out'), + ) + return [starts_node, ends_node, axes_node, node] + + def fill_constant(self, op, block): + value = op.attr('value') + dtype = op.attr('dtype') + shape = op.attr('shape') + value = np.ones(shape) * value + node = helper.make_node( + 'Constant', + inputs=[], + outputs=op.attr('Out'), + value=helper.make_tensor( + name=op.attr('Out'), + data_type=self.paddle_onnx_dtype_map[dtype], + dims=shape, + vals=value.tolist())) + return node + + def transpose2(self, op, block): + node = helper.make_node( + 'Transpose', + inputs=op.input('X'), + outputs=op.output('Out'), + perm=op.attr('perm')) + return node + + def reshape2(self, op, block): + input_names = op.input_names + if 'Shape' in input_names and len(op.input('Shape')) > 0: + node = helper.make_node( + 'Reshape', + inputs=[op.input('X')[0], + op.input('Shape')[0]], + outputs=op.output('Out')) + else: + shape = op.attr('shape') + shape_name = get_name(op.type, 'shape') + shape_node = make_constant_node(shape_name, + onnxpb.TensorProto.INT64, shape) + node = helper.make_node( + 'Reshape', + inputs=[op.input('X')[0], shape_name], + outputs=op.output('Out')) + return [shape_node, node] + return node + + def dropout(self, op, block): + dropout_mode = op.attr('dropout_implementation') + dropout_prob = op.attr('dropout_prob') + if dropout_mode == 'upscale_in_train': + node = helper.make_node( + 'Identity', inputs=op.input('X'), outputs=op.output('Out')) + return node + elif dropout_mode == 'downgrade_in_infer': + scale_name = self.get_name(op.type, 'scale') + scale_node = self.make_constant_node( + scale_name, onnx_pb.TensorProto.FLOAT, 1 - dropout_prob) + node = helper.make_node( + "Mul", + inputs=[op.input('X')[0], scale_name], + outputs=op.output('Out')) + return [scale_node, node] + else: + raise Exception("Unexpected situation happend") + + def reduce_mean(self, op, block): + node = helper.make_node( + 'ReduceMean', + inputs=op.input('X'), + outputs=op.output('Out'), + axes=op.attr('axes'), + keepdims=op.attr('keep_dim')) + return node + + def nearest_interp(self, op, block): + input_names = op.input_names + if 'OutSize' in input_names and len(op.input('OutSize')) > 0: + node = helper.make_node( + 'Resize', + inputs=[op.input('X')[0], '', + op.input('OutSize')[0]], + outputs=op.output('Out')) + elif 'Scale' in input_names and len(op.input('Scale')) > 0: + node = helper.make_node( + 'Resize', + inputs=[op.input('X')[0], + op.input('Scale')[0]], + outputs=op.output('Out')) + else: + out_shape = [op.attr('out_h'), op.attr('out_w')] + scale = op.attr('scale') + if out_shape.count(-1) > 0: + scale_name = self.get_name(op.type, 'scale') + scale_node = self.make_constant_node( + scale_name, onnx_pb.TensorProto.FLOAT, [1, 1, scale, scale]) + roi_name = self.get_name(op.type, 'roi') + roi_node = self.make_constant_node(roi_name, + onnx_pb.TensorProto.FLOAT, + [1, 1, 1, 1, 1, 1, 1, 1]) + node = helper.make_node( + 'Resize', + inputs=[op.input('X')[0], roi_name, scale_name], + outputs=op.output('Out'), + mode='nearest') + return [scale_node, roi_node, node] + else: + raise Exception("Unexpected situation happend") + return node + + def hard_sigmoid(self, op, block): + slope = op.attr('slope') + offset = op.attr('offset') + node = helper.make_node( + 'HardSigmoid', + inputs=op.input('X'), + outputs=op.output('Out'), + alpha=slope, + beta=offset) + return node + + def elementwise_mul(self, op, block): + axis = op.attr('axis') + x_shape = block.var(op.input('X')[0]).shape + y_shape = block.var(op.input('Y')[0]).shape + if len(y_shape) == 1 and axis == 1: + shape_name = self.get_name(op.type, 'shape') + shape_value = [1] * len(x_shape) + shape_value[axis] = y_shape[0] + shape_node = self.make_constant_node( + shape_name, onnx_pb.TensorProto.INT64, shape_value) + temp_value = self.get_name(op.type, 'temp') + y_node = helper.make_node( + 'Reshape', + inputs=[op.input('Y')[0], shape_name], + outputs=[temp_value]) + node = helper.make_node( + 'Mul', + inputs=[op.input('X')[0], temp_value], + outputs=op.output('Out')) + return [shape_node, y_node, node] + elif len(x_shape) == len(y_shape): + node = helper.make_node( + 'Mul', + inputs=[op.input('X')[0], op.input('Y')[0]], + outputs=op.output('Out')) + return node + else: + raise Excpetion("Unexpected situation happend in elementwise_add") + return node + + def feed(self, op, block): + name = op.output('Out')[0] + var = block.var(name) + tensor_info = helper.make_tensor_value_info( + name=name, + shape=var.shape, + elem_type=self.paddle_onnx_dtype_map[var.dtype]) + return tensor_info + + def fetch(self, op, block): + name = op.input('X')[0] + var = block.var(name) + tensor_info = helper.make_tensor_value_info( + name=name, + shape=var.shape, + elem_type=self.paddle_onnx_dtype_map[var.dtype]) + return tensor_info + + def convert_weights(self, program): + var_names = program.global_block().vars + nodes = list() + for name in var_names: + var = program.global_block().var(name) + if name.endswith('feed') or name.endswith('fetch'): + continue + if not var.persistable: + continue + weight = np.array(fluid.global_scope().find_var(name).get_tensor()) + tensor = helper.make_tensor( + name=name, + dims=var.shape, + data_type=self.paddle_onnx_dtype_map[var.dtype], + vals=weight.flatten().tolist()) + node = helper.make_node( + 'Constant', inputs=[], outputs=[name], value=tensor) + nodes.append(node) + return nodes + + def convert(self, program, save_dir): + weight_nodes = self.convert_weights(program) + op_nodes = list() + input_nodes = list() + output_nodes = list() + + unsupported_ops = set() + + for block in program.blocks: + for op in block.ops: + print('Translating op: {}'.format(op.type)) + if not hasattr(self, op.type): + unsupported_ops.add(op.type) + continue + if len(unsupported_ops) > 0: + continue + node = getattr(self, op.type)(op, block) + if op.type == 'feed': + input_nodes.append(node) + elif op.type == 'fetch': + output_nodes.append(node) + else: + if isinstance(node, list): + op_nodes = op_nodes + node + else: + op_nodes.append(node) + + if len(unsupported_ops) > 0: + print("There's {} ops are not supported yet".format( + len(unsupported_ops))) + for op in unsupported_ops: + print("=========== {} ===========".format(op)) + return + + graph = helper.make_graph( + nodes=weight_nodes + op_nodes, + name='onnx_model_from_paddle', + initializer=[], + inputs=input_nodes, + outputs=output_nodes) + model = helper.make_model(graph, producer_name='X2Paddle') + onnx.checker.check_model(model) + + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + with open(os.path.join(save_dir, 'x2paddle_model.onnx'), 'wb') as f: + f.write(model.SerializeToString()) + print("Translated model saved in {}".format( + os.path.join(save_dir, 'x2paddle_model.onnx'))) diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 75ab10e2980272308f0615d202a284e1ce6bba96..05a06ac91247b11111200c1406a5466a616f4847 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -43,6 +43,7 @@ class TFOpMapperNHWC(OpMapper): 'Sqrt': ['sqrt'], 'swish_f32': ['swish'], 'Tanh': ['tanh'], + 'Softplus': ['softplus'], 'LeakyRelu': ['leaky_relu', { 'alpha': 'alpha' }] @@ -128,26 +129,18 @@ class TFOpMapperNHWC(OpMapper): if len(input.out_shapes[0]) == 4 and op_info[0] != 'shape': attr1 = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer('transpose', - inputs=input, - output=node, - param_attr=attr1) + node.fluid_code.add_layer( + 'transpose', inputs=input, output=node, param_attr=attr1) input = node - node.fluid_code.add_layer(op_info[0], - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + op_info[0], inputs=input, output=node, param_attr=attr) input = node attr2 = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer('transpose', - inputs=input, - output=node, - param_attr=attr2) + node.fluid_code.add_layer( + 'transpose', inputs=input, output=node, param_attr=attr2) else: - node.fluid_code.add_layer(op_info[0], - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + op_info[0], inputs=input, output=node, param_attr=attr) def elementwise_map(self, node): assert node.layer_type in self.elementwise_ops @@ -208,42 +201,37 @@ class TFOpMapperNHWC(OpMapper): raise Exception("Unexpected situation happend") if x_need_expand: attr = {"expand_times": x_expand_times} - node.fluid_code.add_layer("expand", - inputs=x_input, - output="x_tmp", - param_attr=attr) + node.fluid_code.add_layer( + "expand", inputs=x_input, output="x_tmp", param_attr=attr) x_input = "x_tmp" if y_need_expand: attr = {"expand_times": y_expand_times} - node.fluid_code.add_layer("expand", - inputs=y_input, - output="y_tmp", - param_attr=attr) + node.fluid_code.add_layer( + "expand", inputs=y_input, output="y_tmp", param_attr=attr) y_input = "y_tmp" if len(x_shape) == 4 and len(y_shape) == 4: - node.fluid_code.add_layer("transpose", - inputs=x_input, - output=x_input, - param_attr={'perm': [0, 3, 1, 2]}) - node.fluid_code.add_layer("transpose", - inputs=y_input, - output=y_input, - param_attr={'perm': [0, 3, 1, 2]}) + node.fluid_code.add_layer( + "transpose", + inputs=x_input, + output=x_input, + param_attr={'perm': [0, 3, 1, 2]}) + node.fluid_code.add_layer( + "transpose", + inputs=y_input, + output=y_input, + param_attr={'perm': [0, 3, 1, 2]}) inputs = {"x": x_input, "y": y_input} - node.fluid_code.add_layer(op_type, - inputs=inputs, - output=node, - param_attr=None) - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr={'perm': [0, 2, 3, 1]}) + node.fluid_code.add_layer( + op_type, inputs=inputs, output=node, param_attr=None) + node.fluid_code.add_layer( + "transpose", + inputs=node, + output=node, + param_attr={'perm': [0, 2, 3, 1]}) else: inputs = {"x": x_input, "y": y_input} - node.fluid_code.add_layer(op_type, - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer( + op_type, inputs=inputs, output=node, param_attr=None) def Placeholder(self, node): shape = node.out_shapes[0] @@ -259,10 +247,8 @@ class TFOpMapperNHWC(OpMapper): 'append_batch_size': False } - node.fluid_code.add_layer("data", - inputs=None, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "data", inputs=None, output=node, param_attr=attr) def Const(self, node): shape = node.out_shapes[0] @@ -282,10 +268,8 @@ class TFOpMapperNHWC(OpMapper): 'name': string(node.layer_name), 'default_initializer': initializer } - node.fluid_code.add_layer("create_parameter", - inputs=None, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "create_parameter", inputs=None, output=node, param_attr=attr) def Transpose(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -296,10 +280,8 @@ class TFOpMapperNHWC(OpMapper): perm = perm.value.tolist() attr = {'perm': perm} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) def MaxPool(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -316,10 +298,8 @@ class TFOpMapperNHWC(OpMapper): if not channel_first: attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) in_shape = [in_shape[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]] k_size = [k_size[i] for i in [0, 3, 1, 2]] @@ -331,17 +311,13 @@ class TFOpMapperNHWC(OpMapper): "pool_stride": strides[2:4], "pool_padding": string(pad_mode) } - node.fluid_code.add_layer("pool2d", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "pool2d", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def Conv2D(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -373,10 +349,8 @@ class TFOpMapperNHWC(OpMapper): strides = [strides[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]] attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node attr = { @@ -393,25 +367,19 @@ class TFOpMapperNHWC(OpMapper): if len(node.dilation) == 1: attr['dilation'] = [1, node.dilation[0]] - node.fluid_code.add_layer("conv2d", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "conv2d", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def BiasAdd(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) bias = self.graph.get_node(node.layer.input[1], copy=True) inputs = {"x": input, "y": bias} - node.fluid_code.add_layer("elementwise_add", - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "elementwise_add", inputs=inputs, output=node, param_attr=None) def FusedBatchNorm(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -433,10 +401,8 @@ class TFOpMapperNHWC(OpMapper): if not channel_first: attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node attr = { @@ -448,17 +414,13 @@ class TFOpMapperNHWC(OpMapper): "is_test": True } - node.fluid_code.add_layer("batch_norm", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "batch_norm", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def DepthwiseConv2dNative(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -487,10 +449,8 @@ class TFOpMapperNHWC(OpMapper): strides = [strides[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]] attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node attr = { @@ -504,17 +464,13 @@ class TFOpMapperNHWC(OpMapper): "use_cudnn": False, "padding": string(pad_mode) } - node.fluid_code.add_layer("conv2d", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "conv2d", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def Reshape(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -533,15 +489,14 @@ class TFOpMapperNHWC(OpMapper): assert len(param.out_shapes[0] ) == 1, "Unexpected situation of shape parameter" attr = {"shape": [-1]} - node.fluid_code.add_layer("reshape", - inputs=param, - output="shape_param", - param_attr=attr) + node.fluid_code.add_layer( + "reshape", + inputs=param, + output="shape_param", + param_attr=attr) attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0} - node.fluid_code.add_layer("split", - inputs="shape_param", - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "split", inputs="shape_param", output=node, param_attr=attr) new_param = "[" for i in range(param.out_shapes[0][0]): new_param += (node.layer_name + "[{}]".format(i) + ", ") @@ -565,10 +520,8 @@ class TFOpMapperNHWC(OpMapper): attr["shape"][index] = int(total_size) attr["shape"][0] = -1 - node.fluid_code.add_layer("reshape", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "reshape", inputs=input, output=node, param_attr=attr) def AvgPool(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -588,10 +541,8 @@ class TFOpMapperNHWC(OpMapper): strides = [strides[i] for i in [0, 3, 1, 2]] k_size = [k_size[i] for i in [0, 3, 1, 2]] attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node attr = { @@ -600,17 +551,13 @@ class TFOpMapperNHWC(OpMapper): "pool_stride": strides[2:4], "pool_padding": string(pad_mode) } - node.fluid_code.add_layer("pool2d", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "pool2d", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def SplitV(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -625,10 +572,8 @@ class TFOpMapperNHWC(OpMapper): "num_or_sections": num_sections.value.tolist(), "dim": dim.value } - node.fluid_code.add_layer("split", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "split", inputs=input, output=node, param_attr=attr) def ConcatV2(self, node): inputs = [ @@ -643,10 +588,8 @@ class TFOpMapperNHWC(OpMapper): axis += len(inputs[0].out_shapes[0]) attr = {"axis": axis} - node.fluid_code.add_layer("concat", - inputs=inputs, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "concat", inputs=inputs, output=node, param_attr=attr) def Tile(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -660,10 +603,8 @@ class TFOpMapperNHWC(OpMapper): if expand_times[i] < 0: expand_times[i] = 1 attr = {"expand_times": expand_times} - node.fluid_code.add_layer("expand", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "expand", inputs=input, output=node, param_attr=attr) def Pack(self, node): inputs = [ @@ -671,10 +612,8 @@ class TFOpMapperNHWC(OpMapper): ] axis = node.get_attr("axis") attr = {"axis": axis} - node.fluid_code.add_layer("stack", - inputs=inputs, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "stack", inputs=inputs, output=node, param_attr=attr) def Pad(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -695,30 +634,22 @@ class TFOpMapperNHWC(OpMapper): if new_padding is not None: if input.tf_data_format == "NHWC": attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node attr = {"paddings": new_padding} - node.fluid_code.add_layer("pad2d", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "pad2d", inputs=input, output=node, param_attr=attr) if input.tf_data_format == "NHWC": attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) return attr = {"paddings": paddings} - node.fluid_code.add_layer("pad", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "pad", inputs=input, output=node, param_attr=attr) def Range(self, node): start = self.graph.get_node(node.layer.input[0], copy=True) @@ -746,10 +677,8 @@ class TFOpMapperNHWC(OpMapper): "step": delta, } attr = {"dtype": string(node.dtype)} - node.fluid_code.add_layer("range", - inputs=inputs, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "range", inputs=inputs, output=node, param_attr=attr) def Mean(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -759,10 +688,8 @@ class TFOpMapperNHWC(OpMapper): keep_dims = node.get_attr("keep_dims") attr = {"dim": dims, "keep_dim": keep_dims} - node.fluid_code.add_layer("reduce_mean", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "reduce_mean", inputs=input, output=node, param_attr=attr) def MatMul(self, node): x = self.graph.get_node(node.layer.input[0], copy=True) @@ -776,15 +703,11 @@ class TFOpMapperNHWC(OpMapper): shape = x.out_shapes[0] shape[-1] = y.out_shapes[0][0] attr = {"shape": shape} - node.fluid_code.add_layer("reshape", - inputs=x, - output=x, - param_attr=attr) + node.fluid_code.add_layer( + "reshape", inputs=x, output=x, param_attr=attr) attr = {"transpose_x": transpose_a, "transpose_y": transpose_b} - node.fluid_code.add_layer("matmul", - inputs=inputs, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "matmul", inputs=inputs, output=node, param_attr=attr) def ArgMax(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -793,10 +716,8 @@ class TFOpMapperNHWC(OpMapper): self.add_omit_nodes(axis.layer_name, node.layer_name) axis = axis.value attr = {"axis": axis} - node.fluid_code.add_layer("argmax", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "argmax", inputs=input, output=node, param_attr=attr) def StridedSlice(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -862,25 +783,19 @@ class TFOpMapperNHWC(OpMapper): "starts": new_begin, "ends": new_end } - node.fluid_code.add_layer("slice", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "slice", inputs=input, output=node, param_attr=attr) if len(new_axes) > 0: attr = {"axes": new_axes} - node.fluid_code.add_layer("unsqueeze", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "unsqueeze", inputs=node, output=node, param_attr=attr) if len(shrink_axes) > 0: if len(input.out_shapes[0]) + len(new_axes) <= 1: pass else: attr = {"axes": shrink_axes} - node.fluid_code.add_layer("squeeze", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "squeeze", inputs=node, output=node, param_attr=attr) def Slice(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -909,10 +824,8 @@ class TFOpMapperNHWC(OpMapper): "ends": size } - node.fluid_code.add_layer("slice", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "slice", inputs=input, output=node, param_attr=attr) def Conv2DBackpropInput(self, node): out_shape = self.graph.get_node(node.layer.input[0], copy=True) @@ -950,10 +863,8 @@ class TFOpMapperNHWC(OpMapper): strides = [strides[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]] attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) input = node else: self.data_format_propagation(node) @@ -968,17 +879,13 @@ class TFOpMapperNHWC(OpMapper): "padding": string(pad_mode), "output_size": out_shape[1:3] } - node.fluid_code.add_layer("conv2d_transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "conv2d_transpose", inputs=input, output=node, param_attr=attr) if not channel_first: attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def Max(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -988,10 +895,8 @@ class TFOpMapperNHWC(OpMapper): dim = reduce_idx.value.tolist() attr = {"dim": dim, "keep_dim": keep_dims} - node.fluid_code.add_layer("reduce_max", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "reduce_max", inputs=input, output=node, param_attr=attr) def Sum(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -1001,19 +906,15 @@ class TFOpMapperNHWC(OpMapper): dim = reduce_idx.value.tolist() attr = {"dim": dim, "keep_dim": keep_dims} - node.fluid_code.add_layer("reduce_sum", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "reduce_sum", inputs=input, output=node, param_attr=attr) def Cast(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) dtype = node.dtype_map[node.get_attr('DstT')] attr = {"dtype": string(dtype)} - node.fluid_code.add_layer("cast", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "cast", inputs=input, output=node, param_attr=attr) def Split(self, node): dim = self.graph.get_node(node.layer.input[0], copy=True) @@ -1024,28 +925,22 @@ class TFOpMapperNHWC(OpMapper): dim = dim.value attr = {"num_or_sections": num_split, "dim": dim} - node.fluid_code.add_layer("split", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "split", inputs=input, output=node, param_attr=attr) def Squeeze(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) squeeze_dims = node.get_attr('squeeze_dims') attr = {"axes": squeeze_dims} - node.fluid_code.add_layer("squeeze", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "squeeze", inputs=input, output=node, param_attr=attr) def Softmax(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) axis = node.get_attr("axis") attr = {"axis": axis} - node.fluid_code.add_layer("softmax", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "softmax", inputs=input, output=node, param_attr=attr) def ResizeNearestNeighbor(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -1058,20 +953,14 @@ class TFOpMapperNHWC(OpMapper): resize_shape, node.out_shapes[0]) align_corners = node.get_attr("align_corners") attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) attr = {"align_corners": align_corners, "out_shape": resize_shape} - node.fluid_code.add_layer("resize_nearest", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "resize_nearest", inputs=node, output=node, param_attr=attr) attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def ResizeBilinear(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -1084,33 +973,25 @@ class TFOpMapperNHWC(OpMapper): resize_shape, node.out_shapes[0]) align_corners = node.get_attr("align_corners") attr = {"perm": [0, 3, 1, 2]} - node.fluid_code.add_layer("transpose", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=input, output=node, param_attr=attr) attr = { "align_corners": align_corners, "out_shape": resize_shape, "align_mode": 1 } - node.fluid_code.add_layer("resize_bilinear", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "resize_bilinear", inputs=node, output=node, param_attr=attr) attr = {"perm": [0, 2, 3, 1]} - node.fluid_code.add_layer("transpose", - inputs=node, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "transpose", inputs=node, output=node, param_attr=attr) def GreaterEqual(self, node): x = self.graph.get_node(node.layer.input[0], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True) inputs = {"x": x, "y": y} - node.fluid_code.add_layer("greater_equal", - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "greater_equal", inputs=inputs, output=node, param_attr=None) def RandomUniform(self, node): shape = self.graph.get_node(node.layer.input[0], copy=True) @@ -1123,29 +1004,24 @@ class TFOpMapperNHWC(OpMapper): if shape[0] < 0: input = self.batch_node - node.fluid_code.add_layer("uniform_random_batch_size_like", - inputs=input, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "uniform_random_batch_size_like", + inputs=input, + output=node, + param_attr=attr) else: - node.fluid_code.add_layer("uniform_random", - inputs=None, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "uniform_random", inputs=None, output=node, param_attr=attr) def SquaredDifference(self, node): x = self.graph.get_node(node.layer.input[0], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True) inputs = {"x": x, "y": y} - node.fluid_code.add_layer("elementwise_sub", - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "elementwise_sub", inputs=inputs, output=node, param_attr=None) inputs = {"x": node, "y": node} - node.fluid_code.add_layer("elementwise_mul", - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "elementwise_mul", inputs=inputs, output=node, param_attr=None) def ExpandDims(self, node): x = self.graph.get_node(node.layer.input[0], copy=True) @@ -1156,19 +1032,15 @@ class TFOpMapperNHWC(OpMapper): dim = self.decoder.infer_tensor(y) self.add_omit_nodes(y.layer_name, node.layer_name) attr = {'axes': [dim]} - node.fluid_code.add_layer("unsqueeze", - inputs=x, - output=node, - param_attr=attr) + node.fluid_code.add_layer( + "unsqueeze", inputs=x, output=node, param_attr=attr) def BatchToSpaceND(self, node): x = self.graph.get_node(node.layer.input[0], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True) if hasattr(node, 'skip') and node.skip: - node.fluid_code.add_layer("=", - inputs=x, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "=", inputs=x, output=node, param_attr=None) else: raise Exception("BatchToSpaceND is not supported") @@ -1176,9 +1048,7 @@ class TFOpMapperNHWC(OpMapper): x = self.graph.get_node(node.layer.input[0], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True) if hasattr(node, 'skip') and node.skip: - node.fluid_code.add_layer("=", - inputs=x, - output=node, - param_attr=None) + node.fluid_code.add_layer( + "=", inputs=x, output=node, param_attr=None) else: raise Exception("SpaceToBatchND is not supported")