tf_op_mapper.py 55.0 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
J
jiangjiajun 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
J
jiangjiajun 已提交
14

S
SunAhong1993 已提交
15
from x2paddle.decoder.tf_decoder import TFGraph, TFGraphNode
S
SunAhong1993 已提交
16
from x2paddle.core.program import PaddleGraph 
J
jiangjiajun 已提交
17
from x2paddle.core.op_mapper import OpMapper
J
jiangjiajun 已提交
18
from x2paddle.core.util import *
J
jiangjiajun 已提交
19 20 21
from x2paddle import program
import traceback
import math
J
jiangjiajun 已提交
22
import inspect
J
jiangjiajun 已提交
23
import numpy
J
jiangjiajun 已提交
24
import sys
25

J
jiangjiajun 已提交
26 27 28 29 30 31 32 33 34 35 36 37
name_counter = dict()


def gen_name(op_name, var_name):
    name = "{}_{}".format(op_name, var_name)
    if name not in name_counter:
        name_counter[name] = 0
    else:
        name_counter[name] += 1
    name = name + '_' + str(name_counter[name])
    return name

J
jiangjiajun 已提交
38

J
jiangjiajun 已提交
39 40 41 42
# compute padding size for SAME mode
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
J
jiangjiajun 已提交
43 44
    if pad_size < 0:
        pad_size = 0
J
jiangjiajun 已提交
45 46 47 48
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]

J
jiangjiajun 已提交
49

J
jiangjiajun 已提交
50
class TFOpMapper(OpMapper):
J
jiangjiajun 已提交
51
    directly_map_ops = {
S
SunAhong1993 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        'Relu': ['paddle.nn.functional.relu'],
        'Relu6': ['paddle.nn.functional.relu6'],
        'Abs': ['paddle.abs'],
        'Sigmoid': ['paddle.nn.functional.sigmoid'],
        'Softmax': ['paddle.nn.functional.softmax'],
        'Exp': ['paddle.exp'],
        'Rsqrt': ['paddle.rsqrt'],
        'Sqrt': ['paddle.sqrt'],
        'swish_f32': ['paddle.nn.functional.swish'],
        'Tanh': ['paddle.tanh'],
        'Softplus': ['paddle.nn.functional.softplus'],
        'LeakyRelu': ['paddle.nn.functional.leaky_relu', 
                     dict(alpha='negative_slope')],
        'Floor': ['paddle.floor'],
        'Erf': ['paddle.erf'],
        'Square': ['paddle.square']
J
jiangjiajun 已提交
68 69
    }
    elementwise_ops = {
S
SunAhong1993 已提交
70 71 72 73
        'Add': 'paddle.add',
        'AddV2': 'paddle.add',
        'RealDiv': 'paddle.divide',
        'DivNoNan': 'paddle.divide',
S
SunAhong1993 已提交
74
        # TODO (syf): replace
S
SunAhong1993 已提交
75 76 77 78 79 80 81 82 83 84 85 86
        'Sub': 'fluid.layers.elementwise_sub',
        'Maximum': 'paddle.maximum',
        'Minimum': 'paddle.minimum',
        'LessEqual': 'paddle.less_equal',
        'GreaterEqual': 'paddle.greater_equal',
        'Greater': 'paddle.greater_than',
        'NotEqual': 'paddle.not_equal',
        'Equal': 'paddle.equal',
        'Mul': 'paddle.multiply',
        'FloorDiv': 'paddle.floor_divide',
        'FloorMod': 'paddle.floor_mod',
        'LogicalAnd': 'logical_and',
J
jiangjiajun 已提交
87 88
    }

J
jiangjiajun 已提交
89 90
    def __init__(self, decoder):
        super(TFOpMapper, self).__init__()
J
jiangjiajun 已提交
91
        self.decoder = decoder
J
jiangjiajun 已提交
92
        self.graph = decoder.tf_graph
S
SunAhong1993 已提交
93 94
        if not self.op_checker():
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
95 96
        self.params = dict()
        self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="static", source_type="tf")
97

J
jiangjiajun 已提交
98 99
        not_placeholder = list()
        for name in self.graph.input_nodes:
J
jiangjiajun 已提交
100 101 102 103 104
            if self.graph.get_node(
                    name).layer_type != "Placeholder" and self.graph.get_node(
                        name
                    ).layer_type != "OneShotIterator" and self.graph.get_node(
                        name).layer_type != "IteratorV2":
J
jiangjiajun 已提交
105 106 107 108
                not_placeholder.append(name)
        for name in not_placeholder:
            idx = self.graph.input_nodes.index(name)
            del self.graph.input_nodes[idx]
J
jiangjiajun 已提交
109

S
SunAhong1993 已提交
110 111
        self.paddle_graph.inputs = self.graph.input_nodes
        self.paddle_graph.outputs = self.graph.output_nodes
J
jiangjiajun 已提交
112

S
SunAhong1993 已提交
113 114 115 116 117 118
        print("Total nodes: {}".format(
            sum([
                isinstance(node, TFGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
119
        for i, node_name in enumerate(self.graph.topo_sort):
J
jiangjiajun 已提交
120
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
121 122
            node = self.graph.get_node(node_name)
            op = node.layer_type
J
jiangjiajun 已提交
123 124 125 126 127
            if op in self.directly_map_ops:
                self.directly_map(node)
            elif op in self.elementwise_ops:
                self.elementwise_map(node)
            elif hasattr(self, op):
J
jiangjiajun 已提交
128
                func = getattr(self, op)
S
SunAhong1993 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141
                func(node)
        print("\nNodes converted.")
        self.paddle_graph.set_name(self.graph.graph_name)
        self.paddle_graph.set_parameters(self.params)
        
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if not hasattr(self, op) and \
                op not in self.directly_map_ops and \
                op not in self.elementwise_ops:
J
jiangjiajun 已提交
142
                unsupported_ops.add(op)
S
SunAhong1993 已提交
143 144 145 146 147 148
        if len(unsupported_ops) == 0:
            return True
        else:
            if len(unsupported_ops) > 0:
                print("\n========= {} OPs are not supported yet ===========".format(
                    len(unsupported_ops)))
J
jiangjiajun 已提交
149
            for op in unsupported_ops:
J
jiangjiajun 已提交
150
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
151
            return False
J
jiangjiajun 已提交
152

J
jiangjiajun 已提交
153 154 155
    def directly_map(self, node):
        assert node.layer_type in self.directly_map_ops
        op_info = self.directly_map_ops[node.layer_type]
J
jiangjiajun 已提交
156
        input = self.graph.get_node(node.layer.input[0])
J
jiangjiajun 已提交
157 158 159 160 161 162
        attr = dict()
        for param in op_info[1:]:
            tf_param_name = list(param.keys())[0]
            pd_param_name = list(param.values())[0]
            tf_param = node.get_attr(tf_param_name)
            attr[pd_param_name] = tf_param
J
jiangjiajun 已提交
163

S
SunAhong1993 已提交
164
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
165
            kernel=op_info[0],
J
jiangjiajun 已提交
166 167 168
            inputs={"x": input.name},
            outputs=[node.name],
            **attr)
J
jiangjiajun 已提交
169 170 171 172

    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
J
jiangjiajun 已提交
173 174
        x = self.graph.get_node(node.layer.input[0])
        y = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
175 176
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
S
SunAhong1993 已提交
177
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
178
            kernel=op_type,
J
jiangjiajun 已提交
179 180 181
            inputs={"x": x.name,
                    "y": y.name},
            outputs=[node.name])
S
SunAhong1993 已提交
182
        self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
J
jiangjiajun 已提交
183

184 185
    def Placeholder(self, node):
        shape = node.out_shapes[0]
J
jiangjiajun 已提交
186 187
        assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
            node.layer_name)
188
        dtype = node.dtype
S
SunAhong1993 已提交
189
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
190
            kernel="paddle.static.data",
J
jiangjiajun 已提交
191 192 193 194 195
            inputs={},
            outputs=[node.name],
            dtype=string(dtype),
            shape=shape,
            name=string(node.name))
J
jiangjiajun@baidu.com 已提交
196

J
jiangjiajun 已提交
197 198 199 200 201 202 203
    def Const(self, node):
        shape = node.out_shapes[0]
        dtype = node.dtype
        value = node.value
        if len(shape) == 0:
            assert value.size == 1, "Unexpected situation happend"
            shape = [1]
J
jiangjiajun 已提交
204 205
            if value == float('inf'):
                value = "float('inf')"
S
SunAhong1993 已提交
206
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
207
                kernel="paddle.full",
C
channingss 已提交
208 209 210 211
                inputs={},
                outputs=[node.name],
                dtype=string(dtype),
                shape=[1],
S
SunAhong1993 已提交
212
                fill_value=value)
