tf_op_mapper.py 59.1 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

S
SunAhong1993 已提交
15
from x2paddle.decoder.tf_decoder import TFGraph, TFGraphNode
16
from x2paddle.core.program import PaddleGraph
J
jiangjiajun 已提交
17
from x2paddle.core.op_mapper import OpMapper
J
jiangjiajun 已提交
18
from x2paddle.core.util import *
J
jiangjiajun 已提交
19 20 21
from x2paddle import program
import traceback
import math
J
jiangjiajun 已提交
22
import inspect
J
jiangjiajun 已提交
23
import numpy
J
jiangjiajun 已提交
24
import sys
25

J
jiangjiajun 已提交
26 27 28 29 30 31 32 33 34 35 36 37
name_counter = dict()


def gen_name(op_name, var_name):
    name = "{}_{}".format(op_name, var_name)
    if name not in name_counter:
        name_counter[name] = 0
    else:
        name_counter[name] += 1
    name = name + '_' + str(name_counter[name])
    return name

J
jiangjiajun 已提交
38

J
jiangjiajun 已提交
39 40 41 42
# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
43 44
    if pad_size < 0:
        pad_size = 0
J
jiangjiajun 已提交
45 46 47 48
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]

J
jiangjiajun 已提交
49

J
jiangjiajun 已提交
50
class TFOpMapper(OpMapper):
J
jiangjiajun 已提交
51
    directly_map_ops = {
S
SunAhong1993 已提交
52 53 54 55 56 57 58 59 60 61 62
        'Relu': ['paddle.nn.functional.relu'],
        'Relu6': ['paddle.nn.functional.relu6'],
        'Abs': ['paddle.abs'],
        'Sigmoid': ['paddle.nn.functional.sigmoid'],
        'Softmax': ['paddle.nn.functional.softmax'],
        'Exp': ['paddle.exp'],
        'Rsqrt': ['paddle.rsqrt'],
        'Sqrt': ['paddle.sqrt'],
        'swish_f32': ['paddle.nn.functional.swish'],
        'Tanh': ['paddle.tanh'],
        'Softplus': ['paddle.nn.functional.softplus'],
63 64
        'LeakyRelu':
        ['paddle.nn.functional.leaky_relu', dict(alpha='negative_slope')],
S
SunAhong1993 已提交
65 66 67
        'Floor': ['paddle.floor'],
        'Erf': ['paddle.erf'],
        'Square': ['paddle.square']
J
jiangjiajun 已提交
68 69
    }
    elementwise_ops = {
S
SunAhong1993 已提交
70 71 72 73
        'Add': 'paddle.add',
        'AddV2': 'paddle.add',
        'RealDiv': 'paddle.divide',
        'DivNoNan': 'paddle.divide',
S
SunAhong1993 已提交
74
        # TODO (syf): replace
S
SunAhong1993 已提交
75
        'Sub': 'paddle.subtract',
S
SunAhong1993 已提交
76 77
        'Maximum': 'paddle.maximum',
        'Minimum': 'paddle.minimum',
S
SunAhong1993 已提交
78 79 80 81 82 83
        'Mul': 'paddle.multiply',
        'FloorDiv': 'paddle.floor_divide',
        'FloorMod': 'paddle.floor_mod',
        'LogicalAnd': 'logical_and',
    }
    bool_ops = {
S
SunAhong1993 已提交
84 85 86 87 88
        'LessEqual': 'paddle.less_equal',
        'GreaterEqual': 'paddle.greater_equal',
        'Greater': 'paddle.greater_than',
        'NotEqual': 'paddle.not_equal',
        'Equal': 'paddle.equal',
J
jiangjiajun 已提交
89 90
    }

J
jiangjiajun 已提交
91 92
    def __init__(self, decoder):
        super(TFOpMapper, self).__init__()
J
jiangjiajun 已提交
93
        self.decoder = decoder
J
jiangjiajun 已提交
94
        self.graph = decoder.tf_graph
S
SunAhong1993 已提交
95 96
        if not self.op_checker():
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
97
        self.params = dict()
98 99
        self.paddle_graph = PaddleGraph(
            parent_layer=None, graph_type="static", source_type="tf")
S
SunAhong1993 已提交
100
        self.params_output2id = dict()
101

J
jiangjiajun 已提交
102 103
        not_placeholder = list()
        for name in self.graph.input_nodes:
J
jiangjiajun 已提交
104 105 106 107 108
            if self.graph.get_node(
                    name).layer_type != "Placeholder" and self.graph.get_node(
                        name
                    ).layer_type != "OneShotIterator" and self.graph.get_node(
                        name).layer_type != "IteratorV2":
J
jiangjiajun 已提交
109 110 111 112
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]
J
jiangjiajun 已提交
113

S
SunAhong1993 已提交
114 115
        self.paddle_graph.inputs = self.graph.input_nodes
        self.paddle_graph.outputs = self.graph.output_nodes
J
jiangjiajun 已提交
116

S
SunAhong1993 已提交
117 118 119 120 121 122
        print("Total nodes: {}".format(
            sum([
                isinstance(node, TFGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
123
        for i, node_name in enumerate(self.graph.topo_sort):
J
jiangjiajun 已提交
124
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
125 126
            node = self.graph.get_node(node_name)
            op = node.layer_type
J
jiangjiajun 已提交
127 128 129 130
            if op in self.directly_map_ops:
                self.directly_map(node)
            elif op in self.elementwise_ops:
                self.elementwise_map(node)
S
SunAhong1993 已提交
131 132
            elif op in self.bool_ops:
                self.bool_map(node)
J
jiangjiajun 已提交
133
            elif hasattr(self, op):
J
jiangjiajun 已提交
134
                func = getattr(self, op)
S
SunAhong1993 已提交
135 136 137 138
                func(node)
        print("\nNodes converted.")
        self.paddle_graph.set_name(self.graph.graph_name)
        self.paddle_graph.set_parameters(self.params)
139

S
SunAhong1993 已提交
140 141 142 143 144 145 146
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if not hasattr(self, op) and \
                op not in self.directly_map_ops and \
S
SunAhong1993 已提交
147 148
                op not in self.elementwise_ops and \
                op not in self.bool_ops:
J
jiangjiajun 已提交
149
                unsupported_ops.add(op)
S
SunAhong1993 已提交
150 151 152 153
        if len(unsupported_ops) == 0:
            return True
        else:
            if len(unsupported_ops) > 0:
154 155
                print("\n========= {} OPs are not supported yet ===========".
                      format(len(unsupported_ops)))
J
jiangjiajun 已提交
156
            for op in unsupported_ops:
J
jiangjiajun 已提交
157
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
158
            return False
J
jiangjiajun 已提交
159

J
jiangjiajun 已提交
160 161 162
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
J
jiangjiajun 已提交
163
        input = self.graph.get_node(node.layer.input[0])
J
jiangjiajun 已提交
164 165 166 167 168 169
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
J
jiangjiajun 已提交
170

S
SunAhong1993 已提交
171
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
172
            kernel=op_info[0],
J
jiangjiajun 已提交
173 174 175
            inputs={"x": input.name},
            outputs=[node.name],
            **attr)
J
jiangjiajun 已提交
176

S
SunAhong1993 已提交
177 178 179 180
    def elementwise_map(self, node, op_type=None):
        if op_type is None:
            assert node.layer_type in self.elementwise_ops
            op_type = self.elementwise_ops[node.layer_type]
J
jiangjiajun 已提交
181 182
        x = self.graph.get_node(node.layer.input[0])
        y = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
183 184
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
S
SunAhong1993 已提交
185
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
186
            kernel=op_type,
J
jiangjiajun 已提交
187 188 189
            inputs={"x": x.name,
                    "y": y.name},
            outputs=[node.name])
190 191 192 193 194
        self.paddle_graph.layers[layer_id].input_shapes = {
            "x": x_shape,
            "y": y_shape
        }

S
SunAhong1993 已提交
195 196 197 198
    def bool_map(self, node):
        op_type = self.bool_ops[node.layer_type]
        self.elementwise_map(node, op_type)
        node.set_dtype("bool")
J
jiangjiajun 已提交
199

200 201
    def Placeholder(self, node):
        shape = node.out_shapes[0]
J
jiangjiajun 已提交
202 203
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
204
        dtype = node.dtype
S
SunAhong1993 已提交
205
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
206
            kernel="paddle.static.data",
J
jiangjiajun 已提交
207 208 209 210 211
            inputs={},
            outputs=[node.name],
            dtype=string(dtype),
            shape=shape,
            name=string(node.name))
J
jiangjiajun@baidu.com 已提交
212

J
jiangjiajun 已提交
213 214 215 216 217 218 219
    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
J
jiangjiajun 已提交
220 221
            if value == float('inf'):
                value = "float('inf')"
S
SunAhong1993 已提交
222
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
223
                kernel="paddle.full",
C
channingss 已提交
224 225 226 227
                inputs={},
                outputs=[node.name],
                dtype=string(dtype),
                shape=[1],
S
SunAhong1993 已提交
228
                fill_value=value)
C
channingss 已提交
229
            return
J
jiangjiajun 已提交
230

S
SunAhong1993 已提交
231
        self.params[node.name] = node.value
S
SunAhong1993 已提交
232
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
233
            kernel="paddle.static.create_parameter",
J
jiangjiajun 已提交
234 235 236 237 238
            inputs={},
            outputs=[node.name],
            dtype=string(dtype),
            shape=shape,
            name=string(node.name),
S
SunAhong1993 已提交
239
            default_initializer="paddle.nn.initializer.Constant(value=0.0)")
