tf_op_mapper.py 49.1 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

J
jiangjiajun 已提交
15 16
from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
J
jiangjiajun 已提交
17
from x2paddle.core.util import *
J
jiangjiajun 已提交
18
import inspect
J
jiangjiajun 已提交
19
import numpy
J
jiangjiajun 已提交
20
import sys
21

J
jiangjiajun 已提交
22

J
jiangjiajun 已提交
23 24 25 26 27 28 29 30
# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]

J
jiangjiajun 已提交
31

J
jiangjiajun 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
def nhwc_dim_to_nchw(node, dim):
    tf_data_format = list(node.tf_data_format)
    pd_data_format = list(node.pd_data_format)
    if isinstance(dim, list):
        for i in range(len(dim)):
            char = tf_data_format[dim[i]]
            dim[i] = pd_data_format.index(char)
    else:
        char = tf_data_format[dim]
        dim = pd_data_format.index(char)
    return dim

    if dim < 0:
        dim += 4
    if dim > 0:
        dim = (dim + 1) % 4 + int((dim + 1) / 4)
    return dim


J
jiangjiajun 已提交
51
class TFOpMapper(OpMapper):
J
jiangjiajun 已提交
52 53 54 55 56 57 58
    directly_map_ops = {
        'Relu': ['relu'],
        'Relu6': ['relu6'],
        'Shape': ['shape'],
        'Abs': ['abs'],
        'Sigmoid': ['sigmoid'],
        'Exp': ['exp'],
J
jiangjiajun 已提交
59
        'Rsqrt': ['rsqrt'],
60 61 62 63
        'swish_f32': ['swish'],
        'LeakyRelu': ['leaky_relu', {
            'alpha': 'alpha'
        }]
J
jiangjiajun 已提交
64 65 66 67 68 69
    }
    elementwise_ops = {
        'Add': 'elementwise_add',
        'RealDiv': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Maximum': 'elementwise_max',
70 71
        'Mul': 'elementwise_mul',
        'FloorDiv': 'elementwise_floordiv'
J
jiangjiajun 已提交
72 73
    }

J
jiangjiajun 已提交
74 75
    def __init__(self, decoder):
        super(TFOpMapper, self).__init__()
J
jiangjiajun 已提交
76
        self.decoder = decoder
J
jiangjiajun 已提交
77
        self.graph = decoder.tf_graph
78
        self.batch_node = None
J
jiangjiajun 已提交
79
        self.weights = dict()
J
jiangjiajun 已提交
80
        self.omit_nodes = list()
J
jiangjiajun 已提交
81
        self.used_custom_layers = dict()
82

J
jiangjiajun 已提交
83 84 85 86 87 88 89
        not_placeholder = list()
        for name in self.graph.input_nodes:
            if self.graph.get_node(name).layer_type != "Placeholder":
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]
J
jiangjiajun 已提交
90

91
        sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
J
jiangjiajun 已提交
92
        unsupported_ops = set()
93 94
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...    ".format(i + 1))
95 96
            node = self.graph.get_node(node_name)
            op = node.layer_type
J
jiangjiajun 已提交
97
            if op in self.directly_map_ops:
J
jiangjiajun 已提交
98 99
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
100 101
                self.directly_map(node)
            elif op in self.elementwise_ops:
J
jiangjiajun 已提交
102 103
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
104 105
                self.elementwise_map(node)
            elif hasattr(self, op):
J
jiangjiajun 已提交
106 107
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
108 109
                func = getattr(self, op)
                func(node)
J
jiangjiajun 已提交
110
            else:
J
jiangjiajun 已提交
111 112
                unsupported_ops.add(op)
        if len(unsupported_ops) > 0:
113 114 115
            sys.stderr.write(
                "=========={} Ops are not supported yet======\n".format(
                    len(unsupported_ops)))
J
jiangjiajun 已提交
116
            for op in unsupported_ops:
117
                sys.stderr.write("========== {} ==========\n".format(op))
J
jiangjiajun 已提交
118
            sys.exit(-1)
119
        sys.stderr.write('\nDone!\n')
120