C
channingss 已提交
213
            return
J
jiangjiajun 已提交
214

S
SunAhong1993 已提交
215 216
        self.params[node.name] = node.value
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
217
            kernel="paddle.static.create_parameter",
J
jiangjiajun 已提交
218 219 220 221 222
            inputs={},
            outputs=[node.name],
            dtype=string(dtype),
            shape=shape,
            name=string(node.name),
S
SunAhong1993 已提交
223
            default_initializer="paddle.nn.initializer.Constant(value=0.0)")
J
jiangjiajun 已提交
224 225

    def Transpose(self, node):
J
jiangjiajun 已提交
226 227
        input = self.graph.get_node(node.layer.input[0])
        perm = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
228
        assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
J
jiangjiajun 已提交
229 230
        perm = perm.value.tolist()

S
SunAhong1993 已提交
231
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
232
            kernel="paddle.transpose",
J
jiangjiajun 已提交
233 234 235 236 237 238 239 240 241 242 243 244
            inputs={"x": input.name},
            outputs=[node.name],
            perm=perm)

    def Fill(self, node):
        dims = self.graph.get_node(node.layer.input[0])
        input_value = self.graph.get_node(node.layer.input[1])
        inputs = dict()
        attr = dict()
        assert input_value.layer_type == "Const", "Value of fill OP should be Const"
        if dims.layer_type == "Const":
            attr["shape"] = dims.value.tolist()
J
jiangjiajun 已提交
245
        else:
J
jiangjiajun 已提交
246 247
            inputs["shape"] = dims.name
        attr["dtype"] = string(input_value.dtype)
S
SunAhong1993 已提交
248
        attr["fill_value"] = input_value.value
J
jiangjiajun 已提交
249

S
SunAhong1993 已提交
250
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
251
            "paddle.full",
J
jiangjiajun 已提交
252 253 254
            inputs=inputs,
            outputs=[node.name],
            **attr)
S
SunAhong1993 已提交
255 256 257 258 259 260
        if dims.layer_type != "Const":
            self.paddle_graph.add_layer(
                "paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
261

J
jiangjiajun 已提交
262 263 264 265 266 267 268 269 270 271 272 273 274
    def DepthToSpace(self, node):
        input = self.graph.get_node(node.layer.input[0])

        block_size = node.get_attr("block_size")
        data_format = node.get_attr("data_format").decode()
        if data_format == "NHWC":
            n, h, w, c = input.out_shapes[0]
        else:
            n, c, h, w = input.out_shapes[0]

        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("depth_to_space", "transpose")
S
SunAhong1993 已提交
275
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
276
                kernel="paddle.transpose",
J
jiangjiajun 已提交
277 278 279 280 281 282 283
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

        shape = [0, block_size * block_size, -1, h, w]
        reshape_name = gen_name("depth_to_space", "reshape")
S
SunAhong1993 已提交
284
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
285
            kernel="paddle.reshape",
J
jiangjiajun 已提交
286 287 288 289 290
            inputs={"x": input_name},
            outputs=[reshape_name],
            shape=shape)

        transpose_name = gen_name("depth_to_space", "transpose")
S
SunAhong1993 已提交
291
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
292
            kernel="paddle.transpose",
J
jiangjiajun 已提交
293 294 295 296 297
            inputs={"x": reshape_name},
            outputs=[transpose_name],
            perm=[0, 2, 1, 3, 4])

        reshape_name = gen_name("depth_to_space", "reshape")
S
SunAhong1993 已提交
298
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
299
            kernel="paddle.reshape",
J
jiangjiajun 已提交
300 301 302 303
            inputs={"x": transpose_name},
            outputs=[reshape_name],
            shape=[0, c, h, w])

S
SunAhong1993 已提交
304
        self.paddle_graph.add_layer(
J
jiangjiajun 已提交
305 306 307 308 309 310
            kernel="fluid.layers.pixel_shuffle",
            inputs={"x": reshape_name},
            outputs=[node.name],
            upscale_factor=block_size)

        if data_format == "NHWC":
S
SunAhong1993 已提交
311
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
312
                kernel="paddle.transpose",
J
jiangjiajun 已提交
313 314 315
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
S
add beg  
SunAhong1993 已提交
316
            
S
SunAhong1993 已提交
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
    def Where(self, node):
        if len(node.layer.input) == 1:
            cond = self.graph.get_input_node(node, 0)
            self.paddle_graph.add_layer(
                "paddle.nonzero",
                inputs={"x": cond.name},
                outputs=[node.name])
        else:
            cond = self.graph.get_input_node(node, 0)
            x = self.graph.get_input_node(node, 1)
            y = self.graph.get_input_node(node, 2)
            self.paddle_graph.add_layer(
                "paddle.where",
                inputs={"condition": cond.name,
                        "x": x.name,
                        "y": y.name},
                outputs=[node.name])
            
S
add beg  
SunAhong1993 已提交
335 336 337 338 339 340 341 342
    def Neg(self, node):
        input = self.graph.get_input_node(node, 0)
        
        self.paddle_graph.add_layer(
            "paddle.scale",
            inputs={"x": input.name},
            outputs=[node.name],
            scale=-1)
J
jiangjiajun 已提交
343 344 345

    def MaxPool(self, node):
        input = self.graph.get_node(node.layer.input[0])
J
jiangjiajun 已提交
346

J
jiangjiajun 已提交
347 348 349 350 351
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()

J
jiangjiajun 已提交
352 353 354
        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("max_pool", "transpose")
S
SunAhong1993 已提交
355
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
356
                kernel="paddle.transpose",
J
jiangjiajun 已提交
357 358 359
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
J
jiangjiajun 已提交
360
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
361
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
362 363
            input_name = transpose_name

S
SunAhong1993 已提交
364
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
365 366
            kernel="paddle.nn.functional.max_pool2d",
            inputs={"x": input_name},
J
jiangjiajun 已提交
367
            outputs=[node.name],
S
SunAhong1993 已提交
368 369 370
            kernel_size=k_size[2:4],
            stride=strides[2:4],
            padding=string(pad_mode))
J
jiangjiajun 已提交
371 372

        if data_format == "NHWC":
S
SunAhong1993 已提交
373
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
374
                kernel="paddle.transpose",
J
jiangjiajun 已提交
375 376 377
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
378 379

    def Conv2D(self, node):
J
jiangjiajun 已提交
380 381
        input = self.graph.get_node(node.layer.input[0])
        kernel = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
382

J
jiangjiajun 已提交
383 384 385 386 387
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
388 389 390 391
        if data_format == "NHWC":
            n, h, w, c = input.out_shapes[0]
        else:
            n, c, h, w = input.out_shapes[0]
J
jiangjiajun 已提交
392

J
jiangjiajun 已提交
393 394 395 396
        if kernel.layer_type == 'Const':
            kernel_value = kernel.value
            kernel_weight_name = kernel.name.replace('/', '_')
        else:
S
SunAhong1993 已提交
397
            kernel_value = self.decoder.infer_tensor(kernel, use_diff_inputs=False)
J
jiangjiajun 已提交
398 399 400 401 402
            if kernel.layer_type == 'Split':
                kernel_weight_name = "{}_{}_kernel".format(node.name,
                                                           kernel.name)
            else:
                kernel_weight_name = kernel.name.replace('/', '_')
S
SunAhong1993 已提交
403
        self.params[kernel_weight_name] = numpy.transpose(kernel_value,
S
SunAhong1993 已提交
404 405 406 407 408 409 410 411 412
                                                          (3, 2, 0, 1))
        self.paddle_graph.add_layer(
            kernel="paddle.static.nn.create_parameter",
            inputs={},
            outputs=[kernel_weight_name],
            shape=self.params[kernel_weight_name].shape,
            dtype=string(str(self.params[kernel_weight_name].dtype)),
            name=string(kernel_weight_name))
        