S
SunAhong1993 已提交
240
        self.params_output2id[node.name] = layer_id
J
jiangjiajun 已提交
241 242

    def Transpose(self, node):
J
jiangjiajun 已提交
243 244
        input = self.graph.get_node(node.layer.input[0])
        perm = self.graph.get_node(node.layer.input[1])
S
SunAhong1993 已提交
245 246 247
        if perm.layer_type == "Const":
            perm = perm.value.tolist()
        else:
248 249
            perm = self.decoder.infer_tensor(
                perm, use_diff_inputs=False).tolist()
J
jiangjiajun 已提交
250

S
SunAhong1993 已提交
251
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
252
            kernel="paddle.transpose",
J
jiangjiajun 已提交
253 254 255 256 257 258 259 260 261 262 263 264
            inputs={"x": input.name},
            outputs=[node.name],
            perm=perm)

    def Fill(self, node):
        dims = self.graph.get_node(node.layer.input[0])
        input_value = self.graph.get_node(node.layer.input[1])
        inputs = dict()
        attr = dict()
        assert input_value.layer_type == "Const", "Value of fill OP should be Const"
        if dims.layer_type == "Const":
            attr["shape"] = dims.value.tolist()
J
jiangjiajun 已提交
265
        else:
J
jiangjiajun 已提交
266 267
            inputs["shape"] = dims.name
        attr["dtype"] = string(input_value.dtype)
S
SunAhong1993 已提交
268
        attr["fill_value"] = input_value.value
J
jiangjiajun 已提交
269

S
SunAhong1993 已提交
270
        self.paddle_graph.add_layer(
271
            "paddle.full", inputs=inputs, outputs=[node.name], **attr)
S
SunAhong1993 已提交
272 273 274 275 276 277
        if dims.layer_type != "Const":
            self.paddle_graph.add_layer(
                "paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
278

J
jiangjiajun 已提交
279 280 281 282 283 284 285 286 287 288 289 290 291
    def DepthToSpace(self, node):
        input = self.graph.get_node(node.layer.input[0])

        block_size = node.get_attr("block_size")
        data_format = node.get_attr("data_format").decode()
        if data_format == "NHWC":
            n, h, w, c = input.out_shapes[0]
        else:
            n, c, h, w = input.out_shapes[0]

        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("depth_to_space", "transpose")
S
SunAhong1993 已提交
292
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
293
                kernel="paddle.transpose",
J
jiangjiajun 已提交
294 295 296 297 298 299 300
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

        shape = [0, block_size * block_size, -1, h, w]
        reshape_name = gen_name("depth_to_space", "reshape")
S
SunAhong1993 已提交
301
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
302
            kernel="paddle.reshape",
J
jiangjiajun 已提交
303 304 305 306 307
            inputs={"x": input_name},
            outputs=[reshape_name],
            shape=shape)

        transpose_name = gen_name("depth_to_space", "transpose")
S
SunAhong1993 已提交
308
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
309
            kernel="paddle.transpose",
J
jiangjiajun 已提交
310 311 312 313 314
            inputs={"x": reshape_name},
            outputs=[transpose_name],
            perm=[0, 2, 1, 3, 4])

        reshape_name = gen_name("depth_to_space", "reshape")
S
SunAhong1993 已提交
315
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
316
            kernel="paddle.reshape",
J
jiangjiajun 已提交
317 318 319 320
            inputs={"x": transpose_name},
            outputs=[reshape_name],
            shape=[0, c, h, w])

S
SunAhong1993 已提交
321
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
322
            kernel="paddle.nn.functional.pixel_shuffle",
J
jiangjiajun 已提交
323 324 325 326 327
            inputs={"x": reshape_name},
            outputs=[node.name],
            upscale_factor=block_size)

        if data_format == "NHWC":
S
SunAhong1993 已提交
328
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
329
                kernel="paddle.transpose",
J
jiangjiajun 已提交
330 331 332
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
333

S
SunAhong1993 已提交
334 335 336 337
    def Where(self, node):
        if len(node.layer.input) == 1:
            cond = self.graph.get_input_node(node, 0)
            self.paddle_graph.add_layer(
338
                "paddle.nonzero", inputs={"x": cond.name}, outputs=[node.name])
S
SunAhong1993 已提交
339 340 341 342 343 344 345 346 347 348
        else:
            cond = self.graph.get_input_node(node, 0)
            x = self.graph.get_input_node(node, 1)
            y = self.graph.get_input_node(node, 2)
            self.paddle_graph.add_layer(
                "paddle.where",
                inputs={"condition": cond.name,
                        "x": x.name,
                        "y": y.name},
                outputs=[node.name])
349

S
add beg  
SunAhong1993 已提交
350 351
    def Neg(self, node):
        input = self.graph.get_input_node(node, 0)
352

S
add beg  
SunAhong1993 已提交
353 354 355 356 357
        self.paddle_graph.add_layer(
            "paddle.scale",
            inputs={"x": input.name},
            outputs=[node.name],
            scale=-1)
J
jiangjiajun 已提交
358 359 360

    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0])
J
jiangjiajun 已提交
361

J
jiangjiajun 已提交
362 363 364 365 366
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()

J
jiangjiajun 已提交
367 368 369
        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("max_pool", "transpose")
S
SunAhong1993 已提交
370
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
371
                kernel="paddle.transpose",
J
jiangjiajun 已提交
372 373 374
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
J
jiangjiajun 已提交
375
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
376
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
377 378
            input_name = transpose_name

S
SunAhong1993 已提交
379
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
380 381
            kernel="paddle.nn.functional.max_pool2d",
            inputs={"x": input_name},
J
jiangjiajun 已提交
382
            outputs=[node.name],
S
SunAhong1993 已提交
383 384 385
            kernel_size=k_size[2:4],
            stride=strides[2:4],
            padding=string(pad_mode))
J
jiangjiajun 已提交
386 387

        if data_format == "NHWC":
S
SunAhong1993 已提交
388
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
389
                kernel="paddle.transpose",
J
jiangjiajun 已提交
390 391 392
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
393 394

    def Conv2D(self, node):
J
jiangjiajun 已提交
395 396
        input = self.graph.get_node(node.layer.input[0])
        kernel = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
397

J
jiangjiajun 已提交
398 399 400 401 402
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
403 404 405 406
        if data_format == "NHWC":
            n, h, w, c = input.out_shapes[0]
        else:
            n, c, h, w = input.out_shapes[0]
J
jiangjiajun 已提交
407

J
jiangjiajun 已提交
408 409 410 411
        if kernel.layer_type == 'Const':
            kernel_value = kernel.value
            kernel_weight_name = kernel.name.replace('/', '_')
        else:
412 413
            kernel_value = self.decoder.infer_tensor(
                kernel, use_diff_inputs=False)
J
jiangjiajun 已提交
414 415 416 417 418
            if kernel.layer_type == 'Split':
                kernel_weight_name = "{}_{}_kernel".format(node.name,
                                                           kernel.name)
            else:
                kernel_weight_name = kernel.name.replace('/', '_')
