opset.py 77.1 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
from x2paddle.core.graph import GraphNode
from x2paddle.core.util import *
from functools import reduce
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
import logging as _logging
from collections import OrderedDict
import math
import os
import copy
import sys
import shutil

_logger = _logging.getLogger(__name__)


def _const_weight_or_none(node, necessary=False):
    if 'Constant' in node.layer_type:
        return node.value
    if isinstance(node, ONNXGraphDataNode):
        return node.weight
    if necessary:
        assert '{} should be an initializer or Constant operator.'.format(
S
SunAhong1993 已提交
41
            node.name)
S
SunAhong1993 已提交
42 43 44
    return None


C
Channingss 已提交
45
def _rename_or_remove_weight(weights, origin_name, target_name=None, is_remove=True):
46 47 48 49 50
    ''' 
    Rename parameters by Paddle's naming rule of parameters.

    Args:
        weights(dict[String:np.ndarray]): Dict stored paramters, the key in weights is name of parameter.
51 52 53 54
        origin_name(String): Name of parameter to rename or remove.
        target_name(String, optional): if target_name is not None, add new key-value pair 
            {target_name:weights[origin_name]} to weights, and target_name must follow paddle's 
            naming rule of parameters. Default: None.
55 56 57 58
        is_remove: if is_remove is True, remove origin key-value pair. Default: True.
    Returns:
        None
    '''   
C
Channingss 已提交
59 60
    if origin_name not in weights:
        raise KeyError('{} not a key in {}'.format(origin_name, weights))
C
Channingss 已提交
61
    if is_remove:
C
Channingss 已提交
62
        # remove weight
C
Channingss 已提交
63 64 65 66 67 68
        data = weights.pop(origin_name)
    else:
        data = weights[origin_name]
    if target_name is not None:
        # rename weight
        weights[target_name] = data
C
Channingss 已提交
69

S
SunAhong1993 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
def _is_static_shape(shape):
    negtive_dims = 0
    error_dims = 0
    for dim in shape:
        if dim < 0:
            negtive_dims += 1
        if dim < -1:
            error_dims += 1
    if negtive_dims > 1:
        return False
    if error_dims > 0:
        return False
    return True


def _get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]


def print_mapping_info(func):
    def run_mapping(*args, **kwargs):
        node = args[1]
        try:
            res = func(*args, **kwargs)
        except:
            print("convert failed node:{}, op_type is {}".format(
S
SunAhong1993 已提交
100
                node.name[9:], node.layer_type))
S
SunAhong1993 已提交
101 102 103 104 105 106 107 108 109 110 111
            raise
        else:
            return res

    return run_mapping


class OpSet9():
    elementwise_ops = {
        'Add': 'paddle.add',
        'Div': 'paddle.divide',
S
SunAhong1993 已提交
112
        'Sub': 'paddle.subtract',
S
SunAhong1993 已提交
113 114 115 116
        'Mul': 'paddle.multiply',
        'Pow': 'paddle.pow',
    }

S
SunAhong1993 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    directly_map_ops = {
        'Ceil': ['paddle.ceil'],
        # reduce function
        'ReduceMean': ['paddle.mean',
                       dict(axes='axis', keepdims='keepdim'), 
                       dict(keepdims=1)],
        'ReduceSum': ['paddle.sum', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdims=1)],
        'ReduceMin': ['paddle.min', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
        'ReduceMax': ['paddle.max', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
S
SunAhong1993 已提交
132 133 134
        'ReduceProd': ['paddle.prod', 
                      dict(axes='axis', keepdims='keepdim'), 
                      dict(keepdim=1)],
S
SunAhong1993 已提交
135 136 137 138
        # active function
        'Relu': ['paddle.nn.ReLU'],
        'LeakyRelu': ['paddle.nn.LeakyReLU', 
                      dict(alpha='negative_slope'), 
S
SunAhong1993 已提交
139
                      dict(negative_slope=.01)],
S
SunAhong1993 已提交
140
        'Elu': ['paddle.nn.functional.elu', 
S
fix  
SunAhong1993 已提交
141
                dict(alpha='alpha'), 
S
SunAhong1993 已提交
142 143 144 145 146 147 148 149
                dict(alpha=1.)],
        'ThresholdedRelu': ['paddle.nn.functional.thresholded_relu', 
                            dict(alpha='threshold'),
                            dict(alpha=1.)],
        'Tanh': ['paddle.nn.Tanh'],
        'Sigmoid': ['paddle.nn.Sigmoid'],
        'Softsign': ['paddle.nn.Softsign'],
        'Softplus': ['paddle.nn.Softplus', 
S
fix  
SunAhong1993 已提交
150
                     dict(threshold='threshold'), 
S
SunAhong1993 已提交
151 152
                     dict(threshold=float(sys.maxsize))],
        'Exp': ['paddle.exp'],
C
Channingss 已提交
153 154 155
        'LogSoftmax': ['paddle.nn.functional.log_softmax', 
                    dict(axis='axis'), 
                    dict(axis=1)],
S
SunAhong1993 已提交
156
        'Softmax': ['paddle.nn.Softmax', 
S
fix  
SunAhong1993 已提交
157
                    dict(axis='axis'), 
S
SunAhong1993 已提交
158 159 160 161 162
                    dict(axis=1)],
        'Sqrt': ['paddle.sqrt'],
        'Floor': ['paddle.floor'],
        'Abs': ['paddle.abs'],
        'Erf': ['paddle.erf'],
S
SunAhong1993 已提交
163 164 165 166 167 168 169 170 171 172
    }

    def __init__(self, decoder, paddle_graph):
        super(OpSet9, self).__init__()
        self.graph = decoder.graph
        self.paddle_graph = paddle_graph
        self.input_index = 0
        self.inputs_info = dict()
        self.weights = dict()
        self.nn_name2id = dict()
S
fix  
SunAhong1993 已提交
173
        self.done_weight_list = list()
S
SunAhong1993 已提交
174 175 176 177 178 179

    @print_mapping_info
    def directly_map(self, node, *args, **kwargs):
        inputs = node.layer.input
        assert len(inputs) == 1, 'directly_map error with multi inputs'
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        onnx_attrs = node.attr_map
        if '' in onnx_attrs:
            onnx_attrs.pop('')
        if '_' in onnx_attrs:
            onnx_attrs.pop('_')
        op_info = self.directly_map_ops[node.layer_type]
        paddle_op = op_info[0]
        layer_attrs = dict()
        if len(op_info) > 1:
            attrs_name_map_dict = op_info[1]
            for onnx_attr_name, pd_attr_name in attrs_name_map_dict.items():
                if onnx_attr_name in onnx_attrs:
                    layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name]
                else:
                    layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name]
195
        if paddle_op.startswith("paddle.nn") and 'functional' not in paddle_op:
S
SunAhong1993 已提交
196 197
            op_name = paddle_op[10:].lower()
            op_name = name_generator(op_name, self.nn_name2id)
S
SunAhong1993 已提交
198
            output_name = node.name
S
SunAhong1993 已提交
199
            layer_outputs = [op_name, output_name]
200

S
SunAhong1993 已提交
201 202
            self.paddle_graph.add_layer(
                kernel=paddle_op,
S
SunAhong1993 已提交
203
                inputs={"x": input.name},
S
SunAhong1993 已提交
204 205 206 207 208
                outputs=layer_outputs,
                **layer_attrs)
        else:
            self.paddle_graph.add_layer(
                kernel=paddle_op,
S
SunAhong1993 已提交
209 210
                inputs={"x": input.name},
                outputs=[node.name],
S
SunAhong1993 已提交
211
                **layer_attrs)        
S
SunAhong1993 已提交
212
       
S
SunAhong1993 已提交
213 214 215 216 217 218
            
    @print_mapping_info
    def elementwise_map(self, node):
        op_type = self.elementwise_ops[node.layer_type]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
219 220
        inputs_dict = {'x': val_x.name, 
                       'y': val_y.name}
S
SunAhong1993 已提交
221 222 223
        self.paddle_graph.add_layer(
            op_type, 
            inputs=inputs_dict, 
S
SunAhong1993 已提交
224
            outputs=[node.name])
S
SunAhong1993 已提交
225 226 227 228 229 230 231 232 233 234 235 236

    @print_mapping_info
    def place_holder(self, node):
        shape = node.out_shapes[0]
        for i, dim_shape in enumerate(shape):
            if dim_shape == 0 and i == 0:
                shape[i] = 1
            if dim_shape == 0 and i != 0:
                assert 'shape of input is not assigned'
        self.paddle_graph.add_layer(
            kernel="paddle.to_tensor",
            inputs={},
S
SunAhong1993 已提交
237
            outputs=[node.name],
S
SunAhong1993 已提交
238 239 240 241 242 243 244 245 246 247
            data="x{}".format(self.input_index))
        self.inputs_info["x{}".format(self.input_index)] = [shape, node.dtype]
        self.input_index += 1

    @print_mapping_info
    def create_parameter(self, node, parameter=None):
        if parameter is not None:
            node = parameter
        dtype = node.dtype
        shape = node.out_shapes[0]
S
fix  
SunAhong1993 已提交
248
        if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
S
SunAhong1993 已提交
249 250 251
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
S
SunAhong1993 已提交
252
                outputs=[node.name],
S
SunAhong1993 已提交
253 254 255 256
                dtype=string(dtype),
                shape=[1],
                fill_value=node.weight)
        else:
S
SunAhong1993 已提交
257
            self.weights[node.name] = node.weight
S
SunAhong1993 已提交
258 259 260
            self.paddle_graph.add_layer(
                "self.create_parameter",
                inputs={},
S
SunAhong1993 已提交
261
                outputs=[node.name],
S
SunAhong1993 已提交
262
                shape=shape,
S
SunAhong1993 已提交
263
                attr=string(node.name),
S
SunAhong1993 已提交
264
                dtype=string(dtype),
S
fix  
SunAhong1993 已提交
265
                default_initializer="paddle.nn.initializer.Constant(value=0.0)")       