J
jiangjiajun 已提交
413 414
        input_name = input.name
        if data_format == "NHWC":
J
jiangjiajun 已提交
415 416
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
417
            transpose_name = gen_name("conv2d", "transpose")
S
SunAhong1993 已提交
418
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
419
                kernel="paddle.transpose",
J
jiangjiajun 已提交
420 421 422 423 424 425 426 427 428
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

        if c == -1:
            attr = {"shape": [0, k_size[2], 0, 0]}
            node.fluid_code.add_layer(
                "reshape", inputs=input, output=input, param_attr=attr)
S
SunAhong1993 已提交
429
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
430
                kernel="paddle.reshape",
J
jiangjiajun 已提交
431 432 433 434
                inputs={"x": input_name},
                outputs=[input_name],
                shape=[0, k_size[2], 0, 0])

S
SunAhong1993 已提交
435
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
436 437
            kernel="paddle.nn.functional.conv2d",
            inputs={"x": input_name, "weight": kernel_weight_name},
J
jiangjiajun 已提交
438
            outputs=[node.name],
S
SunAhong1993 已提交
439
            bias=None,
J
jiangjiajun 已提交
440 441 442 443 444
            stride=strides[2:4],
            dilation=dilations[2:4],
            padding=string(pad_mode))

        if data_format == "NHWC":
S
SunAhong1993 已提交
445
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
446
                kernel="paddle.transpose",
J
jiangjiajun 已提交
447 448 449
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
S
SunAhong1993 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
            
    def Conv3D(self, node):
        input = self.graph.get_input_node(node, 0)
        kernel = self.graph.get_input_node(node, 1)

        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
        if data_format == "NDHWC":
            n, d, h, w, c = input.out_shapes[0]
        else:
            n, c, d, h, w = input.out_shapes[0]

        if kernel.layer_type == 'Const':
            kernel_value = kernel.value
            kernel_weight_name = kernel.name.replace('/', '_')
        else:
            kernel_value = self.decoder.infer_tensor(kernel, use_diff_inputs=False)
            if kernel.layer_type == 'Split':
                kernel_weight_name = "{}_{}_kernel".format(node.name,
                                                           kernel.name)
            else:
                kernel_weight_name = kernel.name.replace('/', '_')
S
SunAhong1993 已提交
475 476 477 478 479 480 481 482 483
        self.params[kernel_weight_name] = numpy.transpose(kernel_value,
                                                          (4, 3, 0, 1, 2))
        self.paddle_graph.add_layer(
            kernel="paddle.static.nn.create_parameter",
            inputs={},
            outputs=[kernel_weight_name],
            shape=self.params[kernel_weight_name].shape,
            dtype=string(str(self.params[kernel_weight_name].dtype)),
            name=string(kernel_weight_name))
S
SunAhong1993 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506
        
        input_name = input.name
        if data_format == "NDHWC":
            strides = [strides[i] for i in [0, 4, 1, 2, 3]]
            dilations = [dilations[i] for i in [0, 4, 1, 2, 3]]
            transpose_name = gen_name("conv3d", "transpose")
            self.paddle_graph.add_layer(
                kernel="paddle.transpose",
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 4, 1, 2, 3])
            input_name = transpose_name

        if c == -1:
            attr = {"shape": [0, k_size[2], 0, 0, 0]}
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": input_name},
                outputs=[input_name],
                shape=[0, k_size[2], 0, 0, 0])        
            
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.conv3d",
S
SunAhong1993 已提交
507
            inputs={"x": input_name,  "weight": kernel_weight_name},
S
SunAhong1993 已提交
508 509 510 511 512 513 514 515 516 517 518 519
            outputs=[node.name],
            bias=None,
            stride=strides[2:5],
            dilation=dilations[2:5],
            padding=string(pad_mode))

        if data_format == "NDHWC":
            self.paddle_graph.add_layer(
                kernel="paddle.transpose",
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 4, 1])
J
jiangjiajun 已提交
520

J
jiangjiajun 已提交
521
    def BiasAdd(self, node):
J
jiangjiajun 已提交
522 523
        input = self.graph.get_node(node.layer.input[0])
        bias = self.graph.get_node(node.layer.input[1])
S
SunAhong1993 已提交
524
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
525
            kernel="paddle.add",
J
jiangjiajun 已提交
526 527 528
            inputs={"x": input.name,
                    "y": bias.name},
            outputs=[node.name])
J
jiangjiajun 已提交
529 530

    def FusedBatchNorm(self, node):
J
jiangjiajun 已提交
531 532 533 534 535
        input = self.graph.get_node(node.layer.input[0])
        gamma = self.graph.get_node(node.layer.input[1])
        beta = self.graph.get_node(node.layer.input[2])
        moving_mean = self.graph.get_node(node.layer.input[3])
        moving_var = self.graph.get_node(node.layer.input[4])
J
jiangjiajun 已提交
536
        data_format = node.get_attr("data_format").decode()
J
jiangjiajun 已提交
537 538 539 540 541

        assert gamma.layer_type == "Const"
        assert beta.layer_type == "Const"
        assert moving_mean.layer_type == "Const"
        assert moving_var.layer_type == "Const"
J
jiangjiajun 已提交
542 543 544 545

        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("batch_norm", "transpose")
S
SunAhong1993 已提交
546
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
547
                kernel="paddle.transpose",
J
jiangjiajun 已提交
548 549 550 551 552
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
553
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
554 555 556 557 558 559
            kernel="paddle.nn.functional.batch_norm",
            inputs={"x": input_name,
                    "running_mean": moving_mean.name,
                    "running_var": moving_var.name,
                    "weight": gamma.name,
                    "bias": beta.name},
J
jiangjiajun 已提交
560
            outputs=[node.name],
S
SunAhong1993 已提交
561
            epsilon=node.get_attr("epsilon"))
J
jiangjiajun 已提交
562 563

        if data_format == "NHWC":
S
SunAhong1993 已提交
564
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
565
                kernel="paddle.transpose",
J
jiangjiajun 已提交
566 567 568
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
S
SunAhong1993 已提交
569 570 571
            
    def FusedBatchNormV3(self, node):
        self.FusedBatchNorm(node)
J
jiangjiajun 已提交
572 573 574 575 576 577 578 579

    def Mean(self, node):
        input = self.graph.get_node(node.layer.input[0])
        reduce_idx = self.graph.get_node(node.layer.input[1])
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        dims = reduce_idx.value.tolist()
        keep_dims = node.get_attr("keep_dims")

S
SunAhong1993 已提交
580
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
581 582
            kernel="paddle.mean",
            inputs={"x": input.name},
J
jiangjiajun 已提交
583
            outputs=[node.name],
S
SunAhong1993 已提交
584 585
            axis=dims,
            keepdim=keep_dims)
J
jiangjiajun 已提交
586 587

    def Reshape(self, node):
S
SunAhong1993 已提交
588 589
        input = self.graph.get_input_node(node, 0)
        param = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
590 591 592 593 594

        input_name = input.name

        if param.layer_type == "Const":
            shape = param.value.tolist()
S
SunAhong1993 已提交
595
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
596
                kernel="paddle.reshape",
J
jiangjiajun 已提交
597 598 599 600
                inputs={"x": input_name},
                outputs=[node.name],
                shape=shape)
        else:
S
SunAhong1993 已提交
601
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
602
                kernel="paddle.reshape",
J
jiangjiajun 已提交
603 604 605 606 607 608 609
                inputs={"x": input_name,
                        "shape": param.name},
                outputs=[node.name])
        if param.layer_type != "Const":
            out_shape = numpy.array(node.out_shapes[0])
            if (out_shape > 0).any():
                out_shape[out_shape < 0] = 0
S
SunAhong1993 已提交
610
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
611
                    kernel="paddle.reshape",
