opset.py 67.6 KB
Newer Older
1
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
C
update  
channingss 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
Channingss 已提交
15
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
C
update  
channingss 已提交
16
from x2paddle.core.graph import GraphNode
C
channingss 已提交
17
from x2paddle.core.util import string
C
Channingss 已提交
18
from functools import reduce
C
update  
channingss 已提交
19
import numpy as np
C
channingss 已提交
20
import onnx
C
channingss 已提交
21
import onnx.numpy_helper as numpy_helper
C
channingss 已提交
22
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
C
update  
channingss 已提交
23
import logging as _logging
24
from collections import OrderedDict
C
channingss 已提交
25
import math
C
channingss 已提交
26
import os
S
SunAhong1993 已提交
27 28
import copy
import sys
C
channingss 已提交
29
import shutil
30

C
update  
channingss 已提交
31 32 33
_logger = _logging.getLogger(__name__)


C
Channingss 已提交
34
def _const_weight_or_none(node, necessary=False):
C
channings 已提交
35
    if 'Constant' in node.layer_type:
C
channingss 已提交
36
        return node.value
C
update  
channingss 已提交
37 38
    if isinstance(node, ONNXGraphDataNode):
        return node.weight
C
Channingss 已提交
39 40 41
    if necessary:
        assert '{} should be an initializer or Constant operator.'.format(
            node.layer_name)
C
update  
channingss 已提交
42 43 44
    return None


C
Channingss 已提交
45 46 47 48 49 50
def _is_static_shape(shape):
    negtive_dims = 0
    error_dims = 0
    for dim in shape:
        if dim < 0:
            negtive_dims += 1
C
update  
Channingss 已提交
51
        if dim < -1:
C
Channingss 已提交
52 53 54 55 56 57 58
            error_dims += 1
    if negtive_dims > 1:
        return False
    if error_dims > 0:
        return False
    return True

59

C
Channingss 已提交
60
def _get_same_padding(in_size, kernel_size, stride):
C
channingss 已提交
61 62 63 64 65 66 67
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]


68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
def print_mapping_info(func):
    def run_mapping(*args, **kwargs):
        node = args[1]
        try:
            res = func(*args, **kwargs)
        except:
            print("convert failed node:{}, op_type is {}".format(
                node.layer_name[9:], node.layer_type))
            raise
        else:
            #print("convert successfully node:{}, op_type is {}".format(
            #    node.layer_name[9:], node.layer_type))
            return res

    return run_mapping


C
Channingss 已提交
85
class OpSet9():
86
    elementwise_ops = {
S
SunAhong1993 已提交
87 88
        'Add': 'paddle.add',
        'Div': 'paddle.divide',
S
fix  
SunAhong1993 已提交
89
        'Sub': 'paddle.subtract',
S
SunAhong1993 已提交
90 91
        'Mul': 'paddle.multiply',
        'Pow': 'paddle.pow',
R
root 已提交
92
    }
93

S
SunAhong1993 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    directly_map_ops = {
        'Ceil': ['paddle.ceil'],
        # reduce function
        'ReduceMean': ['paddle.mean',
                       dict(axes='axis', keepdims='keepdim'), 
                       dict(keepdims=1)],
        'ReduceSum': ['paddle.sum', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdims=1)],
        'ReduceMin': ['paddle.min', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
        'ReduceMax': ['paddle.max', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
S
for pad  
SunAhong1993 已提交
109 110 111
        'ReduceProd': ['paddle.prod', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
S
SunAhong1993 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        # active function
        'Relu': ['paddle.nn.functional.relu'],
        'LeakyRelu': ['paddle.nn.functional.leaky_relu', 
                      dict(alpha='negative_slope'), 
                      dict(negative_slope=.01)],
        'Elu': ['paddle.nn.functional.elu', 
                dict(alpha='alpha'), 
                dict(alpha=1.)],
        'ThresholdedRelu': ['paddle.nn.functional.thresholded_relu', 
                            dict(alpha='threshold'),
                            dict(alpha=1.)],
        'Tanh': ['paddle.nn.functional.tanh'],
        'Sigmoid': ['paddle.nn.functional.sigmoid'],
        'Softsign': ['paddle.nn.functional.softsign'],
        'Softplus': ['paddle.nn.functional.softplus', 
                     dict(threshold='threshold'), 
                     dict(threshold=float(sys.maxsize))],
        'Exp': ['paddle.exp'],
        'Softmax': ['paddle.nn.functional.softmax', 
                    dict(axis='axis'), 
                    dict(axis=1)],
        'Sqrt': ['paddle.sqrt'],
        'Floor': ['paddle.floor'],
        'Abs': ['paddle.abs'],
        'Erf': ['paddle.erf'],
137 138
    }

S
SunAhong1993 已提交
139
    def __init__(self, decoder, paddle_graph):
C
Channingss 已提交
140
        super(OpSet9, self).__init__()
141
        self.graph = decoder.graph
S
SunAhong1993 已提交
142 143 144 145
        self.paddle_graph = paddle_graph
        self.input_index = 0
        self.inputs_info = dict()
        self.params = dict()
R
root 已提交
146

147
    @print_mapping_info
S
SunAhong1993 已提交
148
    def directly_map(self, node, *args, **kwargs):
C
update  
channingss 已提交
149
        inputs = node.layer.input
S
SunAhong1993 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        assert len(inputs) == 1, 'directly_map error with multi inputs'
        input = self.graph.get_input_node(node, idx=0, copy=True)
        onnx_attrs = node.attr_map
        if '' in onnx_attrs:
            onnx_attrs.pop('')
        if '_' in onnx_attrs:
            onnx_attrs.pop('_')
        op_info = self.directly_map_ops[node.layer_type]
        paddle_op = op_info[0]
        layer_attrs = dict()
        if len(op_info) > 1:
            attrs_name_map_dict = op_info[1]
            for onnx_attr_name, pd_attr_name in attrs_name_map_dict.items():
                if onnx_attr_name in onnx_attrs:
                    layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name]
                else:
                    layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name]
        self.paddle_graph.add_layer(
            kernel=paddle_op,
            inputs={"x": input.name},
            outputs=[node.name],
            **layer_attrs)
            
173
    @print_mapping_info
174 175 176 177
    def elementwise_map(self, node):
        op_type = self.elementwise_ops[node.layer_type]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
178 179 180 181 182 183 184
        inputs_dict = {'x': val_x.name, 
                       'y': val_y.name}
        self.paddle_graph.add_layer(
            op_type, 
            inputs=inputs_dict, 
            outputs=[node.name])
        
185
    @print_mapping_info
C
update  
channingss 已提交
186
    def place_holder(self, node):
C
channings 已提交
187 188
        shape = node.out_shapes[0]
        for i, dim_shape in enumerate(shape):
R
root 已提交
189 190 191
            if dim_shape == 0 and i == 0:
                shape[i] = 1
            if dim_shape == 0 and i != 0:
C
channings 已提交
192
                assert 'shape of input is not assigned'
S
SunAhong1993 已提交
193 194 195 196 197 198 199 200 201
        self.paddle_graph.add_layer(
            kernel="paddle.static.data",
            inputs={},
            outputs=[node.name],
            dtype=string(node.dtype),
            shape=shape,
            name=string(node.name))
        self.inputs_info["x{}".format(self.input_index)] = [shape, node.dtype]
        self.input_index += 1
C
update  
channingss 已提交
202

203
    @print_mapping_info
C
update  
channingss 已提交
204 205 206 207
    def create_parameter(self, node, parameter=None):
        if parameter is not None:
            node = parameter
        dtype = node.dtype
C
channingss 已提交
208
        shape = node.out_shapes[0]
S
fix  
SunAhong1993 已提交
209
        if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
S
SunAhong1993 已提交
210 211 212 213 214 215 216
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
                outputs=[node.name],
                dtype=string(dtype),
                shape=[1],
                fill_value=node.weight)
217
        else:
S
SunAhong1993 已提交
218 219 220 221 222 223 224 225 226
            self.params[node.name] = node.weight
            self.paddle_graph.add_layer(
                kernel="paddle.static.create_parameter",
                inputs={},
                outputs=[node.name],
                dtype=string(dtype),
                shape=shape,
                name=string(node.name),
                default_initializer="paddle.nn.initializer.Constant(value=0.0)")
C
update  
channingss 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240

    def _pad_if_asymmetric(self, node, pads, val_name):  # pads: SSEE
        assert len(pads) & 1 == 0
        symmetric = True
        ndims = len(pads) // 2
        for idx_dim in range(ndims):
            if pads[idx_dim] != pads[ndims + idx_dim]:
                symmetric = False
                break
        if symmetric:
            return pads[:ndims], val_name
        val_padded = self.Pad(node, op_independent=False)
        return [0] * ndims, val_padded

C
channingss 已提交
241
    def _interpolate(self, node):
C
channingss 已提交
242
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
243
        inputs = {'x': val_x.name}
S
fix  
SunAhong1993 已提交
244
        attrs = dict()
245
        if node.layer_type == 'Resize':
C
Channingss 已提交
246 247 248
            if len(node.layer.input) == 2:
                # opset 10
                val_scales = self.graph.get_input_node(node, idx=1, copy=True)
S
fix  
SunAhong1993 已提交
249 250 251 252
                # TODO(syf): paddle.nn.functional.interpolate will support the length  
                # which is the same as the rank of input.
#                 inputs['scale_factor'] = val_scales.name
                attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
C
Channingss 已提交
253 254 255
            elif len(node.layer.input) == 3:
                # opset 11
                val_scales = self.graph.get_input_node(node, idx=2, copy=True)