S
SunAhong1993 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281

    def _pad_if_asymmetric(self, node, pads, val_name):  # pads: SSEE
        assert len(pads) & 1 == 0
        symmetric = True
        ndims = len(pads) // 2
        for idx_dim in range(ndims):
            if pads[idx_dim] != pads[ndims + idx_dim]:
                symmetric = False
                break
        if symmetric:
            return pads[:ndims], val_name
        val_padded = self.Pad(node, op_independent=False)
        return [0] * ndims, val_padded

    def _interpolate(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
282
        inputs = {'x': val_x.name}
S
fix  
SunAhong1993 已提交
283
        attrs = dict()
S
SunAhong1993 已提交
284 285 286 287
        if node.layer_type == 'Resize':
            if len(node.layer.input) == 2:
                # opset 10
                val_scales = self.graph.get_input_node(node, idx=1, copy=True)
S
fix  
SunAhong1993 已提交
288 289 290
                # TODO(syf): paddle.nn.functional.interpolate will support the length  
                # which is the same as the rank of input.
                attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
S
SunAhong1993 已提交
291 292 293
            elif len(node.layer.input) == 3:
                # opset 11
                val_scales = self.graph.get_input_node(node, idx=2, copy=True)
S
fix  
SunAhong1993 已提交
294 295 296
                # TODO(syf): paddle.nn.functional.interpolate will support the length  
                # which is the same as the rank of input.
                attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
S
SunAhong1993 已提交
297 298 299
            elif len(node.layer.input) == 4:
                # opset 11
                val_sizes = self.graph.get_input_node(node, idx=3, copy=True)
S
SunAhong1993 已提交
300
                var_nc, var_hw = val_sizes.name + '_nc', val_sizes.name + '_hw'
S
SunAhong1993 已提交
301 302
                self.paddle_graph.add_layer(
                    'paddle.split',
S
SunAhong1993 已提交
303
                    inputs={"x": val_sizes.name},
S
SunAhong1993 已提交
304 305 306 307 308 309 310 311
                    outputs=[var_nc, var_hw],
                    num_or_sections=[2, 2],
                    axis=0)
                self.paddle_graph.add_layer(
                    "paddle.cast",
                    inputs={"x": var_hw},
                    outputs=[var_hw],
                    dtype=string('int32'))
S
SunAhong1993 已提交
312 313 314
                inputs['size'] = var_hw
                attrs = {"align_corners": False,
                         "mode": string(node.get_attr('mode', 'nearest'))}
S
SunAhong1993 已提交
315
                self.paddle_graph.add_layer(
S
docs  
SunAhong1993 已提交
316
                    kernel="paddle.nn.functional.interpolate",
S
SunAhong1993 已提交
317
                    inputs=inputs,
S
SunAhong1993 已提交
318
                    outputs=[node.name],
S
SunAhong1993 已提交
319 320
                    **attrs)
                return
S
SunAhong1993 已提交
321 322
        elif node.layer_type == 'Upsample':
            val_scales = self.graph.get_input_node(node, idx=1, copy=True)
S
fix  
SunAhong1993 已提交
323
            inputs['scale_factor'] = val_scales
S
SunAhong1993 已提交
324 325

        mode = node.get_attr('mode', 'nearest')
S
fix  
SunAhong1993 已提交
326 327 328
        attrs.update({"align_corners": False,
                      "mode": string(mode),
                      "align_mode": 1})
S
SunAhong1993 已提交
329 330 331
        val_x_shape = val_x.out_shapes[0]
        if mode == "linear" and len(val_x_shape) == 4:
            attrs["mode"] = string("bilinear")
S
fix  
SunAhong1993 已提交
332
            attrs["align_corners"] = True
S
SunAhong1993 已提交
333 334 335
        self.paddle_graph.add_layer(
            kernel="paddle.nn.functional.interpolate",
            inputs=inputs,
S
SunAhong1993 已提交
336
            outputs=[node.name],
S
SunAhong1993 已提交
337 338 339 340 341 342 343 344 345
            **attrs)
        
    @print_mapping_info
    def HardSigmoid(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        alpha = node.get_attr('alpha', 0.2)
        beta = node.get_attr('beta', 0.5)
        self.paddle_graph.add_layer(
            kernel="paddle.scale",
S
SunAhong1993 已提交
346 347
            inputs={"x": val_x.name},
            outputs=[node.name + "_val"],
S
SunAhong1993 已提交
348 349 350 351
            scale=alpha,
            bias=beta)
        self.paddle_graph.add_layer(
            kernel="paddle.clip",
S
SunAhong1993 已提交
352 353
            inputs={"x": node.name + "_val"},
            outputs=[node.name],
S
SunAhong1993 已提交
354
            min=0.0,
S
SunAhong1993 已提交
355 356 357 358 359 360 361 362 363 364 365 366 367 368
            max=1.0)  
        
    @print_mapping_info
    def Shape(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            kernel="paddle.shape",
            inputs={"input": val_x.name},
            outputs=[node.name])
        self.paddle_graph.add_layer(
                'paddle.cast',
                inputs={"x": node.name},
                outputs=[node.name],
                dtype=string('int64'))   
S
SunAhong1993 已提交
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385

    @print_mapping_info
    def RoiAlign(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)

        pooled_height = node.get_attr('output_height')
        pooled_width = node.get_attr('output_width')
        spatial_scale = node.get_attr('spatial_scale')
        sampling_ratio = node.get_attr('sampling_ratio')
        layer_attrs = {
            'pooled_height': pooled_height,
            'pooled_width': pooled_width,
            'spatial_scale': spatial_scale,
            'sampling_ratio': sampling_ratio,
        }
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
386
            'paddle.fluid.layers.roi_align',
S
SunAhong1993 已提交
387 388 389
            inputs={'input': val_x.name,
                    'rois': val_rois.name},
            outputs=[node.name],
S
SunAhong1993 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405
            **layer_attrs)
                       

    @print_mapping_info
    def MaxRoiPool(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)

        spatial_scale = node.get_attr('spatial_scale')
        pooled_height, pooled_width = node.get_attr('pooled_shape')
        layer_attrs = {
            'pooled_height': pooled_height,
            'pooled_width': pooled_width,
            'spatial_scale': spatial_scale,
        }
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
406
            'paddle.fluid.layers.roi_pool',
S
SunAhong1993 已提交
407 408 409
            inputs={'input': val_x.name,
                    'rois': val_rois.name},
            outputs=[node.name],
S
SunAhong1993 已提交
410 411 412 413 414 415
            **layer_attrs)

    @print_mapping_info
    def Pad(self, node, op_independent=True):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        pads = node.get_attr('pads')
S
SunAhong1993 已提交
416 417 418 419 420 421 422 423
        is_pads_attr = True
        if pads is None:
            val_pad = self.graph.get_input_node(node, idx=1, copy=True)
            pad_shape = val_pad.out_shapes[0]
            is_pads_attr = False
            pads = _const_weight_or_none(val_pad)
            if pads is not None:
                is_pads_attr = True
S
SunAhong1993 已提交
424 425 426 427
        mode = node.get_attr('mode', 'constant')
        value = node.get_attr('value', 0.)
        data_shape = val_x.out_shapes[0]
        output_shape = node.out_shapes[0]
S
fix  
SunAhong1993 已提交
428
        assume_pad = False
S
SunAhong1993 已提交
429 430
        layer_attrs = {}
        layer_attrs['mode'] = string(mode)
S
fix  
SunAhong1993 已提交
431 432 433
        layer_attrs['value'] = value
        if not op_independent:
            output_name = node.name + '_paded'
S
SunAhong1993 已提交
434
        else:
S
fix  
SunAhong1993 已提交
435 436 437
            output_name = node.name
        nn_op_name = name_generator("pad", self.nn_name2id)
        layer_outputs = [nn_op_name, output_name]
S
SunAhong1993 已提交
438 439
        if is_pads_attr:
            paddings = []
S
fix  
SunAhong1993 已提交
440
            if len(pads) in [2, 4, 6]:
S
SunAhong1993 已提交
441
                if data_shape:
S
fix  
SunAhong1993 已提交
442
                    assume_pad |= data_shape and 2 * (len(data_shape) - 2) == len(pads) # NCHW
S
SunAhong1993 已提交
443
                if output_shape:
S
fix  
SunAhong1993 已提交
444 445 446 447 448
                    assume_pad |= output_shape and 2 * (len(output_shape) - 2) == len(pads)  # NCHW
                if assume_pad:
                    paddle_op = 'paddle.nn.Pad{}D'.format(len(output_shape) - 2)
                    paddings = np.array(pads).reshape(
                        (2, -1)).transpose().astype("int32")
S
for pad  
SunAhong1993 已提交
449
                    paddings = np.flip(paddings, axis=0).flatten().tolist()
S
fix  
SunAhong1993 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462
                    layer_attrs['padding'] = paddings
                else:
                    if data_shape:
                        assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
                    if output_shape:
                        assume_pad |= output_shape and 2 * len(output_shape) == len(pads)  # NCHW
                    if assume_pad:
                        paddle_op = 'paddle.nn.functional.pad'
                        paddings = np.array(pads).reshape(
                            (2, -1)).transpose().astype("int32").flatten().tolist()
                        layer_attrs['pad'] = paddings
                    else:
                        raise Exception("The padding value {} is wrong!".format(pads))
S
SunAhong1993 已提交
463
            elif len(pads) == 8:
S
fix  
SunAhong1993 已提交
464 465 466 467 468
                if data_shape:
                    assume_pad |= data_shape and 2 * len(data_shape) == len(pads) # NCHW
                if output_shape:
                    assume_pad |= output_shape and 2 * len(output_shape) == len(pads)  # NCHW
                if assume_pad:
S
for pad  
SunAhong1993 已提交
469
                    paddle_op = 'paddle.nn.Pad2D'
S
fix  
SunAhong1993 已提交
470
                    paddings = np.array(pads).reshape(
S
for pad  
SunAhong1993 已提交
471 472 473 474 475 476 477 478
                        (2, -1)).transpose().astype("int32")
                    paddings = np.flip(paddings, axis=0).flatten().tolist()
                    if sum(paddings[:4]) == 0:
                        paddings = paddings[4:]
                        layer_attrs['padding'] = paddings
                    else:
                        layer_attrs["pad"] = paddings
                        paddle_op = "custom_layer:PadAllDim4WithOneInput"
S
SunAhong1993 已提交
479
            else:
S
fix  
SunAhong1993 已提交
480
                 raise Exception("The padding value {} is wrong!".format(pads))
S
SunAhong1993 已提交
481 482
            self.paddle_graph.add_layer(
                paddle_op, 
S
SunAhong1993 已提交
483
                inputs={'x': val_x.name}, 
S
fix  
SunAhong1993 已提交
484
                outputs=layer_outputs[1:] if paddle_op == 'paddle.nn.functional.pad' else layer_outputs, 
S
SunAhong1993 已提交
485
                **layer_attrs)
S
fix  
SunAhong1993 已提交
486
            if not op_independent:
S
SunAhong1993 已提交
487
                return node.name + '_paded'
S
SunAhong1993 已提交
488
        else:
S
fix  
SunAhong1993 已提交
489 490
            pads_len = val_pad.out_shapes[0][0]
            if pads_len in [2, 4, 6]:
S
SunAhong1993 已提交
491
                if data_shape:
S
fix  
SunAhong1993 已提交
492
                    assume_pad |= data_shape and 2 * (len(data_shape) - 2) == pads_len # NCHW
S
SunAhong1993 已提交
493
                if output_shape:
S
fix  
SunAhong1993 已提交
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
                    assume_pad |= output_shape and 2 * (len(output_shape) - 2) == pads_len  # NCHW 
                if assume_pad:
                    if pads_len == 2:
                        data_format = "NCL"
                    elif pads_len == 4:
                        data_format = "NCHW"
                    else:
                        data_format = "NCDHW"
                    self.paddle_graph.add_layer(
                        "custom_layer:PadWithTwoInput", 
                        inputs={'x': val_x.name, 'pad': val_pad.name}, 
                        outputs=layer_outputs,
                        value=value,
                        mode=string(mode),
                        data_format=string(data_format))
                else:
                    if data_shape:
                        assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
                    if output_shape:
                        assume_pad |= output_shape and 2 * len(output_shape) == pads_len  # NCHW
                    if assume_pad:
                        if pads_len == 4:
                            self.paddle_graph.add_layer(
                                "custom_layer:PadAllDim2", 
                                inputs={'x': val_x.name, 'pad': val_pad.name}, 
                                outputs=layer_outputs, 
                                value=value,
                                mode=string(mode))
                        else:
                            raise Exception("The padding value is wrong!")
            elif pads_len == 8:
                if data_shape:
                    assume_pad |= data_shape and 2 * len(data_shape) == pads_len # NCHW
                if output_shape:
                    assume_pad |= output_shape and 2 * len(output_shape) == pads_len  # NCHW
                if assume_pad:
                    self.paddle_graph.add_layer(
                        "custom_layer:PadAllDim4", 
                        inputs={'x': val_x.name, 'pad': val_pad.name}, 
                        outputs=layer_outputs, 
                        value=value,
                        mode=string(mode))
            else:
                print(pads_len)
                raise Exception("The padding value is wrong!")   
S
SunAhong1993 已提交
539 540
            if not op_independent:
                return node.name + '_paded'
