tf_op_mapper.py 44.4 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

J
jiangjiajun 已提交
15 16
from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
J
jiangjiajun 已提交
17
from x2paddle.core.util import *
J
jiangjiajun 已提交
18
import inspect
J
jiangjiajun 已提交
19
import numpy
J
jiangjiajun 已提交
20
import sys
21

J
jiangjiajun 已提交
22

J
jiangjiajun 已提交
23 24 25 26
# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
27 28
    if pad_size < 0:
        pad_size = 0
J
jiangjiajun 已提交
29 30 31 32
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]

J
jiangjiajun 已提交
33

J
jiangjiajun 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def nhwc_dim_to_nchw(node, dim):
    tf_data_format = list(node.tf_data_format)
    pd_data_format = list(node.pd_data_format)
    if isinstance(dim, list):
        for i in range(len(dim)):
            char = tf_data_format[dim[i]]
            dim[i] = pd_data_format.index(char)
    else:
        char = tf_data_format[dim]
        dim = pd_data_format.index(char)
    return dim

    if dim < 0:
        dim += 4
    if dim > 0:
        dim = (dim + 1) % 4 + int((dim + 1) / 4)
    return dim


J
jiangjiajun 已提交
53
class TFOpMapper(OpMapper):
J
jiangjiajun 已提交
54 55 56 57 58 59 60
    directly_map_ops = {
        'Relu': ['relu'],
        'Relu6': ['relu6'],
        'Shape': ['shape'],
        'Abs': ['abs'],
        'Sigmoid': ['sigmoid'],
        'Exp': ['exp'],
J
jiangjiajun 已提交
61
        'Rsqrt': ['rsqrt'],
62
        'swish_f32': ['swish'],
J
jiangjiajun 已提交
63
        'Tanh': ['tanh'],
64 65 66
        'LeakyRelu': ['leaky_relu', {
            'alpha': 'alpha'
        }]
J
jiangjiajun 已提交
67 68 69 70 71 72
    }
    elementwise_ops = {
        'Add': 'elementwise_add',
        'RealDiv': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Maximum': 'elementwise_max',
73 74
        'Mul': 'elementwise_mul',
        'FloorDiv': 'elementwise_floordiv'
J
jiangjiajun 已提交
75 76
    }

J
jiangjiajun 已提交
77 78
    def __init__(self, decoder):
        super(TFOpMapper, self).__init__()
J
jiangjiajun 已提交
79
        self.decoder = decoder
J
jiangjiajun 已提交
80
        self.graph = decoder.tf_graph
81
        self.batch_node = None
J
jiangjiajun 已提交
82
        self.weights = dict()
J
jiangjiajun 已提交
83
        self.omit_nodes = list()
J
jiangjiajun 已提交
84
        self.used_custom_layers = dict()
85

J
jiangjiajun 已提交
86 87
        not_placeholder = list()
        for name in self.graph.input_nodes:
88
            if self.graph.get_node(name).layer_type != "Placeholder" \
Q
qili93 已提交
89
               and self.graph.get_node(name).layer_type != "OneShotIterator":
J
jiangjiajun 已提交
90 91 92 93
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]
J
jiangjiajun 已提交
94

95
        sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
J
jiangjiajun 已提交
96
        unsupported_ops = set()
97 98
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...    ".format(i + 1))
99 100
            node = self.graph.get_node(node_name)
            op = node.layer_type
J
jiangjiajun 已提交
101
            if op in self.directly_map_ops:
J
jiangjiajun 已提交
102 103
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
104 105
                self.directly_map(node)
            elif op in self.elementwise_ops:
J
jiangjiajun 已提交
106 107
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
108 109
                self.elementwise_map(node)
            elif hasattr(self, op):
J
jiangjiajun 已提交
110 111
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
112 113
                func = getattr(self, op)
                func(node)
J
jiangjiajun 已提交
114
            else:
J
jiangjiajun 已提交
115 116
                unsupported_ops.add(op)
        if len(unsupported_ops) > 0:
J
jiangjiajun 已提交
117 118
            sys.stderr.write("=========={} Ops are not supported yet======\n".
                             format(len(unsupported_ops)))
J
jiangjiajun 已提交
119
            for op in unsupported_ops:
120
                sys.stderr.write("========== {} ==========\n".format(op))
J
jiangjiajun 已提交
121
            sys.exit(-1)
122
        sys.stderr.write('\nDone!\n')
123

J
jiangjiajun 已提交
124 125 126 127
    def add_omit_nodes(self, in_node_name, out_node_name):
        in_node = self.graph.get_node(in_node_name)
        out_node = self.graph.get_node(out_node_name)
        index = in_node.outputs.index(out_node_name)
J
jiangjiajun 已提交
128
        #        del in_node.outputs[index]
J
jiangjiajun 已提交
129
        index = out_node.inputs.index(in_node_name)