J
jiangjiajun 已提交
121 122 123 124 125 126 127 128 129
    def add_omit_nodes(self, in_node_name, out_node_name):
        in_node = self.graph.get_node(in_node_name)
        out_node = self.graph.get_node(out_node_name)
        index = in_node.outputs.index(out_node_name)
        del in_node.outputs[index]
        index = out_node.inputs.index(in_node_name)
        del out_node.inputs[index]
        self.omit_nodes.append(in_node.layer_name)

J
jiangjiajun 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_node(node.layer.input[0], copy=True)
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
        node.fluid_code.add_layer(op_info[0],
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
J
jiangjiajun 已提交
148 149 150 151
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
152 153 154 155
        if len(x_shape) == 0:
            x_shape = [1]
        if len(y_shape) == 0:
            y_shape = [1]
J
jiangjiajun 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        # incomplement broadcasting support for paddle
        x_input = x
        y_input = y
        if len(x_shape) < len(y_shape):
            unrevertable_ops = [
                "elementwise_sub", "elementwise_div", "elementwise_floordiv",
                "elementwise_mod", "elementwise_pow"
            ]
            if op_type not in unrevertable_ops:
                x_input = y
                y_input = x
                x_shape = y.out_shapes[0]
                y_shape = x.out_shapes[0]
            else:
                raise Exception("Unexpected situation happend")

J
jiangjiajun 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184
        if len(x_shape) == 4 and len(y_shape) == 1:
            if x_input.tf_data_format == "NHWC":
                axis = 1
            else:
                axis = -1
            attr = {"axis": axis}
            inputs = {"x": x_input, "y": y_input}
            node.fluid_code.add_layer(op_type,
                                      inputs=inputs,
                                      output=node,
                                      param_attr=attr)
            return

J
jiangjiajun 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        is_sub_seq = True
        for i in range(len(y_shape)):
            index = -1 * i - 1
            if y_shape[index] != x_shape[index]:
                is_sub_seq = False
        if not is_sub_seq:
            x_expand_times = [1] * len(x_shape)
            y_expand_times = [1] * len(y_shape)
            x_need_expand = False
            y_need_expand = False
            for i in range(len(y_shape)):
                index = -1 * i - 1
                if y_shape[index] != x_shape[index]:
                    if y_shape[index] == 1:
                        y_expand_times[index] = x_shape[index]
                        y_need_expand = True
                    elif x_shape[index] == 1:
                        x_expand_times[index] = y_shape[index]
                        x_need_expand = True
                    else:
                        raise Exception("Unexpected situation happend")
            if x_need_expand:
J
jiangjiajun 已提交
207 208 209 210
                if len(x_expand_times) == 3 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [2, 0, 1]]
                if len(x_expand_times) == 4 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
211 212 213 214 215 216 217
                attr = {"expand_times": x_expand_times}
                node.fluid_code.add_layer("expand",
                                          inputs=x_input,
                                          output="x_tmp",
                                          param_attr=attr)
                x_input = "x_tmp"
            if y_need_expand:
J
jiangjiajun 已提交
218 219 220 221
                if len(y_expand_times) == 3 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [2, 0, 1]]
                if len(y_expand_times) == 4 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