S
SunAhong1993 已提交
541 542 543 544 545 546 547

    @print_mapping_info
    def Unsqueeze(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axes = node.get_attr('axes')
        layer_attrs = {'axis': axes}
        if len(val_x.out_shapes[0]) == 0:
S
SunAhong1993 已提交
548
            if node.name:
S
SunAhong1993 已提交
549 550
                self.paddle_graph.add_layer(
                    'paddle.reshape',
S
SunAhong1993 已提交
551 552
                    inputs={"x": val_x.name},
                    outputs=[node.name],
S
SunAhong1993 已提交
553 554
                    shape=[1])
        else:
S
fix  
SunAhong1993 已提交
555 556
            self.paddle_graph.add_layer(
                'paddle.unsqueeze', 
S
SunAhong1993 已提交
557 558
                inputs={"x": val_x.name}, 
                outputs=[node.name],
S
fix  
SunAhong1993 已提交
559
                **layer_attrs)
S
SunAhong1993 已提交
560 561 562 563 564 565 566 567 568

    @print_mapping_info
    def Shrink(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        bias = node.get_attr('bias')
        lambd = node.get_attr('lambd')
        assert bias == 0.0, 'not support bias!=0'
        self.paddle_graph.add_layer(
            'paddle.nn.functional.hardshrink', 
S
SunAhong1993 已提交
569 570
            inputs={"x": val_x.name}, 
            outputs=[node.name], 
S
SunAhong1993 已提交
571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
            threshold=lambd)

    @print_mapping_info
    def Constant(self, node):
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        value = node.get_attr('value')
        dtype = np.dtype(value.dtype)
        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'

        shape = node.get_attr('shape', None)

        if shape is None:
            shape = val_output.out_shapes[0]
        if shape is None:
            shape = list(value.shape)
            _logger.warning('in (Constant -> %s): '
                            'attribute "shape" of %s not inferred, '
                            'using value as 1-D tensor may lead to fails',
S
SunAhong1993 已提交
592
                            val_output.name, val_output.name)
S
SunAhong1993 已提交
593 594 595 596 597 598
        if len(value) == 1:
            value = value.tolist()
            value = value[0]
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
S
SunAhong1993 已提交
599
                outputs=[node.name],
S
SunAhong1993 已提交
600 601 602 603 604
                dtype=string(dtype),
                shape=[1],
                fill_value=value)
        else:
            value = np.reshape(value, shape)
S
SunAhong1993 已提交
605
            self.weights[node.name] = value
S
SunAhong1993 已提交
606 607 608
            self.paddle_graph.add_layer(
                "self.create_parameter",
                inputs={},
S
SunAhong1993 已提交
609
                outputs=[node.name],
S
SunAhong1993 已提交
610
                shape=shape,
S
SunAhong1993 已提交
611
                attr=string(node.name),
S
SunAhong1993 已提交
612 613 614 615 616 617 618 619 620 621 622 623 624 625
                dtype=string(dtype),
                default_initializer="paddle.nn.initializer.Constant(value=0.0)")

    @print_mapping_info
    def Resize(self, node):
        self._interpolate(node)

    @print_mapping_info
    def Upsample(self, node):
        self._interpolate(node)

    @print_mapping_info
    def InstanceNormalization(self, node):
        op_name = name_generator("instanse_norm", self.nn_name2id)
S
SunAhong1993 已提交
626
        output_name = node.name
S
SunAhong1993 已提交
627 628 629 630 631
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scale = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = self.graph.get_input_node(node, idx=2, copy=True)
        epsilon = node.get_attr('epsilon', 1e-5)
632 633
        self.weights[op_name+'.scale'] = self.weights[val_scale.name]
        self.weights[op_name+'.bias'] = self.weights[val_b.name]
S
SunAhong1993 已提交
634 635 636 637 638
        layer_attrs = {
            'num_features': node.out_shapes[0][1],
            'epsilon': epsilon,
        }
        dim = len(val_x.out_shapes[0])
S
SunAhong1993 已提交
639
        if dim == 3:
S
SunAhong1993 已提交
640 641 642 643 644 645 646 647 648
            paddle_op = "paddle.nn.InstanceNorm1D"
        elif dim == 4:
            paddle_op = "paddle.nn.InstanceNorm2D"
        elif dim == 5:
            paddle_op = "paddle.nn.InstanceNorm3D"
        else:
            raise Exception("The paddle only support 2D, 3D, 4D or 5D input in InstanceNormalization.")
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
649
            inputs={"x": val_x.name}, 
S
SunAhong1993 已提交
650 651 652 653 654 655 656 657
            outputs=layer_outputs, 
            **layer_attrs)

    @print_mapping_info
    def Expand(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
        val_x_dtype = val_x.dtype
S
SunAhong1993 已提交
658
        name_ones = node.name + '_ones'
S
SunAhong1993 已提交
659
        attr_ones = {
S
SunAhong1993 已提交
660
            'shape': val_shape.name,
S
SunAhong1993 已提交
661 662 663 664 665 666 667 668 669
            'dtype': string(val_x_dtype),
            'fill_value': 1
        }
        self.paddle_graph.add_layer(
            'paddle.full',
            inputs={},
            outputs=[name_ones],
            **attr_ones)
        inputs_dict = {'x': name_ones, 
S
SunAhong1993 已提交
670
                       'y': val_x.name}
S
SunAhong1993 已提交
671 672 673
        self.paddle_graph.add_layer(
            'paddle.multiply',
            inputs=inputs_dict,
S
SunAhong1993 已提交
674
            outputs=[node.name])
S
SunAhong1993 已提交
675 676 677 678 679 680 681 682 683 684 685 686 687

    @print_mapping_info
    def Gather(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        indices = self.graph.get_input_node(node, idx=1, copy=True)
        indices_shape = indices.out_shapes[0]
        axis = node.get_attr('axis', 0)
        #assert len(
        #    indices_shape) <= 2, "Gather op don't support dim of indice >2 "
        if axis == 0 and len(indices_shape) <= 1:
            if len(val_x.out_shapes[0]) <= 1:
                self.paddle_graph.add_layer(
                    'paddle.gather',
S
SunAhong1993 已提交
688 689 690
                    inputs={'x': val_x.name,
                            'index': indices.name},
                    outputs=[node.name])
S
SunAhong1993 已提交
691 692
            elif len(val_x.out_shapes[0]) > 1:
                if len(indices_shape) == 0:
S
SunAhong1993 已提交
693
                    gather_ = node.name + '_1'
S
SunAhong1993 已提交
694 695
                    self.paddle_graph.add_layer(
                        'paddle.gather',
S
SunAhong1993 已提交
696 697
                        inputs={'x': val_x.name,
                                'index': indices.name},
S
SunAhong1993 已提交
698 699 700 701
                        outputs=[gather_])
                    self.paddle_graph.add_layer(
                        'paddle.squeeze',
                        inputs={'x': gather_},
S
SunAhong1993 已提交
702
                        outputs=[node.name],
S
SunAhong1993 已提交
703 704 705 706
                        axis=[0])
                else:
                    self.paddle_graph.add_layer(
                        'paddle.gather',
S
SunAhong1993 已提交
707 708 709
                        inputs={'x': val_x.name,
                                'index': indices.name},
                        outputs=[node.name])
S
SunAhong1993 已提交
710 711 712
        elif axis > 0 and len(indices_shape) <= 1:
            perm = list(range(len(val_x.out_shapes[0])))
            perm = [axis] + perm[:axis] + perm[axis + 1:]
S
SunAhong1993 已提交
713
            name_trans = val_x.name + '_trans'
S
SunAhong1993 已提交
714 715
            self.paddle_graph.add_layer(
                'paddle.transpose',
S
SunAhong1993 已提交
716
                inputs={"x": val_x.name},
S
SunAhong1993 已提交
717 718 719 720 721
                outputs=[name_trans],
                perm=perm)
            self.paddle_graph.add_layer(
                'paddle.gather',
                inputs={'x': name_trans,
S
SunAhong1993 已提交
722 723
                        'index': indices.name},
                outputs=[node.name])
S
SunAhong1993 已提交
724 725
            self.paddle_graph.add_layer(
                'paddle.transpose', 
S
SunAhong1993 已提交
726 727
                inputs={"x": node.name}, 
                outputs=[node.name], 
S
SunAhong1993 已提交
728 729 730 731
                perm=perm)
            if len(indices_shape) < 1:
                self.paddle_graph.add_layer(
                    'paddle.squeeze',
S
SunAhong1993 已提交
732 733
                    inputs={'x': node.name},
                    outputs=[node.name],
S
SunAhong1993 已提交
734 735 736 737
                    axis=[axis])
        elif axis == 0 and len(indices_shape) > 1:
            if val_x.out_shapes[0] is not None and isinstance(
                    val_x, ONNXGraphDataNode):
S
SunAhong1993 已提交
738
                indices_cast = indices.name + '_cast'
S
SunAhong1993 已提交
739 740
                self.paddle_graph.add_layer(
                    'paddle.cast',
S
SunAhong1993 已提交
741
                    inputs={"x": indices.name},
S
SunAhong1993 已提交
742
                    outputs=[indices_cast],
S
SunAhong1993 已提交
743 744
                    dtype=string('int64'))
                op_name = name_generator("embedding", self.nn_name2id)
S
SunAhong1993 已提交
745
                output_name = node.name
S
SunAhong1993 已提交
746
                layer_outputs = [op_name, output_name]
C
Channingss 已提交
747
                self.weights[op_name + '.weight'] = _const_weight_or_none(val_x)
S
SunAhong1993 已提交
748 749 750 751
                self.paddle_graph.add_layer(
                    'paddle.nn.Embedding',
                    inputs={"x": indices_cast},
                    outputs=layer_outputs,
S
fix  
SunAhong1993 已提交
752 753
                    num_embeddings=val_x.out_shapes[0][0],
                    embedding_dim=val_x.out_shapes[0][1])
S
SunAhong1993 已提交
754 755 756
            else:
                from functools import reduce
                reshape_shape = reduce(lambda x, y: x * y, indices_shape)
S
SunAhong1993 已提交
757
                indices_reshape = indices.name + '_shape'
S
SunAhong1993 已提交
758 759
                self.paddle_graph.add_layer(
                    'paddle.reshape',
S
SunAhong1993 已提交
760
                    inputs={"x": indices.name},
S
SunAhong1993 已提交
761 762 763 764 765 766
                    outputs=[indices_reshape],
                    shape=[reshape_shape, ])

                perm = list(range(len(val_x.out_shapes[0])))
                self.paddle_graph.add_layer(
                    'paddle.gather',
S
SunAhong1993 已提交
767
                    inputs={'x': val_x.name,
S
SunAhong1993 已提交
768
                            'index': indices_reshape},
S
SunAhong1993 已提交
769
                    outputs=[node.name])
S
SunAhong1993 已提交
770 771 772 773 774 775 776 777
                val_x_shape = val_x.out_shapes[0]
                reshaped_shape = []
                for i in perm:
                    reshaped_shape.append(indices_shape[i])
                for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
                    reshaped_shape.append(i)
                self.paddle_graph.add_layer(
                    'paddle.reshape',
S
SunAhong1993 已提交
778 779
                    inputs={"x": node.name},
                    outputs=[node.name],
S
SunAhong1993 已提交
780 781 782 783
                    shape=reshaped_shape)
        elif axis > 0 and len(indices_shape) > 1:
            from functools import reduce
            reshape_shape = reduce(lambda x, y: x * y, indices_shape)
S
SunAhong1993 已提交
784
            indices_reshape = indices.name + '_shape'
S
SunAhong1993 已提交
785 786
            self.paddle_graph.add_layer(
                'paddle.reshape',
S
SunAhong1993 已提交
787
                inputs={"x": indices.name},
S
SunAhong1993 已提交
788 789 790 791 792
                outputs=[indices_reshape],
                shape=[reshape_shape, ])

            perm = list(range(len(val_x.out_shapes[0])))
            perm = [axis] + perm[:axis] + perm[axis + 1:]
S
SunAhong1993 已提交
793
            name_trans = val_x.name + '_transpose'
S
SunAhong1993 已提交
794 795
            self.paddle_graph.add_layer(
                'paddle.transpose',
S
SunAhong1993 已提交
796
                inputs={"x": val_x.name},
S
SunAhong1993 已提交
797 798 799 800 801 802
                outputs=[name_trans],
                perm=perm)
            self.paddle_graph.add_layer(
                'paddle.gather',
                inputs={'x': name_trans,
                        'index': indices_reshape},
S
SunAhong1993 已提交
803 804
                outputs=[node.name])
            input_transpose = node.name + '_transpose'
S
SunAhong1993 已提交
805 806
            self.paddle_graph.add_layer(
                'paddle.transpose',
S
SunAhong1993 已提交
807
                inputs={"x": node.name},
S
SunAhong1993 已提交
808 809 810 811 812 813 814 815 816 817 818
                outputs=[input_transpose],
                perm=perm)
            val_x_shape = val_x.out_shapes[0]
            reshaped_shape = []
            for i in perm:
                reshaped_shape.append(indices_shape[i])
            for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
                reshaped_shape.append(i)
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": input_transpose},
S
SunAhong1993 已提交
819
                outputs=[node.name],
S
SunAhong1993 已提交
820 821 822 823 824 825 826 827 828 829
                shape=reshaped_shape)

    @print_mapping_info
    def ScatterND(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        indices = self.graph.get_input_node(node, idx=1, copy=True)
        updates = self.graph.get_input_node(node, idx=2, copy=True)
        if len(indices.out_shapes[0]) == 1:
            self.paddle_graph.add_layer(
                'paddle.scatter',
S
SunAhong1993 已提交
830 831 832 833
                inputs={'x': val_x.name,
                        'index': indices.name,
                        'updates': updates.name},
                outputs=[node.name])
S
SunAhong1993 已提交
834
        else:
S
SunAhong1993 已提交
835
            input_inner_indices = node.name + '_input_inner_indices'
S
SunAhong1993 已提交
836 837 838
            shape = val_x.out_shapes[0]
            self.paddle_graph.add_layer(
                'paddle.reshape',
S
SunAhong1993 已提交
839 840
                inputs={"x": indices.name},
                outputs=[indices.name],
S
SunAhong1993 已提交
841 842
                shape=indices.out_shapes[0])

S
SunAhong1993 已提交
843
            zeros_like_val_x = val_x.name + '_zeros'
S
SunAhong1993 已提交
844 845
            self.paddle_graph.add_layer(
                'paddle.zeros_like',
S
SunAhong1993 已提交
846
                inputs={"x": val_x.name},
S
SunAhong1993 已提交
847 848 849 850 851
                outputs=[zeros_like_val_x])
            self.paddle_graph.add_layer(
                'paddle.scatter_nd_add',
                inputs={
                    'x': zeros_like_val_x,
S
SunAhong1993 已提交
852 853
                    'index': indices.name,
                    'updates': updates.name
S
SunAhong1993 已提交
854 855
                },
                outputs=[input_inner_indices])
S
SunAhong1993 已提交
856 857
            indices_mask = node.name + '_indices_mask'
            constant_minus_one = node.name + '_constant_minus_one'
S
SunAhong1993 已提交
858 859 860
            # full_like support create tensor shape like input tensor
            self.paddle_graph.add_layer(
                'paddle.full_like',
S
SunAhong1993 已提交
861
                inputs={"x": updates.name},
S
SunAhong1993 已提交
862 863 864 865 866 867 868
                outputs=[constant_minus_one],
                dtype=string(updates.dtype),
                fill_value=-1)
            self.paddle_graph.add_layer(
                'paddle.scatter_nd_add',
                inputs={
                    'x': zeros_like_val_x,
S
SunAhong1993 已提交
869
                    'index': indices.name,
S
SunAhong1993 已提交
870 871 872
                    'updates': constant_minus_one
                },
                outputs=[indices_mask])
S
SunAhong1993 已提交
873
            constant_one = node.name + '_constant_1'
S
SunAhong1993 已提交
874 875 876
            # full_like support create tensor shape like input tensor
            self.paddle_graph.add_layer(
                'paddle.full_like',
S
SunAhong1993 已提交
877
                inputs={"x": val_x.name},
S
SunAhong1993 已提交
878 879 880
                outputs=[constant_one],
                dtype=string(val_x.dtype),
                fill_value=1)
S
SunAhong1993 已提交
881
            input_out_indices_mask = node.name + '_input_out_indices_mask'
S
SunAhong1993 已提交
882 883 884 885 886 887
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs={"x": indices_mask,
                        "y": constant_one},
                outputs=[input_out_indices_mask])

S
SunAhong1993 已提交
888
            input_out_indices = node.name + '_input_out_indices'
S
SunAhong1993 已提交
889 890
            self.paddle_graph.add_layer(
                "paddle.multiply",
S
SunAhong1993 已提交
891
                inputs={"x": val_x.name,
S
SunAhong1993 已提交
892 893 894 895 896 897 898
                        "y": input_out_indices_mask},
                outputs=[input_out_indices])

            self.paddle_graph.add_layer(
                "paddle.add",
                inputs={"x": input_inner_indices,
                        "y": input_out_indices},
S
SunAhong1993 已提交
899
                outputs=[node.name])
S
SunAhong1993 已提交
900 901 902 903 904 905 906

    @print_mapping_info
    def Range(self, node):
        val_start = self.graph.get_input_node(node, idx=0, copy=True)
        val_limit = self.graph.get_input_node(node, idx=1, copy=True)
        val_delta = self.graph.get_input_node(node, idx=2, copy=True)
        dtype = val_start.dtype
S
SunAhong1993 已提交
907 908 909
        inputs = {'start': val_start.name, 
                  'end': val_limit.name, 
                  'step': val_delta.name}
S
SunAhong1993 已提交
910 911 912
        self.paddle_graph.add_layer(
            'paddle.arange',
            inputs=inputs,
S
SunAhong1993 已提交
913
            outputs=[node.name],
S
SunAhong1993 已提交
914 915 916 917 918 919 920 921 922 923 924
            dtype=string(dtype))

    @print_mapping_info
    def Slice(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        starts, ends, axes, steps = None, None, None, None
        layer_attrs = {}
        if len(node.inputs) > 1:
            starts = self.graph.get_input_node(node, idx=1, copy=True)
            ends = self.graph.get_input_node(node, idx=2, copy=True)
            starts_value = _const_weight_or_none(starts)
S
fix  
SunAhong1993 已提交
925 926
            if starts_value is not None:
                starts_value = starts_value.tolist()
S
SunAhong1993 已提交
927
            ends_value = _const_weight_or_none(ends)
S
fix  
SunAhong1993 已提交
928 929 930 931 932
            if ends_value is not None:
                ends_value = ends_value.tolist()
            if len(node.inputs) > 2:
                s_len = len(val_x.out_shapes[0])
                axes = list(range(s_len))
S
SunAhong1993 已提交
933
            if len(node.inputs) > 3:
S
fix  
SunAhong1993 已提交
934 935
                axes_node = self.graph.get_input_node(node, idx=3, copy=True)
                axes = _const_weight_or_none(axes_node, necessary=True).tolist()
S
SunAhong1993 已提交
936 937
            if len(node.inputs) > 4:
                steps = self.graph.get_input_node(node, idx=4, copy=True)
S
fix  
SunAhong1993 已提交
938 939
                steps = _const_weight_or_none(steps).tolist()
            
S
SunAhong1993 已提交
940 941
            layer_attrs = {
                "axes": axes,
S
SunAhong1993 已提交
942 943
                "starts": starts.name,
                "ends": ends.name
S
SunAhong1993 已提交
944
            }
S
SunAhong1993 已提交
945
            if starts_value is not None and ends_value is not None and axes is not None:
S
SunAhong1993 已提交
946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961
                starts_value = starts_value.copy()
                ends_value = ends_value.copy()
                for idx in range(len(ends_value)):
                    if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
                        starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
                        ends_value[idx] = val_x.out_shapes[0][axes[idx]]
                        starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
                    elif ends_value[idx] > 2**31 - 1:
                        ends_value[idx] = 2**31 - 1
                layer_attrs = {
                    "axes": axes,
                    "starts": starts_value,
                    "ends": ends_value
                }
            else:
                if starts.dtype != 'int32':
S
SunAhong1993 已提交
962
                    starts_cast = starts.name + '_cast'
S
SunAhong1993 已提交
963 964
                    self.paddle_graph.add_layer(
                        'paddle.cast',
S
SunAhong1993 已提交
965
                        inputs={"x": starts.name},
S
SunAhong1993 已提交
966 967 968 969
                        outputs=[starts_cast],
                        dtype=string('int32'))
                    layer_attrs['starts'] = starts_cast
                if ends.dtype != 'int32':
S
SunAhong1993 已提交
970
                    ends_cast = ends.name + '_cast'
S
SunAhong1993 已提交
971 972
                else:
                    ends_cast = ends.name
S
SunAhong1993 已提交
973 974
                self.paddle_graph.add_layer(
                    'paddle.cast',
S
SunAhong1993 已提交
975
                    inputs={"x": ends.name},
S
SunAhong1993 已提交
976 977 978 979 980 981 982 983 984 985 986 987
                    outputs=[ends_cast],
                    dtype=string('int32'))
                layer_attrs['ends'] = ends_cast
        else:
            starts = node.get_attr('starts')
            ends = node.get_attr('ends')
            axes = node.get_attr('axes')
            for idx in range(len(ends)):
                if ends[idx] > 2**31 - 1:
                    ends[idx] = 2**31 - 1
            layer_attrs = {"axes": axes, "starts": starts, "ends": ends}

S
fix  
SunAhong1993 已提交
988

S
SunAhong1993 已提交
989 990 991 992
        if steps is not None:
            layer_attrs['strides'] = steps
            self.paddle_graph.add_layer(
                'paddle.strided_slice', 
S
SunAhong1993 已提交
993 994
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
S
SunAhong1993 已提交
995 996 997 998
                **layer_attrs)
        else:
            self.paddle_graph.add_layer(
                'paddle.slice', 
S
SunAhong1993 已提交
999 1000
                inputs={"input": val_x.name}, 
                outputs=[node.name],  
S
SunAhong1993 已提交
1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
                **layer_attrs)

    @print_mapping_info
    def ConstantOfShape(self, node):
        val_shape = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_node(node.layer.output[0], copy=True)

        value = node.get_attr('value')
        dtype = value.dtype
        value = value.tolist()
        assert len(value) == 1, ('given value not Scalar, shape of value > 1, '
                                 'this is not supported')
        if len(value) == 1:
            value = value[0]
            layer_attrs = {
S
SunAhong1993 已提交
1016
                'shape': val_shape.name,
S
SunAhong1993 已提交
1017 1018 1019 1020 1021 1022
                'dtype': string(dtype),
                'fill_value': value
            }
            self.paddle_graph.add_layer(
                "paddle.full", 
                inputs={}, 
S
SunAhong1993 已提交
1023
                outputs=[node.name],
S
SunAhong1993 已提交
1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037
                **layer_attrs)

    @print_mapping_info
    def Clip(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
        max_value, min_value = None, None
        if len(node.inputs) == 1:
            max_value = node.get_attr('max')
            min_value = node.get_attr('min')
            layer_attrs = {
                'max': max_value,
                'min': min_value,
            }
S
fix  
SunAhong1993 已提交
1038
            
S
SunAhong1993 已提交
1039 1040
            self.paddle_graph.add_layer(
                'paddle.clip', 
S
SunAhong1993 已提交
1041 1042
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
S
SunAhong1993 已提交
1043 1044
                **layer_attrs)
        else:
S
fix  
SunAhong1993 已提交
1045 1046
            min_ipt = self.graph.get_input_node(node, idx=1, copy=True)
            max_ipt = self.graph.get_input_node(node, idx=2, copy=True)
S
SunAhong1993 已提交
1047
            min_value = _const_weight_or_none(min_ipt)
S
fix  
SunAhong1993 已提交
1048
            max_value = _const_weight_or_none(max_ipt)
S
SunAhong1993 已提交
1049 1050 1051 1052 1053 1054 1055 1056
            if max_value.shape == (1, ):
                max_value = max_value[0]
            if min_value.shape == (1, ):
                min_value = min_value[0]
        if max_value is not None and min_value is not None:
            layer_attrs = {'max': max_value, 'min': min_value}
            self.paddle_graph.add_layer(
                'paddle.clip', 
S
SunAhong1993 已提交
1057 1058
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
S
SunAhong1993 已提交
1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074
                **layer_attrs)
        else:
            raise

    @print_mapping_info
    def Split(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        paddle_op = 'split'
        split = node.get_attr('split')
        axis = node.get_attr('axis', 0)
        layer_attrs = {
            'num_or_sections': split,
            'axis': axis,
        }
        outputs_list = list()
        if isinstance(split, list) or isinstance(split, tuple):
S
fix  
SunAhong1993 已提交
1075 1076
            for i in range(len(split)):
                outputs_list.append("{}_p{}".format(node.layer_name, i))
S
SunAhong1993 已提交
1077
        else:
S
SunAhong1993 已提交
1078
            outputs_list.append(node.name)
S
SunAhong1993 已提交
1079 1080
        self.paddle_graph.add_layer(
            'paddle.split', 
S
SunAhong1993 已提交
1081
            inputs={"x": val_x.name}, 
S
SunAhong1993 已提交
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
            outputs=outputs_list, 
            **layer_attrs)

    @print_mapping_info
    def Reshape(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
        val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
        shape_value = _const_weight_or_none(val_shape)
        shape_dims = len(val_shape.out_shapes[0])

        if shape_value is not None:
            self.paddle_graph.add_layer(
                'paddle.reshape',
S
SunAhong1993 已提交
1096 1097
                inputs={'x': val_x.name},
                outputs=[node.name],
S
SunAhong1993 已提交
1098 1099 1100 1101 1102
                shape=shape_value.tolist())
        elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
                0]):
            self.paddle_graph.add_layer(
                'paddle.reshape',
S
SunAhong1993 已提交
1103 1104
                inputs={'x': val_x.name},
                outputs=[node.name],
S
SunAhong1993 已提交
1105 1106 1107 1108 1109 1110
                shape=node.out_shapes[0])
        else:
            # shape may be [], come form Gather by scalar indices
            if len(val_shape.out_shapes[0]) > 0:
                self.paddle_graph.add_layer(
                    'paddle.reshape',
S
SunAhong1993 已提交
1111 1112
                    inputs={'x': val_shape.name},
                    outputs=[val_shape.name],
S
SunAhong1993 已提交
1113
                    shape=val_shape.out_shapes[0])
S
fix  
SunAhong1993 已提交
1114 1115 1116 1117 1118 1119
            if val_shape.dtype != "int32":
                self.paddle_graph.add_layer(
                    'paddle.cast',
                    inputs={'x': val_shape.name},
                    outputs=[val_shape.name],
                    dtype=string("int32"))
S
SunAhong1993 已提交
1120 1121
            self.paddle_graph.add_layer(
                'paddle.reshape',
S
SunAhong1993 已提交
1122 1123
                inputs={'x': val_x.name,
                        'shape': val_shape.name},
S
SunAhong1993 已提交
1124
                outputs=[node.name])
S
SunAhong1993 已提交
1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139

    @print_mapping_info
    def Cast(self, node):
        val_input = self.graph.get_input_node(node, idx=0, copy=True)
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        dtype = node.get_attr('to')
        if not isinstance(dtype, np.dtype):
            dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]

        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'dtype of to unmatches output'
        self.paddle_graph.add_layer(
            'paddle.cast', 
S
SunAhong1993 已提交
1140 1141
            inputs={'x': val_input.name}, 
            outputs=[node.name], 
S
SunAhong1993 已提交
1142 1143 1144 1145 1146 1147
            dtype=string(dtype))

    @print_mapping_info
    def Not(self, node):
        val_input = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer('paddle.logical_not', 
S
SunAhong1993 已提交
1148 1149
                                    inputs={'x': val_input.name}, 
                                    outputs=[node.name])
S
SunAhong1993 已提交
1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172

    @print_mapping_info
    def AveragePool(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)

        auto_pad = node.get_attr('auto_pad', 'NOTSET')
        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))
        pads = node.get_attr('pads', [0] * (poolnd * 2))

        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
            input_shape = val_x.out_shapes[0]
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w