S
SunAhong1993 已提交
419
        self.params[kernel_weight_name] = numpy.transpose(kernel_value,
S
SunAhong1993 已提交
420 421 422 423 424 425 426 427
                                                          (3, 2, 0, 1))
        self.paddle_graph.add_layer(
            kernel="paddle.static.nn.create_parameter",
            inputs={},
            outputs=[kernel_weight_name],
            shape=self.params[kernel_weight_name].shape,
            dtype=string(str(self.params[kernel_weight_name].dtype)),
            name=string(kernel_weight_name))
428

J
jiangjiajun 已提交
429 430
        input_name = input.name
        if data_format == "NHWC":
J
jiangjiajun 已提交
431 432
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
433
            transpose_name = gen_name("conv2d", "transpose")
S
SunAhong1993 已提交
434
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
435
                kernel="paddle.transpose",
J
jiangjiajun 已提交
436 437 438 439 440 441 442
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

        if c == -1:
            attr = {"shape": [0, k_size[2], 0, 0]}
S
SunAhong1993 已提交
443
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
444
                kernel="paddle.reshape",
J
jiangjiajun 已提交
445 446 447 448
                inputs={"x": input_name},
                outputs=[input_name],
                shape=[0, k_size[2], 0, 0])

S
SunAhong1993 已提交
449
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
450
            kernel="paddle.nn.functional.conv2d",
451 452
            inputs={"x": input_name,
                    "weight": kernel_weight_name},
J
jiangjiajun 已提交
453
            outputs=[node.name],
S
SunAhong1993 已提交
454
            bias=None,
J
jiangjiajun 已提交
455 456 457 458 459
            stride=strides[2:4],
            dilation=dilations[2:4],
            padding=string(pad_mode))

        if data_format == "NHWC":
S
SunAhong1993 已提交
460
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
461
                kernel="paddle.transpose",
J
jiangjiajun 已提交
462 463 464
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
465

S
SunAhong1993 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    def Conv3D(self, node):
        input = self.graph.get_input_node(node, 0)
        kernel = self.graph.get_input_node(node, 1)

        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        if data_format == "NDHWC":
            n, d, h, w, c = input.out_shapes[0]
        else:
            n, c, d, h, w = input.out_shapes[0]

        if kernel.layer_type == 'Const':
            kernel_value = kernel.value
            kernel_weight_name = kernel.name.replace('/', '_')
        else:
484 485
            kernel_value = self.decoder.infer_tensor(
                kernel, use_diff_inputs=False)
S
SunAhong1993 已提交
486 487 488 489 490
            if kernel.layer_type == 'Split':
                kernel_weight_name = "{}_{}_kernel".format(node.name,
                                                           kernel.name)
            else:
                kernel_weight_name = kernel.name.replace('/', '_')
S
SunAhong1993 已提交
491 492 493 494 495 496 497 498 499
        self.params[kernel_weight_name] = numpy.transpose(kernel_value,
                                                          (4, 3, 0, 1, 2))
        self.paddle_graph.add_layer(
            kernel="paddle.static.nn.create_parameter",
            inputs={},
            outputs=[kernel_weight_name],
            shape=self.params[kernel_weight_name].shape,
            dtype=string(str(self.params[kernel_weight_name].dtype)),
            name=string(kernel_weight_name))
500

S
SunAhong1993 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
        input_name = input.name
        if data_format == "NDHWC":
            strides = [strides[i] for i in [0, 4, 1, 2, 3]]
            dilations = [dilations[i] for i in [0, 4, 1, 2, 3]]
            transpose_name = gen_name("conv3d", "transpose")
            self.paddle_graph.add_layer(
                kernel="paddle.transpose",
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 4, 1, 2, 3])
            input_name = transpose_name

        if c == -1:
            attr = {"shape": [0, k_size[2], 0, 0, 0]}
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": input_name},
                outputs=[input_name],
519 520
                shape=[0, k_size[2], 0, 0, 0])

S
SunAhong1993 已提交
521 522
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.conv3d",
523 524
            inputs={"x": input_name,
                    "weight": kernel_weight_name},
S
SunAhong1993 已提交
525 526 527 528 529 530 531 532 533 534 535 536
            outputs=[node.name],
            bias=None,
            stride=strides[2:5],
            dilation=dilations[2:5],
            padding=string(pad_mode))

        if data_format == "NDHWC":
            self.paddle_graph.add_layer(
                kernel="paddle.transpose",
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 4, 1])
J
jiangjiajun 已提交
537

J
jiangjiajun 已提交
538
    def BiasAdd(self, node):
J
jiangjiajun 已提交
539 540
        input = self.graph.get_node(node.layer.input[0])
        bias = self.graph.get_node(node.layer.input[1])
S
SunAhong1993 已提交
541
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
542
            kernel="paddle.add",
J
jiangjiajun 已提交
543 544 545
            inputs={"x": input.name,
                    "y": bias.name},
            outputs=[node.name])
J
jiangjiajun 已提交
546 547

    def FusedBatchNorm(self, node):
J
jiangjiajun 已提交
548 549 550 551 552
        input = self.graph.get_node(node.layer.input[0])
        gamma = self.graph.get_node(node.layer.input[1])
        beta = self.graph.get_node(node.layer.input[2])
        moving_mean = self.graph.get_node(node.layer.input[3])
        moving_var = self.graph.get_node(node.layer.input[4])
J
jiangjiajun 已提交
553
        data_format = node.get_attr("data_format").decode()
J
jiangjiajun 已提交
554 555 556 557 558

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
559 560 561 562

        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("batch_norm", "transpose")
S
SunAhong1993 已提交
563
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
564
                kernel="paddle.transpose",
J
jiangjiajun 已提交
565 566 567 568 569
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
570
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
571
            kernel="paddle.nn.functional.batch_norm",
572 573 574 575 576 577 578
            inputs={
                "x": input_name,
                "running_mean": moving_mean.name,
                "running_var": moving_var.name,
                "weight": gamma.name,
                "bias": beta.name
            },
J
jiangjiajun 已提交
579
            outputs=[node.name],
S
SunAhong1993 已提交
580
            epsilon=node.get_attr("epsilon"))
J
jiangjiajun 已提交
581 582

        if data_format == "NHWC":
S
SunAhong1993 已提交
583
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
584
                kernel="paddle.transpose",
J
jiangjiajun 已提交
585 586 587
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
588

S
SunAhong1993 已提交
589 590
    def FusedBatchNormV3(self, node):
        self.FusedBatchNorm(node)
J
jiangjiajun 已提交
591 592 593 594 595 596 597 598

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0])
        reduce_idx = self.graph.get_node(node.layer.input[1])
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        dims = reduce_idx.value.tolist()
        keep_dims = node.get_attr("keep_dims")

S
SunAhong1993 已提交
599
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
600 601
            kernel="paddle.mean",
            inputs={"x": input.name},
J
jiangjiajun 已提交
602
            outputs=[node.name],
S
SunAhong1993 已提交
603 604
            axis=dims,
            keepdim=keep_dims)
J
jiangjiajun 已提交
605 606

    def Reshape(self, node):
S
SunAhong1993 已提交
607 608
        input = self.graph.get_input_node(node, 0)
        param = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
609 610 611 612 613

        input_name = input.name

        if param.layer_type == "Const":
            shape = param.value.tolist()
S
SunAhong1993 已提交
614
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
615
                kernel="paddle.reshape",
J
jiangjiajun 已提交
616 617 618 619
                inputs={"x": input_name},
                outputs=[node.name],
                shape=shape)
        else:
S
SunAhong1993 已提交
620
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
621
                kernel="paddle.reshape",
J
jiangjiajun 已提交
622 623 624 625 626 627 628
                inputs={"x": input_name,
                        "shape": param.name},
                outputs=[node.name])
        if param.layer_type != "Const":
            out_shape = numpy.array(node.out_shapes[0])
            if (out_shape > 0).any():
                out_shape[out_shape < 0] = 0
S
SunAhong1993 已提交
629
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
630
                    kernel="paddle.reshape",
J
jiangjiajun 已提交
631 632 633 634 635
                    inputs={"x": node.name},
                    outputs=[node.name],
                    shape=out_shape.tolist())

    def Pad(self, node):
S
SunAhong1993 已提交
636 637
        input = self.graph.get_input_node(node, 0)
        paddings = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
638 639
        assert paddings.layer_type == "Const", "Padding should be Const"
        paddings = paddings.value.flatten().tolist()