S
fix  
SunAhong1993 已提交
256 257 258 259
                # TODO(syf): paddle.nn.functional.interpolate will support the length  
                # which is the same as the rank of input.
#                 inputs['scale_factor'] = val_scales.name
                attrs['scale_factor'] = self.params[val_scales.name].tolist()[2:]
C
Channingss 已提交
260 261 262
            elif len(node.layer.input) == 4:
                # opset 11
                val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
S
SunAhong1993 已提交
263 264 265 266 267 268 269 270 271 272 273 274
                var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw'
                self.paddle_graph.add_layer(
                    'paddle.split',
                    inputs={"x": val_sizes.name},
                    outputs=[var_nc, var_hw],
                    num_or_sections=[2, 2],
                    axis=0)
                self.paddle_graph.add_layer(
                    "paddle.cast",
                    inputs={"x": var_hw},
                    outputs=[var_hw],
                    dtype=string('int32'))
S
SunAhong1993 已提交
275 276 277
                inputs['size'] = var_hw
                attrs = {"align_corners": False,
                         "mode": string(node.get_attr('mode', 'nearest'))}
S
SunAhong1993 已提交
278
                self.paddle_graph.add_layer(
S
docs  
SunAhong1993 已提交
279
                    kernel="paddle.nn.functional.interpolate",
S
SunAhong1993 已提交
280 281 282 283
                    inputs=inputs,
                    outputs=[node.name],
                    **attrs)
                return
284 285
        elif node.layer_type == 'Upsample':
            val_scales = self.graph.get_input_node(node, idx=1, copy=True)
C
Channingss 已提交
286
            inputs['scale'] = val_scales
R
root 已提交
287

C
channingss 已提交
288
        mode = node.get_attr('mode', 'nearest')
S
fix  
SunAhong1993 已提交
289
        attrs.update({"align_corners": False,
S
SunAhong1993 已提交
290
                 "mode": string(mode),
S
fix  
SunAhong1993 已提交
291
                 "align_mode": 1})
S
SunAhong1993 已提交
292 293 294
        val_x_shape = val_x.out_shapes[0]
        if mode == "linear" and len(val_x_shape) == 4:
            attrs["mode"] = string("bilinear")
S
fix  
SunAhong1993 已提交
295
            attrs["align_corners"] = True
S
SunAhong1993 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.interpolate",
            inputs=inputs,
            outputs=[node.name],
            **attrs)
        
    @print_mapping_info
    def HardSigmoid(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        alpha = node.get_attr('alpha', 0.2)
        beta = node.get_attr('beta', 0.5)
        self.paddle_graph.add_layer(
            kernel="paddle.scale",
            inputs={"x": val_x.name},
            outputs=[node.name + "_val"],
            scale=alpha,
            bias=beta)
        self.paddle_graph.add_layer(
            kernel="paddle.clip",
            inputs={"x": node.name + "_val"},
            outputs=[node.name],
            min=0.0,
            max=1.0)  
        
    @print_mapping_info
    def Shape(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            kernel="paddle.shape",
            inputs={"input": val_x.name},
            outputs=[node.name])
        self.paddle_graph.add_layer(
                'paddle.cast',
                inputs={"x": node.name},
                outputs=[node.name],
                dtype=string('int64'))   
R
root 已提交
332

333
    @print_mapping_info
C
channings 已提交
334 335 336
    def RoiAlign(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)
R
root 已提交
337 338 339

        pooled_height = node.get_attr('output_height')
        pooled_width = node.get_attr('output_width')
C
channings 已提交
340 341
        spatial_scale = node.get_attr('spatial_scale')
        sampling_ratio = node.get_attr('sampling_ratio')
S
SunAhong1993 已提交
342
        layer_attrs = {
R
root 已提交
343 344 345 346 347
            'pooled_height': pooled_height,
            'pooled_width': pooled_width,
            'spatial_scale': spatial_scale,
            'sampling_ratio': sampling_ratio,
        }
S
SunAhong1993 已提交
348
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
349
            'paddle.fluid.layers.roi_align',
S
SunAhong1993 已提交
350 351 352 353
            inputs={'input': val_x.name,
                    'rois': val_rois.name},
            outputs=[node.name],
            **layer_attrs)
354 355

    @print_mapping_info
C
channings 已提交
356 357 358
    def MaxRoiPool(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)
R
root 已提交
359

C
channings 已提交
360 361
        spatial_scale = node.get_attr('spatial_scale')
        pooled_height, pooled_width = node.get_attr('pooled_shape')
S
SunAhong1993 已提交
362
        layer_attrs = {
R
root 已提交
363 364 365 366
            'pooled_height': pooled_height,
            'pooled_width': pooled_width,
            'spatial_scale': spatial_scale,
        }
S
SunAhong1993 已提交
367
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
368
            'paddle.fluid.layers.roi_pool',
S
SunAhong1993 已提交
369 370 371 372
            inputs={'input': val_x.name,
                    'rois': val_rois.name},
            outputs=[node.name],
            **layer_attrs)
373 374

    @print_mapping_info
C
update  
channingss 已提交
375
    def Pad(self, node, op_independent=True):
C
channingss 已提交
376
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
377
        pads = node.get_attr('pads')
S
SunAhong1993 已提交
378 379 380 381 382 383 384 385
        is_pads_attr = True
        if pads is None:
            val_pad = self.graph.get_input_node(node, idx=1, copy=True)
            pad_shape = val_pad.out_shapes[0]
            is_pads_attr = False
            pads = _const_weight_or_none(val_pad)
            if pads is not None:
                is_pads_attr = True
C
update  
channingss 已提交
386 387
        mode = node.get_attr('mode', 'constant')
        value = node.get_attr('value', 0.)
C
channingss 已提交
388 389
        data_shape = val_x.out_shapes[0]
        output_shape = node.out_shapes[0]
S
for pad  
SunAhong1993 已提交
390
        assume_pad = False
S
SunAhong1993 已提交
391 392
        layer_attrs = {}
        layer_attrs['mode'] = string(mode)
S
for pad  
SunAhong1993 已提交
393 394 395 396 397 398
        layer_attrs['value'] = value
        if not op_independent:
            output_name = node.name + '_paded'
        else:
            output_name = node.name
        layer_outputs = [output_name]
S
SunAhong1993 已提交
399 400
        if is_pads_attr:
            paddings = []
S
for pad  
SunAhong1993 已提交
401 402
            paddle_op = 'paddle.nn.functional.pad'
            if len(pads) in [2, 4, 6]:
S
SunAhong1993 已提交
403
                if data_shape:
S
for pad  
SunAhong1993 已提交
404
                    assume_pad |= data_shape and 2 * (len(data_shape) - 2) == len(pads) # NCHW
S
SunAhong1993 已提交
405
                if output_shape:
S
for pad  
SunAhong1993 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
                    assume_pad |= output_shape and 2 * (len(output_shape) - 2) == len(pads)  # NCHW
                if assume_pad:
                    if len(pads) == 2:
                        data_format = "NCL"
                    elif len(pads) == 4:
                        data_format = "NCHW"
                    else:
                        data_format = "NCDHW"
                    
                    paddings = np.array(pads).reshape(
                        (2, -1)).transpose().astype("int32")
                    paddings = np.flip(paddings, axis=0).flatten().tolist()
                    layer_attrs['pad'] = paddings
                    layer_attrs['data_format'] = data_format
                else:
                    if data_shape:
                        assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
                    if output_shape:
                        assume_pad |= output_shape and 2 * len(output_shape) == len(pads)  # NCHW
                    if assume_pad:
                        paddings = np.array(pads).reshape(
                            (2, -1)).transpose().astype("int32").flatten().tolist()
                        layer_attrs['pad'] = paddings
                    else:
                        raise Exception("The padding value {} is wrong!".format(pads))
S
SunAhong1993 已提交
431
            elif len(pads) == 8:
S
for pad  
SunAhong1993 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445
                if data_shape:
                    assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
                if output_shape:
                    assume_pad |= output_shape and 2 * len(output_shape) == len(pads)  # NCHW
                if assume_pad:
                    paddings = np.array(pads).reshape(
                        (2, -1)).transpose().astype("int32")
                    paddings = np.flip(paddings, axis=0).flatten().tolist()
                    if sum(paddings[:4]) == 0:
                        paddings = paddings[4:]
                        layer_attrs['pad'] = paddings
                    else:
                        layer_attrs['pad'] = paddings
                        paddle_op = "custom_layer:pad_all_dim4_one_input"
S
SunAhong1993 已提交
446
            else:
S
for pad  
SunAhong1993 已提交
447 448 449 450 451 452 453
                 raise Exception("The padding value {} is wrong!".format(pads))
            self.paddle_graph.add_layer(
                paddle_op, 
                inputs={'x': val_x.name}, 
                outputs=layer_outputs, 
                **layer_attrs)
            if not op_independent:
S
SunAhong1993 已提交
454
                return node.name + '_paded'
C
update  
channingss 已提交
455
        else:
S
for pad  
SunAhong1993 已提交
456 457
            pads_len = val_pad.out_shapes[0][0]
            if pads_len in [2, 4, 6]:
S
SunAhong1993 已提交
458
                if data_shape:
S
for pad  
SunAhong1993 已提交
459
                    assume_pad |= data_shape and 2 * (len(data_shape) - 2) == pads_len # NCHW
S
SunAhong1993 已提交
460
                if output_shape:
S
for pad  
SunAhong1993 已提交
461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505
                    assume_pad |= output_shape and 2 * (len(output_shape) - 2) == pads_len  # NCHW 
                if assume_pad:
                    if pads_len == 2:
                        data_format = "NCL"
                    elif pads_len == 4:
                        data_format = "NCHW"
                    else:
                        data_format = "NCDHW"
                    self.paddle_graph.add_layer(
                        "custom_layer:pad_with_two_input", 
                        inputs={'x': val_x.name, 'pad': val_pad.name}, 
                        outputs=layer_outputs,
                        value=value,
                        mode=string(mode),
                        data_format=string(data_format))
                else:
                    if data_shape:
                        assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
                    if output_shape:
                        assume_pad |= output_shape and 2 * len(output_shape) == pads_len  # NCHW
                    if assume_pad:
                        if pads_len == 4:
                            self.paddle_graph.add_layer(
                                "custom_layer:pad_all_dim2", 
                                inputs={'x': val_x.name, 'pad': val_pad.name}, 
                                outputs=layer_outputs, 
                                value=value,
                                mode=string(mode))
                        else:
                            raise Exception("The padding value is wrong!")
            elif pads_len == 8:
                if data_shape:
                    assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
                if output_shape:
                    assume_pad |= output_shape and 2 * len(output_shape) == pads_len  # NCHW
                if assume_pad:
                    self.paddle_graph.add_layer(
                        "custom_layer:pad_all_dim4", 
                        inputs={'x': val_x.name, 'pad': val_pad.name}, 
                        outputs=layer_outputs, 
                        value=value,
                        mode=string(mode))
            else:
                print(pads_len)
                raise Exception("The padding value is wrong!")   
S
SunAhong1993 已提交
506 507
            if not op_independent:
                return node.name + '_paded'
C
update  
channingss 已提交
508

509
    @print_mapping_info
C
update  
channingss 已提交
510
    def Unsqueeze(self, node):
C
channingss 已提交
511
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
512
        axes = node.get_attr('axes')
S
SunAhong1993 已提交
513
        layer_attrs = {'axis': axes}
R
root 已提交
514
        if len(val_x.out_shapes[0]) == 0:
S
SunAhong1993 已提交
515 516 517 518 519 520
            if node.name:
                self.paddle_graph.add_layer(
                    'paddle.reshape',
                    inputs={"x": val_x.name},
                    outputs=[node.name],
                    shape=[1])
521
        else:
S
SunAhong1993 已提交
522 523 524 525 526
            self.paddle_graph.add_layer(
                'paddle.unsqueeze', 
                inputs={"x": val_x.name}, 
                outputs=[node.name],
                **layer_attrs)
527

528
    @print_mapping_info
C
channingss 已提交
529
    def Shrink(self, node):
C
channingss 已提交
530
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
531 532 533
        bias = node.get_attr('bias')
        lambd = node.get_attr('lambd')
        assert bias == 0.0, 'not support bias!=0'
S
SunAhong1993 已提交
534 535 536 537 538
        self.paddle_graph.add_layer(
            'paddle.nn.functional.hardshrink', 
            inputs={"x": val_x.name}, 
            outputs=[node.name], 
            threshold=lambd)
C
channingss 已提交
539

540
    @print_mapping_info
C
update  
channingss 已提交
541 542 543 544 545 546 547 548
    def Constant(self, node):
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        value = node.get_attr('value')
        dtype = np.dtype(value.dtype)
        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
R
root 已提交
549

C
update  
channingss 已提交
550
        shape = node.get_attr('shape', None)
R
root 已提交
551

C
update  
channingss 已提交
552
        if shape is None:
C
channingss 已提交
553
            shape = val_output.out_shapes[0]
C
update  
channingss 已提交
554 555
        if shape is None:
            shape = list(value.shape)
556 557 558
            _logger.warning('in (Constant -> %s): '
                            'attribute "shape" of %s not inferred, '
                            'using value as 1-D tensor may lead to fails',
S
SunAhong1993 已提交
559
                            val_output.name, val_output.name)
560
        if len(value) == 1:
C
channingss 已提交
561
            value = value.tolist()
C
update  
channingss 已提交
562
            value = value[0]
S
SunAhong1993 已提交
563 564 565 566 567 568 569
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
                outputs=[node.name],
                dtype=string(dtype),
                shape=[1],
                fill_value=value)
C
channingss 已提交
570 571
        else:
            value = np.reshape(value, shape)
S
SunAhong1993 已提交
572 573 574 575 576 577 578 579 580
            self.params[node.name] = value
            self.paddle_graph.add_layer(
                kernel="paddle.static.create_parameter",
                inputs={},
                outputs=[node.name],
                dtype=string(dtype),
                shape=shape,
                name=string(node.name),
                default_initializer="paddle.nn.initializer.Constant(value=0.0)")
C
update  
channingss 已提交
581

582
    @print_mapping_info
C
update  
channingss 已提交
583
    def Resize(self, node):
584 585
        self._interpolate(node)

586
    @print_mapping_info
587 588 589
    def Upsample(self, node):
        self._interpolate(node)

590 591 592 593 594 595
    @print_mapping_info
    def InstanceNormalization(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scale = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = self.graph.get_input_node(node, idx=2, copy=True)
        epsilon = node.get_attr('epsilon', 1e-5)
S
SunAhong1993 已提交
596
        layer_attrs = {
S
fix  
SunAhong1993 已提交
597
            'eps': epsilon,
598
        }
S
SunAhong1993 已提交
599 600
        dim = len(val_x.out_shapes[0])
        if dim ==2 :
S
fix  
SunAhong1993 已提交
601
            layer_attrs["data_format"] = string("NC")
S
SunAhong1993 已提交
602
        elif dim == 3:
S
fix  
SunAhong1993 已提交
603
            layer_attrs["data_format"] = string("NCL")
S
SunAhong1993 已提交
604
        elif dim == 4:
S
fix  
SunAhong1993 已提交
605
            layer_attrs["data_format"] = string("NCHW")
S
SunAhong1993 已提交
606
        elif dim == 5:
S
fix  
SunAhong1993 已提交
607
            layer_attrs["data_format"] = string("NCDHW")
S
SunAhong1993 已提交
608 609 610
        else:
            raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.")
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
611
            "paddle.nn.functional.instance_norm", 
S
SunAhong1993 已提交
612 613 614
            inputs={"x": val_x.name,
                    "weight": val_scale.name,
                    "bias": val_b.name}, 
S
fix  
SunAhong1993 已提交
615
            outputs=[node.name], 
S
SunAhong1993 已提交
616
            **layer_attrs)
617 618

    @print_mapping_info
619
    def Expand(self, node):
C
channingss 已提交
620
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
621
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
622
        val_x_dtype = val_x.dtype
S
SunAhong1993 已提交
623
        name_ones = node.name + '_ones'
C
Channingss 已提交
624
        attr_ones = {
S
SunAhong1993 已提交
625
            'shape': val_shape.name,
C
Channingss 已提交
626
            'dtype': string(val_x_dtype),
S
SunAhong1993 已提交
627
            'fill_value': 1
C
Channingss 已提交
628
        }
S
SunAhong1993 已提交
629 630 631 632 633 634 635 636 637 638 639
        self.paddle_graph.add_layer(
            'paddle.full',
            inputs={},
            outputs=[name_ones],
            **attr_ones)
        inputs_dict = {'x': name_ones, 
                       'y': val_x.name}
        self.paddle_graph.add_layer(
            'paddle.multiply',
            inputs=inputs_dict,
            outputs=[node.name])
C
update  
channingss 已提交
640

641
    @print_mapping_info