S
SunAhong1993 已提交
1173 1174 1175 1176 1177
        op_name = name_generator("pool", self.nn_name2id)
        output_name = node.name
        layer_outputs = [op_name, output_name]
        paddle_op = 'paddle.nn.AvgPool{}D'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
S
SunAhong1993 已提交
1178
        layer_attrs = {
S
SunAhong1993 已提交
1179 1180 1181
            "kernel_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
S
SunAhong1993 已提交
1182 1183 1184 1185 1186
            "ceil_mode": ceil_mode,
            "exclusive": 'True',
        }
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
1187 1188
            inputs={'x': val_x.name}, 
            outputs=layer_outputs, 
S
SunAhong1993 已提交
1189 1190 1191 1192 1193 1194 1195 1196
            **layer_attrs)

    @print_mapping_info
    def Concat(self, node):
        inputs_list = []
        dtypes = set()
        for i in range(len(node.layer.input)):
            ipt = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1197
            inputs_list.append(ipt.name)
S
SunAhong1993 已提交
1198 1199 1200 1201 1202 1203 1204
            dtypes.add(ipt.dtype)
        if len(dtypes) > 1:
            assert 'Unspported situation happened, please create issue on https://github.com/PaddlePaddle/X2Paddle/issues.'
        axis = node.get_attr('axis')
        self.paddle_graph.add_layer(
            'paddle.concat', 
            inputs={"x": inputs_list}, 
S
SunAhong1993 已提交
1205
            outputs=[node.name], 
S
SunAhong1993 已提交
1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223
            axis=axis)

    @print_mapping_info
    def Flatten(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        output_shape = node.out_shapes[0]
        axis = node.get_attr('axis', 1)
        shape_list = [1, 1]
        if axis == 0:
            for s in output_shape:
                shape_list[1] *= s
        else:
            for s in output_shape[:axis]:
                shape_list[0] *= s
            for s in output_shape[axis:]:
                shape_list[1] *= s
        self.paddle_graph.add_layer(
            'paddle.reshape', 
S
SunAhong1993 已提交
1224 1225
            inputs={"x": val_x.name}, 
            outputs=[node.name],
S
SunAhong1993 已提交
1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
            shape=shape_list)

    @print_mapping_info
    def Gemm(self, node):
        val_a = self.graph.get_input_node(node, idx=0, copy=True)
        val_b = self.graph.get_input_node(node, idx=1, copy=True)
        val_c = self.graph.get_input_node(node, idx=2, copy=True)

        alpha = node.get_attr('alpha', 1.)  # optional
        beta = node.get_attr('beta', 1.)  # optional
        trans_a = bool(node.get_attr('transA', 0))  # optional
        trans_b = bool(node.get_attr('transB', 0))  # optional
S
SunAhong1993 已提交
1238 1239 1240
        val_mm = node.name + '_mm'
        matmul_inputs = {"x": val_a.name, 
                         "y": val_b.name}
S
SunAhong1993 已提交
1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258
        attr_matmul = {
            "transpose_x": trans_a,
            "transpose_y": trans_b,
        }
        self.paddle_graph.add_layer(
            'paddle.matmul',
            inputs=matmul_inputs,
            outputs=[val_mm],
            **attr_matmul)
        self.paddle_graph.add_layer(
            "paddle.scale", 
            inputs={"x": val_mm}, 
            outputs=[val_mm],
            scale=alpha)

        if beta != 0:
            if beta == 1.:
                add_inputs = {"x": val_mm, 
S
SunAhong1993 已提交
1259
                              "y": val_c.name}
S
SunAhong1993 已提交
1260 1261 1262
                self.paddle_graph.add_layer(
                    "paddle.add",
                    inputs=add_inputs,
S
SunAhong1993 已提交
1263
                    outputs=[node.name])
S
SunAhong1993 已提交
1264
            else:
S
SunAhong1993 已提交
1265
                var_beta = node.name + '_beta'
S
SunAhong1993 已提交
1266 1267
                self.paddle_graph.add_layer(
                    "paddle.scale",
S
SunAhong1993 已提交
1268
                    inputs={"x": val_c.name},
S
SunAhong1993 已提交
1269 1270 1271 1272
                    outputs=[var_beta],
                    scale=beta)
                add_inputs = {"x": val_mm, "y": var_beta}
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1273
                    "paddle.add",
S
SunAhong1993 已提交
1274
                    inputs=add_inputs,
S
SunAhong1993 已提交
1275
                    outputs=[node.name])
