tf_op_mapper.py 49.3 KB
Newer Older
J
jiangjiajun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

J
jiangjiajun 已提交
15 16
from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
J
jiangjiajun 已提交
17
from x2paddle.core.util import *
J
jiangjiajun 已提交
18
import inspect
J
jiangjiajun 已提交
19
import numpy
J
jiangjiajun 已提交
20
import sys
21

J
jiangjiajun 已提交
22

J
jiangjiajun 已提交
23 24 25 26 27 28 29 30
# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]

J
jiangjiajun 已提交
31

J
jiangjiajun 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
def nhwc_dim_to_nchw(node, dim):
    tf_data_format = list(node.tf_data_format)
    pd_data_format = list(node.pd_data_format)
    if isinstance(dim, list):
        for i in range(len(dim)):
            char = tf_data_format[dim[i]]
            dim[i] = pd_data_format.index(char)
    else:
        char = tf_data_format[dim]
        dim = pd_data_format.index(char)
    return dim

    if dim < 0:
        dim += 4
    if dim > 0:
        dim = (dim + 1) % 4 + int((dim + 1) / 4)
    return dim


J
jiangjiajun 已提交
51
class TFOpMapper(OpMapper):
J
jiangjiajun 已提交
52 53 54 55 56 57 58
    directly_map_ops = {
        'Relu': ['relu'],
        'Relu6': ['relu6'],
        'Shape': ['shape'],
        'Abs': ['abs'],
        'Sigmoid': ['sigmoid'],
        'Exp': ['exp'],
J
jiangjiajun 已提交
59
        'Rsqrt': ['rsqrt'],
60 61 62 63
        'swish_f32': ['swish'],
        'LeakyRelu': ['leaky_relu', {
            'alpha': 'alpha'
        }]
J
jiangjiajun 已提交
64 65 66 67 68 69
    }
    elementwise_ops = {
        'Add': 'elementwise_add',
        'RealDiv': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Maximum': 'elementwise_max',
70 71
        'Mul': 'elementwise_mul',
        'FloorDiv': 'elementwise_floordiv'
J
jiangjiajun 已提交
72 73
    }

J
jiangjiajun 已提交
74 75
    def __init__(self, decoder):
        super(TFOpMapper, self).__init__()
J
jiangjiajun 已提交
76
        self.decoder = decoder
J
jiangjiajun 已提交
77
        self.graph = decoder.tf_graph
78
        self.batch_node = None
J
jiangjiajun 已提交
79
        self.weights = dict()
J
jiangjiajun 已提交
80
        self.omit_nodes = list()
J
jiangjiajun 已提交
81
        self.used_custom_layers = dict()
82

J
jiangjiajun 已提交
83 84 85 86 87 88 89
        not_placeholder = list()
        for name in self.graph.input_nodes:
            if self.graph.get_node(name).layer_type != "Placeholder":
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]
J
jiangjiajun 已提交
90

91
        sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
J
jiangjiajun 已提交
92
        unsupported_ops = set()
93 94
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...    ".format(i + 1))
95 96
            node = self.graph.get_node(node_name)
            op = node.layer_type
J
jiangjiajun 已提交
97
            if op in self.directly_map_ops:
J
jiangjiajun 已提交
98 99
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
100 101
                self.directly_map(node)
            elif op in self.elementwise_ops:
J
jiangjiajun 已提交
102 103
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
104 105
                self.elementwise_map(node)
            elif hasattr(self, op):
J
jiangjiajun 已提交
106 107
                if len(unsupported_ops) > 0:
                    continue
J
jiangjiajun 已提交
108 109
                func = getattr(self, op)
                func(node)
J
jiangjiajun 已提交
110
            else:
J
jiangjiajun 已提交
111 112
                unsupported_ops.add(op)
        if len(unsupported_ops) > 0:
113 114 115
            sys.stderr.write(
                "=========={} Ops are not supported yet======\n".format(
                    len(unsupported_ops)))
J
jiangjiajun 已提交
116
            for op in unsupported_ops:
117
                sys.stderr.write("========== {} ==========\n".format(op))
J
jiangjiajun 已提交
118
            sys.exit(-1)
119
        sys.stderr.write('\nDone!\n')
120

J
jiangjiajun 已提交
121 122 123 124 125 126 127 128 129
    def add_omit_nodes(self, in_node_name, out_node_name):
        in_node = self.graph.get_node(in_node_name)
        out_node = self.graph.get_node(out_node_name)
        index = in_node.outputs.index(out_node_name)
        del in_node.outputs[index]
        index = out_node.inputs.index(in_node_name)
        del out_node.inputs[index]
        self.omit_nodes.append(in_node.layer_name)