C
channingss 已提交
642 643 644 645
    def Gather(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        indices = self.graph.get_input_node(node, idx=1, copy=True)
        indices_shape = indices.out_shapes[0]
C
Channingss 已提交
646
        axis = node.get_attr('axis', 0)
647 648
        #assert len(
        #    indices_shape) <= 2, "Gather op don't support dim of indice >2 "
R
root 已提交
649
        if axis == 0 and len(indices_shape) <= 1:
C
Channingss 已提交
650
            if len(val_x.out_shapes[0]) <= 1:
S
SunAhong1993 已提交
651 652 653 654 655
                self.paddle_graph.add_layer(
                    'paddle.gather',
                    inputs={'x': val_x.name,
                            'index': indices.name},
                    outputs=[node.name])
C
Channingss 已提交
656 657
            elif len(val_x.out_shapes[0]) > 1:
                if len(indices_shape) == 0:
S
SunAhong1993 已提交
658 659 660 661 662 663 664 665 666 667 668
                    gather_ = node.name + '_1'
                    self.paddle_graph.add_layer(
                        'paddle.gather',
                        inputs={'x': val_x.name,
                                'index': indices.name},
                        outputs=[gather_])
                    self.paddle_graph.add_layer(
                        'paddle.squeeze',
                        inputs={'x': gather_},
                        outputs=[node.name],
                        axis=[0])
C
Channingss 已提交
669
                else:
S
SunAhong1993 已提交
670 671 672 673 674
                    self.paddle_graph.add_layer(
                        'paddle.gather',
                        inputs={'x': val_x.name,
                                'index': indices.name},
                        outputs=[node.name])
C
channingss 已提交
675 676
        elif axis > 0 and len(indices_shape) <= 1:
            perm = list(range(len(val_x.out_shapes[0])))
C
channingss 已提交
677
            perm = [axis] + perm[:axis] + perm[axis + 1:]
S
SunAhong1993 已提交
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693
            name_trans = val_x.name + '_trans'
            self.paddle_graph.add_layer(
                'paddle.transpose',
                inputs={"x": val_x.name},
                outputs=[name_trans],
                perm=perm)
            self.paddle_graph.add_layer(
                'paddle.gather',
                inputs={'x': name_trans,
                        'index': indices.name},
                outputs=[node.name])
            self.paddle_graph.add_layer(
                'paddle.transpose', 
                inputs={"x": node.name}, 
                outputs=[node.name], 
                perm=perm)
C
Channingss 已提交
694
            if len(indices_shape) < 1:
S
SunAhong1993 已提交
695 696 697 698 699
                self.paddle_graph.add_layer(
                    'paddle.squeeze',
                    inputs={'x': node.name},
                    outputs=[node.name],
                    axis=[axis])
700 701 702
        elif axis == 0 and len(indices_shape) > 1:
            if val_x.out_shapes[0] is not None and isinstance(
                    val_x, ONNXGraphDataNode):
S
SunAhong1993 已提交
703 704 705 706
                indices_cast = indices.name + '_cast'
                self.paddle_graph.add_layer(
                    'paddle.cast',
                    inputs={"x": indices.name},
S
SunAhong1993 已提交
707
                    outputs=[indices_cast],
S
SunAhong1993 已提交
708 709
                    dtype=string('int64'))
                self.paddle_graph.add_layer(
S
for pad  
SunAhong1993 已提交
710 711 712 713
                    'paddle.nn.functional.embedding',
                    inputs={"x": indices_cast,
                            "weight": val_x.name},
                    outputs=[node.name])
714 715 716
            else:
                from functools import reduce
                reshape_shape = reduce(lambda x, y: x * y, indices_shape)
S
SunAhong1993 已提交
717 718 719 720 721 722
                indices_reshape = indices.name + '_shape'
                self.paddle_graph.add_layer(
                    'paddle.reshape',
                    inputs={"x": indices.name},
                    outputs=[indices_reshape],
                    shape=[reshape_shape, ])
723 724

                perm = list(range(len(val_x.out_shapes[0])))
S
SunAhong1993 已提交
725 726 727
                self.paddle_graph.add_layer(
                    'paddle.gather',
                    inputs={'x': val_x.name,
728
                            'index': indices_reshape},
S
SunAhong1993 已提交
729
                    outputs=[node.name])
730 731 732 733 734 735
                val_x_shape = val_x.out_shapes[0]
                reshaped_shape = []
                for i in perm:
                    reshaped_shape.append(indices_shape[i])
                for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
                    reshaped_shape.append(i)
S
SunAhong1993 已提交
736 737 738 739 740
                self.paddle_graph.add_layer(
                    'paddle.reshape',
                    inputs={"x": node.name},
                    outputs=[node.name],
                    shape=reshaped_shape)
741
        elif axis > 0 and len(indices_shape) > 1:
C
Channingss 已提交
742
            from functools import reduce
R
root 已提交
743
            reshape_shape = reduce(lambda x, y: x * y, indices_shape)
S
SunAhong1993 已提交
744 745 746 747 748 749
            indices_reshape = indices.name + '_shape'
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": indices.name},
                outputs=[indices_reshape],
                shape=[reshape_shape, ])
R
root 已提交
750

C
Channingss 已提交
751 752
            perm = list(range(len(val_x.out_shapes[0])))
            perm = [axis] + perm[:axis] + perm[axis + 1:]
S
SunAhong1993 已提交
753 754 755 756 757 758 759 760 761
            name_trans = val_x.name + '_transpose'
            self.paddle_graph.add_layer(
                'paddle.transpose',
                inputs={"x": val_x.name},
                outputs=[name_trans],
                perm=perm)
            self.paddle_graph.add_layer(
                'paddle.gather',
                inputs={'x': name_trans,
762
                        'index': indices_reshape},
S
SunAhong1993 已提交
763 764 765 766 767 768 769
                outputs=[node.name])
            input_transpose = node.name + '_transpose'
            self.paddle_graph.add_layer(
                'paddle.transpose',
                inputs={"x": node.name},
                outputs=[input_transpose],
                perm=perm)
C
Channingss 已提交
770 771 772 773 774 775
            val_x_shape = val_x.out_shapes[0]
            reshaped_shape = []
            for i in perm:
                reshaped_shape.append(indices_shape[i])
            for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
                reshaped_shape.append(i)
S
SunAhong1993 已提交
776 777 778 779 780
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": input_transpose},
                outputs=[node.name],
                shape=reshaped_shape)
781

C
Channingss 已提交
782 783 784 785 786 787
    @print_mapping_info
    def ScatterND(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        indices = self.graph.get_input_node(node, idx=1, copy=True)
        updates = self.graph.get_input_node(node, idx=2, copy=True)
        if len(indices.out_shapes[0]) == 1:
S
SunAhong1993 已提交
788 789 790 791 792 793
            self.paddle_graph.add_layer(
                'paddle.scatter',
                inputs={'x': val_x.name,
                        'index': indices.name,
                        'updates': updates.name},
                outputs=[node.name])
C
Channingss 已提交
794
        else:
S
SunAhong1993 已提交
795
            input_inner_indices = node.name + '_input_inner_indices'
796
            shape = val_x.out_shapes[0]
S
SunAhong1993 已提交
797 798 799 800 801 802 803 804 805 806 807 808 809
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": indices.name},
                outputs=[indices.name],
                shape=indices.out_shapes[0])

            zeros_like_val_x = val_x.name + '_zeros'
            self.paddle_graph.add_layer(
                'paddle.zeros_like',
                inputs={"x": val_x.name},
                outputs=[zeros_like_val_x])
            self.paddle_graph.add_layer(
                'paddle.scatter_nd_add',
C
Channingss 已提交
810
                inputs={
S
SunAhong1993 已提交
811 812 813
                    'x': zeros_like_val_x,
                    'index': indices.name,
                    'updates': updates.name
C
Channingss 已提交
814
                },
S
SunAhong1993 已提交
815 816 817
                outputs=[input_inner_indices])
            indices_mask = node.name + '_indices_mask'
            constant_minus_one = node.name + '_constant_minus_one'
818
            # full_like support create tensor shape like input tensor
S
SunAhong1993 已提交
819 820 821 822 823 824 825 826
            self.paddle_graph.add_layer(
                'paddle.full_like',
                inputs={"x": updates.name},
                outputs=[constant_minus_one],
                dtype=string(updates.dtype),
                fill_value=-1)
            self.paddle_graph.add_layer(
                'paddle.scatter_nd_add',
C
Channingss 已提交
827
                inputs={
S
SunAhong1993 已提交
828 829
                    'x': zeros_like_val_x,
                    'index': indices.name,
C
Channingss 已提交
830 831
                    'updates': constant_minus_one
                },
S
SunAhong1993 已提交
832 833
                outputs=[indices_mask])
            constant_one = node.name + '_constant_1'
834
            # full_like support create tensor shape like input tensor
S
SunAhong1993 已提交
835 836 837 838 839 840 841 842 843
            self.paddle_graph.add_layer(
                'paddle.full_like',
                inputs={"x": val_x.name},
                outputs=[constant_one],
                dtype=string(val_x.dtype),
                fill_value=1)
            input_out_indices_mask = node.name + '_input_out_indices_mask'
            self.paddle_graph.add_layer(
                "paddle.add",
C
Channingss 已提交
844
                inputs={"x": indices_mask,
845
                        "y": constant_one},
S
SunAhong1993 已提交
846
                outputs=[input_out_indices_mask])
C
Channingss 已提交
847

S
SunAhong1993 已提交
848 849 850 851
            input_out_indices = node.name + '_input_out_indices'
            self.paddle_graph.add_layer(
                "paddle.multiply",
                inputs={"x": val_x.name,
C
Channingss 已提交
852
                        "y": input_out_indices_mask},
S
SunAhong1993 已提交
853
                outputs=[input_out_indices])
C
Channingss 已提交
854

S
SunAhong1993 已提交
855 856
            self.paddle_graph.add_layer(
                "paddle.add",
C
Channingss 已提交
857 858
                inputs={"x": input_inner_indices,
                        "y": input_out_indices},
S
SunAhong1993 已提交
859
                outputs=[node.name])
C
Channingss 已提交
860

861 862 863 864 865 866
    @print_mapping_info
    def Range(self, node):
        val_start = self.graph.get_input_node(node, idx=0, copy=True)
        val_limit = self.graph.get_input_node(node, idx=1, copy=True)
        val_delta = self.graph.get_input_node(node, idx=2, copy=True)
        dtype = val_start.dtype
S
SunAhong1993 已提交
867 868 869 870 871
        inputs = {'start': val_start.name, 
                  'end': val_limit.name, 
                  'step': val_delta.name}
        self.paddle_graph.add_layer(
            'paddle.arange',
872
            inputs=inputs,
S
SunAhong1993 已提交
873 874
            outputs=[node.name],
            dtype=string(dtype))
875 876

    @print_mapping_info
C
channingss 已提交
877
    def Slice(self, node):
C
channingss 已提交
878
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channings 已提交
879
        starts, ends, axes, steps = None, None, None, None
S
SunAhong1993 已提交
880
        layer_attrs = {}
C
channingss 已提交
881 882 883
        if len(node.inputs) > 1:
            starts = self.graph.get_input_node(node, idx=1, copy=True)
            ends = self.graph.get_input_node(node, idx=2, copy=True)
C
Channingss 已提交
884
            starts_value = _const_weight_or_none(starts)
S
for pad  
SunAhong1993 已提交
885 886
            if starts_value is not None:
                starts_value = starts_value.tolist()
C
Channingss 已提交
887
            ends_value = _const_weight_or_none(ends)