S
SunAhong1993 已提交
640 641 642 643 644
        constant_values = 0
        if len(node.layer.input) > 2:
            constant_values = self.graph.get_input_node(node, 2)
            assert constant_values.layer_type == "Const", "Padding should be Const"
            constant_values = constant_values.value
J
jiangjiajun 已提交
645

S
SunAhong1993 已提交
646
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
647
            kernel="paddle.nn.functional.pad",
S
SunAhong1993 已提交
648
            inputs={"x": input.name},
J
jiangjiajun 已提交
649
            outputs=[node.name],
S
SunAhong1993 已提交
650 651
            pad=paddings,
            value=constant_values)
652

S
SunAhong1993 已提交
653
    def MirrorPad(self, node):
S
SunAhong1993 已提交
654
        self.Pad(node)
655

S
SunAhong1993 已提交
656 657
    def PadV2(self, node):
        self.Pad(node)
J
jiangjiajun 已提交
658 659

    def Squeeze(self, node):
S
SunAhong1993 已提交
660
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
661
        squeeze_dims = node.get_attr('squeeze_dims')
S
SunAhong1993 已提交
662
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
663 664
            kernel="paddle.squeeze",
            inputs={"x": input.name},
J
jiangjiajun 已提交
665
            outputs=[node.name],
S
SunAhong1993 已提交
666
            axis=squeeze_dims)
J
jiangjiajun 已提交
667 668

    def Shape(self, node):
S
SunAhong1993 已提交
669
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
670
        input_name = input.name
S
SunAhong1993 已提交
671
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
672
            kernel="paddle.shape",
J
jiangjiajun 已提交
673 674 675
            inputs={"input": input_name},
            outputs=[node.name])

S
SunAhong1993 已提交
676 677 678 679
    def Size(self, node):
        input = self.graph.get_input_node(node, 0)
        input_name = input.name
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
680
            kernel="paddle.shape",
S
SunAhong1993 已提交
681 682
            inputs={"input": input_name},
            outputs=[node.name])
S
fix  
SunAhong1993 已提交
683
        self.paddle_graph.add_layer(
684 685
            kernel="paddle.prod", inputs={"x": node.name}, outputs=[node.name])

S
SunAhong1993 已提交
686 687 688
    def Ceil(self, node):
        input = self.graph.get_input_node(node, 0)
        self.paddle_graph.add_layer(
689
            kernel="paddle.ceil", inputs={"x": input.name},
S
SunAhong1993 已提交
690 691
            outputs=[node.name])

J
jiangjiajun 已提交
692
    def ArgMax(self, node):
S
SunAhong1993 已提交
693 694
        input = self.graph.get_input_node(node, 0)
        axis = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
695 696
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
        axis = axis.value
S
SunAhong1993 已提交
697
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
698
            kernel="paddle.argmax",
J
jiangjiajun 已提交
699 700 701
            inputs={"x": input.name},
            outputs=[node.name],
            axis=axis)
702

S
SunAhong1993 已提交
703 704 705 706 707 708 709 710 711 712 713 714
    def TopKV2(self, node):
        input = self.graph.get_input_node(node, 0)
        k = self.graph.get_input_node(node, 1)
        assert k.layer_type == "Const", "ArgMax only support Const parameter"
        k = k.value
        sort = node.get_attr('sorted')
        self.paddle_graph.add_layer(
            kernel="paddle.topk",
            inputs={"x": input.name},
            outputs=[node.name],
            k=k,
            sorted=sort)
J
jiangjiajun 已提交
715 716

    def MatMul(self, node):
S
SunAhong1993 已提交
717 718
        x = self.graph.get_input_node(node, 0)
        y = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
719 720 721 722 723 724
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        if transpose_a is None:
            transpose_a = node.get_attr('adj_x')
        if transpose_b is None:
            transpose_b = node.get_attr('adj_y')
S
SunAhong1993 已提交
725
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
726
            kernel="paddle.matmul",
J
jiangjiajun 已提交
727 728 729 730 731 732 733 734 735 736 737
            inputs={"x": x.name,
                    "y": y.name},
            outputs=[node.name],
            transpose_x=transpose_a,
            transpose_y=transpose_b)

    def BatchMatMul(self, node):
        return self.MatMul(node)

    def BatchMatMulV2(self, node):
        return self.MatMul(node)
J
jiangjiajun@baidu.com 已提交
738

J
jiangjiajun 已提交
739
    def DepthwiseConv2dNative(self, node):
J
jiangjiajun 已提交
740 741
        input = self.graph.get_node(node.layer.input[0])
        kernel = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
742
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
743

J
jiangjiajun 已提交
744 745 746 747 748 749
        in_shape = input.out_shapes[0]
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
750

S
SunAhong1993 已提交
751 752
        if len(kernel.outputs) == 1:
            self.params[kernel.name] = numpy.transpose(self.params[kernel.name],
753 754
                                                       (2, 3, 0, 1))
            layer = self.paddle_graph.layers[self.params_output2id[kernel.name]]
S
SunAhong1993 已提交
755 756 757 758 759 760 761
            layer.attrs["shape"] = self.params[kernel.name].shape
        else:
            self.paddle_graph.add_layer(
                kernel="paddle.transpose",
                inputs={"x": kernel.name},
                outputs=[kernel.name],
                perm=[2, 3, 0, 1])
J
jiangjiajun 已提交
762

J
jiangjiajun 已提交
763 764
        input_name = input.name
        if data_format == "NHWC":
J
jiangjiajun 已提交
765 766 767
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
768
            transpose_name = gen_name('depthwise_conv2d', 'transpose')
S
SunAhong1993 已提交
769
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
770
                kernel="paddle.transpose",
J
jiangjiajun 已提交
771 772 773 774 775
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
776
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
777 778 779
            kernel="paddle.nn.functional.conv2d",
            inputs={"x": input_name,
                    "weight": kernel.name},
J
jiangjiajun 已提交
780 781 782 783 784
            outputs=[node.name],
            stride=strides[2:4],
            dilation=dilations[2:4],
            groups=k_size[3] * in_shape[1],
            padding=string(pad_mode),
S
SunAhong1993 已提交
785
            bias=None)
J
jiangjiajun 已提交
786 787

        if data_format == "NHWC":
S
SunAhong1993 已提交
788
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
789
                kernel="paddle.transpose",
J
jiangjiajun 已提交
790 791 792
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
793 794

    def AvgPool(self, node):
S
SunAhong1993 已提交
795
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
796

J
jiangjiajun 已提交
797 798 799 800 801
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()

J
jiangjiajun 已提交
802 803 804
        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("avg_pool", "transpose")
S
SunAhong1993 已提交
805
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
806
                kernel="paddle.transpose",
J
jiangjiajun 已提交
807 808 809
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
J
jiangjiajun 已提交
810
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
811
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
812
            input_name = transpose_name
813

S
SunAhong1993 已提交
814
        # TODO(syf): The op has diff.
J
jiangjiajun 已提交
815

S
SunAhong1993 已提交
816
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
817 818
            kernel="paddle.nn.functional.avg_pool2d",
            inputs={"x": input_name},
J
jiangjiajun 已提交
819
            outputs=[node.name],
S
SunAhong1993 已提交
820 821 822
            kernel_size=k_size[2:4],
            stride=strides[2:4],
            padding=string(pad_mode))
J
jiangjiajun 已提交
823 824

        if data_format == "NHWC":
S
SunAhong1993 已提交
825
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
826
                kernel="paddle.transpose",
J
jiangjiajun 已提交
827 828 829
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
830 831

    def Pack(self, node):
S
SunAhong1993 已提交
832 833 834 835
        inputs_list = list()
        for i in range(len(node.inputs)):
            inputs_list.append(self.graph.get_input_node(node, i))
        input_names = [i.name for i in inputs_list]
J
jiangjiajun 已提交
836
        axis = node.get_attr("axis")
S
SunAhong1993 已提交
837
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
838
            kernel="paddle.stack",
J
jiangjiajun 已提交
839 840 841 842
            inputs={"x": input_names},
            outputs=[node.name],
            axis=axis)
        if len(node.out_shapes[0]) == 1:
S
SunAhong1993 已提交
843
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
844
                kernel="paddle.reshape",
J
jiangjiajun 已提交
845 846 847 848 849
                inputs={"x": node.name},
                outputs=[node.name],
                shape=[-1])

    def Unpack(self, node):