S
SunAhong1993 已提交
1276 1277 1278 1279 1280

    @print_mapping_info
    def Sum(self, node):
        val_inps = node.layer.input
        inputs_dict = {
S
SunAhong1993 已提交
1281 1282 1283 1284
            "x": self.graph.get_input_node(
                node, idx=0, copy=True).name,
            "y": self.graph.get_input_node(
                node, idx=1, copy=True).name,
S
SunAhong1993 已提交
1285 1286 1287
        }
        self.paddle_graph.add_layer("paddle.add", 
                                    inputs=inputs_dict, 
S
SunAhong1993 已提交
1288
                                    outputs=[node.name])
S
SunAhong1993 已提交
1289 1290 1291 1292

        for idx, ipt in enumerate(val_inps[2:]):
            y = self.graph.get_input_node(node, idx=idx, copy=True)
            inputs_dict = {
S
SunAhong1993 已提交
1293 1294
                "x": node.name,
                "y": y.name,
S
SunAhong1993 已提交
1295 1296 1297 1298
            }
            self.paddle_graph.add_layer(
                "paddle.add", 
                inputs=inputs_dict, 
S
SunAhong1993 已提交
1299
                outputs=[node.name])
S
SunAhong1993 已提交
1300 1301 1302 1303 1304 1305 1306

    @print_mapping_info
    def MatMul(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
        x_shape = val_x.out_shapes[0]
        y_shape = val_y.out_shapes[0]
S
SunAhong1993 已提交
1307 1308
        inputs_dict = {"x": val_x.name, 
                       "y": val_y.name}
S
SunAhong1993 已提交
1309
        if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
S
SunAhong1993 已提交
1310
            y_squeeze = val_y.name + '_squeeze'
S
SunAhong1993 已提交
1311 1312
            self.paddle_graph.add_layer(
                "paddle.squeeze",
S
SunAhong1993 已提交
1313
                inputs={"x": val_y.name},
S
SunAhong1993 已提交
1314 1315 1316 1317 1318 1319
                outputs=[y_squeeze],
                axis=[0])
            inputs_dict['y'] = y_squeeze
            self.paddle_graph.add_layer(
                "paddle.matmul", 
                inputs=inputs_dict, 
S
SunAhong1993 已提交
1320
                outputs=[node.name])
S
SunAhong1993 已提交
1321 1322 1323 1324
        else:
            self.paddle_graph.add_layer(
                "paddle.matmul", 
                inputs=inputs_dict, 
S
SunAhong1993 已提交
1325
                outputs=[node.name])
S
SunAhong1993 已提交
1326 1327 1328 1329

    @print_mapping_info
    def BatchNormalization(self, node):
        op_name = name_generator("batchnorm", self.nn_name2id)
S
SunAhong1993 已提交
1330
        output_name = node.name
S
SunAhong1993 已提交
1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scale = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = self.graph.get_input_node(node, idx=2, copy=True)
        val_mean = self.graph.get_input_node(node, idx=3, copy=True)
        val_var = self.graph.get_input_node(node, idx=4, copy=True)

        momentum = node.get_attr('momentum', .9)
        epsilon = node.get_attr('epsilon', 1e-5)
        c = val_x.out_shapes[0][1]

C
Channingss 已提交
1342 1343 1344 1345 1346
        _rename_or_remove_weight(self.weights, val_scale.name, op_name+'.weight')
        _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
        _rename_or_remove_weight(self.weights, val_var.name, op_name+'._variance')
        _rename_or_remove_weight(self.weights, val_mean.name, op_name+'._mean')

S
SunAhong1993 已提交
1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357
        # Attribute: spatial is used in BatchNormalization-1,6,7
        spatial = bool(node.get_attr('spatial'))
        layer_attrs = {
            "num_channels": c,
            "momentum": momentum,
            "epsilon": epsilon,
            "is_test": True,
            "use_global_stats": False,
        }
        self.paddle_graph.add_layer(
            "paddle.nn.BatchNorm", 
S
SunAhong1993 已提交
1358
            inputs={"x": val_x.name}, 
S
SunAhong1993 已提交
1359 1360 1361 1362 1363 1364
            outputs=layer_outputs, 
            **layer_attrs)

    @print_mapping_info
    def Transpose(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
fix  
SunAhong1993 已提交
1365 1366 1367 1368
        s_len = len(val_x.out_shapes[0])
        perm_default = list(range(s_len))
        perm_default.reverse()
        perm = node.get_attr('perm', perm_default)
S
SunAhong1993 已提交
1369 1370
        self.paddle_graph.add_layer(
            "paddle.transpose", 
S
SunAhong1993 已提交
1371 1372
            inputs={"x": val_x.name},
            outputs=[node.name], 
S
SunAhong1993 已提交
1373 1374 1375 1376 1377
            perm=perm)

    @print_mapping_info
    def PRelu(self, node):
        op_name = name_generator("prelu", self.nn_name2id)
S
SunAhong1993 已提交
1378
        output_name = node.name
S
SunAhong1993 已提交
1379 1380 1381 1382 1383 1384
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_slope = self.graph.get_input_node(node, idx=1, copy=True)

        mode = 'channel'
        shape_slope = val_slope.out_shapes[0]
1385
        if shape_slope == [1] * len(shape_slope):
S
SunAhong1993 已提交
1386 1387
            mode = 'all'

S
SunAhong1993 已提交
1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414
        if mode == "element":
            self.paddle_graph.add_layer(
                "paddle.zeros",
                inputs={}, 
                outputs=[output_name + "__zeros"], 
                shape=shape_slope,
                dtype=string(node.dtype))
            self.paddle_graph.add_layer(
                "paddle.maximum",
                inputs={"x": val_x.name, 
                        "y": output_name + "__zeros"}, 
                outputs=[output_name + "__max"])
            self.paddle_graph.add_layer(
                "paddle.minimum",
                inputs={"x": val_x.name, 
                        "y": output_name + "__zeros"}, 
                outputs=[output_name + "__max"])
            self.paddle_graph.add_layer(
                "paddle.multiply",
                inputs={"x": val_slope.name, 
                        "y": output_name + "__min"}, 
                outputs=[output_name + "__mul"])
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs={"x": output_name + "__max", 
                        "y": output_name + "__mul"}, 
                outputs=[output_name])
S
SunAhong1993 已提交
1415
        else:
S
fix  
SunAhong1993 已提交
1416
            if mode == 'channel':
S
SunAhong1993 已提交
1417
                slope_data = _const_weight_or_none(val_slope)
C
Channingss 已提交
1418
                _rename_or_remove_weight(self.weights, val_slope.name)
S
fix  
SunAhong1993 已提交
1419
                if len(shape_slope) > 1:
1420
                    self.weights[op_name+'._weight'] = np.reshape(slope_data, shape_slope[0])
S
SunAhong1993 已提交
1421 1422 1423
                num_parameters = val_x.out_shapes[0][1]
            else:
                num_parameters = 1
C
Channingss 已提交
1424
                _rename_or_remove_weight(self.weights, val_slope.name)
1425
                self.weights[op_name+'._weight'] = np.reshape(self.weights[val_slope.name], [1])
S
SunAhong1993 已提交
1426 1427 1428 1429
            self.paddle_graph.add_layer(
                "paddle.nn.PReLU", 
                inputs={"x": val_x.name}, 
                outputs=layer_outputs, 
1430
                num_parameters=num_parameters)
S
SunAhong1993 已提交
1431 1432 1433 1434 1435 1436 1437 1438

    @print_mapping_info
    def Squeeze(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axes = node.get_attr('axes')
        if len(val_x.out_shapes[0]) == 1:
            self.paddle_graph.add_layer(
                "paddle.cast",
S
SunAhong1993 已提交
1439 1440
                inputs={"x": val_x.name},
                outputs=[node.name],
S
SunAhong1993 已提交
1441 1442 1443 1444
                dtype=string(val_x.dtype))
        else:
            self.paddle_graph.add_layer(
                "paddle.squeeze", 
S
SunAhong1993 已提交
1445 1446
                inputs={"x": val_x.name}, 
                outputs=[node.name], 
S
SunAhong1993 已提交
1447 1448 1449 1450 1451 1452 1453 1454
                axis=axes)

    @print_mapping_info
    def Equal(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
        self.paddle_graph.add_layer(
            "paddle.equal",
S
SunAhong1993 已提交
1455 1456 1457
            inputs={'x': val_x.name,
                    'y': val_y.name},
            outputs=[node.name])
S
SunAhong1993 已提交
1458 1459 1460 1461 1462 1463 1464

    @print_mapping_info
    def Greater(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
        self.paddle_graph.add_layer(
            "paddle.greater_than",
S
SunAhong1993 已提交
1465 1466
            inputs={'x': val_x.name,
                    'y': val_y.name},
S
SunAhong1993 已提交
1467 1468 1469 1470 1471 1472 1473 1474 1475
            outputs=node,
            param_attr=None)

    @print_mapping_info
    def Where(self, node):
        condition = self.graph.get_input_node(node, idx=0, copy=True)
        val_x = self.graph.get_input_node(node, idx=1, copy=True)
        val_y = self.graph.get_input_node(node, idx=2, copy=True)

S
SunAhong1993 已提交
1476
        not_condition = condition.name + '_not'
S
SunAhong1993 已提交
1477 1478
        self.paddle_graph.add_layer(
            "paddle.logical_not",
S
SunAhong1993 已提交
1479
            inputs={"x": condition.name},
S
SunAhong1993 已提交
1480 1481 1482 1483 1484 1485 1486
            outputs=[not_condition])
        cast_not_condition = not_condition + '_cast'
        self.paddle_graph.add_layer(
            "paddle.cast",
            inputs={"x": not_condition},
            outputs=[cast_not_condition],
            dtype=string(val_x.dtype))
S
SunAhong1993 已提交
1487
        cast_condition = condition.name + '_cast'
S
SunAhong1993 已提交
1488 1489
        self.paddle_graph.add_layer(
            "paddle.cast",
S
SunAhong1993 已提交
1490
            inputs={"x": condition.name},
S
SunAhong1993 已提交
1491 1492
            outputs=[cast_condition],
            dtype=string(val_x.dtype))
S
SunAhong1993 已提交
1493
        mul_val_x = val_x.name + '_mul'
S
SunAhong1993 已提交
1494 1495
        self.paddle_graph.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
1496
            inputs={'x': val_x.name,
S
SunAhong1993 已提交
1497 1498
                    'y': cast_condition},
            outputs=[mul_val_x])
S
SunAhong1993 已提交
1499
        mul_val_y = val_y.name + '_mul'
S
SunAhong1993 已提交
1500 1501
        self.paddle_graph.add_layer(
            "paddle.multiply",
S
SunAhong1993 已提交
1502
            inputs={'x': val_y.name,
S
SunAhong1993 已提交
1503 1504 1505 1506 1507 1508 1509
                    'y': cast_not_condition},
            outputs=[mul_val_y])

        self.paddle_graph.add_layer(
            "paddle.add",
            inputs={'x': mul_val_x,
                    'y': mul_val_y},
S
SunAhong1993 已提交
1510
            outputs=[node.name])
S
SunAhong1993 已提交
1511 1512 1513 1514 1515 1516 1517 1518

    @print_mapping_info
    def NonZero(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_x_dim = len(val_x.out_shapes[0])
        if val_x_dim == 1:
            self.paddle_graph.add_layer(
                "paddle.nonzero", 
S
SunAhong1993 已提交
1519 1520
                inputs={"x": val_x.name}, 
                outputs=[val_x.name])
S
SunAhong1993 已提交
1521 1522
            self.paddle_graph.add_layer(
                "paddle.transpose",
S
SunAhong1993 已提交
1523
                inputs={"x": val_x.name},
S
SunAhong1993 已提交
1524 1525 1526 1527 1528
                outputs=[node.layer_naem],
                perm=[1, 0])
        if val_x_dim > 1:
            self.paddle_graph.add_layer(
                "paddle.nonzero", 
S
SunAhong1993 已提交
1529 1530
                inputs={"x": val_x.name}, 
                outputs=[val_x.name])
S
SunAhong1993 已提交
1531 1532
            self.paddle_graph.add_layer(
                "paddle.split",
S
SunAhong1993 已提交
1533 1534
                inputs={"x": val_x.name}, 
                outputs=[val_x.name],
S
SunAhong1993 已提交
1535 1536 1537 1538
                num_or_sections=1,
                axis=val_x_dim)
            self.paddle_graph.add_layer(
                "paddle.concat", 
S
SunAhong1993 已提交
1539 1540
                inputs={"x": val_x.name}, 
                outputs=[node.name])
S
SunAhong1993 已提交
1541 1542 1543 1544 1545 1546

    @print_mapping_info
    def Identity(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            "paddle.assign", 
S
SunAhong1993 已提交
1547 1548
            inputs={"x": val_x.name}, 
            outputs=[node.name])
S
SunAhong1993 已提交
1549 1550 1551 1552 1553 1554 1555 1556

    @print_mapping_info
    def Tile(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
        repeats = _const_weight_or_none(val_repeats)

        if repeats is None:
S
SunAhong1993 已提交
1557
            repeats = val_repeats.name
S
SunAhong1993 已提交
1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570
            if val_repeats.dtype != 'int32':
                self.paddle_graph.add_layer(
                    "paddle.cast",
                    inputs={"x": repeats},
                    outputs=["{}.tmp".format(repeats)],
                    dtype=string("int32"))
                repeats = "{}.tmp".format(repeats)

        elif isinstance(repeats, int):
            repeats = [repeats]

        attr = {
            'expand_times': repeats,
S
SunAhong1993 已提交
1571
            "name": string(node.name),
S
SunAhong1993 已提交
1572 1573 1574
        }
        self.paddle_graph.add_layer(
            "paddle.tile", 
S
SunAhong1993 已提交
1575 1576
            inputs={"x": val_x.name}, 
                    outputs=[node.name], 
S
SunAhong1993 已提交
1577 1578 1579 1580 1581
                    repeat_times=repeats)

    @print_mapping_info
    def MaxPool(self, node):
        op_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
1582
        output_name = node.name
S
SunAhong1993 已提交
1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        auto_pad = node.get_attr('auto_pad', 'NOTSET')
        assert node.get_attr(
            "dilations") is None, 'only dilations = 0 is supported'  # optional

        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))  # optional
        pads = node.get_attr('pads', [0] * (poolnd * 2))  # optional
        paddle_op = 'paddle.nn.MaxPool{}D'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'

        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
            input_shape = val_x.out_shapes[0]
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w
            
        layer_attrs = {
            "kernel_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
            "ceil_mode": ceil_mode,
        }
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
1616
            inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, 
S
SunAhong1993 已提交
1617 1618 1619 1620 1621 1622
            outputs=layer_outputs, 
            **layer_attrs)

    @print_mapping_info
    def GlobalMaxPool(self, node):
        op_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
1623
        output_name = node.name
S
SunAhong1993 已提交
1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = val_x.out_shapes[0]
        if len(input_shape) == 4:
            poolnd = 2
        elif len(input_shape) == 5:
            poolnd = 3
        elif len(input_shape) == 3:
            poolnd = 1
        paddle_op = 'paddle.nn.AdaptiveMaxPool{}D'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
        output_shape = node.out_shapes[0]
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
1638
            inputs={'x': val_x.name}, 
S
SunAhong1993 已提交
1639 1640 1641 1642 1643 1644
            outputs=layer_outputs, 
            output_size=output_shape[2:])

    @print_mapping_info
    def GlobalAveragePool(self, node):
        op_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
1645
        output_name = node.name
S
SunAhong1993 已提交
1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = val_x.out_shapes[0]
        if len(input_shape) == 4:
            poolnd = 2
        elif len(input_shape) == 5:
            poolnd = 3
        elif len(input_shape) == 3:
            poolnd = 1
        paddle_op = 'paddle.nn.AdaptiveAvgPool{}D'.format(poolnd)
        assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
        output_shape = node.out_shapes[0]
        self.paddle_graph.add_layer(
            paddle_op, 
S
SunAhong1993 已提交
1660
            inputs={'x': val_x.name}, 
S
SunAhong1993 已提交
1661 1662 1663 1664 1665 1666
            outputs=layer_outputs, 
            output_size=output_shape[2:])

    @print_mapping_info
    def Conv(self, node):
        op_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
1667
        output_name = node.name
S
SunAhong1993 已提交
1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697
        layer_outputs = [op_name, output_name]
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
        has_bias = len(node.layer.input) == 3
        if has_bias:
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
        auto_pad = node.get_attr('auto_pad', 'NOTSET')

        kernel_shape = node.get_attr('kernel_shape')
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only Conv2D and Conv3D is supported'
        num_out_channels = val_w.out_shapes[0][0]
        num_in_channels = val_w.out_shapes[0][1]
        paddle_op = 'paddle.nn.Conv{}D'.format(convnd)

        num_groups = node.get_attr('group', 1)
        strides = node.get_attr('strides', [1] * convnd)
        dilations = node.get_attr('dilations', [1] * convnd)
        pads = node.get_attr('pads', [0] * (convnd * 2))

        input_shape = val_x.out_shapes[0]
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
            pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
                                      strides[0])
            pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
                                      strides[1])
            paddings = pad_h + pad_w