S
for pad  
SunAhong1993 已提交
888 889 890 891 892
            if ends_value is not None:
                ends_value = ends_value.tolist()
            if len(node.inputs) > 2:
                s_len = len(val_x.out_shapes[0])
                axes = list(range(s_len))
R
root 已提交
893
            if len(node.inputs) > 3:
S
for pad  
SunAhong1993 已提交
894 895
                axes_node = self.graph.get_input_node(node, idx=3, copy=True)
                axes = _const_weight_or_none(axes_node, necessary=True).tolist()
R
root 已提交
896
            if len(node.inputs) > 4:
C
channings 已提交
897
                steps = self.graph.get_input_node(node, idx=4, copy=True)
S
for pad  
SunAhong1993 已提交
898 899
                steps = _const_weight_or_none(steps).tolist()
            
S
SunAhong1993 已提交
900
            layer_attrs = {
901
                "axes": axes,
S
SunAhong1993 已提交
902 903
                "starts": starts.name,
                "ends": ends.name
904
            }
S
SunAhong1993 已提交
905
            if starts_value is not None and ends_value is not None and axes is not None:
C
Channingss 已提交
906
                starts_value = starts_value.copy()
907
                ends_value = ends_value.copy()
908 909 910 911
                #for idx in range(len(ends_value)):
                #    if ends_value[idx] > 2**31 - 1:
                #        ends_value[idx] = 2**31 - 1
                #print(val_x.out_shapes)
912
                for idx in range(len(ends_value)):
913 914
                    if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
                        starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
C
Channingss 已提交
915
                        ends_value[idx] = val_x.out_shapes[0][axes[idx]]
916
                        starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
C
Channingss 已提交
917
                    elif ends_value[idx] > 2**31 - 1:
918
                        ends_value[idx] = 2**31 - 1
S
SunAhong1993 已提交
919
                layer_attrs = {
920 921 922 923 924 925
                    "axes": axes,
                    "starts": starts_value,
                    "ends": ends_value
                }
            else:
                if starts.dtype != 'int32':
S
SunAhong1993 已提交
926 927 928 929 930 931 932
                    starts_cast = starts.name + '_cast'
                    self.paddle_graph.add_layer(
                        'paddle.cast',
                        inputs={"x": starts.name},
                        outputs=[starts_cast],
                        dtype=string('int32'))
                    layer_attrs['starts'] = starts_cast
933
                if ends.dtype != 'int32':
S
SunAhong1993 已提交
934
                    ends_cast = ends.name + '_cast'
S
for pad  
SunAhong1993 已提交
935 936
                else:
                    ends_cast = ends.name
S
SunAhong1993 已提交
937 938 939 940 941 942
                self.paddle_graph.add_layer(
                    'paddle.cast',
                    inputs={"x": ends.name},
                    outputs=[ends_cast],
                    dtype=string('int32'))
                layer_attrs['ends'] = ends_cast
C
channingss 已提交
943 944 945 946
        else:
            starts = node.get_attr('starts')
            ends = node.get_attr('ends')
            axes = node.get_attr('axes')
947 948 949
            for idx in range(len(ends)):
                if ends[idx] > 2**31 - 1:
                    ends[idx] = 2**31 - 1
S
SunAhong1993 已提交
950
            layer_attrs = {"axes": axes, "starts": starts, "ends": ends}
C
channingss 已提交
951

S
for pad  
SunAhong1993 已提交
952

C
Channingss 已提交
953
        if steps is not None:
S
SunAhong1993 已提交
954 955 956 957 958 959
            layer_attrs['strides'] = steps
            self.paddle_graph.add_layer(
                'paddle.strided_slice', 
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
                **layer_attrs)
C
Channingss 已提交
960
        else:
S
SunAhong1993 已提交
961 962 963 964 965
            self.paddle_graph.add_layer(
                'paddle.slice', 
                inputs={"input": val_x.name}, 
                outputs=[node.name],  
                **layer_attrs)
C
channingss 已提交
966

967
    @print_mapping_info
C
update  
channingss 已提交
968
    def ConstantOfShape(self, node):
C
channingss 已提交
969
        val_shape = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
970
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
C
update  
channingss 已提交
971 972 973 974

        value = node.get_attr('value')
        dtype = value.dtype
        value = value.tolist()
975 976
        assert len(value) == 1, ('given value not Scalar, shape of value > 1, '
                                 'this is not supported')
C
update  
channingss 已提交
977 978
        if len(value) == 1:
            value = value[0]
S
SunAhong1993 已提交
979 980
            layer_attrs = {
                'shape': val_shape.name,
981
                'dtype': string(dtype),
S
SunAhong1993 已提交
982
                'fill_value': value
983
            }
S
SunAhong1993 已提交
984 985 986 987 988
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
                outputs=[node.name],
                **layer_attrs)
C
update  
channingss 已提交
989

C
Channingss 已提交
990 991 992 993 994 995 996 997
    @print_mapping_info
    def Clip(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
        max_value, min_value = None, None
        if len(node.inputs) == 1:
            max_value = node.get_attr('max')
            min_value = node.get_attr('min')
S
SunAhong1993 已提交
998
            layer_attrs = {
C
Channingss 已提交
999 1000 1001
                'max': max_value,
                'min': min_value,
            }
S
SunAhong1993 已提交
1002 1003 1004 1005 1006
            self.paddle_graph.add_layer(
                'paddle.clip', 
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
                **layer_attrs)
C
Channingss 已提交
1007
        else:
S
fix  
SunAhong1993 已提交
1008 1009
            min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
            max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
S
SunAhong1993 已提交
1010
            min_value = _const_weight_or_none(min_ipt)
S
fix  
SunAhong1993 已提交
1011
            max_value = _const_weight_or_none(max_ipt)
1012
            if max_value.shape == (1, ):
C
Channingss 已提交
1013
                max_value = max_value[0]
1014
            if min_value.shape == (1, ):
C
Channingss 已提交
1015 1016
                min_value = min_value[0]
        if max_value is not None and min_value is not None:
S
SunAhong1993 已提交
1017 1018 1019 1020 1021 1022
            layer_attrs = {'max': max_value, 'min': min_value}
            self.paddle_graph.add_layer(
                'paddle.clip', 
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
                **layer_attrs)
C
Channingss 已提交
1023 1024 1025
        else:
            raise

1026
    @print_mapping_info
C
update  
channingss 已提交
1027
    def Split(self, node):
C
channingss 已提交
1028
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1029
        paddle_op = 'split'
C
channingss 已提交
1030
        split = node.get_attr('split')
C
update  
channingss 已提交
1031
        axis = node.get_attr('axis', 0)
S
SunAhong1993 已提交
1032
        layer_attrs = {
C
channingss 已提交
1033
            'num_or_sections': split,
S
SunAhong1993 已提交
1034
            'axis': axis,
C
channingss 已提交
1035
        }
S
SunAhong1993 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
        outputs_list = list()
        if isinstance(split, list) or isinstance(split, tuple):
            for i in range(len(split)):
                outputs_list.append("{}_p{}".format(node.layer_name, i))
        else:
            outputs_list.append(node.name)
        self.paddle_graph.add_layer(
            'paddle.split', 
            inputs={"x": val_x.name}, 
            outputs=outputs_list, 
            **layer_attrs)
C
update  
channingss 已提交
1047

1048
    @print_mapping_info
C
update  
channingss 已提交
1049
    def Reshape(self, node):
C
channingss 已提交
1050 1051
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
1052
        val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
1053 1054 1055 1056
        shape_value = _const_weight_or_none(val_shape)
        shape_dims = len(val_shape.out_shapes[0])

        if shape_value is not None:
S
SunAhong1993 已提交
1057 1058 1059 1060 1061
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={'x': val_x.name},
                outputs=[node.name],
                shape=shape_value.tolist())
C
Channingss 已提交
1062 1063
        elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
                0]):
S
SunAhong1993 已提交
1064 1065 1066 1067 1068
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={'x': val_x.name},
                outputs=[node.name],
                shape=node.out_shapes[0])
1069
        else:
1070 1071
            # shape may be [], come form Gather by scalar indices
            if len(val_shape.out_shapes[0]) > 0:
S
SunAhong1993 已提交
1072 1073 1074 1075 1076
                self.paddle_graph.add_layer(
                    'paddle.reshape',
                    inputs={'x': val_shape.name},
                    outputs=[val_shape.name],
                    shape=val_shape.out_shapes[0])
S
for pad  
SunAhong1993 已提交
1077 1078 1079 1080 1081 1082
            if val_shape.dtype != "int32":
                self.paddle_graph.add_layer(
                    'paddle.cast',
                    inputs={'x': val_shape.name},
                    outputs=[val_shape.name],
                    dtype=string("int32"))
S
SunAhong1993 已提交
1083 1084 1085 1086
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={'x': val_x.name,
                        'shape': val_shape.name},
S
SunAhong1993 已提交
1087
                outputs=[node.name])
1088 1089

    @print_mapping_info
C
update  
channingss 已提交
1090
    def Cast(self, node):
C
channingss 已提交
1091
        val_input = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
1092 1093 1094 1095 1096 1097 1098 1099 1100
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        dtype = node.get_attr('to')
        if not isinstance(dtype, np.dtype):
            dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]

        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'dtype of to unmatches output'
S
SunAhong1993 已提交
1101 1102 1103 1104 1105
        self.paddle_graph.add_layer(
            'paddle.cast', 
            inputs={'x': val_input.name}, 
            outputs=[node.name], 
            dtype=string(dtype))
C
update  
channingss 已提交
1106

