tf_op_mapper_nhwc.py 43.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
M
mamingjie-China 已提交
18
import math
19 20 21 22 23 24 25 26 27
import inspect
import numpy
import sys


# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
28 29
    if pad_size < 0:
        pad_size = 0
30 31 32 33 34 35 36 37 38 39 40 41 42 43
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]


class TFOpMapperNHWC(OpMapper):
    directly_map_ops = {
        'Relu': ['relu'],
        'Relu6': ['relu6'],
        'Shape': ['shape'],
        'Abs': ['abs'],
        'Sigmoid': ['sigmoid'],
        'Exp': ['exp'],
        'Rsqrt': ['rsqrt'],
J
jiangjiajun@baidu.com 已提交
44
        'Sqrt': ['sqrt'],
45
        'swish_f32': ['swish'],
46
        'Tanh': ['tanh'],
J
jiangjiajun 已提交
47
        'Softplus': ['softplus'],
48 49
        'LeakyRelu': ['leaky_relu', {
            'alpha': 'alpha'
50
        }],
M
mamingjie-China 已提交
51 52
        'Floor': ['floor'],
        'Erf': ['erf']
53 54 55
    }
    elementwise_ops = {
        'Add': 'elementwise_add',
J
jiangjiajun@baidu.com 已提交
56
        'AddV2': 'elementwise_add',
57 58 59
        'RealDiv': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Maximum': 'elementwise_max',
60
        'Minimum': 'elementwise_min',
M
mamingjie-China 已提交
61
        'LessEqual': 'less_equal',
J
jiangjiajun 已提交
62 63
        'Mul': 'elementwise_mul',
        'FloorDiv': 'elementwise_floordiv'
64 65 66 67 68 69 70
    }

    def __init__(self, decoder):
        super(TFOpMapperNHWC, self).__init__()
        self.decoder = decoder
        self.graph = decoder.tf_graph
        self.weights = dict()
71
        self.batch_node = None
72 73 74 75 76
        self.omit_nodes = list()
        self.used_custom_layers = dict()

        not_placeholder = list()
        for name in self.graph.input_nodes:
M
mamingjie-China 已提交
77 78 79 80 81
            if self.graph.get_node(
                    name).layer_type != "Placeholder" and self.graph.get_node(
                        name
                    ).layer_type != "OneShotIterator" and self.graph.get_node(
                        name).layer_type != "IteratorV2":
82 83 84 85 86 87
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]

        unsupported_ops = set()
88 89
        sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
        for i, node_name in enumerate(self.graph.topo_sort):
M
mamingjie-China 已提交
90
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
91 92 93 94 95 96 97 98 99 100 101 102 103 104
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if op in self.directly_map_ops:
                if len(unsupported_ops) > 0:
                    continue
                self.directly_map(node)
            elif op in self.elementwise_ops:
                if len(unsupported_ops) > 0:
                    continue
                self.elementwise_map(node)
            elif hasattr(self, op):
                if len(unsupported_ops) > 0:
                    continue
                func = getattr(self, op)
J
jiangjiajun@baidu.com 已提交
105 106
                try:
                    func(node)
107
                except Exception as e:
J
jiangjiajun@baidu.com 已提交
108
                    unsupported_ops.add(op)
109
                    print(e)
110 111 112 113 114 115 116 117
            else:
                unsupported_ops.add(op)
        if len(unsupported_ops) > 0:
            print("========= {} OPs are not supported yet ===========".format(
                len(unsupported_ops)))
            for op in unsupported_ops:
                print("========== {} ============".format(op))
            sys.exit(-1)
M
mamingjie-China 已提交
118
        sys.stderr.write("\nDone!\n")
119

J
jiangjiajun 已提交
120 121 122 123 124 125 126 127 128
    def add_omit_nodes(self, in_node_name, out_node_name):
        in_node = self.graph.get_node(in_node_name)
        out_node = self.graph.get_node(out_node_name)
        index = in_node.outputs.index(out_node_name)
        del in_node.outputs[index]
        index = out_node.inputs.index(in_node_name)
        del out_node.inputs[index]
        self.omit_nodes.append(in_node.layer_name)

129 130 131 132 133 134 135 136 137 138
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_node(node.layer.input[0], copy=True)
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
M
modify  
mamingjie-China 已提交
139 140 141

        if len(input.out_shapes[0]) == 4 and op_info[0] != 'shape':
            attr1 = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
142 143
            node.fluid_code.add_layer(
                'transpose', inputs=input, output=node, param_attr=attr1)
M
modify  
mamingjie-China 已提交
144
            input = node
J
jiangjiajun 已提交
145 146
            node.fluid_code.add_layer(
                op_info[0], inputs=input, output=node, param_attr=attr)