S
fix  
SunAhong1993 已提交
1698
        layer_inputs = {'x': val_x if isinstance(val_x, str) else val_x.name}
S
SunAhong1993 已提交
1699 1700 1701 1702 1703 1704 1705 1706 1707
        layer_attrs = {
            "in_channels": num_in_channels * num_groups,
            "out_channels": num_out_channels,
            "kernel_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
            "dilation": dilations,
            "groups": num_groups,
        }
C
Channingss 已提交
1708 1709 1710 1711
        remove_weight = True if  val_w.name in self.done_weight_list else False
        if remove_weight:
            self.done_weight_list.append(val_w.name)
        _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight', remove_weight)
S
SunAhong1993 已提交
1712
        if has_bias:
C
Channingss 已提交
1713 1714 1715 1716
            remove_bias = True if val_b.name in self.done_weight_list else False
            if remove_bias:
                self.done_weight_list.append(val_b_name)
            _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias', remove_bias)
S
SunAhong1993 已提交
1717 1718
        else:
            layer_attrs["bias_attr"] = False
S
fix  
SunAhong1993 已提交
1719 1720 1721 1722 1723 1724 1725 1726 1727 1728
        input_shape = val_x.out_shapes[0]
        if reduce(lambda x,y:x*y, input_shape) in [1, -1] and 1 not in input_shape:
            input_shape[1] = num_in_channels * num_groups
            input_shape[0] = 0
            input_shape[2] = 0
            self.paddle_graph.add_layer(
                "paddle.reshape", 
                inputs=layer_inputs, 
                outputs=[layer_inputs["x"]], 
                shape=input_shape)
S
SunAhong1993 已提交
1729 1730
        self.paddle_graph.add_layer(
            paddle_op, 
S
fix  
SunAhong1993 已提交
1731
            inputs=layer_inputs, 
S
SunAhong1993 已提交
1732 1733 1734 1735 1736
            outputs=layer_outputs, 
            **layer_attrs)

    @print_mapping_info
    def ConvTranspose(self, node):