C
Channingss 已提交
1107 1108 1109
    @print_mapping_info
    def Not(self, node):
        val_input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1110 1111 1112
        self.paddle_graph.add_layer('paddle.logical_not', 
                                    inputs={'x': val_input.name}, 
                                    outputs=[node.name])
C
Channingss 已提交
1113

1114
    @print_mapping_info
C
update  
channingss 已提交
1115
    def AveragePool(self, node):
C
channingss 已提交
1116
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
1117 1118

        auto_pad = node.get_attr('auto_pad', 'NOTSET')
C
update  
channingss 已提交
1119 1120 1121 1122 1123 1124
        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))
        pads = node.get_attr('pads', [0] * (poolnd * 2))
C
channingss 已提交
1125

C
channingss 已提交
1126 1127
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
1128
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
channingss 已提交
1129
            input_shape = val_x.out_shapes[0]
C
Channingss 已提交
1130 1131 1132 1133 1134
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w
C
channingss 已提交
1135

S
SunAhong1993 已提交
1136 1137
        paddle_op = 'paddle.nn.functional.avg_pool{}d'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only avg_pool1d, avg_pool2d and avg_pool3d are supported'
S
SunAhong1993 已提交
1138
        layer_attrs = {
S
SunAhong1993 已提交
1139 1140 1141
            "kernel_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
C
update  
channingss 已提交
1142
            "ceil_mode": ceil_mode,
S
SunAhong1993 已提交
1143
            "exclusive": True,
S
SunAhong1993 已提交
1144
            "name": string(node.name)
C
update  
channingss 已提交
1145
        }
S
SunAhong1993 已提交
1146 1147
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
1148
            inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, 
S
SunAhong1993 已提交
1149 1150
            outputs=[node.name], 
            **layer_attrs)
C
update  
channingss 已提交
1151

1152
    @print_mapping_info
C
update  
channingss 已提交
1153
    def Concat(self, node):
S
SunAhong1993 已提交
1154
        inputs_list = []
C
Channingss 已提交
1155
        dtypes = set()
C
update  
channingss 已提交
1156
        for i in range(len(node.layer.input)):
C
channingss 已提交
1157
            ipt = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1158 1159
            inputs_list.append(ipt.name)
            dtypes.add(ipt.dtype)
C
Channingss 已提交
1160 1161
        if len(dtypes) > 1:
            assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
C
update  
channingss 已提交
1162
        axis = node.get_attr('axis')
S
SunAhong1993 已提交
1163 1164 1165 1166 1167
        self.paddle_graph.add_layer(
            'paddle.concat', 
            inputs={"x": inputs_list}, 
            outputs=[node.name], 
            axis=axis)
C
update  
channingss 已提交
1168

1169
    @print_mapping_info
C
update  
channingss 已提交
1170
    def Flatten(self, node):
C
channingss 已提交
1171
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1172
        output_shape = node.out_shapes[0]
C
update  
channingss 已提交
1173
        axis = node.get_attr('axis', 1)
S
SunAhong1993 已提交
1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187
        shape_list = [1, 1]
        if axis == 0:
            for s in output_shape:
                shape_list[1] *= s
        else:
            for s in output_shape[:axis]:
                shape_list[0] *= s
            for s in output_shape[axis:]:
                shape_list[1] *= s
        self.paddle_graph.add_layer(
            'paddle.reshape', 
            inputs={"x": val_x.name}, 
            outputs=[node.name],
            shape=shape_list)
C
update  
channingss 已提交
1188

1189
    @print_mapping_info
C
update  
channingss 已提交
1190
    def Gemm(self, node):
C
channingss 已提交
1191 1192 1193
        val_a = self.graph.get_input_node(node, idx=0, copy=True)
        val_b = self.graph.get_input_node(node, idx=1, copy=True)
        val_c = self.graph.get_input_node(node, idx=2, copy=True)
C
update  
channingss 已提交
1194 1195 1196 1197 1198

        alpha = node.get_attr('alpha', 1.)  # optional
        beta = node.get_attr('beta', 1.)  # optional
        trans_a = bool(node.get_attr('transA', 0))  # optional
        trans_b = bool(node.get_attr('transB', 0))  # optional
S
SunAhong1993 已提交
1199 1200 1201
        val_mm = node.name + '_mm'
        matmul_inputs = {"x": val_a.name, 
                         "y": val_b.name}
C
update  
channingss 已提交
1202 1203 1204 1205
        attr_matmul = {
            "transpose_x": trans_a,
            "transpose_y": trans_b,
        }
S
SunAhong1993 已提交
1206 1207
        self.paddle_graph.add_layer(
            'paddle.matmul',
1208
            inputs=matmul_inputs,
S
SunAhong1993 已提交
1209 1210 1211 1212 1213 1214 1215
            outputs=[val_mm],
            **attr_matmul)
        self.paddle_graph.add_layer(
            "paddle.scale", 
            inputs={"x": val_mm}, 
            outputs=[val_mm],
            scale=alpha)
C
channingss 已提交
1216

C
update  
channingss 已提交
1217 1218
        if beta != 0:
            if beta == 1.:
S
SunAhong1993 已提交
1219 1220 1221 1222
                add_inputs = {"x": val_mm, 
                              "y": val_c.name}
                self.paddle_graph.add_layer(
                    "paddle.add",
1223
                    inputs=add_inputs,
S
SunAhong1993 已提交
1224
                    outputs=[node.name])
C
update  
channingss 已提交
1225
            else:
S
SunAhong1993 已提交
1226 1227 1228 1229 1230 1231
                var_beta = node.name + '_beta'
                self.paddle_graph.add_layer(
                    "paddle.scale",
                    inputs={"x": val_c.name},
                    outputs=[var_beta],
                    scale=beta)
C
channingss 已提交
1232
                add_inputs = {"x": val_mm, "y": var_beta}
S
SunAhong1993 已提交
1233 1234
                self.paddle_graph.add_layer(
                    "paddle.add",
1235
                    inputs=add_inputs,
S
SunAhong1993 已提交
1236
                    outputs=[node.name])
C
update  
channingss 已提交
1237

1238
    @print_mapping_info
C
update  
channingss 已提交
1239
    def Sum(self, node):
1240
        val_inps = node.layer.input
S
SunAhong1993 已提交
1241
        inputs_dict = {
1242
            "x": self.graph.get_input_node(
S
SunAhong1993 已提交
1243
                node, idx=0, copy=True).name,
1244
            "y": self.graph.get_input_node(
S
SunAhong1993 已提交
1245
                node, idx=1, copy=True).name,
1246
        }
S
SunAhong1993 已提交
1247 1248 1249
        self.paddle_graph.add_layer("paddle.add", 
                                    inputs=inputs_dict, 
                                    outputs=[node.name])
1250

C
channingss 已提交
1251 1252
        for idx, ipt in enumerate(val_inps[2:]):
            y = self.graph.get_input_node(node, idx=idx, copy=True)
S
SunAhong1993 已提交
1253 1254 1255
            inputs_dict = {
                "x": node.name,
                "y": y.name,
1256
            }
S
SunAhong1993 已提交
1257 1258 1259 1260
            self.paddle_graph.add_layer(
                "paddle.add", 
                inputs=inputs_dict, 
                outputs=[node.name])
C
update  
channingss 已提交
1261

1262
    @print_mapping_info
C
update  
channingss 已提交
1263
    def MatMul(self, node):
C
channingss 已提交
1264 1265
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
C
Channingss 已提交
1266 1267
        x_shape = val_x.out_shapes[0]
        y_shape = val_y.out_shapes[0]
S
SunAhong1993 已提交
1268 1269
        inputs_dict = {"x": val_x.name, 
                       "y": val_y.name}
C
Channingss 已提交
1270
        if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
S
SunAhong1993 已提交
1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281
            y_squeeze = val_y.name + '_squeeze'
            self.paddle_graph.add_layer(
                "paddle.squeeze",
                inputs={"x": val_y.name},
                outputs=[y_squeeze],
                axis=[0])
            inputs_dict['y'] = y_squeeze
            self.paddle_graph.add_layer(
                "paddle.matmul", 
                inputs=inputs_dict, 
                outputs=[node.name])
C
Channingss 已提交
1282
        else:
S
SunAhong1993 已提交
1283 1284 1285 1286 1287
            self.paddle_graph.add_layer(
                "paddle.matmul", 
                inputs=inputs_dict, 
                outputs=[node.name])
            
1288
    @print_mapping_info
C
update  
channingss 已提交
1289
    def BatchNormalization(self, node):
C
channingss 已提交
1290 1291 1292 1293 1294
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scale = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = self.graph.get_input_node(node, idx=2, copy=True)
        val_mean = self.graph.get_input_node(node, idx=3, copy=True)
        val_var = self.graph.get_input_node(node, idx=4, copy=True)
C
update  
channingss 已提交
1295 1296 1297 1298

        momentum = node.get_attr('momentum', .9)
        epsilon = node.get_attr('epsilon', 1e-5)

C
channingss 已提交
1299 1300
        # Attribute: spatial is used in BatchNormalization-1,6,7
        spatial = bool(node.get_attr('spatial'))
S
SunAhong1993 已提交
1301
        layer_attrs = {
C
update  
channingss 已提交
1302 1303 1304
            "momentum": momentum,
            "epsilon": epsilon,
        }
S
SunAhong1993 已提交
1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
        self.paddle_graph.add_layer(
            "paddle.nn.functional.batch_norm", 
            inputs={"x": val_x.name,
                    "weight": val_scale.name,
                    "bias": val_b.name,
                    "running_mean": val_mean.name,
                    "running_var": val_var.name}, 
            outputs=[node.name], 
            **layer_attrs)
        