M
modify  
mamingjie-China 已提交
147 148
            input = node
            attr2 = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
149 150
            node.fluid_code.add_layer(
                'transpose', inputs=input, output=node, param_attr=attr2)
M
modify  
mamingjie-China 已提交
151
        else:
J
jiangjiajun 已提交
152 153
            node.fluid_code.add_layer(
                op_info[0], inputs=input, output=node, param_attr=attr)
154 155 156 157 158 159

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
160 161 162
        inputs = {"x": x, "y": y}
        node.fluid_code.add_layer(
            op_type, inputs=inputs, output=node, param_attr=None)
163 164 165 166 167 168

    def Placeholder(self, node):
        shape = node.out_shapes[0]
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
        dtype = node.dtype
J
jiangjiajun 已提交
169 170
        if shape[0] < 0:
            self.batch_node = node
171 172 173 174 175 176
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'append_batch_size': False
        }
J
jiangjiajun 已提交
177

J
jiangjiajun 已提交
178 179
        node.fluid_code.add_layer(
            "data", inputs=None, output=node, param_attr=attr)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        initializer = "Constant(0.0)"
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
            initializer = "Constant({})".format(value)

        self.weights[node.layer_name] = node.value

        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'default_initializer': initializer
        }
J
jiangjiajun 已提交
199 200
        node.fluid_code.add_layer(
            "create_parameter", inputs=None, output=node, param_attr=attr)
201 202 203 204 205 206 207 208 209 210

    def Transpose(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        perm = self.graph.get_node(node.layer.input[1], copy=True)
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
        del self.weights[perm.layer_name.replace('/', '_')]
        perm.fluid_code.clear()
        perm = perm.value.tolist()

        attr = {'perm': perm}
J
jiangjiajun 已提交
211 212
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
213

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    def Fill(self, node):
        dims = self.graph.get_node(node.layer.input[0], copy=True)
        input_value = self.graph.get_node(node.layer.input[1], copy=True)

        assert input_value.layer_type == "Const", "Value of fill OP should be Const"

        self.add_omit_nodes(input_value.layer_name, node.layer_name)
        input_value = input_value.value
        input_dtype = string(input_value.dtype)
        attr = {'value': input_value, 'dtype': input_dtype}

        node.fluid_code.add_layer(
            "fill_constant", inputs=dims, output=node, param_attr=attr)

    def DepthToSpace(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)

        block_size = node.get_attr("block_size")
        data_format = node.get_attr("data_format").decode()

        if data_format == "NHWC":
            attr = {"perm": [0, 3, 1, 2]}
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=input, param_attr=attr)
        n, h, w, c = input.out_shapes[0]

        attr = {'shape': [0, block_size * block_size, -1, h, w]}
        node.fluid_code.add_layer(
            "reshape", inputs=input, output=input, param_attr=attr)

        attr = {'perm': [0, 2, 1, 3, 4]}
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=input, param_attr=attr)
        attr = {'shape': [0, c, h, w]}
        node.fluid_code.add_layer(
            "reshape", inputs=input, output=input, param_attr=attr)

        attr = {'upscale_factor': block_size}
        node.fluid_code.add_layer(
            "pixel_shuffle", inputs=input, output=node, param_attr=attr)

        if data_format == "NHWC":
            attr = {"perm": [0, 2, 3, 1]}
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)

260 261 262 263 264 265 266 267 268 269 270
    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)

        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
271 272
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
273 274 275 276 277 278 279
            strides = [strides[i] for i in [0, 3, 1, 2]]
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
            input = node

        attr = {
            "pool_size": k_size[2:4],
            "pool_type": string("max"),
M
mamingjie-China 已提交
280 281
            "pool_stride": strides[2:4],
            "pool_padding": string(pad_mode)
282
        }