J
jiangjiajun 已提交
612 613 614 615 616 617 618 619 620 621 622 623 624 625
                    inputs={"x": node.name},
                    outputs=[node.name],
                    shape=out_shape.tolist())

    def Pad(self, node):
        input = self.graph.get_node(node.layer.input[0])
        paddings = self.graph.get_node(node.layer.input[1])
        assert paddings.layer_type == "Const", "Padding should be Const"
        paddings = paddings.value.flatten().tolist()

        if len(input.out_shapes[0]) == 4:
            if paddings[0] + paddings[1] + paddings[6] + paddings[7] == 0:
                new_padding = paddings[2:6]
                transpose_name = gen_name("pad", "transpose")
S
SunAhong1993 已提交
626
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
627
                    kernel="paddle.transpose",
J
jiangjiajun 已提交
628 629 630
                    inputs={"x": input.name},
                    outputs=[transpose_name],
                    perm=[0, 3, 1, 2])
S
SunAhong1993 已提交
631
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
632 633
                    kernel="paddle.nn.functional.pad",
                    inputs={"x": transpose_name},
J
jiangjiajun 已提交
634
                    outputs=[node.name],
S
SunAhong1993 已提交
635
                    pad=new_padding)
S
SunAhong1993 已提交
636
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
637
                    kernel="paddle.transpose",
J
jiangjiajun 已提交
638 639 640 641 642
                    inputs={"x": node.name},
                    outputs=[node.name],
                    perm=[0, 2, 3, 1])
                return

S
SunAhong1993 已提交
643
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
644
            kernel="paddle.nn.functional.pad",
S
SunAhong1993 已提交
645
            inputs={"x": input.name},
J
jiangjiajun 已提交
646
            outputs=[node.name],
S
SunAhong1993 已提交
647
            pad=paddings)
S
SunAhong1993 已提交
648 649 650 651 652 653 654 655 656 657 658 659 660
        
    def MirrorPad(self, node):
        input = self.graph.get_input_node(node, 0)
        paddings = self.graph.get_input_node(node, 1)
        assert paddings.layer_type == "Const", "Padding should be Const"
        paddings = np.flip(paddings.value, 0).flatten().tolist()
        transpose_name = gen_name("pad", "transpose")
        self.paddle_graph.add_layer(
            kernel="paddle.transpose",
            inputs={"x": input.name},
            outputs=[transpose_name],
            perm=[0, 3, 1, 2])
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
661
            kernel="paddle.nn.functional.pad".format(dim),
S
SunAhong1993 已提交
662
            inputs={"x": transpose_name},
S
SunAhong1993 已提交
663
            outputs=[node.name],
S
SunAhong1993 已提交
664 665 666 667 668 669
            pad=new_padding)
        self.paddle_graph.add_layer(
            kernel="paddle.transpose",
            inputs={"x": node.name},
            outputs=[node.name],
            perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
670 671

    def Squeeze(self, node):
S
SunAhong1993 已提交
672
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
673
        squeeze_dims = node.get_attr('squeeze_dims')
S
SunAhong1993 已提交
674
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
675 676
            kernel="paddle.squeeze",
            inputs={"x": input.name},
J
jiangjiajun 已提交
677
            outputs=[node.name],
S
SunAhong1993 已提交
678
            axis=squeeze_dims)
J
jiangjiajun 已提交
679 680

    def Shape(self, node):
S
SunAhong1993 已提交
681
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
682
        input_name = input.name
S
SunAhong1993 已提交
683
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
684
            kernel="paddle.shape",
J
jiangjiajun 已提交
685 686 687
            inputs={"input": input_name},
            outputs=[node.name])

S
SunAhong1993 已提交
688 689 690 691
    def Size(self, node):
        input = self.graph.get_input_node(node, 0)
        input_name = input.name
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
692
            kernel="paddle.shape",
S
SunAhong1993 已提交
693 694
            inputs={"input": input_name},
            outputs=[node.name])
S
fix  
SunAhong1993 已提交
695 696 697 698
        self.paddle_graph.add_layer(
            kernel="paddle.prod",
            inputs={"x": node.name},
            outputs=[node.name])
S
SunAhong1993 已提交
699 700 701 702 703 704 705 706
        
    def Ceil(self, node):
        input = self.graph.get_input_node(node, 0)
        self.paddle_graph.add_layer(
            kernel="paddle.ceil",
            inputs={"x": input.name},
            outputs=[node.name])

J
jiangjiajun 已提交
707
    def ArgMax(self, node):
S
SunAhong1993 已提交
708 709
        input = self.graph.get_input_node(node, 0)
        axis = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
710 711
        assert axis.layer_type == "Const", "ArgMax only support Const parameter"
        axis = axis.value
S
SunAhong1993 已提交
712
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
713
            kernel="paddle.argmax",
J
jiangjiajun 已提交
714 715 716
            inputs={"x": input.name},
            outputs=[node.name],
            axis=axis)
S
SunAhong1993 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729
        
    def TopKV2(self, node):
        input = self.graph.get_input_node(node, 0)
        k = self.graph.get_input_node(node, 1)
        assert k.layer_type == "Const", "ArgMax only support Const parameter"
        k = k.value
        sort = node.get_attr('sorted')
        self.paddle_graph.add_layer(
            kernel="paddle.topk",
            inputs={"x": input.name},
            outputs=[node.name],
            k=k,
            sorted=sort)
J
jiangjiajun 已提交
730 731

    def MatMul(self, node):
S
SunAhong1993 已提交
732 733
        x = self.graph.get_input_node(node, 0)
        y = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
734 735 736 737 738 739
        transpose_a = node.get_attr('transpose_a')
        transpose_b = node.get_attr('transpose_b')
        if transpose_a is None:
            transpose_a = node.get_attr('adj_x')
        if transpose_b is None:
            transpose_b = node.get_attr('adj_y')
S
SunAhong1993 已提交
740
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
741
            kernel="paddle.matmul",
J
jiangjiajun 已提交
742 743 744 745 746 747 748 749 750 751 752
            inputs={"x": x.name,
                    "y": y.name},
            outputs=[node.name],
            transpose_x=transpose_a,
            transpose_y=transpose_b)

    def BatchMatMul(self, node):
        return self.MatMul(node)

    def BatchMatMulV2(self, node):
        return self.MatMul(node)
J
jiangjiajun@baidu.com 已提交
753

J
jiangjiajun 已提交
754
    def DepthwiseConv2dNative(self, node):
J
jiangjiajun 已提交
755 756
        input = self.graph.get_node(node.layer.input[0])
        kernel = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
757
        assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
J
jiangjiajun 已提交
758

J
jiangjiajun 已提交
759 760 761 762 763 764
        in_shape = input.out_shapes[0]
        k_size = kernel.out_shapes[0]
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()
J
jiangjiajun 已提交
765

S
SunAhong1993 已提交
766 767 768 769 770
        self.paddle_graph.add_layer(
            kernel="paddle.transpose",
            inputs={"x": kernel.name},
            outputs=[kernel.name],
            perm=[2, 3, 0, 1])
J
jiangjiajun 已提交
771

J
jiangjiajun 已提交
772 773
        input_name = input.name
        if data_format == "NHWC":
J
jiangjiajun 已提交
774 775 776
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
777
            transpose_name = gen_name('depthwise_conv2d', 'transpose')
S
SunAhong1993 已提交
778
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
779
                kernel="paddle.transpose",
J
jiangjiajun 已提交
780 781 782 783 784
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
785
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
786 787 788
            kernel="paddle.nn.functional.conv2d",
            inputs={"x": input_name,
                    "weight": kernel.name},
J
jiangjiajun 已提交
789 790 791 792 793
            outputs=[node.name],
            stride=strides[2:4],
            dilation=dilations[2:4],
            groups=k_size[3] * in_shape[1],
            padding=string(pad_mode),
S
SunAhong1993 已提交
794
            bias=None)
J
jiangjiajun 已提交
795 796

        if data_format == "NHWC":
S
SunAhong1993 已提交
797
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
798
                kernel="paddle.transpose",
J
jiangjiajun 已提交
799 800 801
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
802 803

    def AvgPool(self, node):
S
SunAhong1993 已提交
804
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
805

J
jiangjiajun 已提交
806 807 808 809 810
        k_size = node.get_attr("ksize")
        strides = node.get_attr("strides")
        data_format = node.get_attr("data_format").decode()
        pad_mode = node.get_attr("padding").decode()