1315
    @print_mapping_info
C
update  
channingss 已提交
1316
    def Transpose(self, node):
C
channingss 已提交
1317
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
for pad  
SunAhong1993 已提交
1318 1319 1320 1321
        s_len = len(val_x.out_shapes[0])
        perm_default = list(range(s_len))
        perm_default.reverse()
        perm = node.get_attr('perm', perm_default)
S
SunAhong1993 已提交
1322 1323 1324 1325 1326
        self.paddle_graph.add_layer(
            "paddle.transpose", 
            inputs={"x": val_x.name},
            outputs=[node.name], 
            perm=perm)
C
update  
channingss 已提交
1327

1328
    @print_mapping_info
C
update  
channingss 已提交
1329
    def PRelu(self, node):
C
channingss 已提交
1330 1331
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_slope = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
1332

C
channingss 已提交
1333 1334
        mode = 'channel'
        shape_slope = val_slope.out_shapes[0]
C
Channingss 已提交
1335
        if shape_slope == [1]:
C
channingss 已提交
1336
            mode = 'all'
C
Channingss 已提交
1337

S
SunAhong1993 已提交
1338 1339 1340 1341 1342 1343 1344 1345
        if mode == "element":
            self.paddle_graph.add_layer(
                "paddle.static.nn.prelu", 
                inputs={"x": val_x.name,
                        "param_attr": val_slope.name}, 
                outputs=[node.name],
                mode="element")
        else:
S
SunAhong1993 已提交
1346 1347 1348 1349 1350 1351 1352
            if mode == 'channel':
                if len(shape_slope) > 1:
                    self.paddle_graph.add_layer(
                        "paddle.reshape", 
                        inputs={"x": val_slope.name}, 
                        outputs=[val_slope.name],
                        shape=[shape_slope[0]])
S
SunAhong1993 已提交
1353 1354 1355 1356 1357
            self.paddle_graph.add_layer(
                "paddle.nn.functional.prelu", 
                inputs={"x": val_x.name,
                        "weight": val_slope.name}, 
                outputs=[node.name])
C
update  
channingss 已提交
1358

1359
    @print_mapping_info
C
update  
channingss 已提交
1360
    def Squeeze(self, node):
C
channingss 已提交
1361 1362
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axes = node.get_attr('axes')
1363
        if len(val_x.out_shapes[0]) == 1:
S
SunAhong1993 已提交
1364 1365 1366 1367 1368
            self.paddle_graph.add_layer(
                "paddle.cast",
                inputs={"x": val_x.name},
                outputs=[node.name],
                dtype=string(val_x.dtype))
1369
        else:
S
SunAhong1993 已提交
1370 1371 1372 1373 1374
            self.paddle_graph.add_layer(
                "paddle.squeeze", 
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
                axis=axes)
R
root 已提交
1375

1376
    @print_mapping_info
C
channings 已提交
1377 1378 1379
    def Equal(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1380 1381 1382 1383 1384
        self.paddle_graph.add_layer(
            "paddle.equal",
            inputs={'x': val_x.name,
                    'y': val_y.name},
            outputs=[node.name])
1385

C
Channingss 已提交
1386 1387 1388 1389
    @print_mapping_info
    def Greater(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1390 1391 1392 1393 1394
        self.paddle_graph.add_layer(
            "paddle.greater_than",
            inputs={'x': val_x.name,
                    'y': val_y.name},
            outputs=node,
C
Channingss 已提交
1395 1396
            param_attr=None)

1397
    @print_mapping_info
C
channings 已提交
1398 1399 1400 1401
    def Where(self, node):
        condition = self.graph.get_input_node(node, idx=0, copy=True)
        val_x = self.graph.get_input_node(node, idx=1, copy=True)
        val_y = self.graph.get_input_node(node, idx=2, copy=True)
R
root 已提交
1402

S
SunAhong1993 已提交
1403 1404 1405 1406 1407
        not_condition = condition.name + '_not'
        self.paddle_graph.add_layer(
            "paddle.logical_not",
            inputs={"x": condition.name},
            outputs=[not_condition])
R
root 已提交
1408
        cast_not_condition = not_condition + '_cast'
S
SunAhong1993 已提交
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423
        self.paddle_graph.add_layer(
            "paddle.cast",
            inputs={"x": not_condition},
            outputs=[cast_not_condition],
            dtype=string(val_x.dtype))
        cast_condition = condition.name + '_cast'
        self.paddle_graph.add_layer(
            "paddle.cast",
            inputs={"x": condition.name},
            outputs=[cast_condition],
            dtype=string(val_x.dtype))
        mul_val_x = val_x.name + '_mul'
        self.paddle_graph.add_layer(
            "paddle.multiply",
            inputs={'x': val_x.name,
1424
                    'y': cast_condition},
S
SunAhong1993 已提交
1425 1426 1427 1428 1429
            outputs=[mul_val_x])
        mul_val_y = val_y.name + '_mul'
        self.paddle_graph.add_layer(
            "paddle.multiply",
            inputs={'x': val_y.name,
1430
                    'y': cast_not_condition},
S
SunAhong1993 已提交
1431
            outputs=[mul_val_y])
1432

S
SunAhong1993 已提交
1433 1434
        self.paddle_graph.add_layer(
            "paddle.add",
1435 1436
            inputs={'x': mul_val_x,
                    'y': mul_val_y},
S
SunAhong1993 已提交
1437
            outputs=[node.name])
1438 1439

    @print_mapping_info
R
root 已提交
1440 1441
    def NonZero(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
1442 1443
        val_x_dim = len(val_x.out_shapes[0])
        if val_x_dim == 1:
S
SunAhong1993 已提交
1444 1445 1446 1447 1448 1449 1450 1451 1452
            self.paddle_graph.add_layer(
                "paddle.nonzero", 
                inputs={"x": val_x.name}, 
                outputs=[val_x.name])
            self.paddle_graph.add_layer(
                "paddle.transpose",
                inputs={"x": val_x.name},
                outputs=[node.layer_naem],
                perm=[1, 0])
1453
        if val_x_dim > 1:
S
SunAhong1993 已提交
1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467
            self.paddle_graph.add_layer(
                "paddle.nonzero", 
                inputs={"x": val_x.name}, 
                outputs=[val_x.name])
            self.paddle_graph.add_layer(
                "paddle.split",
                inputs={"x": val_x.name}, 
                outputs=[val_x.name],
                num_or_sections=1,
                axis=val_x_dim)
            self.paddle_graph.add_layer(
                "paddle.concat", 
                inputs={"x": val_x.name}, 
                outputs=[node.name])
1468 1469

    @print_mapping_info
C
update  
channingss 已提交
1470
    def Identity(self, node):
C
channingss 已提交
1471
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1472 1473 1474 1475 1476
        self.paddle_graph.add_layer(
            "paddle.assign", 
            inputs={"x": val_x.name}, 
            outputs=[node.name])
        
1477
    @print_mapping_info
C
channings 已提交
1478 1479 1480 1481
    def Tile(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
        repeats = _const_weight_or_none(val_repeats)
R
root 已提交
1482

1483
        if repeats is None:
S
SunAhong1993 已提交
1484
            repeats = val_repeats.name
J
jiangjiajun 已提交
1485
            if val_repeats.dtype != 'int32':
S
SunAhong1993 已提交
1486 1487 1488 1489 1490
                self.paddle_graph.add_layer(
                    "paddle.cast",
                    inputs={"x": repeats},
                    outputs=["{}.tmp".format(repeats)],
                    dtype=string("int32"))
J
jiangjiajun 已提交
1491 1492
                repeats = "{}.tmp".format(repeats)

1493
        elif isinstance(repeats, int):
C
channings 已提交
1494
            repeats = [repeats]
R
root 已提交
1495

C
channings 已提交
1496
        attr = {
R
root 已提交
1497
            'expand_times': repeats,
S
SunAhong1993 已提交
1498
            "name": string(node.name),
C
channings 已提交
1499
        }
S
SunAhong1993 已提交
1500 1501 1502 1503 1504
        self.paddle_graph.add_layer(
            "paddle.tile", 
            inputs={"x": val_x.name}, 
                    outputs=[node.name], 
                    repeat_times=repeats)
R
root 已提交
1505

1506
    @print_mapping_info
C
update  
channingss 已提交
1507
    def MaxPool(self, node):
C
channingss 已提交
1508
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
1509
        auto_pad = node.get_attr('auto_pad', 'NOTSET')
C
update  
channingss 已提交
1510 1511 1512 1513 1514 1515 1516 1517 1518
        assert node.get_attr(
            "dilations") is None, 'only dilations = 0 is supported'  # optional

        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))  # optional
        pads = node.get_attr('pads', [0] * (poolnd * 2))  # optional
S
SunAhong1993 已提交
1519 1520
        paddle_op = 'paddle.nn.functional.max_pool{}d'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only max_pool1d, max_pool2d and max_pool3d are supported'
C
channingss 已提交
1521

C
channingss 已提交
1522 1523
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
1524
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
channingss 已提交
1525
            input_shape = val_x.out_shapes[0]
C
Channingss 已提交
1526 1527 1528 1529 1530
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w
S
SunAhong1993 已提交
1531 1532 1533 1534 1535
            
        layer_attrs = {
            "kernel_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
C
update  
channingss 已提交
1536 1537
            "ceil_mode": ceil_mode,
        }
S
SunAhong1993 已提交
1538 1539 1540 1541 1542
        self.paddle_graph.add_layer(
            paddle_op, 
            inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, 
            outputs=[node.name], 
            **layer_attrs)
R
root 已提交
1543

1544
    @print_mapping_info
C
channings 已提交
1545
    def GlobalMaxPool(self, node):
S
SunAhong1993 已提交
1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = val_x.out_shapes[0]
        if len(input_shape) == 4:
            poolnd = 2
        elif len(input_shape) == 5:
            poolnd = 3
        elif len(input_shape) == 3:
            poolnd = 1
        paddle_op = 'paddle.nn.functional.adaptive_max_pool{}d'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only adaptive_max_pool1d, adaptive_max_pool2d and adaptive_max_pool3d are supported'
        output_shape = node.out_shapes[0]
        self.paddle_graph.add_layer(
            paddle_op, 
            inputs={'x': val_x.name}, 
            outputs=[node.name], 
            output_size=output_shape[2:])
        
1563
    @print_mapping_info
C
channings 已提交
1564
    def GlobalAveragePool(self, node):
S
SunAhong1993 已提交
1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = val_x.out_shapes[0]
        if len(input_shape) == 4:
            poolnd = 2
        elif len(input_shape) == 5:
            poolnd = 3
        elif len(input_shape) == 3:
            poolnd = 1
        paddle_op = 'paddle.nn.functional.adaptive_avg_pool{}d'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
        output_shape = node.out_shapes[0]
        self.paddle_graph.add_layer(
            paddle_op, 
            inputs={'x': val_x.name}, 
            outputs=[node.name], 
            output_size=output_shape[2:])
R
root 已提交
1581

1582
    @print_mapping_info
C
update  
channingss 已提交
1583
    def Conv(self, node):
C
channingss 已提交
1584 1585
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
1586 1587
        has_bias = len(node.layer.input) == 3
        if has_bias:
C
channingss 已提交
1588
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
C
update  
channingss 已提交
1589 1590
        auto_pad = node.get_attr('auto_pad', 'NOTSET')

C
channingss 已提交
1591
        kernel_shape = node.get_attr('kernel_shape')
C
update  
channingss 已提交
1592 1593
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
C
Channingss 已提交
1594
        num_out_channels = val_w.out_shapes[0][0]
S
SunAhong1993 已提交
1595 1596
        num_in_channels = val_w.out_shapes[0][1]
        paddle_op = 'paddle.nn.functional.conv{}d'.format(convnd)
C
update  
channingss 已提交
1597 1598

        num_groups = node.get_attr('group', 1)
C
Channingss 已提交
1599 1600 1601
        strides = node.get_attr('strides', [1] * convnd)
        dilations = node.get_attr('dilations', [1] * convnd)
        pads = node.get_attr('pads', [0] * (convnd * 2))
C
update  
channingss 已提交
1602

C
channingss 已提交
1603
        input_shape = val_x.out_shapes[0]
C
update  
channingss 已提交
1604 1605
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
1606
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
Channingss 已提交
1607 1608 1609 1610 1611
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w
C
update  
channingss 已提交
1612

S
SunAhong1993 已提交
1613
        layer_attrs = {
C
update  
channingss 已提交
1614 1615 1616 1617
            "stride": strides,
            "padding": paddings,
            "dilation": dilations,
            "groups": num_groups,
S
SunAhong1993 已提交
1618 1619 1620 1621
        }
        layer_inputs = {
            "x": val_x.name,
            "weight": val_w.name
C
update  
channingss 已提交
1622 1623
        }
        if has_bias:
S
SunAhong1993 已提交
1624
            layer_inputs["bias"] = val_b.name
S
fix  
SunAhong1993 已提交
1625 1626 1627 1628 1629 1630 1631 1632 1633 1634
        input_shape = val_x.out_shapes[0]
        if reduce(lambda x,y:x*y, input_shape) in [1, -1] and 1 not in input_shape:
            input_shape[1] = num_in_channels * num_groups
            input_shape[0] = 0
            input_shape[2] = 0
            self.paddle_graph.add_layer(
                "paddle.reshape", 
                inputs={"x": layer_inputs["x"]}, 
                outputs=[layer_inputs["x"]], 
                shape=input_shape)
S
SunAhong1993 已提交
1635 1636 1637 1638 1639
        self.paddle_graph.add_layer(
            paddle_op, 
            inputs=layer_inputs, 
            outputs=[node.name], 
            **layer_attrs)
C
channingss 已提交
1640

1641
    @print_mapping_info
C
channingss 已提交
1642
    def ConvTranspose(self, node):
C
channingss 已提交
1643 1644
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
C
channingss 已提交
1645
        val_b = None
R
root 已提交
1646
        if len(node.layer.input) > 2:
C
channingss 已提交
1647
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
C
channingss 已提交
1648 1649
        auto_pad = node.get_attr('auto_pad', 'NOTSET')
        out_padding = node.get_attr('output_padding', [0, 0])
C
channingss 已提交
1650
        kernel_shape = node.get_attr('kernel_shape')
C
channingss 已提交
1651 1652 1653
        assert kernel_shape, 'kernel_shape not inferred'
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
S
SunAhong1993 已提交
1654
        num_in_channels = val_w.out_shapes[0][0]
C
channingss 已提交
1655
        num_out_channels = val_w.out_shapes[0][1]
S
SunAhong1993 已提交
1656
        paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd)
C
channingss 已提交
1657

C
channingss 已提交
1658 1659 1660 1661 1662
        num_groups = node.get_attr('group', 1)
        strides = node.get_attr('strides', [1] * convnd)
        dilations = node.get_attr('dilations', [1] * convnd)
        output_size = node.get_attr('output_shape', [])
        pads = node.get_attr('pads', [0] * (convnd * 2))
C
channingss 已提交
1663 1664 1665 1666

        paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)

        output_size = [0, 0]