J
jiangjiajun 已提交
130
        #        del out_node.inputs[index]
J
jiangjiajun 已提交
131 132
        self.omit_nodes.append(in_node.layer_name)

J
jiangjiajun 已提交
133 134 135 136 137 138 139 140 141 142
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_node(node.layer.input[0], copy=True)
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
J
jiangjiajun 已提交
143 144
        node.fluid_code.add_layer(
            op_info[0], inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
145 146 147 148

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
J
jiangjiajun 已提交
149 150 151 152
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
153 154 155 156
        if len(x_shape) == 0:
            x_shape = [1]
        if len(y_shape) == 0:
            y_shape = [1]
J
jiangjiajun 已提交
157 158 159 160 161 162 163 164 165 166 167 168
        # incomplement broadcasting support for paddle
        x_input = x
        y_input = y
        if len(x_shape) < len(y_shape):
            unrevertable_ops = [
                "elementwise_sub", "elementwise_div", "elementwise_floordiv",
                "elementwise_mod", "elementwise_pow"
            ]
            if op_type not in unrevertable_ops:
                x_input = y
                y_input = x
                x_shape = y.out_shapes[0]
M
modify  
mamingjie-China 已提交
169 170
                if len(x_shape) == 0:
                    x_shape = [1]
J
jiangjiajun 已提交
171
                y_shape = x.out_shapes[0]
M
modify  
mamingjie-China 已提交
172 173
                if len(y_shape) == 0:
                    y_shape = [1]
J
jiangjiajun 已提交
174
            else:
J
jiangjiajun 已提交
175 176 177 178
                if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[
                        0] == y_shape[-1] and y_shape.count(-1) < 1:
                    shape = [1, x_shape[0], 1, 1]
                    attr = {"shape": shape}
J
jiangjiajun 已提交
179 180 181 182 183
                    node.fluid_code.add_layer(
                        "reshape",
                        inputs=x_input,
                        output="reshape_x",
                        param_attr=attr)
J
jiangjiajun 已提交
184 185
                    if y_shape[0] != 1:
                        attr = {"expand_times": [y_shape[0], 1, 1, 1]}
J
jiangjiajun 已提交
186 187 188 189 190
                        node.fluid_code.add_layer(
                            "expand",
                            inputs="reshape_x",
                            output="reshape_x",
                            param_attr=attr)
J
jiangjiajun 已提交
191
                    inputs = {"x": "reshape_x", "y": y_input}
J
jiangjiajun 已提交
192 193
                    node.fluid_code.add_layer(
                        op_type, inputs=inputs, output=node, param_attr=None)
J
jiangjiajun 已提交
194 195 196
                    return
                else:
                    raise Exception("Unexpected situation happend")
J
jiangjiajun 已提交
197

J
jiangjiajun 已提交
198 199 200 201 202 203 204
        if len(x_shape) == 4 and len(y_shape) == 1:
            if x_input.tf_data_format == "NHWC":
                axis = 1
            else:
                axis = -1
            attr = {"axis": axis}
            inputs = {"x": x_input, "y": y_input}
J
jiangjiajun 已提交
205 206
            node.fluid_code.add_layer(
                op_type, inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
207 208
            return

J
jiangjiajun 已提交
209 210 211 212 213 214
        is_sub_seq = True
        for i in range(len(y_shape)):
            index = -1 * i - 1
            if y_shape[index] != x_shape[index]:
                is_sub_seq = False
        if not is_sub_seq:
J
jiangjiajun 已提交
215 216 217 218
            if x_shape.count(-1) > 2:
                x_shape = self.decoder.infer_tensor_shape(x_input)
            if y_shape.count(-1) > 2:
                y_shape = self.decoder.infer_tensor_shape(y_input)
J
jiangjiajun 已提交
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
            x_expand_times = [1] * len(x_shape)
            y_expand_times = [1] * len(y_shape)
            x_need_expand = False
            y_need_expand = False
            for i in range(len(y_shape)):
                index = -1 * i - 1
                if y_shape[index] != x_shape[index]:
                    if y_shape[index] == 1:
                        y_expand_times[index] = x_shape[index]
                        y_need_expand = True
                    elif x_shape[index] == 1:
                        x_expand_times[index] = y_shape[index]
                        x_need_expand = True
                    else:
                        raise Exception("Unexpected situation happend")
            if x_need_expand:
J
jiangjiajun 已提交
235 236 237 238
                if len(x_expand_times) == 3 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [2, 0, 1]]
                if len(x_expand_times) == 4 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
239
                attr = {"expand_times": x_expand_times}
J
jiangjiajun 已提交
240 241
                node.fluid_code.add_layer(
                    "expand", inputs=x_input, output="x_tmp", param_attr=attr)
J
jiangjiajun 已提交
242 243
                x_input = "x_tmp"
            if y_need_expand:
J
jiangjiajun 已提交
244 245 246 247
                if len(y_expand_times) == 3 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [2, 0, 1]]
                if len(y_expand_times) == 4 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