1737 1738 1739
        op_name = name_generator("conv_trans", self.nn_name2id)
        output_name = node.name
        layer_outputs = [op_name, output_name]
S
SunAhong1993 已提交
1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = None
        if len(node.layer.input) > 2:
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
        auto_pad = node.get_attr('auto_pad', 'NOTSET')
        out_padding = node.get_attr('output_padding', [0, 0])
        kernel_shape = node.get_attr('kernel_shape')
        assert kernel_shape, 'kernel_shape not inferred'
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported'
        num_in_channels = val_w.out_shapes[0][0]
        num_out_channels = val_w.out_shapes[0][1]
1753
        paddle_op = 'paddle.nn.Conv{}DTranspose'.format(convnd)
S
SunAhong1993 已提交
1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770

        num_groups = node.get_attr('group', 1)
        strides = node.get_attr('strides', [1] * convnd)
        dilations = node.get_attr('dilations', [1] * convnd)
        output_size = node.get_attr('output_shape', [])
        pads = node.get_attr('pads', [0] * (convnd * 2))

        paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)

        output_size = [0, 0]

        output_size[0] = (val_x.out_shapes[0][2] - 1
                          ) * strides[0] - 2 * paddings[0] + dilations[0] * (
                              kernel_shape[0] - 1) + 1 + out_padding[0]
        output_size[1] = (val_x.out_shapes[0][3] - 1
                          ) * strides[1] - 2 * paddings[1] + dilations[1] * (
                              kernel_shape[1] - 1) + 1 + out_padding[1]
1771

S
fix  
SunAhong1993 已提交
1772
        # Conv2DTranspose缺少output_size,只能在forward里头传进output_size
1773
        inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name}
S
SunAhong1993 已提交
1774
        layer_attrs = {
1775 1776
            "in_channels": num_in_channels,
            "out_channels": num_out_channels,
1777
            "kernel_size": kernel_shape,
S
fix  
SunAhong1993 已提交
1778 1779 1780
            "stride": strides,
            "dilation": dilations,
            "padding": paddings,
1781 1782 1783
            "groups": num_groups,
            "output_padding":out_padding}
            
C
Channingss 已提交
1784
        _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight',)
S
fix  
SunAhong1993 已提交
1785
        if val_b is not None:
C
Channingss 已提交
1786
            _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
S
SunAhong1993 已提交
1787
        self.paddle_graph.add_layer(
1788
            kernel=paddle_op,
S
fix  
SunAhong1993 已提交
1789
            inputs=inputs_dict,
1790
            outputs=layer_outputs,
S
SunAhong1993 已提交
1791
            **layer_attrs)
S
fix  
SunAhong1993 已提交
1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803
        
    @print_mapping_info
    def ArgMax(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axis = node.get_attr('axis')
        keepdims = False if node.get_attr('keepdims') == 0 else True
        layer_attrs = {'axis': axis,
                      'keepdim': keepdims}
        self.paddle_graph.add_layer(
            'paddle.argmax', 
            inputs={"x": val_x.name}, 
            outputs=[node.name],
C
Channingss 已提交
1804 1805
            **layer_attrs)

S
SunAhong1993 已提交
1806
        
C
Channingss 已提交
1807
    @print_mapping_info
S
SunAhong1993 已提交
1808 1809 1810 1811
    def Size(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            "paddle.shape", 
S
fix  
SunAhong1993 已提交
1812
            inputs={"input": val_x.name}, 
S
SunAhong1993 已提交
1813
            outputs=[node.name])
S
fix  
SunAhong1993 已提交
1814 1815 1816 1817 1818
        self.paddle_graph.add_layer(
            'paddle.cast',
            inputs={"x": node.name},
            outputs=[node.name],
            dtype=string('int64'))  
S
SunAhong1993 已提交
1819 1820 1821 1822 1823 1824 1825 1826
        self.paddle_graph.add_layer(
            "paddle.prod",
            inputs={"x": node.name},
            outputs=[node.name])
        
    @print_mapping_info
    def Sign(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
S
fix  
SunAhong1993 已提交
1827 1828 1829 1830 1831 1832
        if node.dtype not in ["float16", "float32", "float64"]:
            self.paddle_graph.add_layer(
                "paddle.cast", 
                inputs={"x": val_x.name}, 
                outputs=[val_x.name],
                dtype=string("float32"))
S
SunAhong1993 已提交
1833 1834 1835 1836
        self.paddle_graph.add_layer(
            "paddle.sign", 
            inputs={"x": val_x.name}, 
            outputs=[node.name])
S
fix  
SunAhong1993 已提交
1837 1838 1839 1840 1841 1842
        if node.dtype not in ["float16", "float32", "float64"]:
            self.paddle_graph.add_layer(
                "paddle.cast", 
                inputs={"x": node.name}, 
                outputs=[node.name],
                dtype=string(node.dtype))
S
SunAhong1993 已提交
1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867
        
    @print_mapping_info
    def OneHot(self, node):
        nn_op_name = name_generator("onehot", self.nn_name2id)
        output_name = node.name
        layer_outputs = [nn_op_name, output_name]
        indices = self.graph.get_input_node(node, idx=0, copy=True)
        depth = self.graph.get_input_node(node, idx=1, copy=True)
        values = self.graph.get_input_node(node, idx=2, copy=True)
        axis = node.get_attr('axis', -1)
        self.paddle_graph.add_layer(
            "custom_layer:OneHot", 
            inputs={"indices": indices.name,
                    "depth": depth.name,
                    "values": values.name}, 
            outputs=layer_outputs,
            axis=axis)
    
    @print_mapping_info
    def Reciprocal(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        self.paddle_graph.add_layer(
            "paddle.reciprocal", 
            inputs={"x": val_x.name}, 
            outputs=[node.name])
C
Channingss 已提交
1868

1869 1870
    @print_mapping_info
    def LSTM(self, node):
C
Channingss 已提交
1871 1872 1873 1874 1875 1876
        x = self.graph.get_input_node(node, idx=0, copy=True)
        input_weight = self.graph.get_input_node(node, idx=1, copy=True)
        hidden_weight = self.graph.get_input_node(node, idx=2, copy=True)

        input_nums = len(node.layer.input)
        exist_input_nums = 3
1877
        have_bias = False
C
Channingss 已提交
1878 1879
        if input_nums > 3 and node.layer.input[3] != '':
            bias = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
1880
            have_bias = True
C
Channingss 已提交
1881 1882 1883 1884 1885 1886
            exist_input_nums += 1
        if input_nums > 4 and node.layer.input[4] != '':
            sequence_lens = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
            exist_input_nums += 1
        if input_nums > 5 and node.layer.input[5] != '':
            init_h = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
1887 1888 1889 1890 1891 1892
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": init_h.name},
                outputs=[init_h.name],
                shape=init_h.out_shapes[0]
                )
C
Channingss 已提交
1893 1894 1895
            exist_input_nums += 1
        if input_nums > 6 and node.layer.input[6] != '':
            init_c = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
1896 1897 1898 1899 1900 1901
            self.paddle_graph.add_layer(
                'paddle.reshape',
                inputs={"x": init_c.name},
                outputs=[init_c.name],
                shape=init_c.out_shapes[0]
                )
C
Channingss 已提交
1902 1903

        input_weight_np = _const_weight_or_none(input_weight)
C
Channingss 已提交
1904
        _rename_or_remove_weight(self.weights, input_weight.name)
1905
        hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/4)
C
Channingss 已提交
1906 1907
        input_size = input_weight_np.shape[2]
        hidden_weight_np = _const_weight_or_none(hidden_weight)
C
Channingss 已提交
1908
        _rename_or_remove_weight(self.weights, hidden_weight.name)
C
Channingss 已提交
1909
        bias_np = _const_weight_or_none(bias)
C
Channingss 已提交
1910
        _rename_or_remove_weight(self.weights, bias.name)
1911 1912 1913 1914 1915 1916 1917 1918 1919 1920
        input_bias_np = bias_np[:, :4*hidden_size]
        hidden_bias_np = bias_np[:, 4*hidden_size:]

        # parameters order in paddle:lstm:
        # 1. gate order in paddle is: input, forget, cell, output.
        # 2. gate orfer in onnx is: input, output, forget, cell.

        def reform_weights(w, n, intervals):
            slices = [w[:,x * n: y * n] for x, y in intervals]
            return np.concatenate(slices, axis=1)
C
Channingss 已提交
1921

1922 1923 1924 1925
        def transform_weight_with_bias(weights, n, intervals):
            return [reform_weights(w, n, intervals) for w in weights]

        reform_permutation = [(0, 1), (2, 4), (1, 2)]
C
Channingss 已提交
1926

C
Channingss 已提交
1927
        weights = transform_weight_with_bias(
C
Channingss 已提交
1928 1929 1930 1931 1932 1933 1934
            [input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np],
            hidden_size, reform_permutation)

        op_name = name_generator("lstm", self.nn_name2id)
        y_out = node.output(0)
        yh_out = node.output(1) 
        yc_out = node.output(2)
1935
        direction = node.get_attr('direction', 'forward')
C
Channingss 已提交
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950

        def generate_paddle_param_names(op_name, suffix=''):
            param_names = []
            param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}'])
            if have_bias != False: param_names.append('{}.bias_ih_l0{}')
            if have_bias != False: param_names.append('{}.bias_hh_l0{}')
            param_names = [x.format(op_name, suffix) for x in param_names]
            return param_names

        def assign_params(op_name, weights, weight_idx=0, suffix=''):
            param_names = generate_paddle_param_names(op_name, suffix)
            print(param_names)
            for param_name, weight in zip(param_names, weights):
                self.weights[param_name] = weight[weight_idx]

1951 1952 1953
        if direction == 'backward':
            raise Exception("LSTM support 'forward' or 'bidirectional', except '{}'.".format(direction))
        else:
C
Channingss 已提交
1954 1955 1956
            assign_params(op_name, weights)
            if direction == 'bidirectional':
                assign_params(op_name, weights, 1, '_reverse')
1957

C
Channingss 已提交
1958 1959 1960 1961 1962 1963 1964
        self.paddle_graph.add_layer(
            'paddle.nn.LSTM', 
            inputs={'input': x.name, 'initial_states': (init_h.name, init_c.name)},
            outputs=[op_name, y_out, yh_out, yc_out],
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
1965
            direction=string(direction),
C
Channingss 已提交
1966 1967 1968 1969 1970 1971
            time_major=True)

        self.paddle_graph.add_layer(
            'paddle.reshape',
            inputs={"x": y_out},
            outputs=[y_out],
1972
            shape=[0, 0, -1, hidden_size]
C
Channingss 已提交
1973 1974 1975 1976 1977 1978 1979
            )
        self.paddle_graph.add_layer(
            'paddle.transpose',
            inputs={"x": y_out},
            outputs=[y_out],
            perm=[0,2,1,3]
            )