S
SunAhong1993 已提交
850
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
851 852 853 854 855 856
        axis = node.get_attr("axis")
        num = node.get_attr("num")
        shape = input.out_shapes[0]
        input_name = input.name
        if len(shape) == 1:
            if shape[0] > 0 and num == shape[0]:
S
SunAhong1993 已提交
857
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
858 859
                    kernel="paddle.unsqueeze",
                    inputs={"x": input.name},
J
jiangjiajun 已提交
860
                    outputs=[node.name],
S
SunAhong1993 已提交
861
                    axis=[0])
J
jiangjiajun 已提交
862 863 864 865
                input_name = node.name
                axis = 1
            else:
                raise Exception("Unexpected situation happend in Unpack OP")
866 867 868
        layer_outputs = [
            "{}_p{}".format(node.layer_name, i) for i in range(num)
        ]
S
SunAhong1993 已提交
869 870
        if len(layer_outputs) == 1:
            layer_outputs[0] = "[{}]".format(node.layer_name)
S
SunAhong1993 已提交
871
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
872
            kernel="paddle.unstack",
J
jiangjiajun 已提交
873
            inputs={"x": input_name},
S
SunAhong1993 已提交
874
            outputs=layer_outputs,
J
jiangjiajun 已提交
875 876
            axis=axis,
            num=num)
J
jiangjiajun 已提交
877

J
jiangjiajun 已提交
878
    def ConcatV2(self, node):
S
SunAhong1993 已提交
879 880 881 882
        inputs_list = list()
        for i in range(len(node.inputs) - 1):
            inputs_list.append(self.graph.get_input_node(node, i))
        axis = self.graph.get_input_node(node, -1)
J
jiangjiajun 已提交
883 884 885
        assert axis.layer_type == "Const", "axis for ConcatV2 must be type Const"
        axis = axis.value
        if axis < 0:
S
SunAhong1993 已提交
886
            axis += len(inputs_list[0].out_shapes[0])
J
jiangjiajun 已提交
887

S
SunAhong1993 已提交
888
        input_names = [i.name for i in inputs_list]
S
SunAhong1993 已提交
889
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
890 891 892 893
            kernel="paddle.concat",
            inputs={"x": input_names},
            outputs=[node.name],
            axis=axis)
894

S
SunAhong1993 已提交
895 896 897 898 899 900 901 902 903
    def Concat(self, node):
        inputs_list = list()
        for i in range(1, len(node.inputs)):
            inputs_list.append(self.graph.get_input_node(node, i))
        axis = self.graph.get_input_node(node, 0)
        assert axis.layer_type == "Const", "axis for ConcatV2 must be type Const"
        axis = axis.value
        if axis < 0:
            axis += len(inputs_list[0].out_shapes[0])
904

S
SunAhong1993 已提交
905 906 907 908
        input_names = [i.name for i in inputs_list]
        self.paddle_graph.add_layer(
            kernel="paddle.concat",
            inputs={"x": input_names},
J
jiangjiajun 已提交
909 910
            outputs=[node.name],
            axis=axis)
911

S
SunAhong1993 已提交
912 913 914 915 916 917 918 919 920 921
    def AddN(self, node):
        inputs_list = list()
        for i in range(len(node.inputs) - 1):
            inputs_list.append(self.graph.get_input_node(node, i))

        input_names = [i.name for i in inputs_list]
        self.paddle_graph.add_layer(
            kernel="paddle.add_n",
            inputs={"inputs": input_names},
            outputs=[node.name])
J
jiangjiajun 已提交
922

J
jiangjiajun 已提交
923
    def StridedSlice(self, node):
S
SunAhong1993 已提交
924 925 926 927
        input = self.graph.get_input_node(node, 0)
        begin = self.graph.get_input_node(node, 1)
        end = self.graph.get_input_node(node, 2)
        strides = self.graph.get_input_node(node, 3)
J
jiangjiajun 已提交
928

J
jiangjiajun 已提交
929 930
        if strides.layer_type == "Const":
            strides = strides.value.tolist()
931
        else:
S
SunAhong1993 已提交
932
            strides = self.decoder.infer_tensor(strides)
J
jiangjiajun 已提交
933 934
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
935
        else:
S
SunAhong1993 已提交
936
            begin = self.decoder.infer_tensor(begin)
J
jiangjiajun 已提交
937 938
        if end.layer_type == "Const":
            end = end.value.tolist()
939
        else:
S
SunAhong1993 已提交
940
            end = self.decoder.infer_tensor(end)
941

J
jiangjiajun 已提交
942 943
        assert len(set(strides)) == 1 and strides[
            0] == 1, "Only support strides be 1 in StridedSlice OP"
J
jiangjiajun 已提交
944

J
jiangjiajun 已提交
945 946 947 948
        if len(begin) < len(input.out_shapes[0]):
            begin = begin + [0] * (len(input.out_shapes[0]) - len(begin))
        if len(end) < len(input.out_shapes[0]):
            end = end + [0] * (len(input.out_shapes[0]) - len(end))
J
jiangjiajun 已提交
949 950 951 952
        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

J
jiangjiajun 已提交
953 954 955 956
        begin_mask = node.get_attr('begin_mask')
        end_mask = node.get_attr('end_mask')
        ellipsis_mask = node.get_attr('ellipsis_mask')
        new_axis_mask = node.get_attr('new_axis_mask')
J
jiangjiajun 已提交
957
        shrink_axis_mask = node.get_attr('shrink_axis_mask')
J
jiangjiajun 已提交
958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988

        assert ellipsis_mask == 0, "(OP:{} Name:{})Only support ellipsis_mask be 0[now: {}] n StridedSlice OP".format(
            node.layer_type, node.layer.name, ellipsis_mask)

        # TODO codes without validation
        # Use it carefully
        new_begin = list()
        new_end = list()
        new_axes = list()
        shrink_axes = list()
        for i, item in enumerate(begin):
            mask = (new_axis_mask >> i) & 1
            if mask != 0:
                new_axes.append(i)
                continue

            mask = (shrink_axis_mask >> i) & 1
            if mask != 0:
                shrink_axes.append(i)

            mask = (begin_mask >> i) & 1
            if mask != 0:
                new_begin.append(0)
            else:
                new_begin.append(item)

            mask = (end_mask >> i) & 1
            if mask != 0:
                new_end.append(999999)
            else:
                new_end.append(end[i])
989

S
SunAhong1993 已提交
990 991 992 993 994 995
        if input.dtype == "bool":
            self.paddle_graph.add_layer(
                "paddle.cast",
                inputs={"x": input.name},
                outputs=[input.name],
                dtype=string("int32"))
J
jiangjiajun 已提交
996

S
SunAhong1993 已提交
997
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
998
            kernel="paddle.slice",
J
jiangjiajun 已提交
999 1000 1001 1002 1003
            inputs={"input": input.name},
            outputs=[node.name],
            axes=[i for i in range(len(new_begin))],
            starts=new_begin,
            ends=new_end)
1004

S
SunAhong1993 已提交
1005 1006 1007 1008 1009 1010 1011
        if input.dtype == "bool":
            self.paddle_graph.add_layer(
                "paddle.cast",
                inputs={"x": node.name},
                outputs=[node.name],
                dtype=string("bool"))

J
jiangjiajun 已提交
1012
        if len(new_axes) > 0:
S
SunAhong1993 已提交
1013
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1014 1015
                kernel="paddle.unsqueeze",
                inputs={"x": node.name},
J
jiangjiajun 已提交
1016
                outputs=[node.name],
S
SunAhong1993 已提交
1017
                axis=new_axes)
J
jiangjiajun 已提交
1018 1019 1020 1021
        if len(shrink_axes) > 0:
            if len(input.out_shapes[0]) + len(new_axes) <= 1:
                pass
            else:
S
SunAhong1993 已提交
1022
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1023 1024
                    kernel="paddle.squeeze",
                    inputs={"x": node.name},
J
jiangjiajun 已提交
1025
                    outputs=[node.name],
S
SunAhong1993 已提交
1026
                    axis=shrink_axes)
1027

S
SunAhong1993 已提交
1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
    def Prod(self, node):
        input = self.graph.get_input_node(node, 0)
        reduction_indices = self.graph.get_input_node(node, 1)
        assert reduction_indices.layer_type == "Const"
        keep_dims = node.get_attr('keep_dims')
        axis = reduction_indices.value

        self.paddle_graph.add_layer(
            kernel="paddle.prod",
            inputs={"x": input.name},
            outputs=[node.layer_name],
            keepdim=keep_dims,
            axis=axis)