J
jiangjiajun 已提交
811 812 813
        input_name = input.name
        if data_format == "NHWC":
            transpose_name = gen_name("avg_pool", "transpose")
S
SunAhong1993 已提交
814
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
815
                kernel="paddle.transpose",
J
jiangjiajun 已提交
816 817 818
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
J
jiangjiajun 已提交
819
            strides = [strides[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
820
            k_size = [k_size[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
821
            input_name = transpose_name
S
SunAhong1993 已提交
822 823
        
        # TODO(syf): The op has diff.
J
jiangjiajun 已提交
824

S
SunAhong1993 已提交
825
        self.paddle_graph.add_layer(
J
jiangjiajun 已提交
826 827 828 829 830 831 832 833 834
            kernel="fluid.layers.pool2d",
            inputs={"input": input_name},
            outputs=[node.name],
            pool_size=k_size[2:4],
            pool_type=string("avg"),
            pool_stride=strides[2:4],
            pool_padding=string(pad_mode))

        if data_format == "NHWC":
S
SunAhong1993 已提交
835
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
836
                kernel="paddle.transpose",
J
jiangjiajun 已提交
837 838 839
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
J
jiangjiajun 已提交
840 841

    def Pack(self, node):
S
SunAhong1993 已提交
842 843 844 845
        inputs_list = list()
        for i in range(len(node.inputs)):
            inputs_list.append(self.graph.get_input_node(node, i))
        input_names = [i.name for i in inputs_list]
J
jiangjiajun 已提交
846
        axis = node.get_attr("axis")
S
SunAhong1993 已提交
847
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
848
            kernel="paddle.stack",
J
jiangjiajun 已提交
849 850 851 852
            inputs={"x": input_names},
            outputs=[node.name],
            axis=axis)
        if len(node.out_shapes[0]) == 1:
S
SunAhong1993 已提交
853
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
854
                kernel="paddle.reshape",
J
jiangjiajun 已提交
855 856 857 858 859
                inputs={"x": node.name},
                outputs=[node.name],
                shape=[-1])

    def Unpack(self, node):
S
SunAhong1993 已提交
860
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
861 862 863 864 865 866
        axis = node.get_attr("axis")
        num = node.get_attr("num")
        shape = input.out_shapes[0]
        input_name = input.name
        if len(shape) == 1:
            if shape[0] > 0 and num == shape[0]:
S
SunAhong1993 已提交
867
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
868 869
                    kernel="paddle.unsqueeze",
                    inputs={"x": input.name},
J
jiangjiajun 已提交
870
                    outputs=[node.name],
S
SunAhong1993 已提交
871
                    axis=[0])
J
jiangjiajun 已提交
872 873 874 875
                input_name = node.name
                axis = 1
            else:
                raise Exception("Unexpected situation happend in Unpack OP")
S
SunAhong1993 已提交
876 877 878
        layer_outputs = ["{}_p{}".format(node.layer_name, i) for i in range(num)]
        if len(layer_outputs) == 1:
            layer_outputs[0] = "[{}]".format(node.layer_name)
S
SunAhong1993 已提交
879
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
880
            kernel="paddle.unstack",
J
jiangjiajun 已提交
881
            inputs={"x": input_name},
S
SunAhong1993 已提交
882
            outputs=layer_outputs,
J
jiangjiajun 已提交
883 884
            axis=axis,
            num=num)
J
jiangjiajun 已提交
885

J
jiangjiajun 已提交
886
    def ConcatV2(self, node):
S
SunAhong1993 已提交
887 888 889 890
        inputs_list = list()
        for i in range(len(node.inputs) - 1):
            inputs_list.append(self.graph.get_input_node(node, i))
        axis = self.graph.get_input_node(node, -1)
J
jiangjiajun 已提交
891 892 893
        assert axis.layer_type == "Const", "axis for ConcatV2 must be type Const"
        axis = axis.value
        if axis < 0:
S
SunAhong1993 已提交
894
            axis += len(inputs_list[0].out_shapes[0])
J
jiangjiajun 已提交
895

S
SunAhong1993 已提交
896
        input_names = [i.name for i in inputs_list]
S
SunAhong1993 已提交
897
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916
            kernel="paddle.concat",
            inputs={"x": input_names},
            outputs=[node.name],
            axis=axis)
        
    def Concat(self, node):
        inputs_list = list()
        for i in range(1, len(node.inputs)):
            inputs_list.append(self.graph.get_input_node(node, i))
        axis = self.graph.get_input_node(node, 0)
        assert axis.layer_type == "Const", "axis for ConcatV2 must be type Const"
        axis = axis.value
        if axis < 0:
            axis += len(inputs_list[0].out_shapes[0])
            
        input_names = [i.name for i in inputs_list]
        self.paddle_graph.add_layer(
            kernel="paddle.concat",
            inputs={"x": input_names},
J
jiangjiajun 已提交
917 918
            outputs=[node.name],
            axis=axis)
S
SunAhong1993 已提交
919 920 921 922 923 924 925 926 927 928 929
            
    def AddN(self, node):
        inputs_list = list()
        for i in range(len(node.inputs) - 1):
            inputs_list.append(self.graph.get_input_node(node, i))

        input_names = [i.name for i in inputs_list]
        self.paddle_graph.add_layer(
            kernel="paddle.add_n",
            inputs={"inputs": input_names},
            outputs=[node.name])
J
jiangjiajun 已提交
930

J
jiangjiajun 已提交
931
    def StridedSlice(self, node):
S
SunAhong1993 已提交
932 933 934 935
        input = self.graph.get_input_node(node, 0)
        begin = self.graph.get_input_node(node, 1)
        end = self.graph.get_input_node(node, 2)
        strides = self.graph.get_input_node(node, 3)
J
jiangjiajun 已提交
936

J
jiangjiajun 已提交
937 938
        if strides.layer_type == "Const":
            strides = strides.value.tolist()
939
        else:
S
SunAhong1993 已提交
940
            strides = self.decoder.infer_tensor(strides)
J
jiangjiajun 已提交
941 942
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
943
        else:
S
SunAhong1993 已提交
944
            begin = self.decoder.infer_tensor(begin)
J
jiangjiajun 已提交
945 946
        if end.layer_type == "Const":
            end = end.value.tolist()
947
        else:
S
SunAhong1993 已提交
948
            end = self.decoder.infer_tensor(end)
949

J
jiangjiajun 已提交
950 951
        assert len(set(strides)) == 1 and strides[
            0] == 1, "Only support strides be 1 in StridedSlice OP"
J
jiangjiajun 已提交
952

J
jiangjiajun 已提交
953 954 955 956
        if len(begin) < len(input.out_shapes[0]):
            begin = begin + [0] * (len(input.out_shapes[0]) - len(begin))
        if len(end) < len(input.out_shapes[0]):
            end = end + [0] * (len(input.out_shapes[0]) - len(end))
J
jiangjiajun 已提交
957 958 959 960
        for i in range(len(end)):
            if end[i] == 0:
                end[i] = 999999

J
jiangjiajun 已提交
961 962 963 964
        begin_mask = node.get_attr('begin_mask')
        end_mask = node.get_attr('end_mask')
        ellipsis_mask = node.get_attr('ellipsis_mask')
        new_axis_mask = node.get_attr('new_axis_mask')
J
jiangjiajun 已提交
965
        shrink_axis_mask = node.get_attr('shrink_axis_mask')
J
jiangjiajun 已提交
966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996

        assert ellipsis_mask == 0, "(OP:{} Name:{})Only support ellipsis_mask be 0[now: {}] n StridedSlice OP".format(
            node.layer_type, node.layer.name, ellipsis_mask)

        # TODO codes without validation
        # Use it carefully
        new_begin = list()
        new_end = list()
        new_axes = list()
        shrink_axes = list()
        for i, item in enumerate(begin):
            mask = (new_axis_mask >> i) & 1
            if mask != 0:
                new_axes.append(i)
                continue

            mask = (shrink_axis_mask >> i) & 1
            if mask != 0:
                shrink_axes.append(i)

            mask = (begin_mask >> i) & 1
            if mask != 0:
                new_begin.append(0)
            else:
                new_begin.append(item)

            mask = (end_mask >> i) & 1
            if mask != 0:
                new_end.append(999999)
            else:
                new_end.append(end[i])
