From 1b4723939be13c1005bf70a1dca0c128e04cade1 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 27 Aug 2019 13:39:52 +0800 Subject: [PATCH] add mode of without data format optimization for tensorflow --- x2paddle/convert.py | 47 +- x2paddle/core/fluid_code.py | 5 +- x2paddle/decoder/tf_decoder.py | 34 +- x2paddle/op_mapper/tf_op_mapper.py | 72 +- x2paddle/op_mapper/tf_op_mapper_nhwc.py | 1058 +++++++++++++++++++++++ 5 files changed, 1185 insertions(+), 31 deletions(-) create mode 100644 x2paddle/op_mapper/tf_op_mapper_nhwc.py diff --git a/x2paddle/convert.py b/x2paddle/convert.py index d3342fa..0bb2a59 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -57,11 +57,25 @@ def arg_parser(): action="store_true", default=False, help="get version of x2paddle") + parser.add_argument( + "--without_data_format_optimization", + "-wo", + action="store_true", + default=False, + help="tf model conversion without data format optimization") + parser.add_argument("--define_input_shape", + "-d", + action="store_true", + default=False, + help="define input shape for tf model") return parser -def tf2paddle(model_path, save_dir): +def tf2paddle(model_path, + save_dir, + without_data_format_optimization=False, + define_input_shape=False): # check tensorflow installation and version try: import tensorflow as tf @@ -77,17 +91,23 @@ def tf2paddle(model_path, save_dir): from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.op_mapper.tf_op_mapper import TFOpMapper + from x2paddle.op_mapper.tf_op_mapper_nhwc import TFOpMapperNHWC from x2paddle.optimizer.tf_optimizer import TFOptimizer print("Now translating model from tensorflow to paddle.") - model = TFDecoder(model_path) - mapper = TFOpMapper(model) - optimizer = TFOptimizer(mapper) - # neccesary optimization - optimizer.delete_redundance_code() - # optimizer below is experimental - optimizer.merge_activation() - optimizer.merge_bias() + model = TFDecoder(model_path, define_input_shape=define_input_shape) + if not without_data_format_optimization: + mapper = TFOpMapper(model) + optimizer = TFOptimizer(mapper) + # neccesary optimization + optimizer.delete_redundance_code() + # optimizer below is experimental + optimizer.merge_activation() + optimizer.merge_bias() + else: + mapper = TFOpMapperNHWC(model) + optimizer = TFOptimizer(mapper) + optimizer.delete_redundance_code() mapper.save_inference_model(save_dir) @@ -155,7 +175,14 @@ def main(): if args.framework == "tensorflow": assert args.model is not None, "--model should be defined while translating tensorflow model" - tf2paddle(args.model, args.save_dir) + without_data_format_optimization = False + define_input_shape = False + if args.without_data_format_optimization: + without_data_format_optimization = True + if args.define_input_shape: + define_input_shape = True + tf2paddle(args.model, args.save_dir, without_data_format_optimization, + define_input_shape) elif args.framework == "caffe": assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" diff --git a/x2paddle/core/fluid_code.py b/x2paddle/core/fluid_code.py index c971e5a..42b51f1 100644 --- a/x2paddle/core/fluid_code.py +++ b/x2paddle/core/fluid_code.py @@ -64,11 +64,8 @@ class Layer(object): else: layer_code = layer_code + key + "={}, ".format( input.layer_name) - elif isinstance(input, str): - layer_code = layer_code + key + "={}, ".format(input) else: - raise Exception( - "Element of inputs should GraphNode or String") + layer_code = layer_code + key + "={}, ".format(input) elif isinstance(self.inputs, GraphNode): if hasattr(self.inputs, "index"): layer_code += (self.inputs.layer_name + diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index 47acda7..bd8db4a 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -39,7 +39,7 @@ class TFGraphNode(GraphNode): self.pd_data_format = "NCHW" self.fluid_code = FluidCode() - self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"} + self.dtype_map = {1: "float32", 3: "int32", 4: "uint8", 9: "int64"} @property def out_shapes(self): @@ -52,7 +52,11 @@ class TFGraphNode(GraphNode): @property def dtype(self): - dtype = self.layer.attr["dtype"].type + keys = ['dtype', 'Tidx', 'T'] + for k in keys: + dtype = self.layer.attr[k].type + if dtype > 0: + break if dtype not in self.dtype_map: raise Exception("Dtype[{}] not in dtype_map".format(dtype)) return self.dtype_map[dtype] @@ -198,9 +202,10 @@ class TFGraph(Graph): class TFDecoder(object): - def __init__(self, pb_model, data_format="NHWC"): + def __init__(self, pb_model, data_format="NHWC", define_input_shape=False): self.sess = tf.Session() self.input_info = dict() + self.define_input_shape = define_input_shape with gfile.FastGFile(pb_model, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) @@ -229,10 +234,15 @@ class TFDecoder(object): if layer.op != "Placeholder": continue graph_node = TFGraphNode(layer) - dtype = graph_node.dtype + dtype = graph_node.layer.attr['dtype'].type + print("========dtype", dtype) need_define_shape = 0 - if not graph_node.get_attr("shape"): + if self.define_input_shape: + need_define_shape = 3 + elif graph_node.layer.attr[ + 'shape'].shape.unknown_rank or not graph_node.get_attr( + "shape"): need_define_shape = 1 else: value = graph_node.layer.attr["shape"].shape @@ -241,13 +251,21 @@ class TFDecoder(object): need_define_shape = 2 if need_define_shape > 0: + shape = None + if graph_node.get_attr("shape"): + value = value = graph_node.layer.attr["shape"].shape + shape = [dim.size for dim in value.dim] if need_define_shape == 1: print("Unknown shape for input tensor[tensor name: \"{}\"]". format(layer.name)) - else: + elif need_define_shape == 2: print( "\nShape[now is {}] for input tensor[tensor name: \"{}\"] not support yet" .format(shape, layer.name)) + else: + print( + "Define shape[now is {}] for input tensor[tensor name: \"{}\']" + .format(shape, layer.name)) print( "Use your keyboard type the shape of input tensor below :)") @@ -264,12 +282,14 @@ class TFDecoder(object): for dim in shape.strip().split(',') ] assert shape.count(None) <= 1, "Only one dimension can be None" + print("]]]]]]]]]dtype", dtype) x2paddle_input = tf.placeholder(dtype=dtype, shape=shape, name="x2paddle_{}".format( layer.name)) input_map["{}:0".format(layer.name)] = x2paddle_input - shape[shape.index(None)] = -1 + if shape.count(None) > 0: + shape[shape.index(None)] = -1 self.input_info["x2paddle_{}".format(layer.name)] = (shape, dtype) else: diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 87dfccf..30a3fcd 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -57,7 +57,10 @@ class TFOpMapper(OpMapper): 'Sigmoid': ['sigmoid'], 'Exp': ['exp'], 'Rsqrt': ['rsqrt'], - 'swish_f32': ['swish'] + 'swish_f32': ['swish'], + 'LeakyRelu': ['leaky_relu', { + 'alpha': 'alpha' + }] } elementwise_ops = { 'Add': 'elementwise_add', @@ -639,14 +642,20 @@ class TFOpMapper(OpMapper): def Tile(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) expand_times = self.graph.get_node(node.layer.input[1], copy=True) - assert expand_times.layer_type == "Const" self.omit_nodes.append(expand_times.layer_name) - expand_times = expand_times.value.tolist() + if expand_times.layer_type == "Const": + expand_times = expand_times.value.tolist() + else: + expand_times = self.decoder.infer_shape_tensor(expand_times) if input.tf_data_format == "NHWC": if len(input.out_shapes[0]) == 4: expand_times = [expand_times[i] for i in [0, 3, 1, 2]] elif len(input.out_shape[0]) == 3: expand_times = [expand_times[i] for i in [2, 0, 1]] + for i in range(len(expand_times)): + if expand_times[i] < 0: + expand_times[i] = 1 + attr = {"expand_times": expand_times} node.fluid_code.add_layer("expand", inputs=input, @@ -699,20 +708,27 @@ class TFOpMapper(OpMapper): limit = self.graph.get_node(node.layer.input[1], copy=True) delta = self.graph.get_node(node.layer.input[2], copy=True) if start.layer_type == "Const": - self.omit_nodes.append(start.layer_name) start = start.value + else: + start = self.decoder.infer_tensor(start) if limit.layer_type == "Const": - self.omit_nodes.append(limit.layer_name) limit = limit.value + else: + limit = self.decoder.infer_tensor(limit) if delta.layer_type == "Const": - self.omit_nodes.append(delta.layer_name) delta = delta.value + else: + delta = self.decoder.infer_tensor(delta) + self.omit_nodes.append(start.layer_name) + self.omit_nodes.append(limit.layer_name) + limit = self.decoder.infer_tensor(limit) + inputs = {"start": start, "end": limit, "step": delta} attr = {"dtype": string(node.dtype)} - node.fluid_code.append("range", - inputs=inputs, - output=node, - param_attr=None) + node.fluid_code.add_layer("range", + inputs=inputs, + output=node, + param_attr=None) def Mean(self, node): input = self.graph.get_node(node.layer.input[0], copy=True) @@ -1011,3 +1027,39 @@ class TFOpMapper(OpMapper): inputs=input, output=node, param_attr=attr) + + def ResizeNearestNeighbor(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor( + resize_shape, node.out_shapes[0]) + align_corners = node.get_attr("align_corners") + attr = {"align_corners": align_corners, "out_shape": resize_shape} + node.fluid_code.add_layer("resize_nearest", + inputs=input, + output=node, + param_attr=attr) + + def ResizeBilinear(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor( + resize_shape, node.out_shapes[0]) + align_corners = node.get_attr("align_corners") + attr = { + "align_corners": align_corners, + "out_shape": resize_shape, + "align_mode": 1 + } + node.fluid_code.add_layer("resize_bilinear", + inputs=input, + output=node, + param_attr=attr) diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py new file mode 100644 index 0000000..38b790d --- /dev/null +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -0,0 +1,1058 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.decoder.tf_decoder import TFGraph +from x2paddle.core.op_mapper import OpMapper +from x2paddle.core.util import * +import inspect +import numpy +import sys + + +# compute padding size for SAME mode +def get_same_padding(in_size, kernel_size, stride): + new_size = int(math.ceil(in_size * 1.0 / stride)) + pad_size = (new_size - 1) * stride + kernel_size - in_size + pad0 = int(pad_size / 2) + pad1 = pad_size - pad0 + return [pad0, pad1] + + +class TFOpMapperNHWC(OpMapper): + directly_map_ops = { + 'Relu': ['relu'], + 'Relu6': ['relu6'], + 'Shape': ['shape'], + 'Abs': ['abs'], + 'Sigmoid': ['sigmoid'], + 'Exp': ['exp'], + 'Rsqrt': ['rsqrt'], + 'swish_f32': ['swish'], + 'LeakyRelu': ['leaky_relu', { + 'alpha': 'alpha' + }] + } + elementwise_ops = { + 'Add': 'elementwise_add', + 'RealDiv': 'elementwise_div', + 'Sub': 'elementwise_sub', + 'Maximum': 'elementwise_max', + 'Mul': 'elementwise_mul' + } + + def __init__(self, decoder): + super(TFOpMapperNHWC, self).__init__() + self.decoder = decoder + self.graph = decoder.tf_graph + self.weights = dict() + self.omit_nodes = list() + self.used_custom_layers = dict() + + not_placeholder = list() + for name in self.graph.input_nodes: + if self.graph.get_node(name).layer_type != "Placeholder": + not_placeholder.append(name) + for name in not_placeholder: + idx = self.graph.input_nodes.index(name) + del self.graph.input_nodes[idx] + + unsupported_ops = set() + print("Total nodes: {}".format(len(self.graph.topo_sort))) + for node_name in self.graph.topo_sort: + node = self.graph.get_node(node_name) + op = node.layer_type + if op in self.directly_map_ops: + if len(unsupported_ops) > 0: + continue + self.directly_map(node) + elif op in self.elementwise_ops: + if len(unsupported_ops) > 0: + continue + self.elementwise_map(node) + elif hasattr(self, op): + if len(unsupported_ops) > 0: + continue + func = getattr(self, op) + func(node) + else: + unsupported_ops.add(op) + continue + if len(unsupported_ops) > 0: + print("========= {} OPs are not supported yet ===========".format( + len(unsupported_ops))) + for op in unsupported_ops: + print("========== {} ============".format(op)) + sys.exit(-1) + + def directly_map(self, node): + assert node.layer_type in self.directly_map_ops + op_info = self.directly_map_ops[node.layer_type] + input = self.graph.get_node(node.layer.input[0], copy=True) + attr = dict() + for param in op_info[1:]: + tf_param_name = list(param.keys())[0] + pd_param_name = list(param.values())[0] + tf_param = node.get_attr(tf_param_name) + attr[pd_param_name] = tf_param + node.fluid_code.add_layer(op_info[0], + inputs=input, + output=node, + param_attr=attr) + + def elementwise_map(self, node): + assert node.layer_type in self.elementwise_ops + op_type = self.elementwise_ops[node.layer_type] + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + x_shape = x.out_shapes[0] + y_shape = y.out_shapes[0] + # incomplement broadcasting support for paddle + x_input = x + y_input = y + if len(x_shape) < len(y_shape): + unrevertable_ops = [ + "elementwise_sub", "elementwise_div", "elementwise_floordiv", + "elementwise_mod", "elementwise_pow" + ] + if op_type not in unrevertable_ops: + x_input = y + y_input = x + x_shape = y.out_shapes[0] + y_shape = x.out_shapes[0] + else: + raise Exception("Unexpected situation happend") + + if len(x_shape) == 4 and len(y_shape) == 1: + inputs = {"x": x_input, "y": y_input} + node.fluid_code.add_layer(op_type, inputs=inputs, output=node) + return + + is_sub_seq = True + for i in range(len(y_shape)): + index = -1 * i - 1 + if y_shape[index] != x_shape[index]: + is_sub_seq = False + if not is_sub_seq: + x_expand_times = [1] * len(x_shape) + y_expand_times = [1] * len(y_shape) + x_need_expand = False + y_need_expand = False + for i in range(len(y_shape)): + index = -1 * i - 1 + if y_shape[index] != x_shape[index]: + if y_shape[index] == 1: + y_expand_times[index] = x_shape[index] + y_need_expand = True + elif x_shape[index] == 1: + x_expand_times[index] = y_shape[index] + x_need_expand = True + else: + raise Exception("Unexpected situation happend") + if x_need_expand: + attr = {"expand_times": x_expand_times} + node.fluid_code.add_layer("expand", + inputs=x_input, + output="x_tmp", + param_attr=attr) + x_input = "x_tmp" + if y_need_expand: + attr = {"expand_times": y_expand_times} + node.fluid_code.add_layer("expand", + inputs=y_input, + output="y_tmp", + param_attr=attr) + y_input = "y_tmp" + inputs = {"x": x_input, "y": y_input} + node.fluid_code.add_layer(op_type, + inputs=inputs, + output=node, + param_attr=None) + + def Placeholder(self, node): + shape = node.out_shapes[0] + assert len(shape) != 0, "Unknown shape of input nodes[{}].".format( + node.layer_name) + dtype = node.dtype + attr = { + 'dtype': string(dtype), + 'shape': shape, + 'name': string(node.layer_name), + 'append_batch_size': False + } + node.fluid_code.add_layer("data", + inputs=None, + output=node, + param_attr=attr) + + def Const(self, node): + shape = node.out_shapes[0] + dtype = node.dtype + value = node.value + initializer = "Constant(0.0)" + if len(shape) == 0: + assert value.size == 1, "Unexpected situation happend" + shape = [1] + initializer = "Constant({})".format(value) + + self.weights[node.layer_name] = node.value + + attr = { + 'dtype': string(dtype), + 'shape': shape, + 'name': string(node.layer_name), + 'default_initializer': initializer + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node, + param_attr=attr) + + def Transpose(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + perm = self.graph.get_node(node.layer.input[1], copy=True) + assert perm.layer_type == "Const", "Perm of transpose OP should be Const" + del self.weights[perm.layer_name.replace('/', '_')] + perm.fluid_code.clear() + perm = perm.value.tolist() + + attr = {'perm': perm} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + + def MaxPool(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + + in_shape = input.out_shapes[0] + if in_shape.count(-1) > 2: + in_shape = self.decoder.infer_tensor(input).shape + + k_size = node.get_attr("ksize") + strides = node.get_attr("strides") + data_format = node.get_attr("data_format").decode() + pad_mode = node.get_attr("padding").decode() + channel_first = data_format == "NCHW" + + if not channel_first: + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + in_shape = [in_shape[i] for i in [0, 3, 1, 2]] + strides = [strides[i] for i in [0, 3, 1, 2]] + k_size = [k_size[i] for i in [0, 3, 1, 2]] + input = node + + if pad_mode == "SAME": + pad_h = get_same_padding(in_shape[2], k_size[2], strides[2]) + pad_w = get_same_padding(in_shape[3], k_size[3], strides[3]) + pad_h = pad_h[0] + pad_h[1] + pad_w = pad_w[0] + pad_w[1] + attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0} + if pad_h + pad_w != 0: + node.fluid_code.add_layer("pad2d", + inputs=input, + output=node, + param_attr=attr) + input = node + attr = { + "pool_size": k_size[2:4], + "pool_type": string("max"), + "pool_stride": strides[2:4] + } + node.fluid_code.add_layer("pool2d", + inputs=input, + output=node, + param_attr=attr) + + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def Conv2D(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + kernel = self.graph.get_node(node.layer.input[1], copy=True) + assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const" + self.omit_nodes.append(kernel.layer_name) + + node.fluid_code.add_note("#{} : {}".format(node.layer.name, + node.layer_name)) + + in_shape = input.out_shapes[0] + if in_shape.count(-1) > 2: + in_shape = self.decoder.infer_tensor(input).shape + k_size = kernel.out_shapes[0] + if k_size.count(-1) > 2: + k_size = self.decoder.infer_tensor(kernel).shape + + strides = node.get_attr("strides") + dilations = node.get_attr("dilations") + data_format = node.get_attr("data_format").decode() + pad_mode = node.get_attr("padding").decode() + channel_first = data_format == "NCHW" + padding = 0 + + self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( + kernel.value, (3, 2, 0, 1)) + + if not channel_first: + in_shape = [in_shape[i] for i in [0, 3, 1, 2]] + strides = [strides[i] for i in [0, 3, 1, 2]] + dilations = [dilations[i] for i in [0, 3, 1, 2]] + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + + if pad_mode == "SAME": + pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) + pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) + if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]: + padding = [pad_h[0], pad_w[0]] + else: + attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + node.fluid_code.add_layer("pad2d", + inputs=input, + output=node, + param_attr=attr) + input = node + attr = { + "bias_attr": False, + "param_attr": string(kernel.layer_name), + "num_filters": k_size[3], + "filter_size": k_size[0:2], + "stride": strides[2:4], + "dilation": dilations[2:4], + "padding": padding + } + node.fluid_code.add_layer("conv2d", + inputs=input, + output=node, + param_attr=attr) + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def BiasAdd(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + bias = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {"x": input, "y": bias} + node.fluid_code.add_layer("elementwise_add", + inputs=inputs, + output=node, + param_attr=None) + + def FusedBatchNorm(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + gamma = self.graph.get_node(node.layer.input[1], copy=True) + beta = self.graph.get_node(node.layer.input[2], copy=True) + moving_mean = self.graph.get_node(node.layer.input[3], copy=True) + moving_var = self.graph.get_node(node.layer.input[4], copy=True) + data_format = node.get_attr("data_format").decode() + channel_first = data_format == "NCHW" + + assert gamma.layer_type == "Const" + assert beta.layer_type == "Const" + assert moving_mean.layer_type == "Const" + assert moving_var.layer_type == "Const" + self.omit_nodes.append(gamma.layer_name) + self.omit_nodes.append(beta.layer_name) + self.omit_nodes.append(moving_mean.layer_name) + self.omit_nodes.append(moving_var.layer_name) + + if not channel_first: + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + + attr = { + "epsilon": node.get_attr("epsilon"), + "param_attr": string(gamma.layer_name), + "bias_attr": string(beta.layer_name), + "moving_mean_name": string(moving_mean.layer_name), + "moving_variance_name": string(moving_var.layer_name), + "is_test": True + } + + node.fluid_code.add_layer("batch_norm", + inputs=input, + output=node, + param_attr=attr) + + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def DepthwiseConv2dNative(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + kernel = self.graph.get_node(node.layer.input[1], copy=True) + assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const" + self.omit_nodes.append(kernel.layer_name) + + in_shape = input.out_shapes[0] + if in_shape.count(-1) > 2: + in_shape = self.decoder.infer_tensor(input).shape + k_size = kernel.out_shapes[0] + if k_size.count(-1) > 2: + k_size = self.decoder.infer_tensor(kernel).shape + + strides = node.get_attr("strides") + dilations = node.get_attr("dilations") + data_format = node.get_attr("data_format").decode() + pad_mode = node.get_attr("padding").decode() + channel_first = data_format == "NCHW" + padding = 0 + + self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( + kernel.value, (2, 3, 0, 1)) + + if not channel_first: + in_shape = [in_shape[i] for i in [0, 3, 1, 2]] + strides = [strides[i] for i in [0, 3, 1, 2]] + dilations = [dilations[i] for i in [0, 3, 1, 2]] + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + + if pad_mode == "SAME": + pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) + pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) + if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]: + padding = [pad_h[0], pad_w[0]] + else: + attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + node.fluid_code.add_layer("pad2d", + inputs=input, + output=node, + param_attr=attr) + input = node + + attr = { + "bias_attr": False, + "param_attr": string(kernel.layer_name), + "num_filters": in_shape[1], + "filter_size": k_size[0:2], + "stride": strides[2:4], + "dilation": dilations[2:4], + "groups": k_size[3] * in_shape[1], + "use_cudnn": False, + "padding": padding + } + node.fluid_code.add_layer("conv2d", + inputs=input, + output=node, + param_attr=attr) + + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def Reshape(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + param = self.graph.get_node(node.layer.input[1], copy=True) + if param.layer_type == "Const": + attr = {"shape": param.value.tolist()} + self.omit_nodes.append(param.layer_name) + else: + # Here is a trick method to solove tensor parameter in tensorflow + shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0]) + if shape.count(-1) <= 1: + attr = {"shape": shape} + self.omit_nodes.append(param.layer_name) + else: + assert len(param.out_shapes[0] + ) == 1, "Unexpected situation of shape parameter" + attr = {"shape": [-1]} + node.fluid_code.add_layer("reshape", + inputs=param, + output="shape_param", + param_attr=attr) + attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0} + node.fluid_code.add_layer("split", + inputs="shape_param", + output=node, + param_attr=attr) + new_param = "[" + for i in range(param.out_shapes[0][0]): + new_param += (node.layer_name + "[{}]".format(i) + ", ") + new_param = new_param.strip(", ") + "]" + attr = {"shape": new_param} + node.fluid_code.add_layer("reshape", + inputs=input, + output=node, + param_attr=attr) + + def AvgPool(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + + in_shape = input.out_shapes[0] + if in_shape.count(-1) > 2: + in_shape = self.decoder.infer_tensor(input).shape + + k_size = node.get_attr("ksize") + strides = node.get_attr("strides") + data_format = node.get_attr("data_format").decode() + pad_mode = node.get_attr("padding").decode() + channel_first = data_format == "NCHW" + + if not channel_first: + in_shape = [in_shape[i] for i in [0, 3, 1, 2]] + strides = [strides[i] for i in [0, 3, 1, 2]] + k_size = [k_size[i] for i in [0, 3, 1, 2]] + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + + attr = { + "pool_size": k_size[2:4], + "pool_type": string("avg"), + "pool_stride": strides[2:4] + } + if pad_mode == "SAME": + pad_h = get_same_padding(in_shape[2], k_size[2], strides[2]) + pad_w = get_same_padding(in_shape[3], k_size[3], strides[3]) + assert pad_h[0] == pad_h[1] and pad_w[0] == pad_w[ + 1], "Cannot map AvgPool" + attr["pool_padding"] = [pad_h[0], pad_w[0]] + node.fluid_code.add_layer("pool2d", + inputs=input, + output=node, + param_attr=attr) + + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def SplitV(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + num_sections = self.graph.get_node(node.layer.input[1], copy=True) + dim = self.graph.get_node(node.layer.input[2], copy=True) + assert num_sections.layer_type == "Const" + assert dim.layer_type == "Const" + self.omit_nodes.append(num_sections.layer_name) + self.omit_nodes.append(dim.layer_name) + dim = dim.value + attr = { + "num_or_sections": num_sections.value.tolist(), + "dim": dim.value + } + node.fluid_code.add_layer("split", + inputs=input, + output=node, + param_attr=attr) + + def ConcatV2(self, node): + inputs = [ + self.graph.get_node(name, copy=True) + for name in node.layer.input[:-1] + ] + axis = self.graph.get_node(node.layer.input[-1], copy=True) + assert axis.layer_type == "Const" + self.omit_nodes.append(axis.layer_name) + axis = axis.value + if axis < 0: + axis += len(inputs[0].out_shapes[0]) + + attr = {"axis": axis} + node.fluid_code.add_layer("concat", + inputs=inputs, + output=node, + param_attr=attr) + + def Tile(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + expand_times = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(expand_times.layer_name) + if expand_times.layer_type == "Const": + expand_times = expand_times.value.tolist() + else: + expand_times = self.decoder.infer_shape_tensor(expand_times) + for i in range(len(expand_times)): + if expand_times[i] < 0: + expand_times[i] = 1 + attr = {"expand_times": expand_times} + node.fluid_code.add_layer("expand", + inputs=input, + output=node, + param_attr=attr) + + def Pack(self, node): + inputs = [ + self.graph.get_node(name, copy=True) for name in node.layer.input + ] + axis = node.get_attr("axis") + attr = {"axis": axis} + node.fluid_code.add_layer("stack", + inputs=inputs, + output=node, + param_attr=attr) + + def Pad(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + paddings = self.graph.get_node(node.layer.input[1], copy=True) + assert paddings.layer_type == "Const", "Padding should be Const" + self.omit_nodes.append(paddings.layer_name) + paddings = paddings.value.flatten().tolist() + data_format = input.tf_data_format + + if len(input.out_shapes[0]) == 4: + new_padding = None + if input.tf_data_format == "NHWC": + if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0: + new_padding = paddings[2:6] + else: + if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0: + new_padding = paddings[4:] + if new_padding is not None: + if input.tf_data_format == "NHWC": + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + attr = {"paddings": new_padding} + node.fluid_code.add_layer("pad2d", + inputs=input, + output=node, + param_attr=attr) + if input.tf_data_format == "NHWC": + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + return + + attr = {"paddings": paddings} + node.fluid_code.add_layer("pad", + inputs=input, + output=node, + param_attr=attr) + + def Range(self, node): + start = self.graph.get_node(node.layer.input[0], copy=True) + limit = self.graph.get_node(node.layer.input[1], copy=True) + delta = self.graph.get_node(node.layer.input[2], copy=True) + self.omit_nodes.append(start.layer_name) + self.omit_nodes.append(limit.layer_name) + self.omit_nodes.append(delta.layer_name) + if start.layer_type == "Const": + start = start.value + else: + start = self.decoder.infer_tensor(start) + if limit.layer_type == "Const": + limit = limit.value + else: + limit = self.decoder.infer_tensor(limit) + if delta.layer_type == "Const": + delta = delta.value + else: + delta = self.decoder.infer_tensor(delta) + dtype = node.dtype + inputs = { + "start": start, + "end": limit, + "step": delta, + "dtype": string(dtype) + } + attr = {"dtype": string(node.dtype)} + node.fluid_code.add_layer("range", + inputs=inputs, + output=node, + param_attr=None) + + def Mean(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) + assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" + dims = reduce_idx.value.tolist() + keep_dims = node.get_attr("keep_dims") + + attr = {"dim": dims, "keep_dim": keep_dims} + node.fluid_code.add_layer("reduce_mean", + inputs=input, + output=node, + param_attr=attr) + + def MatMul(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + transpose_a = node.get_attr('transpose_a') + transpose_b = node.get_attr('transpose_b') + inputs = {"x": x, "y": y} + # fix paddle shape infer problem + # should be removed after paddle 1.6 + if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0: + shape = x.out_shapes[0] + shape[-1] = y.out_shapes[0][0] + attr = {"shape": shape} + node.fluid_code.add_layer("reshape", + inputs=x, + output=x, + param_attr=attr) + attr = {"transpose_x": transpose_a, "transpose_y": transpose_b} + node.fluid_code.add_layer("matmul", + inputs=inputs, + output=node, + param_attr=attr) + + def ArgMax(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + axis = self.graph.get_node(node.layer.input[1], copy=True) + assert axis.layer_type == "Const", "ArgMax only support Const parameter" + self.omit_nodes.append(axis.layer_name) + axis = axis.value + attr = {"axis": axis} + node.fluid_code.add_layer("argmax", + inputs=input, + output=node, + param_attr=attr) + + def StridedSlice(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + begin = self.graph.get_node(node.layer.input[1], copy=True) + end = self.graph.get_node(node.layer.input[2], copy=True) + strides = self.graph.get_node(node.layer.input[3], copy=True) + assert begin.layer_type == "Const" + assert end.layer_type == "Const" + assert strides.layer_type == "Const" + self.omit_nodes.append(begin.layer_name) + self.omit_nodes.append(end.layer_name) + self.omit_nodes.append(strides.layer_name) + strides = strides.value.tolist() + assert len(set(strides)) == 1 and strides[ + 0] == 1, "Only support strides be 1 in StridedSlice OP" + + begin = begin.value.tolist() + end = end.value.tolist() + + for i in range(len(end)): + if end[i] == 0: + end[i] = 999999 + + begin_mask = node.get_attr('begin_mask') + end_mask = node.get_attr('end_mask') + ellipsis_mask = node.get_attr('ellipsis_mask') + new_axis_mask = node.get_attr('new_axis_mask') + shrink_axis_mask = node.get_attr('shrink_axis_mask') + + assert ellipsis_mask == 0, "(OP:{} Name:{})Only support ellipsis_mask be 0[now: {}] n StridedSlice OP".format( + node.layer_type, node.layer.name, ellipsis_mask) + + # TODO codes without validation + # Use it carefully + new_begin = list() + new_end = list() + new_axes = list() + shrink_axes = list() + for i, item in enumerate(begin): + mask = (new_axis_mask >> i) & 1 + if mask != 0: + new_axes.append(i) + continue + + mask = (shrink_axis_mask >> i) & 1 + if mask != 0: + shrink_axes.append(i) + + mask = (begin_mask >> i) & 1 + if mask != 0: + new_begin.append(0) + else: + new_begin.append(item) + + mask = (end_mask >> i) & 1 + if mask != 0: + new_end.append(999999) + else: + new_end.append(end[i]) + + attr = { + "axes": [i for i in range(len(new_begin))], + "starts": new_begin, + "ends": new_end + } + node.fluid_code.add_layer("slice", + inputs=input, + output=node, + param_attr=attr) + print(node.layer.name) + if len(new_axes) > 0: + attr = {"axes": new_axes} + node.fluid_code.add_layer("unsqueeze", + inputs=node, + output=node, + param_attr=attr) + if len(shrink_axes) > 0: + if len(input.out_shapes[0]) + len(new_axes) <= 1: + pass + else: + attr = {"axes": shrink_axes} + node.fluid_code.add_layer("squeeze", + inputs=node, + output=node, + param_attr=attr) + + def Slice(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + begin = self.graph.get_node(node.layer.input[1], copy=True) + size = self.graph.get_node(node.layer.input[2], copy=True) + # assert begin.layer_type == "Const" + # assert size.layer_type == "Const" + self.omit_nodes.append(begin.layer_name) + self.omit_nodes.append(size.layer_name) + if begin.layer_type == "Const": + begin = begin.value.tolist() + else: + begin = self.decoder.infer_tensor(begin).tolist() + if size.layer_type == "const": + size = size.value.tolist() + else: + size = self.decoder.infer_tensor(size).tolist() + + attr = {"shape": size, "offsets": begin} + node.fluid_code.add_layer("crop", + inputs=input, + output=node, + param_attr=attr) + + def Conv2DBackpropInput(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + kernel = self.graph.get_node(node.layer.input[1], copy=True) + assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" + self.omit_nodes.append(kernel.layer_name) + + in_shape = input.out_shapes[0] + if in_shape.count(-1) > 2: + in_shape = self.decoder.infer_tensor(input).shape + k_size = kernel.out_shapes[0] + if k_size.count(-1) > 2: + k_size = self.decoder.infer_tensor(kernel).shape + + strides = node.get_attr("strides") + dilations = node.get_attr("dilations") + data_format = node.get_attr("data_format").decode() + pad_mode = node.get_attr("padding").decode() + channel_first = data_format == "NCHW" + self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( + kernel.value, (3, 2, 0, 1)) + + if not channel_first: + in_shape = [in_shape[i] for i in [0, 3, 1, 2]] + strides = [strides[i] for i in [0, 3, 1, 2]] + dilations = [dilations[i] for i in [0, 3, 1, 2]] + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + input = node + padding = 0 + if pad_mode == "SAME": + pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) + pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) + if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]: + padding = [pad_h[0], pad_w[0]] + else: + attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} + node.fluid_code.add_layer("pad2d", + inputs=input, + output=node, + param_attr=attr) + input = node + attr = { + "bias_attr": False, + "param_attr": string(kernel.layer_name), + "num_filters": k_size[3], + "filter_size": k_size[0:2], + "stride": strides[2:4], + "dilation": dilations[2:4], + "padding": padding + } + node.fluid_code.add_layer("conv2d_transpose", + inputs=input, + output=node, + param_attr=attr) + if not channel_first: + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def Max(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) + assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" + keep_dims = node.get_attr("keep_dims") + dim = reduce_idx.value.tolist() + + attr = {"dim": dim, "keep_dim": keep_dims} + node.fluid_code.add_layer("reduce_max", + inputs=input, + output=node, + param_attr=attr) + + def Sum(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) + assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" + keep_dims = node.get_attr("keep_dims") + dim = reduce_idx.value.tolist() + + attr = {"dim": dim, "keep_dim": keep_dims} + node.fluid_code.add_layer("reduce_sum", + inputs=input, + output=node, + param_attr=attr) + + def Cast(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + dtype = node.dtype_map[node.get_attr('DstT')] + attr = {"dtype": string(dtype)} + node.fluid_code.add_layer("cast", + inputs=input, + output=node, + param_attr=attr) + + def FloorDiv(self, node): + x = self.graph.get_node(node.layer.input[0], copy=True) + y = self.graph.get_node(node.layer.input[1], copy=True) + inputs = {'x': x, 'y': y} + node.fluid_code.add_layer("elementwise_div", + inputs=inputs, + output=node, + param_attr=None) + node.fluid_code.add_layer("floor", + inputs=node, + output=node, + param_attr=None) + + def Split(self, node): + dim = self.graph.get_node(node.layer.input[0], copy=True) + input = self.graph.get_node(node.layer.input[1], copy=True) + assert dim.layer_type == "Const" + self.omit_nodes.append(dim.layer_name) + num_split = node.get_attr('num_split') + dim = dim.value + + attr = {"num_or_sections": num_split, "dim": dim} + node.fluid_code.add_layer("split", + inputs=input, + output=node, + param_attr=attr) + + def Squeeze(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + squeeze_dims = node.get_attr('squeeze_dims') + attr = {"axes": squeeze_dims} + node.fluid_code.add_layer("squeeze", + inputs=input, + output=node, + param_attr=attr) + + def Softmax(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + axis = node.get_attr("axis") + attr = {"axis": axis} + node.fluid_code.add_layer("softmax", + inputs=input, + output=node, + param_attr=attr) + + def ResizeNearestNeighbor(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor( + resize_shape, node.out_shapes[0]) + align_corners = node.get_attr("align_corners") + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + attr = {"align_corners": align_corners, "out_shape": resize_shape} + node.fluid_code.add_layer("resize_nearest", + inputs=node, + output=node, + param_attr=attr) + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) + + def ResizeBilinear(self, node): + input = self.graph.get_node(node.layer.input[0], copy=True) + resize_shape = self.graph.get_node(node.layer.input[1], copy=True) + self.omit_nodes.append(resize_shape.layer_name) + if resize_shape.layer_type == "Const": + resize_shape = resize_shape.value.tolist() + else: + resize_shape = self.decoder.infer_shape_tensor( + resize_shape, node.out_shapes[0]) + align_corners = node.get_attr("align_corners") + attr = {"perm": [0, 3, 1, 2]} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + attr = { + "align_corners": align_corners, + "out_shape": resize_shape, + "align_mode": 1 + } + node.fluid_code.add_layer("resize_bilinear", + inputs=node, + output=node, + param_attr=attr) + attr = {"perm": [0, 2, 3, 1]} + node.fluid_code.add_layer("transpose", + inputs=node, + output=node, + param_attr=attr) -- GitLab