J
jiangjiajun 已提交
1041 1042

    def Split(self, node):
S
SunAhong1993 已提交
1043 1044
        dim = self.graph.get_input_node(node, 0)
        input = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1045 1046 1047 1048
        assert dim.layer_type == "Const"
        num_split = node.get_attr('num_split')
        dim = dim.value

S
SunAhong1993 已提交
1049
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1050 1051
            kernel="paddle.split",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1052 1053 1054 1055
            outputs=[
                "{}_p{}".format(node.layer_name, i) for i in range(num_split)
            ],
            num_or_sections=num_split,
S
SunAhong1993 已提交
1056
            axis=dim)
1057

S
SunAhong1993 已提交
1058 1059 1060 1061 1062 1063 1064 1065
    def SplitV(self, node):
        input = self.graph.get_input_node(node, 0)
        size_splits = self.graph.get_input_node(node, 1)
        assert size_splits.layer_type == "Const", "size_splits of SplitV OP should be Const"
        size_splits = size_splits.value.tolist()
        dim = self.graph.get_input_node(node, 2)
        assert dim.layer_type == "Const", "dim of SplitV OP should be Const"
        dim = dim.value
1066

S
SunAhong1993 已提交
1067 1068 1069 1070
        self.paddle_graph.add_layer(
            kernel="paddle.split",
            inputs={"x": input.name},
            outputs=[
1071 1072
                "{}_p{}".format(node.layer_name, i)
                for i in range(len(size_splits))
S
SunAhong1993 已提交
1073 1074 1075
            ],
            num_or_sections=size_splits,
            axis=dim)
1076 1077

    def Slice(self, node):
S
SunAhong1993 已提交
1078 1079 1080
        input = self.graph.get_input_node(node, 0)
        begin = self.graph.get_input_node(node, 1)
        size = self.graph.get_input_node(node, 2)
J
jiangjiajun 已提交
1081 1082 1083

        inputs = {"x": input.name}
        attrs = {}
J
jiangjiajun 已提交
1084 1085
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
J
jiangjiajun 已提交
1086
            attrs['offsets'] = begin
J
jiangjiajun 已提交
1087
        else:
1088 1089
            begin = self.decoder.infer_tensor(
                begin, use_diff_inputs=False).tolist()
J
jiangjiajun 已提交
1090 1091
            attrs['offsets'] = begin
        if size.layer_type == "Const":
J
jiangjiajun 已提交
1092
            size = size.value.tolist()
J
jiangjiajun 已提交
1093 1094 1095 1096
            attrs['shape'] = size
        else:
            shape = size.out_shapes[0]
            reshape_name = gen_name("slice", "reshape")
S
SunAhong1993 已提交
1097
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1098
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1099 1100 1101 1102
                inputs={"x": size.name},
                outputs=[reshape_name],
                shape=shape)
            inputs['shape'] = reshape_name
S
SunAhong1993 已提交
1103
        self.paddle_graph.add_layer(
1104
            kernel="paddle.crop", inputs=inputs, outputs=[node.name], **attrs)
J
jiangjiajun 已提交
1105 1106

    def ResizeNearestNeighbor(self, node):
S
SunAhong1993 已提交
1107 1108
        input = self.graph.get_input_node(node, 0)
        resize_shape = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1109
        data_format = "NHWC"
S
SunAhong1993 已提交
1110
        inputs = {"x": input.name}
1111 1112 1113 1114 1115
        attrs = {
            "align_corners": node.get_attr("align_corners"),
            "mode": string("nearest"),
            "align_mode": 1
        }
J
jiangjiajun 已提交
1116 1117 1118

        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
S
SunAhong1993 已提交
1119
            attrs["size"] = resize_shape
J
jiangjiajun 已提交
1120
        else:
J
jiangjiajun 已提交
1121 1122
            shape = resize_shape.out_shapes[0]
            reshape_name = gen_name("resize_nearest", "reshape")
S
SunAhong1993 已提交
1123
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1124
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1125 1126 1127
                inputs={"x": resize_shape.name},
                outputs=[reshape_name],
                shape=shape)
S
SunAhong1993 已提交
1128
            inputs["size"] = reshape_name
J
jiangjiajun 已提交
1129 1130 1131

        if data_format == "NHWC":
            transpose_name = gen_name("resize_nearest", "reshape")
S
SunAhong1993 已提交
1132
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1133
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1134 1135 1136
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
S
SunAhong1993 已提交
1137
            inputs["x"] = transpose_name
J
jiangjiajun 已提交
1138

S
SunAhong1993 已提交
1139
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1140
            kernel="paddle.nn.functional.interpolate",
J
jiangjiajun 已提交
1141 1142 1143 1144 1145
            inputs=inputs,
            outputs=[node.name],
            **attrs)

        if data_format == "NHWC":
S
SunAhong1993 已提交
1146
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1147
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1148 1149 1150
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
1151

J
jiangjiajun 已提交
1152
    def ResizeBilinear(self, node):
S
SunAhong1993 已提交
1153 1154
        input = self.graph.get_input_node(node, 0)
        resize_shape = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1155
        data_format = "NHWC"
S
SunAhong1993 已提交
1156
        inputs = {"x": input.name}
1157 1158 1159 1160 1161
        attrs = {
            "align_corners": node.get_attr("align_corners"),
            "mode": string("bilinear"),
            "align_mode": 1
        }
J
jiangjiajun 已提交
1162

J
jiangjiajun 已提交
1163 1164
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
S
SunAhong1993 已提交
1165
            attrs["size"] = resize_shape
J
jiangjiajun 已提交
1166 1167 1168
        else:
            shape = resize_shape.out_shapes[0]
            reshape_name = gen_name("resize_bilinear", "reshape")
S
SunAhong1993 已提交
1169
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1170
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1171 1172 1173
                inputs={"x": resize_shape.name},
                outputs=[reshape_name],
                shape=shape)
S
SunAhong1993 已提交
1174
            inputs["size"] = reshape_name
J
jiangjiajun 已提交
1175 1176 1177

        if data_format == "NHWC":
            transpose_name = gen_name("resize_bilinear", "reshape")
S
SunAhong1993 已提交
1178
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1179
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1180 1181 1182
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
S
SunAhong1993 已提交
1183
            inputs["x"] = transpose_name
J
jiangjiajun 已提交
1184

S
SunAhong1993 已提交
1185
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1186
            kernel="paddle.nn.functional.interpolate",
J
jiangjiajun 已提交
1187 1188 1189 1190 1191
            inputs=inputs,
            outputs=[node.name],
            **attrs)

        if data_format == "NHWC":
S
SunAhong1993 已提交
1192
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1193
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1194 1195 1196 1197 1198
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])

    def Cast(self, node):
S
SunAhong1993 已提交
1199
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
1200
        dtype = node.dtype
S
SunAhong1993 已提交
1201
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1202
            kernel="paddle.cast",
J
jiangjiajun 已提交
1203 1204 1205 1206 1207
            inputs={"x": input.name},
            outputs=[node.name],
            dtype=string(dtype))

    def Sum(self, node):
S
SunAhong1993 已提交
1208 1209
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1210 1211 1212 1213
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

S
SunAhong1993 已提交
1214
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1215 1216
            kernel="paddle.sum",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1217
            outputs=[node.name],
S
SunAhong1993 已提交
1218 1219
            axis=dim,
            keepdim=keep_dims)
J
jiangjiajun 已提交
1220 1221

    def Max(self, node):
S
SunAhong1993 已提交
1222 1223
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1224 1225 1226
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()
S
SunAhong1993 已提交
1227
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1228 1229
            kernel="paddle.max",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1230
            outputs=[node.name],
S
SunAhong1993 已提交
1231 1232
            axis=dim,
            keepdim=keep_dims)
1233

J
jiangjiajun 已提交
1234
    def RandomUniform(self, node):
S
SunAhong1993 已提交
1235
        shape = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
1236 1237
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
S
SunAhong1993 已提交
1238
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1239
                kernel="paddle.uniform",
J
jiangjiajun 已提交
1240 1241 1242 1243 1244 1245
                inputs={},
                outputs=[node.name],
                shape=shape,
                min=0.0,
                max=0.9999)
        else:
S
SunAhong1993 已提交
1246
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1247
                kernel="paddle.uniform",
J
jiangjiajun 已提交
1248 1249 1250 1251
                inputs={'shape': shape.name},
                outputs=[node.name],
                min=0.0,
                max=0.9999)