S
SunAhong1993 已提交
997 998 999 1000 1001 1002 1003
            
        if input.dtype == "bool":
            self.paddle_graph.add_layer(
                "paddle.cast",
                inputs={"x": input.name},
                outputs=[input.name],
                dtype=string("int32"))
J
jiangjiajun 已提交
1004

S
SunAhong1993 已提交
1005
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1006
            kernel="paddle.slice",
J
jiangjiajun 已提交
1007 1008 1009 1010 1011
            inputs={"input": input.name},
            outputs=[node.name],
            axes=[i for i in range(len(new_begin))],
            starts=new_begin,
            ends=new_end)
S
SunAhong1993 已提交
1012 1013 1014 1015 1016 1017 1018 1019
        
        if input.dtype == "bool":
            self.paddle_graph.add_layer(
                "paddle.cast",
                inputs={"x": node.name},
                outputs=[node.name],
                dtype=string("bool"))

J
jiangjiajun 已提交
1020
        if len(new_axes) > 0:
S
SunAhong1993 已提交
1021
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1022 1023
                kernel="paddle.unsqueeze",
                inputs={"x": node.name},
J
jiangjiajun 已提交
1024
                outputs=[node.name],
S
SunAhong1993 已提交
1025
                axis=new_axes)
J
jiangjiajun 已提交
1026 1027 1028 1029
        if len(shrink_axes) > 0:
            if len(input.out_shapes[0]) + len(new_axes) <= 1:
                pass
            else:
S
SunAhong1993 已提交
1030
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1031 1032
                    kernel="paddle.squeeze",
                    inputs={"x": node.name},
J
jiangjiajun 已提交
1033
                    outputs=[node.name],
S
SunAhong1993 已提交
1034
                    axis=shrink_axes)
S
SunAhong1993 已提交
1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048
                
    def Prod(self, node):
        input = self.graph.get_input_node(node, 0)
        reduction_indices = self.graph.get_input_node(node, 1)
        assert reduction_indices.layer_type == "Const"
        keep_dims = node.get_attr('keep_dims')
        axis = reduction_indices.value

        self.paddle_graph.add_layer(
            kernel="paddle.prod",
            inputs={"x": input.name},
            outputs=[node.layer_name],
            keepdim=keep_dims,
            axis=axis)
J
jiangjiajun 已提交
1049 1050

    def Split(self, node):
S
SunAhong1993 已提交
1051 1052
        dim = self.graph.get_input_node(node, 0)
        input = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1053 1054 1055 1056
        assert dim.layer_type == "Const"
        num_split = node.get_attr('num_split')
        dim = dim.value

S
SunAhong1993 已提交
1057
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1058 1059
            kernel="paddle.split",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1060 1061 1062 1063
            outputs=[
                "{}_p{}".format(node.layer_name, i) for i in range(num_split)
            ],
            num_or_sections=num_split,
S
SunAhong1993 已提交
1064
            axis=dim)
1065 1066

    def Slice(self, node):
S
SunAhong1993 已提交
1067 1068 1069
        input = self.graph.get_input_node(node, 0)
        begin = self.graph.get_input_node(node, 1)
        size = self.graph.get_input_node(node, 2)
J
jiangjiajun 已提交
1070 1071 1072

        inputs = {"x": input.name}
        attrs = {}
J
jiangjiajun 已提交
1073 1074
        if begin.layer_type == "Const":
            begin = begin.value.tolist()
J
jiangjiajun 已提交
1075
            attrs['offsets'] = begin
J
jiangjiajun 已提交
1076
        else:
J
jiangjiajun 已提交
1077 1078
            #             shape = begin.out_shapes[0]
            #             reshape_name = gen_name("slice", "reshape")
S
SunAhong1993 已提交
1079
            #             self.paddle_graph.add_layer(
J
jiangjiajun 已提交
1080 1081 1082 1083 1084
            #                 kernel="fluid.layers.reshape",
            #                 inputs={"x": begin.name},
            #                 outputs=[reshape_name],
            #                 shape=shape)
            #             inputs['offsets'] = reshape_name
S
SunAhong1993 已提交
1085
            begin = self.decoder.infer_tensor(begin, use_diff_inputs=False).tolist()
J
jiangjiajun 已提交
1086 1087
            attrs['offsets'] = begin
        if size.layer_type == "Const":
J
jiangjiajun 已提交
1088
            size = size.value.tolist()
J
jiangjiajun 已提交
1089 1090 1091 1092
            attrs['shape'] = size
        else:
            shape = size.out_shapes[0]
            reshape_name = gen_name("slice", "reshape")
S
SunAhong1993 已提交
1093
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1094
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1095 1096 1097 1098
                inputs={"x": size.name},
                outputs=[reshape_name],
                shape=shape)
            inputs['shape'] = reshape_name
S
SunAhong1993 已提交
1099
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1100
            kernel="paddle.crop",
J
jiangjiajun 已提交
1101 1102 1103 1104 1105
            inputs=inputs,
            outputs=[node.name],
            **attrs)

    def ResizeNearestNeighbor(self, node):
S
SunAhong1993 已提交
1106 1107
        input = self.graph.get_input_node(node, 0)
        resize_shape = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1108
        data_format = "NHWC"
S
SunAhong1993 已提交
1109 1110 1111 1112
        inputs = {"x": input.name}
        attrs = {"align_corners": node.get_attr("align_corners"),
                 "mode": string("nearest"),
                 "align_mode": 1}
J
jiangjiajun 已提交
1113 1114 1115

        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
S
SunAhong1993 已提交
1116
            attrs["size"] = resize_shape
J
jiangjiajun 已提交
1117
        else:
J
jiangjiajun 已提交
1118 1119
            shape = resize_shape.out_shapes[0]
            reshape_name = gen_name("resize_nearest", "reshape")
S
SunAhong1993 已提交
1120
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1121
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1122 1123 1124
                inputs={"x": resize_shape.name},
                outputs=[reshape_name],
                shape=shape)
S
SunAhong1993 已提交
1125
            inputs["size"] = reshape_name
J
jiangjiajun 已提交
1126 1127 1128

        if data_format == "NHWC":
            transpose_name = gen_name("resize_nearest", "reshape")
S
SunAhong1993 已提交
1129
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1130
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1131 1132 1133
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
S
SunAhong1993 已提交
1134
            inputs["x"] = transpose_name
J
jiangjiajun 已提交
1135

S
SunAhong1993 已提交
1136
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1137
            kernel="paddle.nn.functional.interpolate",
J
jiangjiajun 已提交
1138 1139 1140 1141 1142
            inputs=inputs,
            outputs=[node.name],
            **attrs)

        if data_format == "NHWC":
S
SunAhong1993 已提交
1143
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1144
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1145 1146 1147
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
1148

J
jiangjiajun 已提交
1149
    def ResizeBilinear(self, node):
S
SunAhong1993 已提交
1150 1151
        input = self.graph.get_input_node(node, 0)
        resize_shape = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1152
        data_format = "NHWC"
S
SunAhong1993 已提交
1153 1154 1155 1156
        inputs = {"x": input.name}
        attrs = {"align_corners": node.get_attr("align_corners"),
                 "mode": string("bilinear"),
                 "align_mode": 1}
J
jiangjiajun 已提交
1157

J
jiangjiajun 已提交
1158 1159
        if resize_shape.layer_type == "Const":
            resize_shape = resize_shape.value.tolist()
S
SunAhong1993 已提交
1160
            attrs["size"] = resize_shape
J
jiangjiajun 已提交
1161 1162 1163
        else:
            shape = resize_shape.out_shapes[0]
            reshape_name = gen_name("resize_bilinear", "reshape")
S
SunAhong1993 已提交
1164
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1165
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1166 1167 1168
                inputs={"x": resize_shape.name},
                outputs=[reshape_name],
                shape=shape)
S
SunAhong1993 已提交
1169
            inputs["size"] = reshape_name
J
jiangjiajun 已提交
1170 1171 1172

        if data_format == "NHWC":
            transpose_name = gen_name("resize_bilinear", "reshape")