C
channingss 已提交
1667

1668 1669
        output_size[0] = (val_x.out_shapes[0][2] - 1
                          ) * strides[0] - 2 * paddings[0] + dilations[0] * (
C
channingss 已提交
1670
                              kernel_shape[0] - 1) + 1 + out_padding[0]
1671 1672
        output_size[1] = (val_x.out_shapes[0][3] - 1
                          ) * strides[1] - 2 * paddings[1] + dilations[1] * (
C
channingss 已提交
1673
                              kernel_shape[1] - 1) + 1 + out_padding[1]
S
SunAhong1993 已提交
1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684
        layer_inputs = {'x': val_x.name,
                       "weight": val_w.name}
        layer_attrs = {
            "stride": strides,
            "dilation": dilations,
            "padding": paddings,
            "groups": num_groups,
            "output_size": node.out_shapes[0][2:]}
        if val_b is not None:
            layer_inputs["bias"] = val_b.name
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
1685
            kernel=paddle_op,
S
SunAhong1993 已提交
1686 1687
            inputs=layer_inputs,
            outputs=[node.name],
S
fix  
SunAhong1993 已提交
1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700
            **layer_attrs)
        
    @print_mapping_info
    def ArgMax(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axis = node.get_attr('axis')
        keepdims = False if node.get_attr('keepdims') == 0 else True
        layer_attrs = {'axis': axis,
                      'keepdim': keepdims}
        self.paddle_graph.add_layer(
            'paddle.argmax', 
            inputs={"x": val_x.name}, 
            outputs=[node.name],
S
SunAhong1993 已提交
1701 1702 1703 1704 1705 1706 1707
            **layer_attrs)
        
    @print_mapping_info
    def Size(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            "paddle.shape", 
S
for pad  
SunAhong1993 已提交
1708
            inputs={"input": val_x.name}, 
S
SunAhong1993 已提交
1709
            outputs=[node.name])
S
for pad  
SunAhong1993 已提交
1710 1711 1712 1713 1714
        self.paddle_graph.add_layer(
            'paddle.cast',
            inputs={"x": node.name},
            outputs=[node.name],
            dtype=string('int64'))  
S
SunAhong1993 已提交
1715 1716 1717 1718 1719 1720 1721 1722
        self.paddle_graph.add_layer(
            "paddle.prod",
            inputs={"x": node.name},
            outputs=[node.name])

    @print_mapping_info
    def Sign(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
for pad  
SunAhong1993 已提交
1723 1724 1725 1726 1727 1728
        if node.dtype not in ["float16", "float32", "float64"]:
            self.paddle_graph.add_layer(
                "paddle.cast", 
                inputs={"x": val_x.name}, 
                outputs=[val_x.name],
                dtype=string("float32"))
S
SunAhong1993 已提交
1729 1730 1731 1732
        self.paddle_graph.add_layer(
            "paddle.sign", 
            inputs={"x": val_x.name}, 
            outputs=[node.name])
S
for pad  
SunAhong1993 已提交
1733 1734 1735 1736 1737 1738
        if node.dtype not in ["float16", "float32", "float64"]:
            self.paddle_graph.add_layer(
                "paddle.cast", 
                inputs={"x": node.name}, 
                outputs=[node.name],
                dtype=string(node.dtype))
S
SunAhong1993 已提交
1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760

    @print_mapping_info
    def OneHot(self, node):
        indices = self.graph.get_input_node(node, idx=0, copy=True)
        depth = self.graph.get_input_node(node, idx=1, copy=True)
        values = self.graph.get_input_node(node, idx=2, copy=True)
        axis = node.get_attr('axis', -1)
        self.paddle_graph.add_layer(
            "custom_layer:one_hot", 
            inputs={"indices": indices.name,
                    "depth": depth.name,
                    "values": values.name}, 
            outputs=[node.name],
            axis=axis)

    @print_mapping_info
    def Reciprocal(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            "paddle.reciprocal", 
            inputs={"x": val_x.name}, 
            outputs=[node.name])