248
                attr = {"expand_times": y_expand_times}
J
jiangjiajun 已提交
249 250
                node.fluid_code.add_layer(
                    "expand", inputs=y_input, output="y_tmp", param_attr=attr)
J
jiangjiajun 已提交
251 252
                y_input = "y_tmp"
        inputs = {"x": x_input, "y": y_input}
J
jiangjiajun 已提交
253 254
        node.fluid_code.add_layer(
            op_type, inputs=inputs, output=node, param_attr=None)
J
jiangjiajun 已提交
255

256 257
    def Placeholder(self, node):
        shape = node.out_shapes[0]
J
jiangjiajun 已提交
258 259
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
J
jiangjiajun 已提交
260 261 262 263
        if node.tf_data_format == "NHWC" and len(shape) == 4:
            shape = [shape[i] for i in [0, 3, 1, 2]]
        elif node.tf_data_format == "NCHW" and len(shape) == 4:
            self.graph.data_format_propagation(node)
264 265
        dtype = node.dtype
        attr = {
J
jiangjiajun 已提交
266
            'dtype': string(dtype),
267
            'shape': shape,
J
jiangjiajun 已提交
268 269
            'name': string(node.layer_name),
            'append_batch_size': False
270
        }
M
mamingjie-China 已提交
271

272 273 274
        if shape[0] < 0:
            self.batch_node = node

J
jiangjiajun 已提交
275 276
        node.fluid_code.add_layer(
            "data", inputs=None, output=node, param_attr=attr)
J
jiangjiajun 已提交
277

J
jiangjiajun@baidu.com 已提交
278 279 280
    def OneShotIterator(self, node):
        return self.Placeholder(node)

J
jiangjiajun 已提交
281 282 283 284 285 286 287 288 289 290
    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

J
jiangjiajun 已提交
291 292 293 294 295 296 297
        self.weights[node.layer_name] = node.value

        if node.tf_data_format == "NHWC":
            if len(shape) == 4:
                shape = [shape[i] for i in [0, 3, 1, 2]]
            if len(shape) == 3:
                shape = [shape[i] for i in [2, 0, 1]]
J
jiangjiajun 已提交
298 299
                self.weights[node.layer_name] = numpy.transpose(node.value,
                                                                (2, 0, 1))
J
jiangjiajun 已提交
300 301 302 303
        elif node.tf_data_format == "NCHW":
            if len(shape) == 4:
                self.graph.data_format_propagation(node)

J
jiangjiajun 已提交
304 305 306 307 308 309
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
J
jiangjiajun 已提交
310 311
        node.fluid_code.add_layer(
            "create_parameter", inputs=None, output=node, param_attr=attr)
J
jiangjiajun 已提交
312 313

    def Transpose(self, node):
J
jiangjiajun 已提交
314 315
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
316
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
317
        del self.weights[perm.layer_name.replace('/', '_')]
J
jiangjiajun 已提交
318 319 320
        perm.fluid_code.clear()
        perm = perm.value.tolist()

J
jiangjiajun 已提交
321
        if perm == [0, 3, 1, 2] and input.data_format == "NHWC":
322 323 324 325 326
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
J
jiangjiajun 已提交
327 328 329
            node.tf_data_format = "NCHW"
            self.graph.data_format_propagation(node)
        elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW":
330 331 332 333 334
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
J
jiangjiajun 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
            node.tf_data_format = "NHWC"
            self.graph.data_format_propagation(node)
        elif len(input.out_shapes[0]) > 4:
            tf_data_format = list(input.tf_data_format)
            pd_data_format = list(input.pd_data_format)
            new_perm = [i for i in range(len(perm))]
            for i in range(len(perm)):
                char0 = tf_data_format[i]
                char1 = tf_data_format[perm[i]]
                index0 = pd_data_format.index(char0)
                index1 = pd_data_format.index(char1)
                new_perm[index0] = index1
            node.tf_data_format = [tf_data_format[i] for i in perm]
            node.pd_data_format = [pd_data_format[i] for i in perm]
            attr = {'perm': new_perm}
J
jiangjiajun 已提交
350 351
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
352 353
        elif len(node.out_shapes[0]) != 4:
            attr = {'perm': perm}
J
jiangjiajun 已提交
354 355
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
356 357
        else:
            raise Exception("Unexpected situation happend in Transpose OP")
J
jiangjiajun 已提交
358

J
jiangjiajun 已提交
359 360
    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
361

J
jiangjiajun 已提交
362
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
363 364 365
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
366 367 368 369
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
370
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
371

J
jiangjiajun 已提交
372
        if not channel_first:
J
jiangjiajun 已提交
373 374
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
375
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
376 377
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
378 379

        attr = {
J
jiangjiajun 已提交
380
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
381
            "pool_type": string("max"),
M
mamingjie-China 已提交
382
            "pool_padding": string(pad_mode),
J
jiangjiajun 已提交
383
            "pool_stride": strides[2:4]
J
jiangjiajun 已提交
384
        }
J
jiangjiajun 已提交
385 386
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
387 388 389 390 391

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
J
jiangjiajun 已提交
392
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
393

J
jiangjiajun 已提交
394
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
395 396
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
397
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
398 399 400
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
401 402 403 404 405
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
406 407 408

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
J
jiangjiajun 已提交
409 410 411 412 413

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
414 415
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
416

J
jiangjiajun 已提交
417 418 419 420 421 422
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
423
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
424
            "padding": string(pad_mode)
J
jiangjiajun 已提交
425
        }
J
jiangjiajun 已提交
426 427
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
428

J
jiangjiajun 已提交
429 430 431 432 433 434 435 436
    def BiasAdd(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        axis = -1
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = 1
        inputs = {"x": input, "y": bias}
        attr = {"axis": axis}
J
jiangjiajun 已提交
437 438
        node.fluid_code.add_layer(
            "elementwise_add", inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
439 440 441 442 443 444 445

    def FusedBatchNorm(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        gamma = self.graph.get_node(node.layer.input[1], copy=True)
        beta = self.graph.get_node(node.layer.input[2], copy=True)
        moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
        moving_var = self.graph.get_node(node.layer.input[4], copy=True)
J
jiangjiajun 已提交
446 447
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
448 449 450 451 452

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
453 454 455 456
        self.add_omit_nodes(gamma.layer_name, node.layer_name)
        self.add_omit_nodes(beta.layer_name, node.layer_name)
        self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
        self.add_omit_nodes(moving_var.layer_name, node.layer_name)
J
jiangjiajun 已提交
457 458
        if channel_first:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
459

J
jiangjiajun 已提交
460 461 462 463 464 465 466 467 468
        attr = {
            "epsilon": node.get_attr("epsilon"),
            "param_attr": string(gamma.layer_name),
            "bias_attr": string(beta.layer_name),
            "moving_mean_name": string(moving_mean.layer_name),
            "moving_variance_name": string(moving_var.layer_name),
            "is_test": True
        }

J
jiangjiajun 已提交
469 470
        node.fluid_code.add_layer(
            "batch_norm", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
471

J
jiangjiajun@baidu.com 已提交
472 473 474
    def FusedBatchNormV3(self, node):
        return self.FusedBatchNorm(node)

J
jiangjiajun 已提交
475 476 477 478
    def DepthwiseConv2dNative(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
479
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
480

J
jiangjiajun 已提交
481
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
482 483
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
484
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
485 486 487
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
488 489 490 491 492
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
493 494 495

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (2, 3, 0, 1))
J
jiangjiajun 已提交
496 497 498 499 500

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
501 502
        else:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
503 504 505 506 507 508 509 510

        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": in_shape[1],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
J
jiangjiajun 已提交
511
            "groups": k_size[3] * in_shape[1],
J
jiangjiajun 已提交
512
            "use_cudnn": False,
M
mamingjie-China 已提交
513
            "padding": string(pad_mode)
J
jiangjiajun 已提交
514
        }
J
jiangjiajun 已提交
515 516
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
517

J
jiangjiajun 已提交
518 519 520
    def Reshape(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        param = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
521
        is_variable = False
J
jiangjiajun 已提交
522 523
        if param.layer_type == "Const":
            attr = {"shape": param.value.tolist()}
J
jiangjiajun 已提交
524
            self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
525 526
        else:
            # Here is a trick method to solove tensor parameter in tensorflow
J
jiangjiajun 已提交
527 528 529
            shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
            if shape.count(-1) <= 1:
                attr = {"shape": shape}
J
jiangjiajun 已提交
530 531 532 533 534
                self.add_omit_nodes(param.layer_name, node.layer_name)
            elif shape.count(-1) == 2 and shape[0] == -1:
                shape[0] = 0
                attr = {"shape": shape}
                self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
535
            else:
J
jiangjiajun 已提交
536 537
                assert len(param.out_shapes[
                    0]) == 1, "Unexpected situation of shape parameter"
J
jiangjiajun 已提交
538
                attr = {"shape": [-1]}
J
jiangjiajun 已提交
539 540 541 542 543
                node.fluid_code.add_layer(
                    "reshape",
                    inputs=param,
                    output="shape_param",
                    param_attr=attr)
J
jiangjiajun 已提交
544
                attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0}
J
jiangjiajun 已提交
545 546
                node.fluid_code.add_layer(
                    "split", inputs="shape_param", output=node, param_attr=attr)
J
jiangjiajun 已提交
547 548 549 550 551
                new_param = "["
                for i in range(param.out_shapes[0][0]):
                    new_param += (node.layer_name + "[{}]".format(i) + ", ")
                new_param = new_param.strip(", ") + "]"
                attr = {"shape": new_param}
J
jiangjiajun 已提交
552 553 554 555 556
                is_variable = True

        # to change [192, -1]->[-1, 192], allways put -1 in the first dimension
        # optimization for Paddle-Lite
        in_shape = input.out_shapes[0]
J
fix bug  
jiangjiajun 已提交
557
        if not is_variable and in_shape.count(-1) < 1:
J
jiangjiajun 已提交
558 559 560 561 562 563 564 565 566 567 568 569
            total_size = 1
            for i in range(len(in_shape)):
                total_size *= in_shape[i]
            for i in range(len(attr["shape"])):
                if attr["shape"][i] == 0:
                    attr["shape"][i] = in_shape[i]
                if attr["shape"][i] != -1:
                    total_size /= attr["shape"][i]
            if attr["shape"].count(-1) > 0:
                index = attr["shape"].index(-1)
                attr["shape"][index] = int(total_size)
                attr["shape"][0] = -1
570 571 572 573

        if len(input.out_shapes[0]) == 4 and node.tf_data_format == "NHWC":
            if len(attr["shape"]) < 3:
                perm = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
574 575 576 577
                node.fluid_code.add_layer(
                    "transpose", inputs=input, output=node, param_attr=perm)
                node.fluid_code.add_layer(
                    "reshape", inputs=node, output=node, param_attr=attr)
578 579
                return

J
jiangjiajun 已提交
580
        if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
J
jiangjiajun 已提交
581 582 583 584 585
            input_shape = self.decoder.infer_tensor(input).shape
            if input_shape[1] == attr["shape"][1]:
                attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
            else:
                perm = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
586 587 588 589
                node.fluid_code.add_layer(
                    "transpose", inputs=input, output=node, param_attr=perm)
                node.fluid_code.add_layer(
                    "reshape", inputs=node, output=node, param_attr=attr)
J
jiangjiajun 已提交
590
                perm = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
591 592
                node.fluid_code.add_layer(
                    "transpose", inputs=node, output=node, param_attr=perm)
J
jiangjiajun 已提交
593
                return
J
jiangjiajun 已提交
594 595 596
        if len(attr["shape"]) == 5:
            attr["shape"] = [attr["shape"][i] for i in [0, 1, 4, 2, 3]]

J
jiangjiajun 已提交
597 598
        node.fluid_code.add_layer(
            "reshape", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
599 600 601

    def AvgPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
602

J
jiangjiajun 已提交
603
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
604 605 606
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
607 608 609 610 611 612 613 614 615
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
616
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
617 618
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
619 620

        attr = {
J
jiangjiajun 已提交
621
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
622
            "pool_type": string("avg"),
M
mamingjie-China 已提交
623 624
            "pool_stride": strides[2:4],
            "pool_padding": string(pad_mode)
J
jiangjiajun 已提交
625
        }
J
jiangjiajun 已提交
626 627
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
628

J
jiangjiajun 已提交
629 630 631 632 633 634
    def SplitV(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        num_sections = self.graph.get_node(node.layer.input[1], copy=True)
        dim = self.graph.get_node(node.layer.input[2], copy=True)
        assert num_sections.layer_type == "Const"
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
635 636
        self.add_omit_nodes(num_sections.layer_name, node.layer_name)
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
637 638 639
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)
J
jiangjiajun 已提交
640 641 642 643
        attr = {
            "num_or_sections": num_sections.value.tolist(),
            "dim": dim.value
        }
J
jiangjiajun 已提交
644 645
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
646 647

    def ConcatV2(self, node):
J
jiangjiajun 已提交
648
        inputs = [
J
jiangjiajun 已提交
649 650
            self.graph.get_node(
                name, copy=True) for name in node.layer.input[:-1]
J
jiangjiajun 已提交
651
        ]
J
jiangjiajun 已提交
652 653
        axis = self.graph.get_node(node.layer.input[-1], copy=True)
        assert axis.layer_type == "Const"
J
jiangjiajun 已提交
654
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
655
        axis = axis.value
J
jiangjiajun 已提交
656 657
        if inputs[0].tf_data_format == "NHWC" and len(inputs[0].out_shapes[
                0]) == 4:
J
jiangjiajun 已提交
658 659
            axis = nhwc_dim_to_nchw(inputs[0], axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
660 661
        node.fluid_code.add_layer(
            "concat", inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
662 663 664 665

    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        expand_times = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
666
        self.add_omit_nodes(expand_times.layer_name, node.layer_name)
667 668 669 670
        if expand_times.layer_type == "Const":
            expand_times = expand_times.value.tolist()
        else:
            expand_times = self.decoder.infer_shape_tensor(expand_times)
J
jiangjiajun 已提交
671 672 673
        if input.tf_data_format == "NHWC":
            if len(input.out_shapes[0]) == 4:
                expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
J
Jason 已提交
674
            elif len(input.out_shapes[0]) == 3:
J
jiangjiajun 已提交
675
                expand_times = [expand_times[i] for i in [2, 0, 1]]
676 677 678 679
        for i in range(len(expand_times)):
            if expand_times[i] < 0:
                expand_times[i] = 1

J
jiangjiajun 已提交
680
        attr = {"expand_times": expand_times}
J
jiangjiajun 已提交
681 682
        node.fluid_code.add_layer(
            "expand", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
683 684

    def Pack(self, node):
J
jiangjiajun 已提交
685
        inputs = [
J
jiangjiajun 已提交
686 687
            self.graph.get_node(
                name, copy=True) for name in node.layer.input
J
jiangjiajun 已提交
688
        ]
J
jiangjiajun 已提交
689
        axis = node.get_attr("axis")
J
jiangjiajun 已提交
690 691
        if inputs[0].tf_data_format == "NHWC" and len(inputs[0].out_shapes[
                0]) == 4:
J
jiangjiajun 已提交
692 693 694 695 696 697 698 699 700
            tf_data_format = list(inputs[0].tf_data_format)
            tf_data_format.insert(axis, str(len(tf_data_format)))
            axis = nhwc_dim_to_nchw(inputs[0], axis)
            pd_data_format = list(inputs[0].pd_data_format)
            pd_data_format.insert(axis, str(len(pd_data_format)))
            node.tf_data_format = "".join(tf_data_format)
            node.pd_data_format = "".join(pd_data_format)

        attr = {"axis": axis}
J
jiangjiajun 已提交
701 702
        node.fluid_code.add_layer(
            "stack", inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
703 704 705

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
706
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
707
        assert paddings.layer_type == "Const", "Padding should be Const"
J
jiangjiajun 已提交
708
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
J
jiangjiajun 已提交
709 710 711
        paddings = paddings.value.flatten().tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
J
jiangjiajun 已提交
712 713 714 715 716 717

        pad_op = "pad"
        if len(input.out_shapes[0]) == 4:
            if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                paddings = paddings[4:]
                pad_op = "pad2d"
J
jiangjiajun 已提交
718
        attr = {"paddings": paddings}
J
jiangjiajun 已提交
719 720
        node.fluid_code.add_layer(
            pad_op, inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
721

J
jiangjiajun 已提交
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738
    def MirrorPad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
        assert paddings.layer_type == "Const", "Padding should be Const"
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
        paddings = paddings.value.flatten().tolist()
        mode = node.get_attr("mode").decode()
        assert mode == "REFLECT", "Only support 'REFLECT` mode in MirrorPad"
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]

        pad_op = "pad"
        if len(input.out_shapes[0]) == 4:
            if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                paddings = paddings[4:]
                pad_op = "pad2d"
        attr = {"paddings": paddings, "mode": string("reflect")}
J
jiangjiajun 已提交
739 740
        node.fluid_code.add_layer(
            pad_op, inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
741

J
jiangjiajun 已提交
742 743 744 745
    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0], copy=True)
        limit = self.graph.get_node(node.layer.input[1], copy=True)
        delta = self.graph.get_node(node.layer.input[2], copy=True)
M
mamingjie-China 已提交
746 747 748
        self.add_omit_nodes(start.layer_name, node.layer_name)
        self.add_omit_nodes(limit.layer_name, node.layer_name)
        self.add_omit_nodes(delta.layer_name, node.layer_name)
J
jiangjiajun 已提交
749 750
        if start.layer_type == "Const":
            start = start.value
751 752
        else:
            start = self.decoder.infer_tensor(start)
J
jiangjiajun 已提交
753 754
        if limit.layer_type == "Const":
            limit = limit.value
755 756
        else:
            limit = self.decoder.infer_tensor(limit)
J
jiangjiajun 已提交
757 758
        if delta.layer_type == "Const":
            delta = delta.value
759 760 761
        else:
            delta = self.decoder.infer_tensor(delta)

J
jiangjiajun 已提交
762
        inputs = {"start": start, "end": limit, "step": delta}
J
jiangjiajun 已提交
763
        attr = {"dtype": string(node.dtype)}
J
jiangjiajun 已提交
764 765
        node.fluid_code.add_layer(
            "range", inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
766 767 768 769 770

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
J
jiangjiajun 已提交
771
        dims = reduce_idx.value.tolist()
J
jiangjiajun 已提交
772
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
773 774 775 776 777 778

        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(dims)):
                dims[i] = nhwc_dim_to_nchw(input, dims[i])

        attr = {"dim": dims, "keep_dim": keep_dims}
J
jiangjiajun 已提交
779 780
        node.fluid_code.add_layer(
            "reduce_mean", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
781 782 783 784 785 786 787

    def MatMul(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
788 789 790 791 792 793
        # fix paddle shape infer problem
        # should be removed after paddle 1.6
        if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0:
            shape = x.out_shapes[0]
            shape[-1] = y.out_shapes[0][0]
            attr = {"shape": shape}
J
jiangjiajun 已提交
794 795
            node.fluid_code.add_layer(
                "reshape", inputs=x, output=x, param_attr=attr)
J
jiangjiajun 已提交
796
        attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
J
jiangjiajun 已提交
797 798
        node.fluid_code.add_layer(
            "matmul", inputs=inputs, output=node, param_attr=attr)
J
jiangjiajun 已提交
799 800 801 802 803

    def ArgMax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = self.graph.get_node(node.layer.input[1], copy=True)
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
J
jiangjiajun 已提交
804
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
805 806 807 808
        axis = axis.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
809 810
        node.fluid_code.add_layer(
            "argmax", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
811 812 813 814 815 816 817 818 819

    def StridedSlice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        end = self.graph.get_node(node.layer.input[2], copy=True)
        strides = self.graph.get_node(node.layer.input[3], copy=True)
        assert begin.layer_type == "Const"
        assert end.layer_type == "Const"
        assert strides.layer_type == "Const"
J
jiangjiajun 已提交
820 821 822
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(end.layer_name, node.layer_name)
        self.add_omit_nodes(strides.layer_name, node.layer_name)
J
jiangjiajun 已提交
823 824 825
        strides = strides.value.tolist()
        assert len(set(strides)) == 1 and strides[0] == 1

J
jiangjiajun 已提交
826 827 828 829 830 831
        begin = begin.value.tolist()
        end = end.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            begin = [begin[i] for i in [0, 3, 1, 2]]
            end = [end[i] for i in [0, 3, 1, 2]]

J
jiangjiajun 已提交
832 833 834 835 836 837 838 839 840
        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

        attr = {
            "axes": [i for i in range(len(strides))],
            "starts": begin,
            "ends": end
        }
J
jiangjiajun 已提交
841 842 843 844 845 846 847

        shrink_axis_mask = node.get_attr('shrink_axis_mask')
        squeeze_dims = list()
        for i in range(len(begin)):
            x = shrink_axis_mask >> i & 1
            if x == 1:
                squeeze_dims.append(i)
J
jiangjiajun 已提交
848 849
        node.fluid_code.add_layer(
            "slice", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
850 851
        if shrink_axis_mask > 0 and len(input.out_shapes[0]) == 5:
            attr = {"axes": squeeze_dims}
J
jiangjiajun 已提交
852 853
            node.fluid_code.add_layer(
                "squeeze", inputs=node, output=node, param_attr=attr)
854 855 856 857 858

    def Slice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        size = self.graph.get_node(node.layer.input[2], copy=True)
J
jiangjiajun 已提交
859 860
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(size.layer_name, node.layer_name)
J
jiangjiajun 已提交
861 862 863 864 865 866 867 868
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
        else:
            begin = self.decoder.infer_tensor(begin).tolist()
        if size.layer_type == "const":
            size = size.value.tolist()
        else:
            size = self.decoder.infer_tensor(size).tolist()
869

J
jiangjiajun 已提交
870 871 872 873
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            size = [size[i] for i in [0, 3, 1, 2]]
            begin = [begin[i] for i in [0, 3, 1, 2]]

874 875 876 877 878 879 880 881 882 883 884
        for i in range(len(size)):
            if size[i] < 0:
                size[i] = 99999999
            else:
                size[i] = size[i] + begin[i]

        attr = {
            "axes": [i for i in range(len(size))],
            "starts": begin,
            "ends": size
        }
J
jiangjiajun 已提交
885 886
        node.fluid_code.add_layer(
            "slice", inputs=input, output=node, param_attr=attr)
887 888

    def Conv2DBackpropInput(self, node):
889
        out_shape = self.graph.get_node(node.layer.input[0], copy=True)
890
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
891 892
        input = self.graph.get_node(node.layer.input[2], copy=True)

893
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
894

J
jiangjiajun 已提交
895
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
896 897
        self.add_omit_nodes(out_shape.layer_name, node.layer_name)

J
jiangjiajun 已提交
898 899 900 901 902 903
        if out_shape.layer_type == "Const":
            out_shape = out_shape.value.tolist()
        else:
            out_shape = self.decoder.infer_shape_tensor(out_shape,
                                                        node.out_shapes[0])

904
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
905 906
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
907
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
908 909 910
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
911
        pad_mode = node.get_attr("padding").decode()
912 913 914 915
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
916

J
jiangjiajun 已提交
917 918
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
919 920 921 922
        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
923 924
        else:
            self.data_format_propagation(node)
925 926 927 928

        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
M
mamingjie-China 已提交
929
            "num_filters": k_size[2],
930 931
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
932
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
933 934
            "padding": string(pad_mode),
            "output_size": out_shape[1:3]
935
        }
J
jiangjiajun 已提交
936 937
        node.fluid_code.add_layer(
            "conv2d_transpose", inputs=input, output=node, param_attr=attr)
938 939 940 941 942 943

    def Max(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
944 945 946 947 948
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
949 950
        node.fluid_code.add_layer(
            "reduce_max", inputs=input, output=node, param_attr=attr)
951 952 953 954 955 956

    def Sum(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
957 958 959 960 961
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
962 963
        node.fluid_code.add_layer(
            "reduce_sum", inputs=input, output=node, param_attr=attr)
964

J
jiangjiajun 已提交
965 966 967 968
    def Cast(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        dtype = node.dtype_map[node.get_attr('DstT')]
        attr = {"dtype": string(dtype)}
J
jiangjiajun 已提交
969 970
        node.fluid_code.add_layer(
            "cast", inputs=input, output=node, param_attr=attr)
971

J
jiangjiajun 已提交
972 973 974
    def Split(self, node):
        dim = self.graph.get_node(node.layer.input[0], copy=True)
        input = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
975
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
976
        num_split = node.get_attr('num_split')
J
jiangjiajun 已提交
977 978 979 980 981
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"num_or_sections": num_split, "dim": dim}
J
jiangjiajun 已提交
982 983
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
984 985 986 987 988 989 990 991

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(squeeze_dims)):
                squeeze_dims[i] = nhwc_dim_to_nchw(input, squeeze_dims[i])
        attr = {"axes": squeeze_dims}
J
jiangjiajun 已提交
992 993
        node.fluid_code.add_layer(
            "squeeze", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
994 995 996 997

    def Softmax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = node.get_attr("axis")
J
jiangjiajun 已提交
998 999
        if axis is None:
            axis = -1 + len(input.out_shapes[0])
J
jiangjiajun 已提交
1000 1001 1002
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
1003 1004
        node.fluid_code.add_layer(
            "softmax", inputs=input, output=node, param_attr=attr)
J
jiangjiajun 已提交
1005

1006 1007 1008
    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1009
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1010 1011 1012
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
J
jiangjiajun 已提交
1013 1014
            resize_shape = self.decoder.infer_shape_tensor(resize_shape,
                                                           node.out_shapes[0])
1015 1016
        align_corners = node.get_attr("align_corners")
        attr = {"align_corners": align_corners, "out_shape": resize_shape}
J
jiangjiajun 已提交
1017 1018
        node.fluid_code.add_layer(
            "resize_nearest", inputs=input, output=node, param_attr=attr)
1019 1020 1021 1022

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1023
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1024 1025 1026
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
J
jiangjiajun 已提交
1027 1028
            resize_shape = self.decoder.infer_shape_tensor(resize_shape,
                                                           node.out_shapes[0])
1029 1030 1031 1032 1033 1034
        align_corners = node.get_attr("align_corners")
        attr = {
            "align_corners": align_corners,
            "out_shape": resize_shape,
            "align_mode": 1
        }
J
jiangjiajun 已提交
1035 1036
        node.fluid_code.add_layer(
            "resize_bilinear", inputs=input, output=node, param_attr=attr)
1037 1038

    def GreaterEqual(self, node):
J
jiangjiajun 已提交
1039 1040 1041
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
1042 1043
        node.fluid_code.add_layer(
            "greater_equal", inputs=inputs, output=node, param_attr=None)
J
jiangjiajun 已提交
1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        self.add_omit_nodes(shape.layer_name, node.layer_name)
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
        else:
            shape = self.decoder.infer_shape_tensor(shape)
        if len(shape) == 4 and node.tf_data_format == "NHWC":
            shape = [shape[i] for i in [0, 3, 1, 2]]
        attr = {"shape": shape, "min": 0.0, "max": 0.9999}
        if shape[0] < 0:
            input = self.batch_node
J
jiangjiajun 已提交
1057 1058 1059 1060 1061
            node.fluid_code.add_layer(
                "uniform_random_batch_size_like",
                inputs=input,
                output=node,
                param_attr=attr)
J
jiangjiajun 已提交
1062
        else:
J
jiangjiajun 已提交
1063 1064
            node.fluid_code.add_layer(
                "uniform_random", inputs=None, output=node, param_attr=attr)
J
jiangjiajun 已提交
1065 1066 1067 1068 1069

    def SquaredDifference(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
1070 1071
        node.fluid_code.add_layer(
            "elementwise_sub", inputs=inputs, output=node, param_attr=None)
J
jiangjiajun 已提交
1072
        inputs = {"x": node, "y": node}
J
jiangjiajun 已提交
1073 1074
        node.fluid_code.add_layer(
            "elementwise_mul", inputs=inputs, output=node, param_attr=None)