S
SunAhong1993 已提交
1173
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1174
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1175 1176 1177
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
S
SunAhong1993 已提交
1178
            inputs["x"] = transpose_name
J
jiangjiajun 已提交
1179

S
SunAhong1993 已提交
1180
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1181
            kernel="paddle.nn.functional.interpolate",
J
jiangjiajun 已提交
1182 1183 1184 1185 1186
            inputs=inputs,
            outputs=[node.name],
            **attrs)

        if data_format == "NHWC":
S
SunAhong1993 已提交
1187
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1188
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1189 1190 1191 1192 1193
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])

    def Cast(self, node):
S
SunAhong1993 已提交
1194
        input = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
1195
        dtype = node.dtype
S
SunAhong1993 已提交
1196
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1197
            kernel="paddle.cast",
J
jiangjiajun 已提交
1198 1199 1200 1201 1202
            inputs={"x": input.name},
            outputs=[node.name],
            dtype=string(dtype))

    def Sum(self, node):
S
SunAhong1993 已提交
1203 1204
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1205 1206 1207 1208
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()

S
SunAhong1993 已提交
1209
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1210 1211
            kernel="paddle.sum",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1212
            outputs=[node.name],
S
SunAhong1993 已提交
1213 1214
            axis=dim,
            keepdim=keep_dims)
J
jiangjiajun 已提交
1215 1216

    def Max(self, node):
S
SunAhong1993 已提交
1217 1218
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1219 1220 1221
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        keep_dims = node.get_attr("keep_dims")
        dim = reduce_idx.value.tolist()
S
SunAhong1993 已提交
1222
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1223 1224
            kernel="paddle.max",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1225
            outputs=[node.name],
S
SunAhong1993 已提交
1226 1227
            axis=dim,
            keepdim=keep_dims)
1228

J
jiangjiajun 已提交
1229
    def RandomUniform(self, node):
S
SunAhong1993 已提交
1230
        shape = self.graph.get_input_node(node, 0)
J
jiangjiajun 已提交
1231 1232
        if shape.layer_type == "Const":
            shape = shape.value.tolist()
S
SunAhong1993 已提交
1233
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1234
                kernel="paddle.uniform",
J
jiangjiajun 已提交
1235 1236 1237 1238 1239 1240
                inputs={},
                outputs=[node.name],
                shape=shape,
                min=0.0,
                max=0.9999)
        else:
S
SunAhong1993 已提交
1241
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1242
                kernel="paddle.uniform",
J
jiangjiajun 已提交
1243 1244 1245 1246
                inputs={'shape': shape.name},
                outputs=[node.name],
                min=0.0,
                max=0.9999)
1247 1248

    def Conv2DBackpropInput(self, node):
S
SunAhong1993 已提交
1249 1250 1251
        out_shape = self.graph.get_input_node(node, 0)
        kernel = self.graph.get_input_node(node, 1)
        input = self.graph.get_input_node(node, 2)
1252

1253
        assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
1254

J
jiangjiajun 已提交
1255 1256 1257
        if out_shape.layer_type == "Const":
            out_shape = out_shape.value.tolist()
        else:
S
SunAhong1993 已提交
1258 1259
            out_shape = self.decoder.infer_tensor(out_shape,
                                                  out_shape=node.out_shapes[0])
J
jiangjiajun 已提交
1260

1261
        in_shape = input.out_shapes[0]
J
jiangjiajun 已提交
1262
        if in_shape.count(-1) > 2:
S
SunAhong1993 已提交
1263
            in_shape = self.decoder.infer_tensor(input, use_diff_inputs=False).shape
1264
        k_size = kernel.out_shapes[0]
J
jiangjiajun 已提交
1265
        if k_size.count(-1) > 2:
S
SunAhong1993 已提交
1266
            k_size = self.decoder.infer_tensor(kernel, use_diff_inputs=False).shape
J
jiangjiajun 已提交
1267

J
jiangjiajun 已提交
1268
        pad_mode = node.get_attr("padding").decode()
1269 1270 1271
        strides = node.get_attr("strides")
        dilations = node.get_attr("dilations")
        data_format = node.get_attr("data_format").decode()
1272

S
SunAhong1993 已提交
1273 1274
        kernel_name = node.name + ".weight"
        self.params[kernel_name] = numpy.transpose(kernel.value, (3, 2, 0, 1))
J
jiangjiajun 已提交
1275 1276 1277

        input_name = input.name
        if data_format == "NHWC":
1278 1279 1280
            in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
            strides = [strides[i] for i in [0, 3, 1, 2]]
            dilations = [dilations[i] for i in [0, 3, 1, 2]]
J
jiangjiajun 已提交
1281
            transpose_name = gen_name("conv2dbackpropinput", "transpose")
S
SunAhong1993 已提交
1282
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1283
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1284 1285 1286 1287 1288
                inputs={"x": input.name},
                outputs=[transpose_name],
                perm=[0, 3, 1, 2])
            input_name = transpose_name

S
SunAhong1993 已提交
1289
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300
            kernel="paddle.static.create_parameter",
            inputs={},
            outputs=["{}_{}".format(node.name, kernel_name).replace(".", "_")],
            dtype=string(str(self.params[kernel_name].dtype)),
            shape=self.params[kernel_name].shape,
            name=string(kernel_name))
    
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.conv2d_transpose",
            inputs={"x": input_name,
                    "weight": "{}_{}".format(node.name, kernel_name).replace(".", "_")},
J
jiangjiajun 已提交
1301
            outputs=[node.name],
S
SunAhong1993 已提交
1302
            bias=None,
J
jiangjiajun 已提交
1303 1304 1305 1306 1307 1308
            stride=strides[2:4],
            dilation=dilations[2:4],
            padding=string(pad_mode),
            output_size=out_shape[1:3])

        if data_format == "NHWC":
S
SunAhong1993 已提交
1309
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1310
                kernel="paddle.transpose",
J
jiangjiajun 已提交
1311 1312 1313
                inputs={"x": node.name},
                outputs=[node.name],
                perm=[0, 2, 3, 1])
1314

J
jiangjiajun 已提交
1315 1316
    def Tile(self, node):
        input = self.graph.get_node(node.layer.input[0])
S
SunAhong1993 已提交
1317
        repeat_times = self.graph.get_node(node.layer.input[1])
J
jiangjiajun 已提交
1318 1319
        inputs = {"x": input.name}
        attr = dict()
S
SunAhong1993 已提交
1320 1321 1322
        if repeat_times.layer_type == "Const":
            repeat_times = repeat_times.value.tolist()
            attr["repeat_times"] = repeat_times
J
jiangjiajun 已提交
1323
        else:
S
SunAhong1993 已提交
1324 1325
            inputs["repeat_times"] = repeat_times.name
            
S
SunAhong1993 已提交
1326
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1327
            kernel="paddle.tile",
J
jiangjiajun 已提交
1328 1329 1330
            inputs=inputs,
            outputs=[node.name],
            **attr)
S
SunAhong1993 已提交
1331 1332 1333 1334 1335 1336 1337
        
        if not isinstance(repeat_times, list) and repeat_times.layer_type != "Const":
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
1338

J
jiangjiajun 已提交
1339 1340 1341 1342 1343 1344
    def Range(self, node):
        start = self.graph.get_node(node.layer.input[0])
        limit = self.graph.get_node(node.layer.input[1])
        delta = self.graph.get_node(node.layer.input[2])
        inputs = dict()
        attr = dict()
1345

C
channingss 已提交
1346 1347 1348
        dtype = 'int32'
        if start.dtype.startswith('float'):
            dtype = start.dtype
J
jiangjiajun 已提交
1349 1350
        if start.layer_type == "Const":
            attr["start"] = start.value
1351
        else:
J
jiangjiajun 已提交
1352
            inputs["start"] = start.name
C
channingss 已提交
1353 1354
        if limit.dtype.startswith('float'):
            dtype = limit.dtype
J
jiangjiajun 已提交
1355 1356
        if limit.layer_type == "Const":
            attr["end"] = limit.value
J
jiangjiajun 已提交
1357
        else:
J
jiangjiajun 已提交
1358
            inputs["end"] = limit.name