1252 1253

    def Conv2DBackpropInput(self, node):
S
SunAhong1993 已提交
1254 1255 1256
        out_shape = self.graph.get_input_node(node, 0)
        kernel = self.graph.get_input_node(node, 1)
        input = self.graph.get_input_node(node, 2)
1257

1258
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
1259

J
jiangjiajun 已提交
1260 1261 1262
        if out_shape.layer_type == "Const":
            out_shape = out_shape.value.tolist()
        else:
1263 1264
            out_shape = self.decoder.infer_tensor(
                out_shape, out_shape=node.out_shapes[0])
J
jiangjiajun 已提交
1265

1266
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
1267
        if in_shape.count(-1) > 2:
1268 1269
            in_shape = self.decoder.infer_tensor(
                input, use_diff_inputs=False).shape
1270
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
1271
        if k_size.count(-1) > 2:
1272 1273
            k_size = self.decoder.infer_tensor(
                kernel, use_diff_inputs=False).shape
J
jiangjiajun 已提交
1274

J
jiangjiajun 已提交
1275
        pad_mode = node.get_attr("padding").decode()
1276 1277 1278
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
1279

S
SunAhong1993 已提交
1280 1281
        kernel_name = node.name + ".weight"
        self.params[kernel_name] = numpy.transpose(kernel.value, (3, 2, 0, 1))
J
jiangjiajun 已提交
1282 1283 1284

        input_name = input.name
        if data_format == "NHWC":
1285 1286 1287
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
1288
            transpose_name = gen_name("conv2dbackpropinput", "transpose")
S
SunAhong1993 已提交
1289
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1290
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1291 1292 1293 1294 1295
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
1296
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1297 1298 1299 1300 1301 1302
            kernel="paddle.static.create_parameter",
            inputs={},
            outputs=["{}_{}".format(node.name, kernel_name).replace(".", "_")],
            dtype=string(str(self.params[kernel_name].dtype)),
            shape=self.params[kernel_name].shape,
            name=string(kernel_name))
1303

S
SunAhong1993 已提交
1304 1305
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.conv2d_transpose",
1306 1307 1308 1309 1310
            inputs={
                "x": input_name,
                "weight":
                "{}_{}".format(node.name, kernel_name).replace(".", "_")
            },
J
jiangjiajun 已提交
1311
            outputs=[node.name],
S
SunAhong1993 已提交
1312
            bias=None,
J
jiangjiajun 已提交
1313 1314 1315 1316 1317 1318
            stride=strides[2:4],
            dilation=dilations[2:4],
            padding=string(pad_mode),
            output_size=out_shape[1:3])

        if data_format == "NHWC":
S
SunAhong1993 已提交
1319
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1320
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1321 1322 1323
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
1324

J
jiangjiajun 已提交
1325 1326
    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0])
S
SunAhong1993 已提交
1327
        repeat_times = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
1328 1329
        inputs = {"x": input.name}
        attr = dict()
S
SunAhong1993 已提交
1330 1331 1332
        if repeat_times.layer_type == "Const":
            repeat_times = repeat_times.value.tolist()
            attr["repeat_times"] = repeat_times
J
jiangjiajun 已提交
1333
        else:
S
SunAhong1993 已提交
1334
            inputs["repeat_times"] = repeat_times.name
1335

S
SunAhong1993 已提交
1336
        self.paddle_graph.add_layer(
1337 1338 1339 1340
            kernel="paddle.tile", inputs=inputs, outputs=[node.name], **attr)

        if not isinstance(repeat_times,
                          list) and repeat_times.layer_type != "Const":
S
SunAhong1993 已提交
1341 1342 1343 1344 1345
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
1346

J
jiangjiajun 已提交
1347 1348 1349 1350 1351 1352
    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0])
        limit = self.graph.get_node(node.layer.input[1])
        delta = self.graph.get_node(node.layer.input[2])
        inputs = dict()
        attr = dict()
1353

C
channingss 已提交
1354 1355 1356
        dtype = 'int32'
        if start.dtype.startswith('float'):
            dtype = start.dtype
J
jiangjiajun 已提交
1357 1358
        if start.layer_type == "Const":
            attr["start"] = start.value
1359
        else:
J
jiangjiajun 已提交
1360
            inputs["start"] = start.name
C
channingss 已提交
1361 1362
        if limit.dtype.startswith('float'):
            dtype = limit.dtype
J
jiangjiajun 已提交
1363 1364
        if limit.layer_type == "Const":
            attr["end"] = limit.value
J
jiangjiajun 已提交
1365
        else:
J
jiangjiajun 已提交
1366
            inputs["end"] = limit.name
C
channingss 已提交
1367 1368
        if delta.dtype.startswith('float'):
            dtype = delta.dtype
J
jiangjiajun 已提交
1369 1370
        if delta.layer_type == "Const":
            attr["step"] = delta.value
J
jiangjiajun 已提交
1371
        else:
J
jiangjiajun 已提交
1372
            inputs["step"] = delta.name
C
channingss 已提交
1373
        node.set_dtype(dtype)
J
jiangjiajun 已提交
1374 1375
        attr["dtype"] = string(node.dtype)

S
SunAhong1993 已提交
1376
        self.paddle_graph.add_layer(
1377
            kernel="paddle.arange", inputs=inputs, outputs=[node.name], **attr)
S
SunAhong1993 已提交
1378 1379 1380 1381 1382 1383 1384 1385
        if start.layer_type != "Const" or \
                limit.layer_type != "Const" or \
                delta.layer_type != "Const":
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
1386 1387

    def SquaredDifference(self, node):
S
SunAhong1993 已提交
1388 1389
        x = self.graph.get_input_node(node, 0)
        y = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1390 1391 1392
        inputs = {"x": x.name, "y": y.name}
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
S
SunAhong1993 已提交
1393
        # TODO(syf)
S
SunAhong1993 已提交
1394
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1395
            "paddle.subtract", inputs=inputs, outputs=[node.name])
1396 1397 1398 1399
        self.paddle_graph.layers[layer_id].input_shapes = {
            "x": x_shape,
            "y": y_shape
        }
J
jiangjiajun 已提交
1400 1401 1402 1403

        inputs = {"x": node.name, "y": node.name}
        x_shape = node.out_shapes[0]
        y_shape = node.out_shapes[0]
S
SunAhong1993 已提交
1404
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1405
            "paddle.multiply", inputs=inputs, outputs=[node.name])
1406 1407 1408 1409
        self.paddle_graph.layers[layer_id].input_shapes = {
            "x": x_shape,
            "y": y_shape
        }
J
jiangjiajun 已提交
1410 1411

    def OneHot(self, node):
S
SunAhong1993 已提交
1412 1413 1414 1415
        input = self.graph.get_input_node(node, 0)
        depth = self.graph.get_input_node(node, 1)
        on_value = self.graph.get_input_node(node, 2)
        off_value = self.graph.get_input_node(node, 3)
J
jiangjiajun 已提交
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427
        assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot'
        assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot'
        assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot'

        attr = {'depth': depth.value}
        on_value = on_value.value
        off_value = off_value.value
        assert math.fabs(on_value -
                         1.0) < 1e-06, "on_value should be 1 in OneHot"
        assert math.fabs(off_value -
                         0.0) < 1e-06, "off_value should be 0 in OneHot"

S
SunAhong1993 已提交
1428
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1429 1430
            "paddle.nn.functional.one_hot",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1431
            outputs=[node.name],
S
SunAhong1993 已提交
1432
            num_classes=depth.value)
J
jiangjiajun 已提交
1433 1434

    def Pow(self, node):
S
SunAhong1993 已提交
1435 1436
        x = self.graph.get_input_node(node, 0)
        factor = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1437 1438 1439
        inputs = {"x": x.name}
        attr = dict()
        if factor.layer_type == 'Const':
S
SunAhong1993 已提交
1440
            attr["y"] = factor.value.tolist()
J
jiangjiajun 已提交
1441
        else:
S
SunAhong1993 已提交
1442
            inputs["y"] = factor.name
S
SunAhong1993 已提交
1443
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1444
            "paddle.pow", inputs=inputs, outputs=[node.name], **attr)
J
jiangjiajun 已提交
1445 1446

    def All(self, node):