222 223 224 225 226 227 228 229 230 231 232 233
                attr = {"expand_times": y_expand_times}
                node.fluid_code.add_layer("expand",
                                          inputs=y_input,
                                          output="y_tmp",
                                          param_attr=attr)
                y_input = "y_tmp"
        inputs = {"x": x_input, "y": y_input}
        node.fluid_code.add_layer(op_type,
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

234 235
    def Placeholder(self, node):
        shape = node.out_shapes[0]
J
jiangjiajun 已提交
236 237
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
J
jiangjiajun 已提交
238 239 240 241
        if node.tf_data_format == "NHWC" and len(shape) == 4:
            shape = [shape[i] for i in [0, 3, 1, 2]]
        elif node.tf_data_format == "NCHW" and len(shape) == 4:
            self.graph.data_format_propagation(node)
242 243
        dtype = node.dtype
        attr = {
J
jiangjiajun 已提交
244
            'dtype': string(dtype),
245
            'shape': shape,
J
jiangjiajun 已提交
246 247
            'name': string(node.layer_name),
            'append_batch_size': False
248
        }
249 250 251
        if shape[0] < 0:
            self.batch_node = node

J
jiangjiajun 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        node.fluid_code.add_layer("data",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

J
jiangjiajun 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279
        self.weights[node.layer_name] = node.value

        if node.tf_data_format == "NHWC":
            if len(shape) == 4:
                shape = [shape[i] for i in [0, 3, 1, 2]]
            if len(shape) == 3:
                shape = [shape[i] for i in [2, 0, 1]]
                self.weights[node.layer_name] = numpy.transpose(
                    node.value, (2, 0, 1))
        elif node.tf_data_format == "NCHW":
            if len(shape) == 4:
                self.graph.data_format_propagation(node)

J
jiangjiajun 已提交
280 281 282 283 284 285 286 287 288 289 290 291
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
        node.fluid_code.add_layer("create_parameter",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def Transpose(self, node):
J
jiangjiajun 已提交
292 293
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
294
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
295
        del self.weights[perm.layer_name.replace('/', '_')]
J
jiangjiajun 已提交
296 297 298
        perm.fluid_code.clear()
        perm = perm.value.tolist()

J
jiangjiajun 已提交
299
        if perm == [0, 3, 1, 2] and input.data_format == "NHWC":
300 301 302 303 304 305 306 307 308
            #            node.fluid_code.add_layer("assign",
            #                                      inputs=input,
            #                                      output=node,
            #                                      param_attr=None)
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
J
jiangjiajun 已提交
309 310 311
            node.tf_data_format = "NCHW"
            self.graph.data_format_propagation(node)
        elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW":
312 313 314 315 316 317 318 319 320 321
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
            #
            #            node.fluid_code.add_layer("assign",
            #                                      inputs=input,
            #                                      output=node,
            #                                      param_attr=None)
J
jiangjiajun 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
            node.tf_data_format = "NHWC"
            self.graph.data_format_propagation(node)
        elif len(input.out_shapes[0]) > 4:
            print(input.layer_name, input.tf_data_format, input.pd_data_format)
            tf_data_format = list(input.tf_data_format)
            pd_data_format = list(input.pd_data_format)
            new_perm = [i for i in range(len(perm))]
            for i in range(len(perm)):
                char0 = tf_data_format[i]
                char1 = tf_data_format[perm[i]]
                index0 = pd_data_format.index(char0)
                index1 = pd_data_format.index(char1)
                new_perm[index0] = index1
            node.tf_data_format = [tf_data_format[i] for i in perm]
            node.pd_data_format = [pd_data_format[i] for i in perm]
            attr = {'perm': new_perm}
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        elif len(node.out_shapes[0]) != 4:
            attr = {'perm': perm}
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        else:
            raise Exception("Unexpected situation happend in Transpose OP")
J
jiangjiajun 已提交
350

J
jiangjiajun 已提交
351 352
    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
353

J
jiangjiajun 已提交
354
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
355 356 357
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
358 359 360 361
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
362
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
363
        padding = 0
J
jiangjiajun 已提交
364

J
jiangjiajun 已提交
365
        if not channel_first:
J
jiangjiajun 已提交
366 367
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
368
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
369 370
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
371 372

        if pad_mode == "SAME":
J
jiangjiajun 已提交
373 374
            pad_h = get_same_padding(in_shape[2], k_size[2], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[3], strides[3])
J
jiangjiajun 已提交
375 376 377
            pad_h = pad_h[0] + pad_h[1]
            pad_w = pad_w[0] + pad_w[1]
            attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0}
J
jiangjiajun 已提交
378 379 380 381 382
            node.fluid_code.add_layer("pad2d",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
            input = node
J
jiangjiajun 已提交
383
        attr = {
J
jiangjiajun 已提交
384
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
385
            "pool_type": string("max"),
J
jiangjiajun 已提交
386
            "pool_padding": padding,
J
jiangjiajun 已提交
387
            "pool_stride": strides[2:4]
J
jiangjiajun 已提交
388
        }
J
jiangjiajun 已提交
389 390 391 392
        node.fluid_code.add_layer("pool2d",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
393 394 395 396 397

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
J
jiangjiajun 已提交
398
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
399

J
jiangjiajun 已提交
400
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
401 402
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
403
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
404 405 406
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
407 408 409 410 411
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
412 413 414 415
        padding = 0

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
J
jiangjiajun 已提交
416 417 418 419 420

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
421 422
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
423

J
jiangjiajun 已提交
424 425 426
        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
427 428 429 430 431 432 433 434 435
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
                node.fluid_code.add_layer("pad2d",
                                          inputs=input,
                                          output=node,
                                          param_attr=attr)
                input = node
J
jiangjiajun 已提交
436 437 438 439 440 441
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
442 443
            "dilation": dilations[2:4],
            "padding": padding
J
jiangjiajun 已提交
444
        }
J
jiangjiajun 已提交
445 446 447 448
        node.fluid_code.add_layer("conv2d",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
449

J
jiangjiajun 已提交
450 451 452 453 454 455 456 457 458 459 460 461
    def BiasAdd(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        axis = -1
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = 1
        inputs = {"x": input, "y": bias}
        attr = {"axis": axis}
        node.fluid_code.add_layer("elementwise_add",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
462 463 464 465 466 467 468

    def FusedBatchNorm(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        gamma = self.graph.get_node(node.layer.input[1], copy=True)
        beta = self.graph.get_node(node.layer.input[2], copy=True)
        moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
        moving_var = self.graph.get_node(node.layer.input[4], copy=True)
J
jiangjiajun 已提交
469 470
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
471 472 473 474 475

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
476 477 478 479
        self.add_omit_nodes(gamma.layer_name, node.layer_name)
        self.add_omit_nodes(beta.layer_name, node.layer_name)
        self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
        self.add_omit_nodes(moving_var.layer_name, node.layer_name)
J
jiangjiajun 已提交
480 481
        if channel_first:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
482

J
jiangjiajun 已提交
483 484 485 486 487 488 489 490 491 492
        attr = {
            "epsilon": node.get_attr("epsilon"),
            "param_attr": string(gamma.layer_name),
            "bias_attr": string(beta.layer_name),
            "moving_mean_name": string(moving_mean.layer_name),
            "moving_variance_name": string(moving_var.layer_name),
            "is_test": True
        }

        node.fluid_code.add_layer("batch_norm",
J
jiangjiajun 已提交
493
                                  inputs=input,
J
jiangjiajun 已提交
494 495 496 497 498 499 500
                                  output=node,
                                  param_attr=attr)

    def DepthwiseConv2dNative(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
501
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
502

J
jiangjiajun 已提交
503
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
504 505
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
506
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
507 508 509
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
510 511 512 513 514
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
515 516 517 518
        padding = 0

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (2, 3, 0, 1))
J
jiangjiajun 已提交
519 520 521 522 523

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
524 525
        else:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
526 527 528 529

        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
530 531 532 533
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
J
jiangjiajun 已提交
534
                node.fluid_code.add_layer("pad2d",
J
jiangjiajun 已提交
535
                                          inputs=input,
J
jiangjiajun 已提交
536 537
                                          output=node,
                                          param_attr=attr)
J
jiangjiajun 已提交
538 539
                input = node

J
jiangjiajun 已提交
540 541 542 543 544 545 546
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": in_shape[1],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
J
jiangjiajun 已提交
547
            "groups": k_size[3] * in_shape[1],
J
jiangjiajun 已提交
548
            "use_cudnn": False,
J
jiangjiajun 已提交
549
            "padding": padding
J
jiangjiajun 已提交
550
        }
J
jiangjiajun 已提交
551
        node.fluid_code.add_layer("conv2d",
J
jiangjiajun 已提交
552
                                  inputs=input,
J
jiangjiajun 已提交
553 554
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
555

J
jiangjiajun 已提交
556 557 558 559 560
    def Reshape(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        param = self.graph.get_node(node.layer.input[1], copy=True)
        if param.layer_type == "Const":
            attr = {"shape": param.value.tolist()}
J
jiangjiajun 已提交
561
            self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
562 563
        else:
            # Here is a trick method to solove tensor parameter in tensorflow
J
jiangjiajun 已提交
564 565 566
            shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
            if shape.count(-1) <= 1:
                attr = {"shape": shape}
J
jiangjiajun 已提交
567 568 569 570 571
                self.add_omit_nodes(param.layer_name, node.layer_name)
            elif shape.count(-1) == 2 and shape[0] == -1:
                shape[0] = 0
                attr = {"shape": shape}
                self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
            else:
                assert len(param.out_shapes[0]
                           ) == 1, "Unexpected situation of shape parameter"
                attr = {"shape": [-1]}
                node.fluid_code.add_layer("reshape",
                                          inputs=param,
                                          output="shape_param",
                                          param_attr=attr)
                attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0}
                node.fluid_code.add_layer("split",
                                          inputs="shape_param",
                                          output=node,
                                          param_attr=attr)
                new_param = "["
                for i in range(param.out_shapes[0][0]):
                    new_param += (node.layer_name + "[{}]".format(i) + ", ")
                new_param = new_param.strip(", ") + "]"
                attr = {"shape": new_param}
590 591 592 593 594 595 596 597 598 599 600 601 602 603

        if len(input.out_shapes[0]) == 4 and node.tf_data_format == "NHWC":
            if len(attr["shape"]) < 3:
                perm = {"perm": [0, 2, 3, 1]}
                node.fluid_code.add_layer("transpose",
                                          inputs=input,
                                          output=node,
                                          param_attr=perm)
                node.fluid_code.add_layer("reshape",
                                          inputs=node,
                                          output=node,
                                          param_attr=attr)
                return

J
jiangjiajun 已提交
604
        if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
J
jiangjiajun 已提交
605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
            input_shape = self.decoder.infer_tensor(input).shape
            if input_shape[1] == attr["shape"][1]:
                attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
            else:
                perm = {"perm": [0, 2, 3, 1]}
                node.fluid_code.add_layer("transpose",
                                          inputs=input,
                                          output=node,
                                          param_attr=perm)
                node.fluid_code.add_layer("reshape",
                                          inputs=node,
                                          output=node,
                                          param_attr=attr)
                perm = {"perm": [0, 3, 1, 2]}
                node.fluid_code.add_layer("transpose",
                                          inputs=node,
                                          output=node,
                                          param_attr=perm)
                return
J
jiangjiajun 已提交
624 625 626 627 628 629 630
        node.fluid_code.add_layer("reshape",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def AvgPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
631

J
jiangjiajun 已提交
632
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
633 634 635
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
636 637 638 639 640 641 642 643 644
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
645
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
646 647
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
648 649

        attr = {
J
jiangjiajun 已提交
650
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
651 652 653 654
            "pool_type": string("avg"),
            "pool_stride": strides[2:4]
        }
        if pad_mode == "SAME":
J
jiangjiajun 已提交
655 656
            pad_h = get_same_padding(in_shape[2], k_size[2], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[3], strides[3])
J
jiangjiajun 已提交
657 658 659 660
            assert pad_h[0] == pad_h[1] and pad_w[0] == pad_w[
                1], "Cannot map AvgPool"
            attr["pool_padding"] = [pad_h[0], pad_w[0]]
        node.fluid_code.add_layer("pool2d",
J
jiangjiajun 已提交
661
                                  inputs=input,
J
jiangjiajun 已提交
662 663 664
                                  output=node,
                                  param_attr=attr)

J
jiangjiajun 已提交
665 666 667 668 669 670
    def SplitV(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        num_sections = self.graph.get_node(node.layer.input[1], copy=True)
        dim = self.graph.get_node(node.layer.input[2], copy=True)
        assert num_sections.layer_type == "Const"
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
671 672
        self.add_omit_nodes(num_sections.layer_name, node.layer_name)
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
673 674 675
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)
J
jiangjiajun 已提交
676 677 678 679 680 681 682 683
        attr = {
            "num_or_sections": num_sections.value.tolist(),
            "dim": dim.value
        }
        node.fluid_code.add_layer("split",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
684 685

    def ConcatV2(self, node):
J
jiangjiajun 已提交
686 687 688 689
        inputs = [
            self.graph.get_node(name, copy=True)
            for name in node.layer.input[:-1]
        ]
J
jiangjiajun 已提交
690 691
        axis = self.graph.get_node(node.layer.input[-1], copy=True)
        assert axis.layer_type == "Const"
J
jiangjiajun 已提交
692
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
693 694 695 696 697
        axis = axis.value
        if inputs[0].tf_data_format == "NHWC" and len(
                inputs[0].out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(inputs[0], axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
698 699 700 701
        node.fluid_code.add_layer("concat",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
702 703 704 705

    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        expand_times = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
706
        self.add_omit_nodes(expand_times.layer_name, node.layer_name)
707 708 709 710
        if expand_times.layer_type == "Const":
            expand_times = expand_times.value.tolist()
        else:
            expand_times = self.decoder.infer_shape_tensor(expand_times)
J
jiangjiajun 已提交
711 712 713 714 715
        if input.tf_data_format == "NHWC":
            if len(input.out_shapes[0]) == 4:
                expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
            elif len(input.out_shape[0]) == 3:
                expand_times = [expand_times[i] for i in [2, 0, 1]]
716 717 718 719
        for i in range(len(expand_times)):
            if expand_times[i] < 0:
                expand_times[i] = 1

J
jiangjiajun 已提交
720
        attr = {"expand_times": expand_times}
J
jiangjiajun 已提交
721 722 723 724
        node.fluid_code.add_layer("expand",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
725 726

    def Pack(self, node):
J
jiangjiajun 已提交
727 728 729
        inputs = [
            self.graph.get_node(name, copy=True) for name in node.layer.input
        ]
J
jiangjiajun 已提交
730 731 732 733 734 735 736 737 738 739 740 741
        axis = node.get_attr("axis")
        if inputs[0].tf_data_format == "NHWC" and len(
                inputs[0].out_shapes[0]) == 4:
            tf_data_format = list(inputs[0].tf_data_format)
            tf_data_format.insert(axis, str(len(tf_data_format)))
            axis = nhwc_dim_to_nchw(inputs[0], axis)
            pd_data_format = list(inputs[0].pd_data_format)
            pd_data_format.insert(axis, str(len(pd_data_format)))
            node.tf_data_format = "".join(tf_data_format)
            node.pd_data_format = "".join(pd_data_format)

        attr = {"axis": axis}
J
jiangjiajun 已提交
742 743 744
        node.fluid_code.add_layer("stack",
                                  inputs=inputs,
                                  output=node,
J
jiangjiajun 已提交
745
                                  param_attr=attr)
J
jiangjiajun 已提交
746 747 748

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
749
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
750
        assert paddings.layer_type == "Const", "Padding should be Const"
J
jiangjiajun 已提交
751
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
J
jiangjiajun 已提交
752 753 754
        paddings = paddings.value.flatten().tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
J
jiangjiajun 已提交
755 756 757 758 759 760

        pad_op = "pad"
        if len(input.out_shapes[0]) == 4:
            if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                paddings = paddings[4:]
                pad_op = "pad2d"
J
jiangjiajun 已提交
761
        attr = {"paddings": paddings}
J
jiangjiajun 已提交
762
        node.fluid_code.add_layer(pad_op,
J
jiangjiajun 已提交
763 764 765
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
766 767 768 769 770 771 772

    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0], copy=True)
        limit = self.graph.get_node(node.layer.input[1], copy=True)
        delta = self.graph.get_node(node.layer.input[2], copy=True)
        if start.layer_type == "Const":
            start = start.value
773 774
        else:
            start = self.decoder.infer_tensor(start)
J
jiangjiajun 已提交
775 776
        if limit.layer_type == "Const":
            limit = limit.value
777 778
        else:
            limit = self.decoder.infer_tensor(limit)
J
jiangjiajun 已提交
779 780
        if delta.layer_type == "Const":
            delta = delta.value
781 782
        else:
            delta = self.decoder.infer_tensor(delta)
J
jiangjiajun 已提交
783 784 785
        self.add_omit_nodes(start.layer_name, node.layer_name)
        self.add_omit_nodes(limit.layer_name, node.layer_name)
        self.add_omit_nodes(delta.layer_name, node.layer_name)
786

J
jiangjiajun 已提交
787
        inputs = {"start": start, "end": limit, "step": delta}
J
jiangjiajun 已提交
788
        attr = {"dtype": string(node.dtype)}
789 790 791 792
        node.fluid_code.add_layer("range",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)
J
jiangjiajun 已提交
793 794 795 796 797

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
J
jiangjiajun 已提交
798
        dims = reduce_idx.value.tolist()
J
jiangjiajun 已提交
799
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
800 801 802 803 804 805

        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(dims)):
                dims[i] = nhwc_dim_to_nchw(input, dims[i])

        attr = {"dim": dims, "keep_dim": keep_dims}
J
jiangjiajun 已提交
806 807 808 809 810 811 812 813 814 815 816
        node.fluid_code.add_layer("reduce_mean",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def MatMul(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
817 818 819 820 821 822 823 824 825 826
        # fix paddle shape infer problem
        # should be removed after paddle 1.6
        if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0:
            shape = x.out_shapes[0]
            shape[-1] = y.out_shapes[0][0]
            attr = {"shape": shape}
            node.fluid_code.add_layer("reshape",
                                      inputs=x,
                                      output=x,
                                      param_attr=attr)
J
jiangjiajun 已提交
827 828 829 830 831 832 833 834 835 836
        attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
        node.fluid_code.add_layer("matmul",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)

    def ArgMax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = self.graph.get_node(node.layer.input[1], copy=True)
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
J
jiangjiajun 已提交
837
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
838 839 840 841
        axis = axis.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
842 843 844 845 846 847 848 849 850 851 852 853 854
        node.fluid_code.add_layer("argmax",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def StridedSlice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        end = self.graph.get_node(node.layer.input[2], copy=True)
        strides = self.graph.get_node(node.layer.input[3], copy=True)
        assert begin.layer_type == "Const"
        assert end.layer_type == "Const"
        assert strides.layer_type == "Const"
J
jiangjiajun 已提交
855 856 857
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(end.layer_name, node.layer_name)
        self.add_omit_nodes(strides.layer_name, node.layer_name)
J
jiangjiajun 已提交
858 859 860
        strides = strides.value.tolist()
        assert len(set(strides)) == 1 and strides[0] == 1

J
jiangjiajun 已提交
861 862 863 864 865 866
        begin = begin.value.tolist()
        end = end.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            begin = [begin[i] for i in [0, 3, 1, 2]]
            end = [end[i] for i in [0, 3, 1, 2]]

J
jiangjiajun 已提交
867 868 869 870 871 872 873 874 875
        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

        attr = {
            "axes": [i for i in range(len(strides))],
            "starts": begin,
            "ends": end
        }
J
jiangjiajun 已提交
876 877 878 879
        node.fluid_code.add_layer("slice",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
880 881 882 883 884

    def Slice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        size = self.graph.get_node(node.layer.input[2], copy=True)
J
jiangjiajun 已提交
885 886
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(size.layer_name, node.layer_name)
J
jiangjiajun 已提交
887 888 889 890 891 892 893 894
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
        else:
            begin = self.decoder.infer_tensor(begin).tolist()
        if size.layer_type == "const":
            size = size.value.tolist()
        else:
            size = self.decoder.infer_tensor(size).tolist()
895

J
jiangjiajun 已提交
896 897 898 899
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            size = [size[i] for i in [0, 3, 1, 2]]
            begin = [begin[i] for i in [0, 3, 1, 2]]

900 901 902 903 904 905 906 907 908 909 910 911
        for i in range(len(size)):
            if size[i] < 0:
                size[i] = 99999999
            else:
                size[i] = size[i] + begin[i]

        attr = {
            "axes": [i for i in range(len(size))],
            "starts": begin,
            "ends": size
        }
        node.fluid_code.add_layer("slice",
J
jiangjiajun 已提交
912 913 914
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
915 916

    def Conv2DBackpropInput(self, node):
917
        out_shape = self.graph.get_node(node.layer.input[0], copy=True)
918
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
919 920
        input = self.graph.get_node(node.layer.input[2], copy=True)

921
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
922

J
jiangjiajun 已提交
923
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
924 925
        self.add_omit_nodes(out_shape.layer_name, node.layer_name)

926
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
927 928
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
929
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
930 931 932
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

933
        pad_mode = node.get_attr("padding")
934 935 936 937
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
938

J
jiangjiajun 已提交
939 940
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
941 942 943 944
        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
945 946
        else:
            self.data_format_propagation(node)
947

J
jiangjiajun 已提交
948
        padding = 0
949 950 951
        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
952 953 954 955 956 957 958 959 960
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
                node.fluid_code.add_layer("pad2d",
                                          inputs=input,
                                          output=node,
                                          param_attr=attr)
                input = node
961

962 963 964 965 966 967
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
968 969
            "dilation": dilations[2:4],
            "padding": padding
970
        }
971 972 973 974
        node.fluid_code.add_layer("conv2d_transpose",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
975 976 977 978 979 980

    def Max(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
981 982 983 984 985
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
986 987 988 989 990 991 992 993 994 995
        node.fluid_code.add_layer("reduce_max",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def Sum(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
996 997 998 999 1000
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
1001 1002 1003 1004 1005
        node.fluid_code.add_layer("reduce_sum",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

J
jiangjiajun 已提交
1006 1007 1008 1009 1010 1011 1012 1013
    def Cast(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        dtype = node.dtype_map[node.get_attr('DstT')]
        attr = {"dtype": string(dtype)}
        node.fluid_code.add_layer("cast",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1014

1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027

#    def FloorDiv(self, node):
#        x = self.graph.get_node(node.layer.input[0], copy=True)
#        y = self.graph.get_node(node.layer.input[1], copy=True)
#        inputs = {'x': x, 'y': y}
#        node.fluid_code.add_layer("elementwise_div",
#                                  inputs=inputs,
#                                  output=node,
#                                  param_attr=None)
#        node.fluid_code.add_layer("floor",
#                                  inputs=node,
#                                  output=node,
#                                  param_attr=None)
J
jiangjiajun 已提交
1028 1029 1030 1031

    def Split(self, node):
        dim = self.graph.get_node(node.layer.input[0], copy=True)
        input = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1032
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
1033
        num_split = node.get_attr('num_split')
J
jiangjiajun 已提交
1034 1035 1036 1037 1038
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"num_or_sections": num_split, "dim": dim}
J
jiangjiajun 已提交
1039 1040 1041 1042
        node.fluid_code.add_layer("split",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(squeeze_dims)):
                squeeze_dims[i] = nhwc_dim_to_nchw(input, squeeze_dims[i])
        attr = {"axes": squeeze_dims}
        node.fluid_code.add_layer("squeeze",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def Softmax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = node.get_attr("axis")
J
jiangjiajun 已提交
1059 1060
        if axis is None:
            axis = -1 + len(input.out_shapes[0])
J
jiangjiajun 已提交
1061 1062 1063 1064 1065 1066 1067
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
        node.fluid_code.add_layer("softmax",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
1068 1069 1070 1071

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1072
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
J
jiangjiajun 已提交
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(resize_shape)
        align_corners = node.get_attr("align_corners")
        attr = {"align_corners": align_corners, "out_shape": resize_shape}
        node.fluid_code.add_layer("resize_nearest",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1087
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
J
jiangjiajun 已提交
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(resize_shape)
        align_corners = node.get_attr("align_corners")
        attr = {
            "align_corners": align_corners,
            "out_shape": resize_shape,
            "align_mode": 1
        }
        node.fluid_code.add_layer("resize_bilinear",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1102 1103 1104 1105

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1106
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(
                resize_shape, node.out_shapes[0])
        align_corners = node.get_attr("align_corners")
        attr = {"align_corners": align_corners, "out_shape": resize_shape}
        node.fluid_code.add_layer("resize_nearest",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1122
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(
                resize_shape, node.out_shapes[0])
        align_corners = node.get_attr("align_corners")
        attr = {
            "align_corners": align_corners,
            "out_shape": resize_shape,
            "align_mode": 1
        }
        node.fluid_code.add_layer("resize_bilinear",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168

    def GreaterEqual(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
        node.fluid_code.add_layer("greater_equal",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        self.add_omit_nodes(shape.layer_name, node.layer_name)
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
        else:
            shape = self.decoder.infer_shape_tensor(shape)
        if node.tf_data_format == "NHWC" and len(shape) == 4:
            shape = [shape[i] for i in [0, 3, 1, 2]]
        attr = {"shape": shape, "min": 0.0, "max": 0.9999}
        if shape[0] < 0:
            input = self.batch_node
            node.fluid_code.add_layer("uniform_random_batch_size_like",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        else:
            node.fluid_code.add_layer("uniform_random",
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)