J
jiangjiajun 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_node(node.layer.input[0], copy=True)
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
        node.fluid_code.add_layer(op_info[0],
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
J
jiangjiajun 已提交
148 149 150 151
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
152 153 154 155
        if len(x_shape) == 0:
            x_shape = [1]
        if len(y_shape) == 0:
            y_shape = [1]
J
jiangjiajun 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        # incomplement broadcasting support for paddle
        x_input = x
        y_input = y
        if len(x_shape) < len(y_shape):
            unrevertable_ops = [
                "elementwise_sub", "elementwise_div", "elementwise_floordiv",
                "elementwise_mod", "elementwise_pow"
            ]
            if op_type not in unrevertable_ops:
                x_input = y
                y_input = x
                x_shape = y.out_shapes[0]
                y_shape = x.out_shapes[0]
            else:
                raise Exception("Unexpected situation happend")

J
jiangjiajun 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184
        if len(x_shape) == 4 and len(y_shape) == 1:
            if x_input.tf_data_format == "NHWC":
                axis = 1
            else:
                axis = -1
            attr = {"axis": axis}
            inputs = {"x": x_input, "y": y_input}
            node.fluid_code.add_layer(op_type,
                                      inputs=inputs,
                                      output=node,
                                      param_attr=attr)
            return

J
jiangjiajun 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        is_sub_seq = True
        for i in range(len(y_shape)):
            index = -1 * i - 1
            if y_shape[index] != x_shape[index]:
                is_sub_seq = False
        if not is_sub_seq:
            x_expand_times = [1] * len(x_shape)
            y_expand_times = [1] * len(y_shape)
            x_need_expand = False
            y_need_expand = False
            for i in range(len(y_shape)):
                index = -1 * i - 1
                if y_shape[index] != x_shape[index]:
                    if y_shape[index] == 1:
                        y_expand_times[index] = x_shape[index]
                        y_need_expand = True
                    elif x_shape[index] == 1:
                        x_expand_times[index] = y_shape[index]
                        x_need_expand = True
                    else:
                        raise Exception("Unexpected situation happend")
            if x_need_expand:
J
jiangjiajun 已提交
207 208 209 210
                if len(x_expand_times) == 3 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [2, 0, 1]]
                if len(x_expand_times) == 4 and x.tf_data_format == "NHWC":
                    x_expand_times = [x_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
211 212 213 214 215 216 217
                attr = {"expand_times": x_expand_times}
                node.fluid_code.add_layer("expand",
                                          inputs=x_input,
                                          output="x_tmp",
                                          param_attr=attr)
                x_input = "x_tmp"
            if y_need_expand:
J
jiangjiajun 已提交
218 219 220 221
                if len(y_expand_times) == 3 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [2, 0, 1]]
                if len(y_expand_times) == 4 and y.tf_data_format == "NHWC":
                    y_expand_times = [y_expand_times[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
222 223 224 225 226 227 228 229 230 231 232 233
                attr = {"expand_times": y_expand_times}
                node.fluid_code.add_layer("expand",
                                          inputs=y_input,
                                          output="y_tmp",
                                          param_attr=attr)
                y_input = "y_tmp"
        inputs = {"x": x_input, "y": y_input}
        node.fluid_code.add_layer(op_type,
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

234 235
    def Placeholder(self, node):
        shape = node.out_shapes[0]
J
jiangjiajun 已提交
236 237
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
J
jiangjiajun 已提交
238 239 240 241
        if node.tf_data_format == "NHWC" and len(shape) == 4:
            shape = [shape[i] for i in [0, 3, 1, 2]]
        elif node.tf_data_format == "NCHW" and len(shape) == 4:
            self.graph.data_format_propagation(node)
242 243
        dtype = node.dtype
        attr = {
J
jiangjiajun 已提交
244
            'dtype': string(dtype),
245
            'shape': shape,
J
jiangjiajun 已提交
246 247
            'name': string(node.layer_name),
            'append_batch_size': False
248
        }
249 250 251
        if shape[0] < 0:
            self.batch_node = node

J
jiangjiajun 已提交
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        node.fluid_code.add_layer("data",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

J
jiangjiajun 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279
        self.weights[node.layer_name] = node.value

        if node.tf_data_format == "NHWC":
            if len(shape) == 4:
                shape = [shape[i] for i in [0, 3, 1, 2]]
            if len(shape) == 3:
                shape = [shape[i] for i in [2, 0, 1]]
                self.weights[node.layer_name] = numpy.transpose(
                    node.value, (2, 0, 1))
        elif node.tf_data_format == "NCHW":
            if len(shape) == 4:
                self.graph.data_format_propagation(node)

J
jiangjiajun 已提交
280 281 282 283 284 285 286 287 288 289 290 291
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
        node.fluid_code.add_layer("create_parameter",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def Transpose(self, node):
J
jiangjiajun 已提交
292 293
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
294
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
295
        del self.weights[perm.layer_name.replace('/', '_')]
J
jiangjiajun 已提交
296 297 298
        perm.fluid_code.clear()
        perm = perm.value.tolist()

J
jiangjiajun 已提交
299
        if perm == [0, 3, 1, 2] and input.data_format == "NHWC":
300 301 302 303 304
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
J
jiangjiajun 已提交
305 306 307
            node.tf_data_format = "NCHW"
            self.graph.data_format_propagation(node)
        elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW":
308 309 310 311 312
            input_name = input.layer_name
            if hasattr(input, "index"):
                input_name = input_name + "[{}]".format(input.index)
            node.fluid_code.add_layer("{} = {}").format(node.layer_name,
                                                        input_name)
J
jiangjiajun 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
            node.tf_data_format = "NHWC"
            self.graph.data_format_propagation(node)
        elif len(input.out_shapes[0]) > 4:
            tf_data_format = list(input.tf_data_format)
            pd_data_format = list(input.pd_data_format)
            new_perm = [i for i in range(len(perm))]
            for i in range(len(perm)):
                char0 = tf_data_format[i]
                char1 = tf_data_format[perm[i]]
                index0 = pd_data_format.index(char0)
                index1 = pd_data_format.index(char1)
                new_perm[index0] = index1
            node.tf_data_format = [tf_data_format[i] for i in perm]
            node.pd_data_format = [pd_data_format[i] for i in perm]
            attr = {'perm': new_perm}
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        elif len(node.out_shapes[0]) != 4:
            attr = {'perm': perm}
            node.fluid_code.add_layer("transpose",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        else:
            raise Exception("Unexpected situation happend in Transpose OP")
J
jiangjiajun 已提交
340

J
jiangjiajun 已提交
341 342
    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
343

J
jiangjiajun 已提交
344
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
345 346 347
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
348 349 350 351
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
352
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
353
        padding = 0
J
jiangjiajun 已提交
354

J
jiangjiajun 已提交
355
        if not channel_first:
J
jiangjiajun 已提交
356 357
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
358
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
359 360
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
361 362

        if pad_mode == "SAME":
J
jiangjiajun 已提交
363 364
            pad_h = get_same_padding(in_shape[2], k_size[2], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[3], strides[3])
J
jiangjiajun 已提交
365 366 367
            pad_h = pad_h[0] + pad_h[1]
            pad_w = pad_w[0] + pad_w[1]
            attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0}
J
jiangjiajun 已提交
368 369 370 371 372
            node.fluid_code.add_layer("pad2d",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
            input = node
J
jiangjiajun 已提交
373
        attr = {
J
jiangjiajun 已提交
374
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
375
            "pool_type": string("max"),
J
jiangjiajun 已提交
376
            "pool_padding": padding,
J
jiangjiajun 已提交
377
            "pool_stride": strides[2:4]
J
jiangjiajun 已提交
378
        }
J
jiangjiajun 已提交
379 380 381 382
        node.fluid_code.add_layer("pool2d",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
383 384 385 386 387

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of Conv2D should be Const"
J
jiangjiajun 已提交
388
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
389

J
jiangjiajun 已提交
390
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
391 392
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
393
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
394 395 396
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
397 398 399 400 401
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
402 403 404 405
        padding = 0

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
J
jiangjiajun 已提交
406 407 408 409 410

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
411 412
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
413

J
jiangjiajun 已提交
414 415 416
        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
417 418 419 420 421 422 423 424 425
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
                node.fluid_code.add_layer("pad2d",
                                          inputs=input,
                                          output=node,
                                          param_attr=attr)
                input = node
J
jiangjiajun 已提交
426 427 428 429 430 431
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
432 433
            "dilation": dilations[2:4],
            "padding": padding
J
jiangjiajun 已提交
434
        }
J
jiangjiajun 已提交
435 436 437 438
        node.fluid_code.add_layer("conv2d",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
439

J
jiangjiajun 已提交
440 441 442 443 444 445 446 447 448 449 450 451
    def BiasAdd(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        axis = -1
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = 1
        inputs = {"x": input, "y": bias}
        attr = {"axis": axis}
        node.fluid_code.add_layer("elementwise_add",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
452 453 454 455 456 457 458

    def FusedBatchNorm(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        gamma = self.graph.get_node(node.layer.input[1], copy=True)
        beta = self.graph.get_node(node.layer.input[2], copy=True)
        moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
        moving_var = self.graph.get_node(node.layer.input[4], copy=True)
J
jiangjiajun 已提交
459 460
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
461 462 463 464 465

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
466 467 468 469
        self.add_omit_nodes(gamma.layer_name, node.layer_name)
        self.add_omit_nodes(beta.layer_name, node.layer_name)
        self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
        self.add_omit_nodes(moving_var.layer_name, node.layer_name)
J
jiangjiajun 已提交
470 471
        if channel_first:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
472

J
jiangjiajun 已提交
473 474 475 476 477 478 479 480 481 482
        attr = {
            "epsilon": node.get_attr("epsilon"),
            "param_attr": string(gamma.layer_name),
            "bias_attr": string(beta.layer_name),
            "moving_mean_name": string(moving_mean.layer_name),
            "moving_variance_name": string(moving_var.layer_name),
            "is_test": True
        }

        node.fluid_code.add_layer("batch_norm",
J
jiangjiajun 已提交
483
                                  inputs=input,
J
jiangjiajun 已提交
484 485 486 487 488 489 490
                                  output=node,
                                  param_attr=attr)

    def DepthwiseConv2dNative(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
491
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
J
jiangjiajun 已提交
492

J
jiangjiajun 已提交
493
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
494 495
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
J
jiangjiajun 已提交
496
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
497 498 499
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

J
jiangjiajun 已提交
500 501 502 503 504
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"
J
jiangjiajun 已提交
505 506 507 508
        padding = 0

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (2, 3, 0, 1))
J
jiangjiajun 已提交
509 510 511 512 513

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
514 515
        else:
            self.data_format_propagation(node)
J
jiangjiajun 已提交
516 517 518 519

        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
520 521 522 523
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
J
jiangjiajun 已提交
524
                node.fluid_code.add_layer("pad2d",
J
jiangjiajun 已提交
525
                                          inputs=input,
J
jiangjiajun 已提交
526 527
                                          output=node,
                                          param_attr=attr)
J
jiangjiajun 已提交
528 529
                input = node

J
jiangjiajun 已提交
530 531 532 533 534 535 536
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": in_shape[1],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
J
jiangjiajun 已提交
537
            "groups": k_size[3] * in_shape[1],
J
jiangjiajun 已提交
538
            "use_cudnn": False,
J
jiangjiajun 已提交
539
            "padding": padding
J
jiangjiajun 已提交
540
        }
J
jiangjiajun 已提交
541
        node.fluid_code.add_layer("conv2d",
J
jiangjiajun 已提交
542
                                  inputs=input,
J
jiangjiajun 已提交
543 544
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
545

J
jiangjiajun 已提交
546 547 548 549 550
    def Reshape(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        param = self.graph.get_node(node.layer.input[1], copy=True)
        if param.layer_type == "Const":
            attr = {"shape": param.value.tolist()}
J
jiangjiajun 已提交
551
            self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
552 553
        else:
            # Here is a trick method to solove tensor parameter in tensorflow
J
jiangjiajun 已提交
554 555 556
            shape = self.decoder.infer_shape_tensor(param, node.out_shapes[0])
            if shape.count(-1) <= 1:
                attr = {"shape": shape}
J
jiangjiajun 已提交
557 558 559 560 561
                self.add_omit_nodes(param.layer_name, node.layer_name)
            elif shape.count(-1) == 2 and shape[0] == -1:
                shape[0] = 0
                attr = {"shape": shape}
                self.add_omit_nodes(param.layer_name, node.layer_name)
J
jiangjiajun 已提交
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
            else:
                assert len(param.out_shapes[0]
                           ) == 1, "Unexpected situation of shape parameter"
                attr = {"shape": [-1]}
                node.fluid_code.add_layer("reshape",
                                          inputs=param,
                                          output="shape_param",
                                          param_attr=attr)
                attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0}
                node.fluid_code.add_layer("split",
                                          inputs="shape_param",
                                          output=node,
                                          param_attr=attr)
                new_param = "["
                for i in range(param.out_shapes[0][0]):
                    new_param += (node.layer_name + "[{}]".format(i) + ", ")
                new_param = new_param.strip(", ") + "]"
                attr = {"shape": new_param}
580 581 582 583 584 585 586 587 588 589 590 591 592 593

        if len(input.out_shapes[0]) == 4 and node.tf_data_format == "NHWC":
            if len(attr["shape"]) < 3:
                perm = {"perm": [0, 2, 3, 1]}
                node.fluid_code.add_layer("transpose",
                                          inputs=input,
                                          output=node,
                                          param_attr=perm)
                node.fluid_code.add_layer("reshape",
                                          inputs=node,
                                          output=node,
                                          param_attr=attr)
                return

J
jiangjiajun 已提交
594
        if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
J
jiangjiajun 已提交
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613
            input_shape = self.decoder.infer_tensor(input).shape
            if input_shape[1] == attr["shape"][1]:
                attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
            else:
                perm = {"perm": [0, 2, 3, 1]}
                node.fluid_code.add_layer("transpose",
                                          inputs=input,
                                          output=node,
                                          param_attr=perm)
                node.fluid_code.add_layer("reshape",
                                          inputs=node,
                                          output=node,
                                          param_attr=attr)
                perm = {"perm": [0, 3, 1, 2]}
                node.fluid_code.add_layer("transpose",
                                          inputs=node,
                                          output=node,
                                          param_attr=perm)
                return
J
jiangjiajun 已提交
614 615 616 617 618 619 620
        node.fluid_code.add_layer("reshape",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def AvgPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
621

J
jiangjiajun 已提交
622
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
623 624 625
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape

J
jiangjiajun 已提交
626 627 628 629 630 631 632 633 634
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
635
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
636 637
        else:
            self.graph.data_format_propagation(node)
J
jiangjiajun 已提交
638 639

        attr = {
J
jiangjiajun 已提交
640
            "pool_size": k_size[2:4],
J
jiangjiajun 已提交
641 642 643 644
            "pool_type": string("avg"),
            "pool_stride": strides[2:4]
        }
        if pad_mode == "SAME":
J
jiangjiajun 已提交
645 646
            pad_h = get_same_padding(in_shape[2], k_size[2], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[3], strides[3])
J
jiangjiajun 已提交
647 648 649 650
            assert pad_h[0] == pad_h[1] and pad_w[0] == pad_w[
                1], "Cannot map AvgPool"
            attr["pool_padding"] = [pad_h[0], pad_w[0]]
        node.fluid_code.add_layer("pool2d",
J
jiangjiajun 已提交
651
                                  inputs=input,
J
jiangjiajun 已提交
652 653 654
                                  output=node,
                                  param_attr=attr)

J
jiangjiajun 已提交
655 656 657 658 659 660
    def SplitV(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        num_sections = self.graph.get_node(node.layer.input[1], copy=True)
        dim = self.graph.get_node(node.layer.input[2], copy=True)
        assert num_sections.layer_type == "Const"
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
661 662
        self.add_omit_nodes(num_sections.layer_name, node.layer_name)
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
663 664 665
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)
J
jiangjiajun 已提交
666 667 668 669 670 671 672 673
        attr = {
            "num_or_sections": num_sections.value.tolist(),
            "dim": dim.value
        }
        node.fluid_code.add_layer("split",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
674 675

    def ConcatV2(self, node):
J
jiangjiajun 已提交
676 677 678 679
        inputs = [
            self.graph.get_node(name, copy=True)
            for name in node.layer.input[:-1]
        ]
J
jiangjiajun 已提交
680 681
        axis = self.graph.get_node(node.layer.input[-1], copy=True)
        assert axis.layer_type == "Const"
J
jiangjiajun 已提交
682
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
683 684 685 686 687
        axis = axis.value
        if inputs[0].tf_data_format == "NHWC" and len(
                inputs[0].out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(inputs[0], axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
688 689 690 691
        node.fluid_code.add_layer("concat",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
692 693 694 695

    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        expand_times = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
696
        self.add_omit_nodes(expand_times.layer_name, node.layer_name)
697 698 699 700
        if expand_times.layer_type == "Const":
            expand_times = expand_times.value.tolist()
        else:
            expand_times = self.decoder.infer_shape_tensor(expand_times)
J
jiangjiajun 已提交
701 702 703 704 705
        if input.tf_data_format == "NHWC":
            if len(input.out_shapes[0]) == 4:
                expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
            elif len(input.out_shape[0]) == 3:
                expand_times = [expand_times[i] for i in [2, 0, 1]]
706 707 708 709
        for i in range(len(expand_times)):
            if expand_times[i] < 0:
                expand_times[i] = 1

J
jiangjiajun 已提交
710
        attr = {"expand_times": expand_times}
J
jiangjiajun 已提交
711 712 713 714
        node.fluid_code.add_layer("expand",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
715 716

    def Pack(self, node):
J
jiangjiajun 已提交
717 718 719
        inputs = [
            self.graph.get_node(name, copy=True) for name in node.layer.input
        ]
J
jiangjiajun 已提交
720 721 722 723 724 725 726 727 728 729 730 731
        axis = node.get_attr("axis")
        if inputs[0].tf_data_format == "NHWC" and len(
                inputs[0].out_shapes[0]) == 4:
            tf_data_format = list(inputs[0].tf_data_format)
            tf_data_format.insert(axis, str(len(tf_data_format)))
            axis = nhwc_dim_to_nchw(inputs[0], axis)
            pd_data_format = list(inputs[0].pd_data_format)
            pd_data_format.insert(axis, str(len(pd_data_format)))
            node.tf_data_format = "".join(tf_data_format)
            node.pd_data_format = "".join(pd_data_format)

        attr = {"axis": axis}
J
jiangjiajun 已提交
732 733 734
        node.fluid_code.add_layer("stack",
                                  inputs=inputs,
                                  output=node,
J
jiangjiajun 已提交
735
                                  param_attr=attr)
J
jiangjiajun 已提交
736 737 738

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
J
jiangjiajun 已提交
739
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
740
        assert paddings.layer_type == "Const", "Padding should be Const"
J
jiangjiajun 已提交
741
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
J
jiangjiajun 已提交
742 743 744
        paddings = paddings.value.flatten().tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
J
jiangjiajun 已提交
745 746 747 748 749 750

        pad_op = "pad"
        if len(input.out_shapes[0]) == 4:
            if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                paddings = paddings[4:]
                pad_op = "pad2d"
J
jiangjiajun 已提交
751
        attr = {"paddings": paddings}
J
jiangjiajun 已提交
752
        node.fluid_code.add_layer(pad_op,
J
jiangjiajun 已提交
753 754 755
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
756 757 758 759 760 761 762

    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0], copy=True)
        limit = self.graph.get_node(node.layer.input[1], copy=True)
        delta = self.graph.get_node(node.layer.input[2], copy=True)
        if start.layer_type == "Const":
            start = start.value
763 764
        else:
            start = self.decoder.infer_tensor(start)
J
jiangjiajun 已提交
765 766
        if limit.layer_type == "Const":
            limit = limit.value
767 768
        else:
            limit = self.decoder.infer_tensor(limit)
J
jiangjiajun 已提交
769 770
        if delta.layer_type == "Const":
            delta = delta.value
771 772
        else:
            delta = self.decoder.infer_tensor(delta)
J
jiangjiajun 已提交
773 774 775
        self.add_omit_nodes(start.layer_name, node.layer_name)
        self.add_omit_nodes(limit.layer_name, node.layer_name)
        self.add_omit_nodes(delta.layer_name, node.layer_name)
776

J
jiangjiajun 已提交
777
        inputs = {"start": start, "end": limit, "step": delta}
J
jiangjiajun 已提交
778
        attr = {"dtype": string(node.dtype)}
779 780 781 782
        node.fluid_code.add_layer("range",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)
J
jiangjiajun 已提交
783 784 785 786 787

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
J
jiangjiajun 已提交
788
        dims = reduce_idx.value.tolist()
J
jiangjiajun 已提交
789
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
790 791 792 793 794 795

        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(dims)):
                dims[i] = nhwc_dim_to_nchw(input, dims[i])

        attr = {"dim": dims, "keep_dim": keep_dims}
J
jiangjiajun 已提交
796 797 798 799 800 801 802 803 804 805 806
        node.fluid_code.add_layer("reduce_mean",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def MatMul(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
807 808 809 810 811 812 813 814 815 816
        # fix paddle shape infer problem
        # should be removed after paddle 1.6
        if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0:
            shape = x.out_shapes[0]
            shape[-1] = y.out_shapes[0][0]
            attr = {"shape": shape}
            node.fluid_code.add_layer("reshape",
                                      inputs=x,
                                      output=x,
                                      param_attr=attr)
J
jiangjiajun 已提交
817 818 819 820 821 822 823 824 825 826
        attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
        node.fluid_code.add_layer("matmul",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)

    def ArgMax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = self.graph.get_node(node.layer.input[1], copy=True)
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
J
jiangjiajun 已提交
827
        self.add_omit_nodes(axis.layer_name, node.layer_name)
J
jiangjiajun 已提交
828 829 830 831
        axis = axis.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
J
jiangjiajun 已提交
832 833 834 835 836 837 838 839 840 841 842 843 844
        node.fluid_code.add_layer("argmax",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def StridedSlice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        end = self.graph.get_node(node.layer.input[2], copy=True)
        strides = self.graph.get_node(node.layer.input[3], copy=True)
        assert begin.layer_type == "Const"
        assert end.layer_type == "Const"
        assert strides.layer_type == "Const"
J
jiangjiajun 已提交
845 846 847
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(end.layer_name, node.layer_name)
        self.add_omit_nodes(strides.layer_name, node.layer_name)
J
jiangjiajun 已提交
848 849 850
        strides = strides.value.tolist()
        assert len(set(strides)) == 1 and strides[0] == 1

J
jiangjiajun 已提交
851 852 853 854 855 856
        begin = begin.value.tolist()
        end = end.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            begin = [begin[i] for i in [0, 3, 1, 2]]
            end = [end[i] for i in [0, 3, 1, 2]]

J
jiangjiajun 已提交
857 858 859 860 861 862 863 864 865
        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

        attr = {
            "axes": [i for i in range(len(strides))],
            "starts": begin,
            "ends": end
        }
J
jiangjiajun 已提交
866 867 868 869
        node.fluid_code.add_layer("slice",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
870 871 872 873 874

    def Slice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        size = self.graph.get_node(node.layer.input[2], copy=True)
J
jiangjiajun 已提交
875 876
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(size.layer_name, node.layer_name)
J
jiangjiajun 已提交
877 878 879 880 881 882 883 884
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
        else:
            begin = self.decoder.infer_tensor(begin).tolist()
        if size.layer_type == "const":
            size = size.value.tolist()
        else:
            size = self.decoder.infer_tensor(size).tolist()
885

J
jiangjiajun 已提交
886 887 888 889
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            size = [size[i] for i in [0, 3, 1, 2]]
            begin = [begin[i] for i in [0, 3, 1, 2]]

890 891 892 893 894 895 896 897 898 899 900 901
        for i in range(len(size)):
            if size[i] < 0:
                size[i] = 99999999
            else:
                size[i] = size[i] + begin[i]

        attr = {
            "axes": [i for i in range(len(size))],
            "starts": begin,
            "ends": size
        }
        node.fluid_code.add_layer("slice",
J
jiangjiajun 已提交
902 903 904
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
905 906

    def Conv2DBackpropInput(self, node):
907
        out_shape = self.graph.get_node(node.layer.input[0], copy=True)
908
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
909 910
        input = self.graph.get_node(node.layer.input[2], copy=True)

911
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
912

J
jiangjiajun 已提交
913
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
914 915
        self.add_omit_nodes(out_shape.layer_name, node.layer_name)

916
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
917 918
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
919
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
920 921 922
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

923
        pad_mode = node.get_attr("padding")
924 925 926 927
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
928

J
jiangjiajun 已提交
929 930
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
931 932 933 934
        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
935 936
        else:
            self.data_format_propagation(node)
937

J
jiangjiajun 已提交
938
        padding = 0
939 940 941
        if pad_mode == "SAME":
            pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
            pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
J
jiangjiajun 已提交
942 943 944 945 946 947 948 949 950
            if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
                padding = [pad_h[0], pad_w[0]]
            else:
                attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
                node.fluid_code.add_layer("pad2d",
                                          inputs=input,
                                          output=node,
                                          param_attr=attr)
                input = node
951

952 953 954 955 956 957
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
J
jiangjiajun 已提交
958 959
            "dilation": dilations[2:4],
            "padding": padding
960
        }
961 962 963 964
        node.fluid_code.add_layer("conv2d_transpose",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
965 966 967 968 969 970

    def Max(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
971 972 973 974 975
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
976 977 978 979 980 981 982 983 984 985
        node.fluid_code.add_layer("reduce_max",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def Sum(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
J
jiangjiajun 已提交
986 987 988 989 990
        dim = reduce_idx.value.tolist()
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"dim": dim, "keep_dim": keep_dims}
991 992 993 994 995
        node.fluid_code.add_layer("reduce_sum",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

J
jiangjiajun 已提交
996 997 998 999 1000 1001 1002 1003
    def Cast(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        dtype = node.dtype_map[node.get_attr('DstT')]
        attr = {"dtype": string(dtype)}
        node.fluid_code.add_layer("cast",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1004

J
jiangjiajun 已提交
1005 1006 1007
    def Split(self, node):
        dim = self.graph.get_node(node.layer.input[0], copy=True)
        input = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1008
        self.add_omit_nodes(dim.layer_name, node.layer_name)
J
jiangjiajun 已提交
1009
        num_split = node.get_attr('num_split')
J
jiangjiajun 已提交
1010 1011 1012 1013 1014
        dim = dim.value
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            dim = nhwc_dim_to_nchw(input, dim)

        attr = {"num_or_sections": num_split, "dim": dim}
J
jiangjiajun 已提交
1015 1016 1017 1018
        node.fluid_code.add_layer("split",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            for i in range(len(squeeze_dims)):
                squeeze_dims[i] = nhwc_dim_to_nchw(input, squeeze_dims[i])
        attr = {"axes": squeeze_dims}
        node.fluid_code.add_layer("squeeze",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def Softmax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = node.get_attr("axis")
J
jiangjiajun 已提交
1035 1036
        if axis is None:
            axis = -1 + len(input.out_shapes[0])
J
jiangjiajun 已提交
1037 1038 1039 1040 1041 1042 1043
        if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
            axis = nhwc_dim_to_nchw(input, axis)
        attr = {"axis": axis}
        node.fluid_code.add_layer("softmax",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
J
jiangjiajun 已提交
1044 1045 1046 1047

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1048
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
J
jiangjiajun 已提交
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(resize_shape)
        align_corners = node.get_attr("align_corners")
        attr = {"align_corners": align_corners, "out_shape": resize_shape}
        node.fluid_code.add_layer("resize_nearest",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1063
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
J
jiangjiajun 已提交
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(resize_shape)
        align_corners = node.get_attr("align_corners")
        attr = {
            "align_corners": align_corners,
            "out_shape": resize_shape,
            "align_mode": 1
        }
        node.fluid_code.add_layer("resize_bilinear",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1078 1079 1080 1081

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1082
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(
                resize_shape, node.out_shapes[0])
        align_corners = node.get_attr("align_corners")
        attr = {"align_corners": align_corners, "out_shape": resize_shape}
        node.fluid_code.add_layer("resize_nearest",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
J
jiangjiajun 已提交
1098
        self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
        else:
            resize_shape = self.decoder.infer_shape_tensor(
                resize_shape, node.out_shapes[0])
        align_corners = node.get_attr("align_corners")
        attr = {
            "align_corners": align_corners,
            "out_shape": resize_shape,
            "align_mode": 1
        }
        node.fluid_code.add_layer("resize_bilinear",
                                  inputs=input,
                                  output=node,
                                  param_attr=attr)
1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144

    def GreaterEqual(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
        node.fluid_code.add_layer("greater_equal",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        self.add_omit_nodes(shape.layer_name, node.layer_name)
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
        else:
            shape = self.decoder.infer_shape_tensor(shape)
        if node.tf_data_format == "NHWC" and len(shape) == 4:
            shape = [shape[i] for i in [0, 3, 1, 2]]
        attr = {"shape": shape, "min": 0.0, "max": 0.9999}
        if shape[0] < 0:
            input = self.batch_node
            node.fluid_code.add_layer("uniform_random_batch_size_like",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        else:
            node.fluid_code.add_layer("uniform_random",
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)
J
jiangjiajun 已提交
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175

    def GreaterEqual(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
        node.fluid_code.add_layer("greater_equal",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=None)

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        self.add_omit_nodes(shape.layer_name, node.layer_name)
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
        else:
            shape = self.decoder.infer_shape_tensor(shape)
        if len(shape) == 4 and node.tf_data_format == "NHWC":
            shape = [shape[i] for i in [0, 3, 1, 2]]
        attr = {"shape": shape, "min": 0.0, "max": 0.9999}
        if shape[0] < 0:
            input = self.batch_node
            node.fluid_code.add_layer("uniform_random_batch_size_like",
                                      inputs=input,
                                      output=node,
                                      param_attr=attr)
        else:
            node.fluid_code.add_layer("uniform_random",
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)