S
SunAhong1993 已提交
1447 1448
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1449 1450
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        attr = dict()
S
SunAhong1993 已提交
1451 1452
        attr["axis"] = reduce_idx.value.tolist()
        attr["keepdim"] = node.get_attr("keep_dims")
J
jiangjiajun 已提交
1453

J
jiangjiajun 已提交
1454 1455 1456
        input_name = input.name
        if input.dtype != "bool":
            input_name = gen_name("all", "cast")
S
SunAhong1993 已提交
1457
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1458
                "paddle.cast",
J
jiangjiajun 已提交
1459 1460 1461
                inputs={"x": input.name},
                outputs=[input_name],
                dtype=string("bool"))
S
SunAhong1993 已提交
1462
        self.paddle_graph.add_layer(
1463
            "paddle.all", inputs={"x": input_name}, outputs=[node.name], **attr)
J
jiangjiajun 已提交
1464 1465 1466 1467

        node.layer.attr['dtype'].type = 10

    def GatherV2(self, node):
S
SunAhong1993 已提交
1468 1469 1470
        embeddings = self.graph.get_input_node(node, 0)
        index = self.graph.get_input_node(node, 1)
        axis = self.graph.get_input_node(node, 2)
J
jiangjiajun 已提交
1471
        assert axis.layer_type == 'Const', "Only support Const parameter[axis]"
S
SunAhong1993 已提交
1472
        axis = axis.value
J
jiangjiajun 已提交
1473 1474 1475 1476
        index_name = index.name
        if len(index.out_shapes[0]) != 1:
            reshape_name = gen_name("gather", "reshape")
            index_name = reshape_name
S
SunAhong1993 已提交
1477
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1478
                "paddle.reshape",
J
jiangjiajun 已提交
1479 1480 1481
                inputs={"x": index.name},
                outputs=[reshape_name],
                shape=[-1])
S
SunAhong1993 已提交
1482
        inputs = {'x': embeddings.name, 'index': index_name}
S
SunAhong1993 已提交
1483
        self.paddle_graph.add_layer(
1484
            "paddle.gather", inputs=inputs, outputs=[node.name], axis=axis)
J
jiangjiajun 已提交
1485 1486
        if len(index.out_shapes[0]) != 1:
            out_shape = node.out_shapes[0]
S
SunAhong1993 已提交
1487
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1488
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1489 1490 1491
                inputs={"x": node.name},
                outputs=[node.name],
                shape=out_shape)
1492

S
SunAhong1993 已提交
1493 1494 1495 1496 1497
    def GatherNd(self, node):
        x = self.graph.get_input_node(node, 0)
        index = self.graph.get_input_node(node, 1)
        inputs = {'x': x.name, 'index': index.name}
        self.paddle_graph.add_layer(
1498
            "paddle.gather_nd", inputs=inputs, outputs=[node.name])
J
jiangjiajun 已提交
1499 1500

    def ExpandDims(self, node):
S
SunAhong1993 已提交
1501 1502 1503
        x = self.graph.get_input_node(node, 0, copy=True)
        y = self.graph.get_input_node(node, 1, copy=True)
        inputs = {"x": x.name}
J
jiangjiajun 已提交
1504 1505 1506 1507 1508
        attr = dict()
        if y.layer_type == 'Const':
            dim = y.value.tolist()
            if not isinstance(dim, list):
                dim = [dim]
S
SunAhong1993 已提交
1509
            attr['axis'] = dim
J
jiangjiajun 已提交
1510
        else:
S
SunAhong1993 已提交
1511
            inputs['axis'] = y.name
S
SunAhong1993 已提交
1512
        self.paddle_graph.add_layer(
1513 1514
            "paddle.unsqueeze", inputs=inputs, outputs=[node.name], **attr)

S
SunAhong1993 已提交
1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527
    def ReverseV2(self, node):
        x = self.graph.get_input_node(node, 0)
        axis = self.graph.get_input_node(node, 1)
        inputs = {"x": x.name}
        attr = dict()
        if axis.layer_type == 'Const':
            axis = axis.value.tolist()
            if not isinstance(axis, list):
                axis = [axis]
            attr['axis'] = axis
        else:
            inputs['axis'] = axis.name
        self.paddle_graph.add_layer(
1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585
            "paddle.flip", inputs=inputs, outputs=[node.name], **attr)

    def BatchToSpaceND(self, node):
        '''
        reshape->transpose->reshape->crop
        '''
        x = self.graph.get_input_node(node, 0)
        block_shape = self.graph.get_input_node(node, 1)
        crops = self.graph.get_input_node(node, 2)
        if block_shape.layer_type == "Const":
            block_shape = block_shape.value.tolist()
        if crops.layer_type == "Const":
            crops = crops.value.tolist()
        data_format = x.get_attr("data_format").decode()
        if data_format == "NHWC":
            n, h, w, c = x.out_shapes[0]
        else:
            n, c, h, w = x.out_shapes[0]
        input_name = x.name
        #reshape
        shape = block_shape + [-1, h, w, c]
        reshape_name = gen_name("batch_to_space", "reshape")
        self.paddle_graph.add_layer(
            kernel="paddle.reshape",
            inputs={"x": input_name},
            outputs=[reshape_name],
            shape=shape)
        #transpose
        perm = [len(block_shape)] + list(j for i in range(len(block_shape)) for j in (i + len(block_shape) + 1, i)) +\
                                    list(i + 2*len(block_shape) + 1 for i in range(len(x.out_shapes[0]) - len(block_shape) - 1))
        transpose_name = gen_name("batch_to_space", "transpose")
        self.paddle_graph.add_layer(
            kernel="paddle.transpose",
            inputs={"x": reshape_name},
            outputs=[transpose_name],
            perm=perm)
        #reshape
        shape = [-1] + list(i * j
                            for i, j in zip(block_shape, x.out_shapes[0][
                                1:])) + x.out_shapes[0][1 + len(block_shape):]
        reshape_name = gen_name("batch_to_space", "reshape")
        self.paddle_graph.add_layer(
            kernel="paddle.reshape",
            inputs={"x": transpose_name},
            outputs=[reshape_name],
            shape=shape)
        #crop
        attrs = {}
        crop_shape = shape
        crop_offsets = [0] * len(shape)
        for i in range(len(crops)):
            crop_shape[i + 1] = crop_shape[i + 1] - crops[i][0] - crops[i][1]
            crop_offsets[i + 1] = crops[i][0]
        attrs['shape'] = crop_shape
        attrs['offsets'] = crop_offsets
        self.paddle_graph.add_layer(
            kernel="paddle.crop",
            inputs={"x": reshape_name},
S
SunAhong1993 已提交
1586
            outputs=[node.name],
1587
            **attrs)
S
SunAhong1993 已提交
1588

1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638
    def SpaceToBatchND(self, node):
        '''
        zero-pad->reshape->transpose->reshape
        '''
        x = self.graph.get_input_node(node, 0)
        block_shape = self.graph.get_input_node(node, 1)
        paddings = self.graph.get_input_node(node, 2)
        if block_shape.layer_type == "Const":
            block_shape = block_shape.value.tolist()
        if paddings.layer_type == "Const":
            paddings = paddings.value.flatten().tolist()
        input_name = x.name
        #zero-pad
        constant_values = 0
        pad_name = gen_name("space_to_batch", "pad")
        paddings = [0, 0] + paddings + [0, 0]
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.pad",
            inputs={"x": input_name},
            outputs=[pad_name],
            pad=paddings,
            value=constant_values)
        #reshape
        n, h, w, c = x.out_shapes[0]
        h = h + paddings[2] + paddings[3]
        w = w + paddings[4] + paddings[5]
        shape = [
            n, h // block_shape[0], block_shape[0], w // block_shape[1],
            block_shape[1], c
        ]
        reshape_name = gen_name("space_to_batch", "reshape")
        self.paddle_graph.add_layer(
            kernel="paddle.reshape",
            inputs={"x": pad_name},
            outputs=[reshape_name],
            shape=shape)
        #transpose
        transpose_name = gen_name("space_to_batch", "transpose")
        self.paddle_graph.add_layer(
            kernel="paddle.transpose",
            inputs={"x": reshape_name},
            outputs=[transpose_name],
            perm=[2, 4, 0, 1, 3, 5])
        #reshape
        shape = [-1, h // block_shape[0], w // block_shape[1], c]
        self.paddle_graph.add_layer(
            kernel="paddle.reshape",
            inputs={"x": transpose_name},
            outputs=[node.name],
            shape=shape)