J
jiangjiajun 已提交
283 284
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
285 286 287

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
288 289
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
290 291 292 293

    def Conv2D(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
294
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
295 296 297 298 299 300 301 302

        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

J
jiangjiajun@baidu.com 已提交
303 304
        if kernel.layer_type == 'Const':
            kernel_value = kernel.value
305 306 307 308 309 310 311 312 313 314
            kernel_weight_name = kernel.layer_name.replace('/', '_')
        else:
            kernel_value = self.decoder.infer_tensor(kernel)
            if kernel.layer_type == 'Split':
                kernel_weight_name = "{}_{}_kernel".format(node.layer_name,
                                                           kernel.layer_name)
            else:
                kernel_weight_name = kernel.layer_name.replace('/', '_')
        self.weights[kernel_weight_name] = numpy.transpose(kernel_value,
                                                           (3, 2, 0, 1))
315 316 317 318 319

        if not channel_first:
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
320 321
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
322 323 324
            input = node
        attr = {
            "bias_attr": False,
325
            "param_attr": string(kernel_weight_name),
326 327 328 329
            "num_filters": k_size[3],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
330
            "padding": string(pad_mode)
331
        }
J
jiangjiajun@baidu.com 已提交
332 333 334 335

        if hasattr(node, 'dilation') and attr['dilation'] == [1, 1]:
            if len(node.dilation) == 1:
                attr['dilation'] = [1, node.dilation[0]]
J
jiangjiajun 已提交
336 337
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
338 339
        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
340 341
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
342 343 344 345 346

    def BiasAdd(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        bias = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": input, "y": bias}
J
jiangjiajun 已提交
347 348
        node.fluid_code.add_layer(
            "elementwise_add", inputs=inputs, output=node, param_attr=None)
349 350 351 352 353 354 355 356 357 358 359 360 361 362

    def FusedBatchNorm(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        gamma = self.graph.get_node(node.layer.input[1], copy=True)
        beta = self.graph.get_node(node.layer.input[2], copy=True)
        moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
        moving_var = self.graph.get_node(node.layer.input[4], copy=True)
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
363 364 365 366
        self.add_omit_nodes(gamma.layer_name, node.layer_name)
        self.add_omit_nodes(beta.layer_name, node.layer_name)
        self.add_omit_nodes(moving_mean.layer_name, node.layer_name)
        self.add_omit_nodes(moving_var.layer_name, node.layer_name)
367 368 369

        if not channel_first:
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
370 371
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
372 373 374 375 376 377 378 379 380 381 382
            input = node

        attr = {
            "epsilon": node.get_attr("epsilon"),
            "param_attr": string(gamma.layer_name),
            "bias_attr": string(beta.layer_name),
            "moving_mean_name": string(moving_mean.layer_name),
            "moving_variance_name": string(moving_var.layer_name),
            "is_test": True
        }

J
jiangjiajun 已提交
383 384
        node.fluid_code.add_layer(
            "batch_norm", inputs=input, output=node, param_attr=attr)
385 386 387

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
388 389
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
390 391 392 393 394

    def DepthwiseConv2dNative(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
395
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412

        in_shape = input.out_shapes[0]
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (2, 3, 0, 1))

        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
413 414
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
415 416 417 418 419 420 421 422 423 424 425
            input = node

        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
            "num_filters": in_shape[1],
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
            "groups": k_size[3] * in_shape[1],
            "use_cudnn": False,
M
mamingjie-China 已提交
426
            "padding": string(pad_mode)
427
        }
J
jiangjiajun 已提交
428 429
        node.fluid_code.add_layer(
            "conv2d", inputs=input, output=node, param_attr=attr)
430 431 432

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
433 434
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
435 436 437 438 439

    def Reshape(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        param = self.graph.get_node(node.layer.input[1], copy=True)
        if param.layer_type == "Const":
J
jiangjiajun 已提交
440
            self.add_omit_nodes(param.layer_name, node.layer_name)
441
            shape = param.value.tolist()
442
        else:
443 444 445 446 447 448 449 450 451
            shape = param
        inputs = {"x": input, "shape": shape}
        node.fluid_code.add_layer(
            "reshape", inputs=inputs, output=node, param_attr=None)
        if param.layer_type != "Const":
            out_shape = numpy.array(node.out_shapes[0])
            if (out_shape > 0).any():
                out_shape[out_shape < 0] = 0
                attr = {'shape': out_shape.tolist()}
J
jiangjiajun 已提交
452
                node.fluid_code.add_layer(
453
                    "reshape", inputs=node, output=node, param_attr=attr)
454 455 456 457 458 459 460 461 462 463 464 465 466 467

    def AvgPool(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)

        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        channel_first = data_format == "NCHW"

        if not channel_first:
            strides = [strides[i] for i in [0, 3, 1, 2]]
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
468 469
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
470 471 472 473 474
            input = node

        attr = {
            "pool_size": k_size[2:4],
            "pool_type": string("avg"),
M
mamingjie-China 已提交
475 476
            "pool_stride": strides[2:4],
            "pool_padding": string(pad_mode)
477
        }
J
jiangjiajun 已提交
478 479
        node.fluid_code.add_layer(
            "pool2d", inputs=input, output=node, param_attr=attr)
480 481 482

        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
483 484
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
485 486 487 488 489 490 491

    def SplitV(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        num_sections = self.graph.get_node(node.layer.input[1], copy=True)
        dim = self.graph.get_node(node.layer.input[2], copy=True)
        assert num_sections.layer_type == "Const"
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
492 493
        self.add_omit_nodes(num_sections.layer_name, node.layer_name)
        self.add_omit_nodes(dim.layer_name, node.layer_name)
494 495 496 497 498
        dim = dim.value
        attr = {
            "num_or_sections": num_sections.value.tolist(),
            "dim": dim.value
        }
J
jiangjiajun 已提交
499 500
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
501 502 503

    def ConcatV2(self, node):
        inputs = [
J
jiangjiajun 已提交
504 505
            self.graph.get_node(
                name, copy=True) for name in node.layer.input[:-1]
506 507 508
        ]
        axis = self.graph.get_node(node.layer.input[-1], copy=True)
        assert axis.layer_type == "Const"
J
jiangjiajun 已提交
509
        self.add_omit_nodes(axis.layer_name, node.layer_name)
510 511 512 513
        axis = axis.value
        if axis < 0:
            axis += len(inputs[0].out_shapes[0])
        attr = {"axis": axis}
J
jiangjiajun 已提交
514 515
        node.fluid_code.add_layer(
            "concat", inputs=inputs, output=node, param_attr=attr)
516 517 518 519 520

    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        expand_times = self.graph.get_node(node.layer.input[1], copy=True)
        if expand_times.layer_type == "Const":
521
            self.add_omit_nodes(expand_times.layer_name, node.layer_name)
522 523
            expand_times = expand_times.value.tolist()
        else:
524 525
            expand_times = expand_times
        inputs = {"x": input, "expand_times": expand_times}
J
jiangjiajun 已提交
526
        node.fluid_code.add_layer(
527
            "expand", inputs=inputs, output=node, param_attr=None)
528 529 530

    def Pack(self, node):
        inputs = [
J
jiangjiajun 已提交
531 532
            self.graph.get_node(
                name, copy=True) for name in node.layer.input
533
        ]
534 535 536 537 538 539 540 541 542 543 544 545 546 547
        reshape_shape = list()
        for input_node in inputs:
            k_size = input_node.out_shapes[0]
            if len(k_size) and k_size[-1] != -1:
                reshape_shape = [0] * len(k_size)
                reshape_shape[-1] = k_size[-1]
                break
        if len(reshape_shape):
            for i, input_node in enumerate(inputs):
                node.fluid_code.add_layer(
                    "reshape",
                    inputs=input_node,
                    output='tmp_{}'.format(i),
                    param_attr={"shape": reshape_shape})
548 549
        axis = node.get_attr("axis")
        attr = {"axis": axis}
550 551
        if len(reshape_shape):
            inputs = ['tmp_{}'.format(i) for i in range(len(inputs))]
J
jiangjiajun 已提交
552 553
        node.fluid_code.add_layer(
            "stack", inputs=inputs, output=node, param_attr=attr)
554 555 556 557 558

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        paddings = self.graph.get_node(node.layer.input[1], copy=True)
        assert paddings.layer_type == "Const", "Padding should be Const"
J
jiangjiajun 已提交
559
        self.add_omit_nodes(paddings.layer_name, node.layer_name)
560 561 562 563 564 565 566 567 568 569 570 571 572 573
        paddings = paddings.value.flatten().tolist()
        data_format = input.tf_data_format

        if len(input.out_shapes[0]) == 4:
            new_padding = None
            if input.tf_data_format == "NHWC":
                if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0:
                    new_padding = paddings[2:6]
            else:
                if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
                    new_padding = paddings[4:]
            if new_padding is not None:
                if input.tf_data_format == "NHWC":
                    attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
574 575
                    node.fluid_code.add_layer(
                        "transpose", inputs=input, output=node, param_attr=attr)
576 577
                    input = node
                attr = {"paddings": new_padding}
J
jiangjiajun 已提交
578 579
                node.fluid_code.add_layer(
                    "pad2d", inputs=input, output=node, param_attr=attr)
580 581
                if input.tf_data_format == "NHWC":
                    attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
582 583
                    node.fluid_code.add_layer(
                        "transpose", inputs=node, output=node, param_attr=attr)
584 585 586 587

                return

        attr = {"paddings": paddings}
J
jiangjiajun 已提交
588 589
        node.fluid_code.add_layer(
            "pad", inputs=input, output=node, param_attr=attr)
590 591 592 593 594

    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0], copy=True)
        limit = self.graph.get_node(node.layer.input[1], copy=True)
        delta = self.graph.get_node(node.layer.input[2], copy=True)
595

596
        if start.layer_type == "Const":
597
            self.add_omit_nodes(start.layer_name, node.layer_name)
598 599
            start = start.value
        if limit.layer_type == "Const":
600
            self.add_omit_nodes(limit.layer_name, node.layer_name)
601 602
            limit = limit.value
        if delta.layer_type == "Const":
603
            self.add_omit_nodes(delta.layer_name, node.layer_name)
604
            delta = delta.value
605

606 607 608 609 610 611 612
        dtype = node.dtype
        inputs = {
            "start": start,
            "end": limit,
            "step": delta,
        }
        attr = {"dtype": string(node.dtype)}
J
jiangjiajun 已提交
613 614
        node.fluid_code.add_layer(
            "range", inputs=inputs, output=node, param_attr=attr)
615 616 617 618 619 620 621 622 623

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        dims = reduce_idx.value.tolist()
        keep_dims = node.get_attr("keep_dims")

        attr = {"dim": dims, "keep_dim": keep_dims}
J
jiangjiajun 已提交
624 625
        node.fluid_code.add_layer(
            "reduce_mean", inputs=input, output=node, param_attr=attr)
626 627 628 629 630 631 632 633 634 635 636 637 638

    def MatMul(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        inputs = {"x": x, "y": y}
        # fix paddle shape infer problem
        # should be removed after paddle 1.6
        if x.out_shapes[0][-1] < 0 and y.out_shapes[0][0] > 0:
            shape = x.out_shapes[0]
            shape[-1] = y.out_shapes[0][0]
            attr = {"shape": shape}
J
jiangjiajun 已提交
639 640
            node.fluid_code.add_layer(
                "reshape", inputs=x, output=x, param_attr=attr)
M
mamingjie-China 已提交
641 642 643 644
        if transpose_a is None:
            transpose_a = node.get_attr('adj_x')
        if transpose_b is None:
            transpose_b = node.get_attr('adj_y')
645
        attr = {"transpose_x": transpose_a, "transpose_y": transpose_b}
J
jiangjiajun 已提交
646 647
        node.fluid_code.add_layer(
            "matmul", inputs=inputs, output=node, param_attr=attr)
648

M
mamingjie-China 已提交
649 650 651 652 653 654
    def BatchMatMul(self, node):
        return self.MatMul(node)

    def BatchMatMulV2(self, node):
        return self.MatMul(node)

655 656 657 658
    def ArgMax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = self.graph.get_node(node.layer.input[1], copy=True)
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
J
jiangjiajun 已提交
659
        self.add_omit_nodes(axis.layer_name, node.layer_name)
660 661
        axis = axis.value
        attr = {"axis": axis}
J
jiangjiajun 已提交
662 663
        node.fluid_code.add_layer(
            "argmax", inputs=input, output=node, param_attr=attr)
664 665 666 667 668 669 670 671 672

    def StridedSlice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        end = self.graph.get_node(node.layer.input[2], copy=True)
        strides = self.graph.get_node(node.layer.input[3], copy=True)
        assert begin.layer_type == "Const"
        assert end.layer_type == "Const"
        assert strides.layer_type == "Const"
J
jiangjiajun 已提交
673 674 675
        self.add_omit_nodes(begin.layer_name, node.layer_name)
        self.add_omit_nodes(end.layer_name, node.layer_name)
        self.add_omit_nodes(strides.layer_name, node.layer_name)
676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728
        strides = strides.value.tolist()
        assert len(set(strides)) == 1 and strides[
            0] == 1, "Only support strides be 1 in StridedSlice OP"

        begin = begin.value.tolist()
        end = end.value.tolist()

        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

        begin_mask = node.get_attr('begin_mask')
        end_mask = node.get_attr('end_mask')
        ellipsis_mask = node.get_attr('ellipsis_mask')
        new_axis_mask = node.get_attr('new_axis_mask')
        shrink_axis_mask = node.get_attr('shrink_axis_mask')

        assert ellipsis_mask == 0, "(OP:{} Name:{})Only support ellipsis_mask be 0[now: {}] n StridedSlice OP".format(
            node.layer_type, node.layer.name, ellipsis_mask)

        # TODO codes without validation
        # Use it carefully
        new_begin = list()
        new_end = list()
        new_axes = list()
        shrink_axes = list()
        for i, item in enumerate(begin):
            mask = (new_axis_mask >> i) & 1
            if mask != 0:
                new_axes.append(i)
                continue

            mask = (shrink_axis_mask >> i) & 1
            if mask != 0:
                shrink_axes.append(i)

            mask = (begin_mask >> i) & 1
            if mask != 0:
                new_begin.append(0)
            else:
                new_begin.append(item)

            mask = (end_mask >> i) & 1
            if mask != 0:
                new_end.append(999999)
            else:
                new_end.append(end[i])

        attr = {
            "axes": [i for i in range(len(new_begin))],
            "starts": new_begin,
            "ends": new_end
        }
J
jiangjiajun 已提交
729 730
        node.fluid_code.add_layer(
            "slice", inputs=input, output=node, param_attr=attr)
731 732
        if len(new_axes) > 0:
            attr = {"axes": new_axes}
J
jiangjiajun 已提交
733 734
            node.fluid_code.add_layer(
                "unsqueeze", inputs=node, output=node, param_attr=attr)
735 736 737 738 739
        if len(shrink_axes) > 0:
            if len(input.out_shapes[0]) + len(new_axes) <= 1:
                pass
            else:
                attr = {"axes": shrink_axes}
J
jiangjiajun 已提交
740 741
                node.fluid_code.add_layer(
                    "squeeze", inputs=node, output=node, param_attr=attr)
742 743 744 745 746 747

    def Slice(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        begin = self.graph.get_node(node.layer.input[1], copy=True)
        size = self.graph.get_node(node.layer.input[2], copy=True)
        if begin.layer_type == "Const":
748
            self.add_omit_nodes(begin.layer_name, node.layer_name)
749 750
            begin = begin.value.tolist()
        else:
751 752 753 754 755 756 757
            begin = begin
            shape = begin.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape", inputs=begin, output=begin, param_attr=attr)
        if size.layer_type == "Const":
            self.add_omit_nodes(size.layer_name, node.layer_name)
758 759
            size = size.value.tolist()
        else:
760 761 762 763 764 765
            size = size
            shape = size.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape", inputs=size, output=size, param_attr=attr)
        inputs = {"x": input, "offsets": begin, "shape": size}
J
jiangjiajun 已提交
766
        node.fluid_code.add_layer(
767
            "crop_tensor", inputs=inputs, output=node, param_attr=None)
768 769

    def Conv2DBackpropInput(self, node):
770
        out_shape = self.graph.get_node(node.layer.input[0], copy=True)
771
        kernel = self.graph.get_node(node.layer.input[1], copy=True)
772 773
        input = self.graph.get_node(node.layer.input[2], copy=True)

774
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
775

J
jiangjiajun 已提交
776
        self.add_omit_nodes(kernel.layer_name, node.layer_name)
777
        self.add_omit_nodes(out_shape.layer_name, node.layer_name)
778

779 780 781 782 783 784
        if out_shape.layer_type == "Const":
            out_shape = out_shape.value.tolist()
        else:
            out_shape = self.decoder.infer_shape_tensor(out_shape,
                                                        node.out_shapes[0])

785 786 787 788 789 790 791
        in_shape = input.out_shapes[0]
        if in_shape.count(-1) > 2:
            in_shape = self.decoder.infer_tensor(input).shape
        k_size = kernel.out_shapes[0]
        if k_size.count(-1) > 2:
            k_size = self.decoder.infer_tensor(kernel).shape

792
        pad_mode = node.get_attr("padding").decode()
793 794 795 796
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        channel_first = data_format == "NCHW"
797

798 799 800 801 802 803 804
        self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
            kernel.value, (3, 2, 0, 1))
        if not channel_first:
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
            attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
805 806
            node.fluid_code.add_layer(
                "transpose", inputs=input, output=node, param_attr=attr)
807
            input = node
808
        else:
M
mamingjie-China 已提交
809
            self.graph.data_format_propagation(node)
810

811 812 813
        attr = {
            "bias_attr": False,
            "param_attr": string(kernel.layer_name),
M
mamingjie-China 已提交
814
            "num_filters": k_size[2],
815 816 817
            "filter_size": k_size[0:2],
            "stride": strides[2:4],
            "dilation": dilations[2:4],
M
mamingjie-China 已提交
818 819
            "padding": string(pad_mode),
            "output_size": out_shape[1:3]
820
        }
J
jiangjiajun 已提交
821 822
        node.fluid_code.add_layer(
            "conv2d_transpose", inputs=input, output=node, param_attr=attr)
823

824 825
        if not channel_first:
            attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
826 827
            node.fluid_code.add_layer(
                "transpose", inputs=node, output=node, param_attr=attr)
828 829 830 831 832 833 834 835 836

    def Max(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
837 838
        node.fluid_code.add_layer(
            "reduce_max", inputs=input, output=node, param_attr=attr)
839 840 841 842 843 844 845 846 847

    def Sum(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

        attr = {"dim": dim, "keep_dim": keep_dims}
J
jiangjiajun 已提交
848 849
        node.fluid_code.add_layer(
            "reduce_sum", inputs=input, output=node, param_attr=attr)
850 851 852 853 854

    def Cast(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        dtype = node.dtype_map[node.get_attr('DstT')]
        attr = {"dtype": string(dtype)}
J
jiangjiajun 已提交
855 856
        node.fluid_code.add_layer(
            "cast", inputs=input, output=node, param_attr=attr)
857 858 859 860 861

    def Split(self, node):
        dim = self.graph.get_node(node.layer.input[0], copy=True)
        input = self.graph.get_node(node.layer.input[1], copy=True)
        assert dim.layer_type == "Const"
J
jiangjiajun 已提交
862
        self.add_omit_nodes(dim.layer_name, node.layer_name)
863 864 865 866
        num_split = node.get_attr('num_split')
        dim = dim.value

        attr = {"num_or_sections": num_split, "dim": dim}
J
jiangjiajun 已提交
867 868
        node.fluid_code.add_layer(
            "split", inputs=input, output=node, param_attr=attr)
869 870 871 872 873

    def Squeeze(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        squeeze_dims = node.get_attr('squeeze_dims')
        attr = {"axes": squeeze_dims}
J
jiangjiajun 已提交
874 875
        node.fluid_code.add_layer(
            "squeeze", inputs=input, output=node, param_attr=attr)
876 877 878 879 880

    def Softmax(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        axis = node.get_attr("axis")
        attr = {"axis": axis}
J
jiangjiajun 已提交
881 882
        node.fluid_code.add_layer(
            "softmax", inputs=input, output=node, param_attr=attr)
883 884 885 886 887

    def ResizeNearestNeighbor(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
        if resize_shape.layer_type == "Const":
888
            self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
889 890
            resize_shape = resize_shape.value.tolist()
        else:
891 892 893 894 895 896 897 898 899
            resize_shape = resize_shape
            shape = resize_shape.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape",
                inputs=resize_shape,
                output=resize_shape,
                param_attr=attr)

900 901
        align_corners = node.get_attr("align_corners")
        attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
902 903
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
904 905
        inputs = {"input": node, "out_shape": resize_shape}
        attr = {"align_corners": align_corners}
J
jiangjiajun 已提交
906
        node.fluid_code.add_layer(
907
            "resize_nearest", inputs=inputs, output=node, param_attr=attr)
908
        attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
909 910
        node.fluid_code.add_layer(
            "transpose", inputs=node, output=node, param_attr=attr)
911 912 913 914 915

    def ResizeBilinear(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        resize_shape = self.graph.get_node(node.layer.input[1], copy=True)
        if resize_shape.layer_type == "Const":
916
            self.add_omit_nodes(resize_shape.layer_name, node.layer_name)
917 918
            resize_shape = resize_shape.value.tolist()
        else:
919 920 921 922 923 924 925
            shape = resize_shape.out_shapes[0]
            attr = {"shape": shape}
            node.fluid_code.add_layer(
                "reshape",
                inputs=resize_shape,
                output=resize_shape,
                param_attr=attr)
926 927
        align_corners = node.get_attr("align_corners")
        attr = {"perm": [0, 3, 1, 2]}
J
jiangjiajun 已提交
928 929
        node.fluid_code.add_layer(
            "transpose", inputs=input, output=node, param_attr=attr)
930
        inputs = {"input": node, "out_shape": resize_shape}
931
        attr = {
932
            #"out_shape": resize_shape,
933 934 935
            "align_corners": align_corners,
            "align_mode": 1
        }
J
jiangjiajun 已提交
936
        node.fluid_code.add_layer(
937
            "resize_bilinear", inputs=inputs, output=node, param_attr=attr)
938
        attr = {"perm": [0, 2, 3, 1]}
J
jiangjiajun 已提交
939 940
        node.fluid_code.add_layer(
            "transpose", inputs=node, output=node, param_attr=attr)
941 942 943 944 945

    def GreaterEqual(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
946 947
        node.fluid_code.add_layer(
            "greater_equal", inputs=inputs, output=node, param_attr=None)
948 949 950 951

    def RandomUniform(self, node):
        shape = self.graph.get_node(node.layer.input[0], copy=True)
        if shape.layer_type == "Const":
952
            self.add_omit_nodes(shape.layer_name, node.layer_name)
953 954
            shape = shape.value.tolist()
        else:
955 956
            shape = shape
        attr = {"min": 0.0, "max": 0.9999}
M
mamingjie-China 已提交
957

958 959
        node.fluid_code.add_layer(
            "uniform_random", inputs=shape, output=node, param_attr=attr)
960 961 962 963 964

    def SquaredDifference(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        inputs = {"x": x, "y": y}
J
jiangjiajun 已提交
965 966
        node.fluid_code.add_layer(
            "elementwise_sub", inputs=inputs, output=node, param_attr=None)
967
        inputs = {"x": node, "y": node}
J
jiangjiajun 已提交
968 969
        node.fluid_code.add_layer(
            "elementwise_mul", inputs=inputs, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
970 971 972 973 974

    def ExpandDims(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if y.layer_type == 'Const':
975
            self.add_omit_nodes(y.layer_name, node.layer_name)
J
jiangjiajun@baidu.com 已提交
976
            dim = y.value.tolist()
M
mamingjie-China 已提交
977 978 979
            if not isinstance(dim, list):
                dim = [dim]
            attr = {'axes': dim}
J
jiangjiajun@baidu.com 已提交
980
        else:
981
            attr = {'axes': y}
J
jiangjiajun 已提交
982 983
        node.fluid_code.add_layer(
            "unsqueeze", inputs=x, output=node, param_attr=attr)
J
jiangjiajun@baidu.com 已提交
984 985 986 987 988

    def BatchToSpaceND(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if hasattr(node, 'skip') and node.skip:
J
jiangjiajun 已提交
989 990
            node.fluid_code.add_layer(
                "=", inputs=x, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
991 992 993 994 995 996 997
        else:
            raise Exception("BatchToSpaceND is not supported")

    def SpaceToBatchND(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        y = self.graph.get_node(node.layer.input[1], copy=True)
        if hasattr(node, 'skip') and node.skip:
J
jiangjiajun 已提交
998 999
            node.fluid_code.add_layer(
                "=", inputs=x, output=node, param_attr=None)
J
jiangjiajun@baidu.com 已提交
1000 1001
        else:
            raise Exception("SpaceToBatchND is not supported")
M
mamingjie-China 已提交
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093

    def OneHot(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        depth = self.graph.get_node(node.layer.input[1], copy=True)
        on_value = self.graph.get_node(node.layer.input[2], copy=True)
        off_value = self.graph.get_node(node.layer.input[3], copy=True)
        assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot'
        assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot'
        assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot'
        self.add_omit_nodes(depth.layer_name, node.layer_name)
        self.add_omit_nodes(on_value.layer_name, node.layer_name)
        self.add_omit_nodes(off_value.layer_name, node.layer_name)
        depth = depth.value
        on_value = on_value.value
        off_value = off_value.value
        assert math.fabs(on_value -
                         1.0) < 1e-06, "on_value should be 1 in OneHot"
        assert math.fabs(off_value -
                         0.0) < 1e-06, "off_value should be 0 in OneHot"
        attr = {'depth': depth}
        node.fluid_code.add_layer(
            "one_hot",
            inputs=input,
            output=node,
            param_attr=attr,
            use_fluid=True)

    def Pow(self, node):
        x = self.graph.get_node(node.layer.input[0], copy=True)
        factor = self.graph.get_node(node.layer.input[1], copy=True)
        self.add_omit_nodes(factor.layer_name, node.layer_name)
        if factor.layer_type == 'Const':
            factor = factor.value.tolist()
        else:
            factor = self.decoder.infer_tensor(factor)
        attr = {'factor': factor}
        node.fluid_code.add_layer("pow", inputs=x, output=node, param_attr=attr)

    def All(self, node):
        input = self.graph.get_node(node.layer.input[0], copy=True)
        reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
        self.add_omit_nodes(reduce_idx.layer_name, node.layer_name)
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        dims = reduce_idx.value.tolist()
        keep_dims = node.get_attr("keep_dims")

        attr = {"dim": dims, "keep_dim": keep_dims}
        node.fluid_code.add_layer(
            "reduce_all", inputs=input, output=node, param_attr=attr)

    def GatherV2(self, node):
        embeddings = self.graph.get_node(node.layer.input[0], copy=True)
        index = self.graph.get_node(node.layer.input[1], copy=True)
        axis = self.graph.get_node(node.layer.input[2], copy=True)
        self.add_omit_nodes(axis.layer_name, node.layer_name)
        assert axis.layer_type == 'Const', "Only support Const parameter[axis]"
        axis = axis.value.tolist()
        assert axis == 0, "Only support axis=0 in GatherV2 OP"
        attr = {'overwrite': False}
        if len(index.out_shapes[0]) != 1:
            reshape_attr = {"shape": [-1]}
            node.fluid_code.add_layer(
                "reshape", inputs=index, output=index, param_attr=reshape_attr)
        inputs = {'input': embeddings, 'index': index}
        node.fluid_code.add_layer(
            "gather", inputs=inputs, output=node, param_attr=attr)

    def OneShotIterator(self, node):
        return self.Placeholder(node)

    def IteratorV2(self, node):
        dtype_map = {
            1: "float32",
            3: "int32",
            4: "uint8",
            9: "int64",
            10: "bool"
        }
        shapes = node.out_shapes
        dtypes = node.layer.attr['output_types'].list.type
        node.fluid_code.add_note("{} = [0] * {}".format(node.layer_name,
                                                        len(shapes)))
        for i, shape in enumerate(shapes):
            attr = {
                'dtype': string(dtype_map[dtypes[i]]),
                'shape': shape,
                'name': string("{}_{}".format(node.layer_name, i)),
                'append_batch_size': False
            }
            output = "{}[{}]".format(node.layer_name, i)
            node.fluid_code.add_layer(
                "data", inputs=None, output=output, param_attr=attr)