C
channingss 已提交
1359 1360
        if delta.dtype.startswith('float'):
            dtype = delta.dtype
J
jiangjiajun 已提交
1361 1362
        if delta.layer_type == "Const":
            attr["step"] = delta.value
J
jiangjiajun 已提交
1363
        else:
J
jiangjiajun 已提交
1364
            inputs["step"] = delta.name
C
channingss 已提交
1365
        node.set_dtype(dtype)
J
jiangjiajun 已提交
1366 1367
        attr["dtype"] = string(node.dtype)

S
SunAhong1993 已提交
1368
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1369
            kernel="paddle.arange",
J
jiangjiajun 已提交
1370 1371 1372
            inputs=inputs,
            outputs=[node.name],
            **attr)
S
SunAhong1993 已提交
1373 1374 1375 1376 1377 1378 1379 1380
        if start.layer_type != "Const" or \
                limit.layer_type != "Const" or \
                delta.layer_type != "Const":
            self.paddle_graph.add_layer(
                kernel="paddle.reshape",
                inputs={"x": node.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
J
jiangjiajun 已提交
1381 1382

    def SquaredDifference(self, node):
S
SunAhong1993 已提交
1383 1384
        x = self.graph.get_input_node(node, 0)
        y = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1385 1386 1387
        inputs = {"x": x.name, "y": y.name}
        x_shape = x.out_shapes[0]
        y_shape = y.out_shapes[0]
S
SunAhong1993 已提交
1388
        # TODO(syf)
S
SunAhong1993 已提交
1389
        layer_id = self.paddle_graph.add_layer(
J
jiangjiajun 已提交
1390
            "fluid.layers.elementwise_sub", inputs=inputs, outputs=[node.name])
S
SunAhong1993 已提交
1391
        self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
J
jiangjiajun 已提交
1392 1393 1394 1395

        inputs = {"x": node.name, "y": node.name}
        x_shape = node.out_shapes[0]
        y_shape = node.out_shapes[0]
S
SunAhong1993 已提交
1396
        layer_id = self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1397
            "paddle.multiply", inputs=inputs, outputs=[node.name])
S
SunAhong1993 已提交
1398
        self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
J
jiangjiajun 已提交
1399 1400

    def OneHot(self, node):
S
SunAhong1993 已提交
1401 1402 1403 1404
        input = self.graph.get_input_node(node, 0)
        depth = self.graph.get_input_node(node, 1)
        on_value = self.graph.get_input_node(node, 2)
        off_value = self.graph.get_input_node(node, 3)
J
jiangjiajun 已提交
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416
        assert depth.layer_type == 'Const', 'Parameter depth should be Const in OneHot'
        assert on_value.layer_type == 'Const', 'Parameter on_value should be Const in OneHot'
        assert off_value.layer_type == 'Const', 'Parameter off_value should be Const in OneHot'

        attr = {'depth': depth.value}
        on_value = on_value.value
        off_value = off_value.value
        assert math.fabs(on_value -
                         1.0) < 1e-06, "on_value should be 1 in OneHot"
        assert math.fabs(off_value -
                         0.0) < 1e-06, "off_value should be 0 in OneHot"

S
SunAhong1993 已提交
1417
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1418 1419
            "paddle.nn.functional.one_hot",
            inputs={"x": input.name},
J
jiangjiajun 已提交
1420
            outputs=[node.name],
S
SunAhong1993 已提交
1421
            num_classes=depth.value)
J
jiangjiajun 已提交
1422 1423

    def Pow(self, node):
S
SunAhong1993 已提交
1424 1425
        x = self.graph.get_input_node(node, 0)
        factor = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1426 1427 1428
        inputs = {"x": x.name}
        attr = dict()
        if factor.layer_type == 'Const':
S
SunAhong1993 已提交
1429
            attr["y"] = factor.value.tolist()
J
jiangjiajun 已提交
1430
        else:
S
SunAhong1993 已提交
1431
            inputs["y"] = factor.name
S
SunAhong1993 已提交
1432
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1433
            "paddle.pow", inputs=inputs, outputs=[node.name], **attr)
J
jiangjiajun 已提交
1434 1435

    def All(self, node):
S
SunAhong1993 已提交
1436 1437
        input = self.graph.get_input_node(node, 0)
        reduce_idx = self.graph.get_input_node(node, 1)
J
jiangjiajun 已提交
1438 1439
        assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
        attr = dict()
S
SunAhong1993 已提交
1440 1441
        attr["axis"] = reduce_idx.value.tolist()
        attr["keepdim"] = node.get_attr("keep_dims")
J
jiangjiajun 已提交
1442

J
jiangjiajun 已提交
1443 1444 1445
        input_name = input.name
        if input.dtype != "bool":
            input_name = gen_name("all", "cast")
S
SunAhong1993 已提交
1446
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1447
                "paddle.cast",
J
jiangjiajun 已提交
1448 1449 1450
                inputs={"x": input.name},
                outputs=[input_name],
                dtype=string("bool"))
S
SunAhong1993 已提交
1451
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1452 1453
            "paddle.all",
            inputs={"x": input_name},
J
jiangjiajun 已提交
1454 1455 1456 1457 1458 1459
            outputs=[node.name],
            **attr)

        node.layer.attr['dtype'].type = 10

    def GatherV2(self, node):
S
SunAhong1993 已提交
1460 1461 1462
        embeddings = self.graph.get_input_node(node, 0)
        index = self.graph.get_input_node(node, 1)
        axis = self.graph.get_input_node(node, 2)
J
jiangjiajun 已提交
1463
        assert axis.layer_type == 'Const', "Only support Const parameter[axis]"
S
SunAhong1993 已提交
1464
        axis = axis.value
J
jiangjiajun 已提交
1465 1466 1467 1468
        index_name = index.name
        if len(index.out_shapes[0]) != 1:
            reshape_name = gen_name("gather", "reshape")
            index_name = reshape_name
S
SunAhong1993 已提交
1469
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1470
                "paddle.reshape",
J
jiangjiajun 已提交
1471 1472 1473
                inputs={"x": index.name},
                outputs=[reshape_name],
                shape=[-1])
S
SunAhong1993 已提交
1474
        inputs = {'x': embeddings.name, 'index': index_name}
S
SunAhong1993 已提交
1475
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1476
            "paddle.gather",
J
jiangjiajun 已提交
1477 1478
            inputs=inputs,
            outputs=[node.name],
S
SunAhong1993 已提交
1479
            axis=axis)
J
jiangjiajun 已提交
1480 1481
        if len(index.out_shapes[0]) != 1:
            out_shape = node.out_shapes[0]
S
SunAhong1993 已提交
1482
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1483
                kernel="paddle.reshape",
J
jiangjiajun 已提交
1484 1485 1486
                inputs={"x": node.name},
                outputs=[node.name],
                shape=out_shape)
S
SunAhong1993 已提交
1487 1488 1489 1490 1491 1492 1493 1494 1495
            
    def GatherNd(self, node):
        x = self.graph.get_input_node(node, 0)
        index = self.graph.get_input_node(node, 1)
        inputs = {'x': x.name, 'index': index.name}
        self.paddle_graph.add_layer(
            "paddle.gather_nd",
            inputs=inputs,
            outputs=[node.name])
J
jiangjiajun 已提交
1496 1497

    def ExpandDims(self, node):
S
SunAhong1993 已提交
1498 1499 1500
        x = self.graph.get_input_node(node, 0, copy=True)
        y = self.graph.get_input_node(node, 1, copy=True)
        inputs = {"x": x.name}
J
jiangjiajun 已提交
1501 1502 1503 1504 1505
        attr = dict()
        if y.layer_type == 'Const':
            dim = y.value.tolist()
            if not isinstance(dim, list):
                dim = [dim]
S
SunAhong1993 已提交
1506
            attr['axis'] = dim
J
jiangjiajun 已提交
1507
        else:
S
SunAhong1993 已提交
1508
            inputs['axis'] = y.name
S
SunAhong1993 已提交
1509
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1510
            "paddle.unsqueeze",
J
jiangjiajun 已提交
1511 1512 1513
            inputs=inputs,
            outputs=[node.name],
            **attr)