onnx_op_mapper.py 58.3 KB
Newer Older
C
update  
channingss 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from x2paddle.core.graph import GraphNode
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.fluid_code import Layer
from x2paddle.core.fluid_code import FluidCode
from x2paddle.decoder.onnx_decoder import ONNXGraph, ONNXGraphNode, ONNXGraphDataNode
from x2paddle.op_mapper.onnx_directly_map import default_op_mapping_field_values
from x2paddle.op_mapper.onnx_directly_map import default_op_mapping
from x2paddle.op_mapper.onnx_directly_map import default_ioa_constraint
C
channingss 已提交
23
from x2paddle.op_mapper.onnx_custom_layer import *
C
channingss 已提交
24
from x2paddle.core.util import string
C
update  
channingss 已提交
25
import numpy as np
C
channingss 已提交
26
import onnx
C
channingss 已提交
27
import onnx.numpy_helper as numpy_helper
C
channingss 已提交
28
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
C
update  
channingss 已提交
29 30
import logging as _logging
from collections import OrderedDict as _dict
C
channingss 已提交
31
import math
C
channingss 已提交
32 33
import os
import shutil
C
update  
channingss 已提交
34 35 36 37 38

_logger = _logging.getLogger(__name__)


def _const_weight_or_none(node):
C
channings 已提交
39
    if 'Constant' in node.layer_type:
C
channingss 已提交
40
        return node.value
C
update  
channingss 已提交
41 42 43 44 45
    if isinstance(node, ONNXGraphDataNode):
        return node.weight
    return None


C
channingss 已提交
46 47 48 49 50 51 52 53
def get_same_padding(in_size, kernel_size, stride):
    new_size = int(math.ceil(in_size * 1.0 / stride))
    pad_size = (new_size - 1) * stride + kernel_size - in_size
    pad0 = int(pad_size / 2)
    pad1 = pad_size - pad0
    return [pad0, pad1]


C
channings 已提交
54
class ONNXOpMapper(OpMapper):        
55 56 57 58 59
    elementwise_ops = {
        'Add': 'elementwise_add',
        'Div': 'elementwise_div',
        'Sub': 'elementwise_sub',
        'Mul': 'elementwise_mul',
C
channings 已提交
60
        'Pow': 'elementwise_pow',}
61

C
channingss 已提交
62
    def __init__(self, decoder, save_dir):
C
update  
channingss 已提交
63 64 65 66 67 68
        super(ONNXOpMapper, self).__init__()
        self.decoder = decoder
        self.graph = decoder.onnx_graph
        self.input_shapes = []
        self.weights = dict()
        self.omit_nodes = list()
C
channingss 已提交
69
        self.used_custom_layers = dict()
C
channingss 已提交
70 71 72
        self.is_inference = False
        self.tmp_data_dir = os.path.join(save_dir, 'tmp_data')
        self.get_output_shapes()
C
channings 已提交
73
        
C
update  
channingss 已提交
74 75
        if not self.op_checker():
            raise Exception("Model are not supported yet.")
C
channings 已提交
76
            
C
update  
channingss 已提交
77
        #mapping op
C
updatea  
channingss 已提交
78 79 80 81 82
        print("Total nodes: {}".format(
            sum([
                isinstance(node, ONNXGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
C
update  
channingss 已提交
83 84 85 86 87 88 89
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self, op):
                func = getattr(self, op)
                func(node)
            elif op in default_op_mapping:
C
channingss 已提交
90
                self.directly_map(node)
C
channingss 已提交
91 92
            elif op in custom_layers:
                self.deal_custom_layer(node)
93 94
            elif op in self.elementwise_ops:
                self.elementwise_map(node)
C
update  
channingss 已提交
95

C
channingss 已提交
96 97
        self.remove_tmp_data()

C
update  
channingss 已提交
98 99 100 101 102
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
103 104 105 106
            if not hasattr(self, op) and \
                op not in default_op_mapping and \
                op not in custom_layers and \
                op not in self.elementwise_ops:
C
update  
channingss 已提交
107 108 109 110 111 112 113 114 115 116
                unsupported_ops.add(op)
        if len(unsupported_ops) == 0:
            return True
        else:
            print("There are {} ops not supported yet, list as below".format(
                len(unsupported_ops)))
            for op in unsupported_ops:
                print(op)
            return False

C
channingss 已提交
117
    def get_results_of_inference(self, model, value_infos, data_nodes):
118 119
        if not os.path.exists(self.tmp_data_dir):
            os.makedirs(self.tmp_data_dir)
C
channings 已提交
120
            
C
channingss 已提交
121 122
        for data_node in data_nodes:
            value_info = value_infos[data_node]
C
channings 已提交
123 124 125 126 127 128 129
            shape = value_info['shape']
            for i, dim_shape in enumerate(shape):
                if dim_shape==0 and i==0:
                    shape[i]=1
                if dim_shape==0 and i!=0:
                    assert 'shape of input is not assigned'
            ipt = np.random.random(shape).astype(
C
channingss 已提交
130
                value_info['dtype'])
131
            np.save(os.path.join(self.tmp_data_dir, data_node), ipt)
C
channings 已提交
132
            
C
channingss 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        model = onnx.shape_inference.infer_shapes(model)
        outputs = []
        for value_info in model.graph.value_info:
            outputs.append(value_info)

        model.graph.ClearField('output')
        model.graph.output.MergeFrom(outputs)
        onnx.save(model, os.path.join(self.tmp_data_dir,
                                      'onnx_model_infer.onnx'))
        os.system('onnx_infer --save_dir=' + self.tmp_data_dir)
        return

    def get_dynamic_shape(self, layer):
        """
        get dynamic shape from infer_result
        """
149 150 151 152
        path = os.path.join(self.tmp_data_dir, layer + '.npy')
        if not os.path.exists(path):
            return [None, None, None]
        output = np.load(path)
C
channingss 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        return output.tolist(), output.dtype, output.shape

    def get_output_shapes(self):
        """
        build topo_sort of ONNX model
        """
        nodes = self.decoder.model.graph.node
        node_map = self.decoder.onnx_graph.node_map
        value_infos = self.decoder.onnx_graph.value_infos
        onnx_model = self.decoder.model
        for layer in nodes:
            node = node_map[layer.name]
            for opt in layer.output:
                if opt in value_infos:
                    value_info = value_infos[opt]
C
channings 已提交
168 169
                    if len(value_info['shape']
                           ) == 0 or value_info['dtype'] is None or 0 in value_info['shape']:
C
channingss 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
                        if self.is_inference == False:
                            self.get_results_of_inference(
                                onnx_model, value_infos,
                                self.decoder.onnx_graph.place_holder_nodes)
                            self.is_inference = True
                        _, dtype, shape = self.get_dynamic_shape(opt)
                        node.out_shapes.append(shape)
                        node.dtype = dtype
                    else:
                        node.dtype = value_info['dtype']
                        node.out_shapes.append(value_info['shape'])
                else:
                    if self.is_inference == False:
                        self.get_results_of_inference(
                            onnx_model, value_infos,
                            self.decoder.onnx_graph.place_holder_nodes)
                        self.is_inference = True
                    _, dtype, shape = self.get_dynamic_shape(opt)
                    node.dtype = dtype
                    node.out_shapes.append(shape)

    def remove_tmp_data(self):
        """
        remove temporarily generated file
        """
        if os.path.exists(self.tmp_data_dir):
            import shutil
            shutil.rmtree(self.tmp_data_dir)

C
channingss 已提交
199
    def directly_map(self, node, name='', *args, **kwargs):
C
update  
channingss 已提交
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        inputs = node.layer.input
        outputs = node.layer.output
        op_type = node.layer_type
        attrs = node.attr_map
        info = default_op_mapping[op_type]
        info.extend(list(default_op_mapping_field_values.values())[len(info):])
        (
            fluid_op,
            fluid_input_args,
            fluid_output_args,
            attr_mapping,
            default_attrs,
            input_perm,
            output_perm,
            fill_name_field,
        ) = info

        if fluid_op in default_ioa_constraint:
            for predicate, message in default_ioa_constraint[fluid_op]:
                assert predicate(inputs, outputs, attrs), message

        mapped_attrs = {
            attr_mapping.get(key, key): value
            for key, value in attrs.items()
        }
        if '' in mapped_attrs:
            mapped_attrs.pop('')
        if '_' in mapped_attrs:
            mapped_attrs.pop('_')
        fluid_attrs = default_attrs.copy()
        fluid_attrs.update(mapped_attrs)
C
channingss 已提交
231
        inputs = inputs if input_perm is None else list(
C
update  
channingss 已提交
232
            map(lambda i: inputs[i], input_perm))
C
channingss 已提交
233 234 235 236
        val_inps = []
        for idx, ipt in enumerate(inputs):
            val_inps.append(self.graph.get_input_node(node, idx=idx, copy=True))

C
update  
channingss 已提交
237 238 239
        val_outs = outputs if output_perm is None else list(
            map(lambda i: outputs[i], output_perm))
        attr = fluid_attrs
C
channingss 已提交
240 241
        assert len(val_inps) == 1, 'directly_map error with multi inputs'
        if fluid_op not in ['shape']:
C
update  
channingss 已提交
242 243
            attr['name'] = string(node.layer_name)
        node.fluid_code.add_layer(fluid_op,
C
channingss 已提交
244
                                  inputs=val_inps[0],
C
update  
channingss 已提交
245 246 247
                                  output=val_outs[0],
                                  param_attr=attr)

C
channingss 已提交
248 249 250
    def deal_custom_layer(self, node):
        op = node.layer_type
        custom_code, func = make_custom_layer(node)
C
channingss 已提交
251
        child_func_code, child_func = make_custom_child_func(node)
C
channingss 已提交
252 253 254 255
        params = get_params(node.layer, node.layer_type)
        arg_names, kwargs = set_args(func, params)
        kwargs['name'] = string(node.layer_name)
        node.fluid_code.add_layer(func.__code__.co_name,
C
channingss 已提交
256
                                  inputs=node.inputs,
C
channingss 已提交
257 258 259 260 261
                                  output=node,
                                  param_attr=kwargs,
                                  is_custom_layer=True)
        if op not in self.used_custom_layers:
            self.used_custom_layers[op] = custom_code
C
channingss 已提交
262
            if op + '_child_func' not in self.used_custom_layers:
C
channingss 已提交
263 264 265
                if child_func_code is not None:
                    self.used_custom_layers[op +
                                            '_child_func'] = child_func_code
266 267 268
    def elementwise_map(self, node):
        assert node.layer_type in self.elementwise_ops
        op_type = self.elementwise_ops[node.layer_type]
C
channings 已提交
269
        
270 271 272 273
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
        val_y_shape = val_y.out_shapes[0]
        val_x_shape = val_x.out_shapes[0]
C
channings 已提交
274 275
        
        if len(val_x_shape)<len(val_y_shape):
276 277 278 279
            val_x, val_y = val_y, val_x

        str_y_shape = ','.join(str(e) for e in val_y_shape)
        str_x_shape = ','.join(str(e) for e in val_x_shape)
280
        slice_idx = 0
281 282 283 284 285 286
        if str_y_shape not in str_x_shape:
            for dim in val_y_shape:
                if dim == 1:
                    slice_idx += 1
                else:
                    break
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
        attr = {"name": string(node.layer_name)}
        if slice_idx < len(val_y_shape) and slice_idx > 0:
            val_y_reshaped = val_y_shape[slice_idx:]
            var_y_reshaped = val_y.layer_name + '_reshaped'
            attr_reshaped = {
                'shape': val_y_reshaped,
                'name': string(var_y_reshaped)
            }
            node.fluid_code.add_layer('reshape',
                                      inputs=val_y,
                                      output=var_y_reshaped,
                                      param_attr=attr_reshaped)
            inputs = {'x': val_x, 'y': var_y_reshaped}
            node.fluid_code.add_layer(op_type,
                                      inputs=inputs,
                                      output=node,
                                      param_attr=attr)
        else:
            inputs = {'x': val_x, 'y': val_y}
            node.fluid_code.add_layer(op_type,
                                      inputs=inputs,
                                      output=node,
                                      param_attr=attr)
C
channingss 已提交
310

C
update  
channingss 已提交
311
    def place_holder(self, node):
C
channingss 已提交
312
        self.input_shapes.append(node.out_shapes[0])
C
channings 已提交
313 314 315 316 317 318 319
        
        shape = node.out_shapes[0]
        for i, dim_shape in enumerate(shape):
            if dim_shape==0 and i==0:
                shape[i]=1
            if dim_shape==0 and i!=0:
                assert 'shape of input is not assigned'
C
update  
channingss 已提交
320 321
        attr = {
            "dtype": string(node.dtype),
C
channings 已提交
322
            "shape": shape,
C
update  
channingss 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335
            "name": string(node.layer_name),
            "append_batch_size": 'False'
        }

        node.fluid_code.add_layer("data",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def create_parameter(self, node, parameter=None):
        if parameter is not None:
            node = parameter
        dtype = node.dtype
C
channingss 已提交
336
        shape = node.out_shapes[0]
C
update  
channingss 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363

        self.weights[node.layer_name] = node.weight
        attr = {
            'dtype': string(dtype),
            'shape': shape,
            'name': string(node.layer_name),
            'attr': string(node.layer_name),
            'default_initializer': 'Constant(0.0)'
        }
        node.fluid_code.add_layer("create_parameter",
                                  inputs=None,
                                  output=node,
                                  param_attr=attr)

    def _pad_if_asymmetric(self, node, pads, val_name):  # pads: SSEE
        assert len(pads) & 1 == 0
        symmetric = True
        ndims = len(pads) // 2
        for idx_dim in range(ndims):
            if pads[idx_dim] != pads[ndims + idx_dim]:
                symmetric = False
                break
        if symmetric:
            return pads[:ndims], val_name
        val_padded = self.Pad(node, op_independent=False)
        return [0] * ndims, val_padded

C
channingss 已提交
364
    def _interpolate(self, node):
C
channingss 已提交
365 366
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scales = self.graph.get_input_node(node, idx=1, copy=True)
C
channingss 已提交
367
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
C
channings 已提交
368
        
369 370 371 372
        out_shape = val_y.out_shapes[0]
        if out_shape is not None:
            assert len(out_shape) == 4, 'only 4-D Tensor as X and Y supported'
            out_shape = out_shape[2:]
C
channings 已提交
373
            
C
channingss 已提交
374
        scales = _const_weight_or_none(val_scales)
C
channings 已提交
375
    
376 377 378
        if isinstance(val_scales, ONNXGraphNode):
            scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)

C
channings 已提交
379
        attr = { 'name': string(node.layer_name)}
380
        use_scales = True
C
channingss 已提交
381
        if scales is not None:
382 383 384
            try:
                assert len(scales) == 4, 'only 4-D Tensor as X and Y supported'
                assert scales[0] == 1 and scales[
C
channings 已提交
385
                1] == 1, 'only scale on (NC)HW supported'
386
                assert scales[2] == scales[
C
channings 已提交
387
                3], 'only aspect-ratio-invariant scale supported'
388
            except:
C
channings 已提交
389
                use_scales=False
C
channingss 已提交
390 391
        scale = scales[2] if scales else None
        if scale is None:
392
            assert out_shape, 'neither scales nor output shape is available'
C
channingss 已提交
393
        else:
394
            if out_shape is None:
C
channingss 已提交
395 396 397 398
                in_shape = val_x.out_shapes[0]
                assert in_shape is not None, 'out_shape required but not inferrable'
                assert len(
                    in_shape) == 4, 'only 4-D Tensor as X and Y supported'
399
                out_shape = [in_shape[2] * scale, in_shape[3] * scale]
400

C
channingss 已提交
401
        mode = node.get_attr('mode', 'nearest')
C
channings 已提交
402
        
C
channingss 已提交
403
        fluid_op = 'resize_{}'.format(mode)
404
        if 'linear' in mode:
C
channings 已提交
405
            print('Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear')
406
            fluid_op = 'resize_bilinear'
C
channings 已提交
407
        
408 409
        if use_scales and scale is not None:
            attr['scale'] = scale
C
channings 已提交
410
        else:   
411
            attr['out_shape'] = out_shape
412

C
channingss 已提交
413 414 415 416
        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)
C
channings 已提交
417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
    def RoiAlign(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)
        
        pooled_height =  node.get_attr('output_height')
        pooled_width =  node.get_attr('output_width')
        spatial_scale = node.get_attr('spatial_scale')
        sampling_ratio = node.get_attr('sampling_ratio')
        attr = {
                'pooled_height': pooled_height,
                'pooled_width': pooled_width,
                'spatial_scale': spatial_scale,
                'sampling_ratio':sampling_ratio,
            }
        node.fluid_code.add_layer('roi_align',
                                  inputs={'input':val_x,'rois':val_rois},
                                  output=node,
                                  param_attr=attr)
C
channingss 已提交
435

C
channings 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
    def MaxRoiPool(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_rois = self.graph.get_input_node(node, idx=1, copy=True)
        
        spatial_scale = node.get_attr('spatial_scale')
        pooled_height, pooled_width = node.get_attr('pooled_shape')
        attr = {
                'pooled_height': pooled_height,
                'pooled_width': pooled_width,
                'spatial_scale': spatial_scale,
            }
        node.fluid_code.add_layer('roi_pool',
                                  inputs={'input':val_x,'rois':val_rois},
                                  output=node,
                                  param_attr=attr)
            
C
update  
channingss 已提交
452
    def Pad(self, node, op_independent=True):
C
channingss 已提交
453
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
454 455 456
        pads = node.get_attr('pads')
        mode = node.get_attr('mode', 'constant')
        value = node.get_attr('value', 0.)
C
channingss 已提交
457 458
        data_shape = val_x.out_shapes[0]
        output_shape = node.out_shapes[0]
C
update  
channingss 已提交
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
        assume_pad2d = False
        attr = {}
        if len(pads) == 4:
            assume_pad2d |= mode != 'constant'
            if data_shape:
                assume_pad2d |= data_shape and len(data_shape) == 4  # NCHW
            if output_shape:
                assume_pad2d |= output_shape and len(output_shape) == 4  # NCHW
        if assume_pad2d:
            fluid_op = 'pad2d'
            attr['data_format'] = string('NCHW')
            attr['mode'] = string(mode)
        else:
            attr = {'pad_value': value}
            fluid_op = 'pad'
        if len(pads) == 4:
            paddings = np.array(pads).reshape(
                (-1, 2)).transpose().flatten().tolist()  # SSEE -> SESE
        elif len(pads) == 8:
            paddings = np.array(pads).reshape(
                (-1, 4)).transpose().flatten().tolist()  # SSEE -> SESE
C
channingss 已提交
480 481 482 483
            if sum(paddings[:4]) == 0:
                fluid_op = 'pad2d'
                paddings = paddings[4:]
                attr['mode'] = string(mode)
C
update  
channingss 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
        attr['paddings'] = paddings
        if op_independent:
            attr['name'] = string(node.layer_name)
            node.fluid_code.add_layer(fluid_op,
                                      inputs=val_x,
                                      output=node,
                                      param_attr=attr)
        else:
            attr['name'] = string(node.layer_name + '_paded')
            node.fluid_code.add_layer(fluid_op,
                                      inputs=val_x,
                                      output=node.layer_name + '_paded',
                                      param_attr=attr)
            return node.layer_name + '_paded'

    def Unsqueeze(self, node):
C
channingss 已提交
500
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
501
        axes = node.get_attr('axes')
C
channings 已提交
502 503
        print(val_x.outputs)
        if len(val_x.out_shapes[0])==0:
504
            node.fluid_code.add_layer('assign',
C
channings 已提交
505 506 507
                                  inputs=val_x,
                                  output=node,
                                  param_attr=None)
508 509 510 511 512 513 514
        else:
            attr = {'axes': axes, 'name': string(node.layer_name)}
            node.fluid_code.add_layer('unsqueeze',
                                      inputs=val_x,
                                      output=node,
                                      param_attr=attr)

C
channingss 已提交
515
    def Shrink(self, node):
C
channingss 已提交
516
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
517 518 519 520 521 522 523 524 525
        bias = node.get_attr('bias')
        lambd = node.get_attr('lambd')
        assert bias == 0.0, 'not support bias!=0'
        attr = {'threshold': lambd, 'name': node.layer_name}
        node.fluid_code.add_layer('hard_shrink',
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

C
update  
channingss 已提交
526 527 528 529 530 531 532 533
    def Constant(self, node):
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        value = node.get_attr('value')
        dtype = np.dtype(value.dtype)
        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
C
channings 已提交
534
        
C
update  
channingss 已提交
535
        shape = node.get_attr('shape', None)
C
channings 已提交
536
        
C
update  
channingss 已提交
537
        if shape is None:
C
channingss 已提交
538
            shape = val_output.out_shapes[0]
C
update  
channingss 已提交
539 540 541 542 543 544 545 546
        if shape is None:
            shape = list(value.shape)
            _logger.warning(
                'in (Constant -> %s): '
                'attribute "shape" of %s not inferred, '
                'using value as 1-D tensor may lead to fails',
                val_output.layer_name, val_output.layer_name)

547
        if len(value) == 1:
C
channingss 已提交
548
            value = value.tolist()
C
update  
channingss 已提交
549 550 551 552 553 554 555 556 557
            shape = [1]
            value = value[0]
            if dtype.name == 'int64':
                dtype = 'int32'
            attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
            node.fluid_code.add_layer('fill_constant',
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)
C
channingss 已提交
558 559 560 561 562 563 564 565 566 567 568 569 570 571
        else:
            value = np.reshape(value, shape)
            self.weights[node.layer_name] = value
            attr = {
                'dtype': string(dtype),
                'shape': shape,
                'name': string(node.layer_name),
                'attr': string(node.layer_name),
                'default_initializer': 'Constant(0.0)'
            }
            node.fluid_code.add_layer("create_parameter",
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)
C
update  
channingss 已提交
572 573

    def Resize(self, node):
574 575 576 577 578 579
        self._interpolate(node)

    def Upsample(self, node):
        self._interpolate(node)

    def Expand(self, node):
C
channingss 已提交
580
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
581
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
C
channings 已提交
582 583
        
        if len(val_shape.outputs)==1:
584 585
            self.omit_nodes.append(val_shape.layer_name)

C
channingss 已提交
586
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
587
        out_shape = node.out_shapes[0]
588
        val_x_dtype = val_x.dtype
C
channings 已提交
589 590 591 592 593 594
        
        name_ones= node.layer_name + '_ones'
        attr_ones = {
                'shape':out_shape,
                'dtype':string(val_x_dtype)
            }
595 596 597 598
        node.fluid_code.add_layer('ones',
                                  inputs=None,
                                  output=name_ones,
                                  param_attr=attr_ones)
C
channings 已提交
599 600
        inputs = {'x':name_ones,'y':val_x}
        attr = {'name':string(node.layer_name)}
601 602 603
        node.fluid_code.add_layer('elementwise_mul',
                                  inputs=inputs,
                                  output=node.layer_name,
C
channings 已提交
604 605
                                  param_attr=attr
                                 )
C
update  
channingss 已提交
606

C
channingss 已提交
607 608 609 610
    def Gather(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        indices = self.graph.get_input_node(node, idx=1, copy=True)
        indices_shape = indices.out_shapes[0]
C
Channingss 已提交
611
        axis = node.get_attr('axis', 0)
C
channingss 已提交
612
        assert len(
C
Channingss 已提交
613
            indices_shape) <= 2, "Gather op don't support dim of indice >2 "
C
channings 已提交
614
        if axis==0 and len(indices_shape)<=1:
C
channingss 已提交
615
            node.fluid_code.add_layer('gather',
C
channingss 已提交
616 617 618 619
                                      inputs={
                                          'input': val_x,
                                          'index': indices
                                      },
C
channingss 已提交
620 621
                                      output=node,
                                      param_attr=None)
C
channingss 已提交
622 623
        elif axis > 0 and len(indices_shape) <= 1:
            perm = list(range(len(val_x.out_shapes[0])))
C
channingss 已提交
624 625 626 627 628 629 630 631
            perm = [axis] + perm[:axis] + perm[axis + 1:]
            attr_trans = {'perm': perm}
            name_trans = val_x.layer_name + '_trans'
            node.fluid_code.add_layer('transpose',
                                      inputs=val_x,
                                      output=name_trans,
                                      param_attr=attr_trans)
            node.fluid_code.add_layer('gather',
C
channingss 已提交
632 633 634 635
                                      inputs={
                                          'input': name_trans,
                                          'index': indices
                                      },
C
channingss 已提交
636 637 638 639 640 641
                                      output=node,
                                      param_attr=None)
            node.fluid_code.add_layer('transpose',
                                      inputs=node,
                                      output=node,
                                      param_attr=attr_trans)
C
channings 已提交
642
        elif len(indices_shape)>1:
C
Channingss 已提交
643
            from functools import reduce
C
channings 已提交
644
            reshape_shape = reduce(lambda x,y:x*y, indices_shape)
C
Channingss 已提交
645 646 647
            node.fluid_code.add_layer('reshape',
                                      inputs=indices,
                                      output=indices,
C
channings 已提交
648 649
                                      param_attr={'shape':[reshape_shape,]})
            
C
Channingss 已提交
650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677
            perm = list(range(len(val_x.out_shapes[0])))
            perm = [axis] + perm[:axis] + perm[axis + 1:]
            attr_trans = {'perm': perm}
            name_trans = val_x.layer_name + '_trans'
            node.fluid_code.add_layer('transpose',
                                      inputs=val_x,
                                      output=name_trans,
                                      param_attr=attr_trans)
            node.fluid_code.add_layer('gather',
                                      inputs={
                                          'input': name_trans,
                                          'index': indices
                                      },
                                      output=node,
                                      param_attr=None)
            node.fluid_code.add_layer('transpose',
                                      inputs=node,
                                      output=node,
                                      param_attr=attr_trans)
            val_x_shape = val_x.out_shapes[0]
            reshaped_shape = []
            for i in perm:
                reshaped_shape.append(indices_shape[i])
            for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
                reshaped_shape.append(i)
            node.fluid_code.add_layer('reshape',
                                      inputs=node,
                                      output=node,
C
channings 已提交
678
                                      param_attr={'shape':reshaped_shape})
C
channingss 已提交
679

C
channingss 已提交
680
    def Slice(self, node):
C
channingss 已提交
681
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channings 已提交
682
        starts, ends, axes, steps = None, None, None, None
C
channingss 已提交
683 684 685
        if len(node.inputs) > 1:
            starts = self.graph.get_input_node(node, idx=1, copy=True)
            ends = self.graph.get_input_node(node, idx=2, copy=True)
C
channings 已提交
686 687 688 689 690 691 692 693 694
            if len(node.inputs)>3:
                axes = self.graph.get_input_node(node, idx=3, copy=True)
                self.omit_nodes.append(axes.layer_name)
                axes = _const_weight_or_none(axes)
            if len(node.inputs)>4: 
                steps = self.graph.get_input_node(node, idx=4, copy=True)
                self.omit_nodes.append(steps.layer_name)
                steps = _const_weight_or_none(steps)
 
C
channingss 已提交
695 696
            self.omit_nodes.append(starts.layer_name)
            self.omit_nodes.append(ends.layer_name)
C
channings 已提交
697 698
            starts = _const_weight_or_none(starts)
            ends = _const_weight_or_none(ends)
C
channingss 已提交
699 700 701 702
        else:
            starts = node.get_attr('starts')
            ends = node.get_attr('ends')
            axes = node.get_attr('axes')
C
channingss 已提交
703

C
channingss 已提交
704 705 706 707 708 709
        val_y = self.graph.get_node(node.layer.output[0], copy=True)

        shape = val_x.out_shapes[0]

        if shape is not None:
            for idx, value in enumerate(starts):
C
channingss 已提交
710 711
                if value > shape[axes[idx]]:
                    starts[idx] = shape[axes[idx]]
C
channingss 已提交
712
            for idx, value in enumerate(ends):
C
channingss 已提交
713 714
                if value > shape[axes[idx]]:
                    ends[idx] = shape[axes[idx]]
C
channingss 已提交
715 716 717 718 719 720
        attr = {"axes": axes, "starts": starts, "ends": ends}
        node.fluid_code.add_layer('slice',
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

C
update  
channingss 已提交
721
    def ConstantOfShape(self, node):
C
channingss 已提交
722
        val_shape = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
723
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
C
update  
channingss 已提交
724 725 726
        shape = _const_weight_or_none(val_shape)

        if shape is None:
C
channingss 已提交
727
            shape = node.out_shapes[0]
C
update  
channingss 已提交
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747

        assert shape is not None, (
            'given shape is neither const value nor deductible from output, '
            'this is not supported')

        value = node.get_attr('value')
        dtype = value.dtype
        value = value.tolist()
        if len(value) == 1:
            shape = [1]
            value = value[0]
            if dtype.name == 'int64':
                dtype = 'int32'
            attr = {'shape': shape, 'dtype': string(dtype), 'value': value}
            node.fluid_code.add_layer('fill_constant',
                                      inputs=None,
                                      output=node,
                                      param_attr=attr)

    def Split(self, node):
C
channingss 已提交
748 749
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
C
update  
channingss 已提交
750 751

        fluid_op = 'split'
C
channingss 已提交
752
        split = node.get_attr('split')
C
update  
channingss 已提交
753
        axis = node.get_attr('axis', 0)
C
channingss 已提交
754 755 756 757 758
        attr = {
            'num_or_sections': split,
            'dim': axis,
            'name': string(node.layer_name)
        }
C
channings 已提交
759
        
C
update  
channingss 已提交
760
        node.fluid_code.add_layer('split',
C
channingss 已提交
761 762
                                  inputs=val_x,
                                  output=val_y,
C
update  
channingss 已提交
763 764 765
                                  param_attr=attr)

    def Reshape(self, node):
C
channingss 已提交
766 767
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_shape = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
768 769
        val_reshaped = self.graph.get_node(node.layer.output[0], copy=True)
        shape = None
C
channingss 已提交
770

C
update  
channingss 已提交
771 772
        if isinstance(val_shape, ONNXGraphDataNode):
            self.omit_nodes.append(val_shape.layer_name)
C
channings 已提交
773
            
774
        attr = {'name': string(node.layer_name)}
C
update  
channingss 已提交
775 776
        # catch dynamic graph shape
        if isinstance(val_shape, ONNXGraphNode):
C
channingss 已提交
777
            shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
778
            if val_shape.dtype == 'int64':
C
channings 已提交
779
                val_shape_cast  = val_shape.layer_name+'_cast'
780
                node.fluid_code.add_layer('cast',
C
channings 已提交
781 782 783 784
                                  inputs=val_shape,
                                  output=val_shape_cast,
                                  param_attr={'dtype':string('int32')})
            
785 786 787
                attr['actual_shape'] = val_shape_cast
            else:
                attr['actual_shape'] = val_shape
C
channings 已提交
788

C
update  
channingss 已提交
789
        if shape is None:
C
channingss 已提交
790
            shape = val_reshaped.out_shapes[0]
C
update  
channingss 已提交
791 792

        if shape is None:
C
channingss 已提交
793
            shape = [1, -1]
C
update  
channingss 已提交
794 795 796
            _logger.warning(
                'in %s(%s -> Reshape -> %s): '
                'input "shape" not inferred, use [1, -1] as dummy value, '
C
channingss 已提交
797 798
                'the behavior of Paddle fluid maybe undefined', node.layer_name,
                val_x.layer_name, val_reshaped.layer_name)
C
channings 已提交
799
        
800
        attr['shape'] = shape
C
update  
channingss 已提交
801 802 803 804 805 806
        node.fluid_code.add_layer('reshape',
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Cast(self, node):
C
channingss 已提交
807
        val_input = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
        val_output = self.graph.get_node(node.layer.output[0], copy=True)

        dtype = node.get_attr('to')
        if not isinstance(dtype, np.dtype):
            dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]

        output_dtype = val_output.dtype
        if output_dtype:
            assert dtype == output_dtype, 'dtype of to unmatches output'
        attr = {'dtype': string(dtype)}
        node.fluid_code.add_layer('cast',
                                  inputs=val_input,
                                  output=node,
                                  param_attr=attr)

    def AveragePool(self, node):
C
channingss 已提交
824
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
channingss 已提交
825 826

        auto_pad = node.get_attr('auto_pad', 'NOTSET')
C
update  
channingss 已提交
827 828 829 830 831 832 833 834
        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))
        pads = node.get_attr('pads', [0] * (poolnd * 2))
        fluid_op = 'pool{}d'.format(poolnd)
        assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
C
channingss 已提交
835

C
channingss 已提交
836 837
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
838
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
channingss 已提交
839
            input_shape = val_x.out_shapes[0]
C
channingss 已提交
840 841 842 843 844 845
            pad_h = get_same_padding(input_shape[2], kernel_shape[0],
                                     strides[0])
            pad_w = get_same_padding(input_shape[3], kernel_shape[1],
                                     strides[1])
            attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}

C
update  
channingss 已提交
846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863
        attr = {
            "pool_size": kernel_shape,
            "pool_type": string('avg'),
            "pool_stride": strides,
            "pool_padding": paddings,
            "ceil_mode": ceil_mode,
            "exclusive": 'True',
            "name": string(node.layer_name)
        }

        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Concat(self, node):
        inputs = []
        for i in range(len(node.layer.input)):
C
channingss 已提交
864
            ipt = self.graph.get_input_node(node, idx=i, copy=True)
C
update  
channingss 已提交
865 866 867 868 869 870 871
            if isinstance(ipt, str):
                inputs.append(ipt)
            else:
                inputs.append(ipt.layer_name)
        axis = node.get_attr('axis')
        attr = {'axis': axis}
        node.fluid_code.add_layer('concat',
C
channingss 已提交
872
                                  inputs=inputs,
C
update  
channingss 已提交
873 874 875 876
                                  output=node,
                                  param_attr=attr)

    def Flatten(self, node):
C
channingss 已提交
877
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
878 879 880 881 882 883 884 885
        axis = node.get_attr('axis', 1)
        attr = {"axis": str(axis), "name": string(node.layer_name)}
        node.fluid_code.add_layer('flatten',
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Gemm(self, node):
C
channingss 已提交
886 887 888
        val_a = self.graph.get_input_node(node, idx=0, copy=True)
        val_b = self.graph.get_input_node(node, idx=1, copy=True)
        val_c = self.graph.get_input_node(node, idx=2, copy=True)
C
update  
channingss 已提交
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905

        alpha = node.get_attr('alpha', 1.)  # optional
        beta = node.get_attr('beta', 1.)  # optional
        trans_a = bool(node.get_attr('transA', 0))  # optional
        trans_b = bool(node.get_attr('transB', 0))  # optional
        val_mm = node.layer_name + '_mm'
        matmul_inputs = {"x": val_a, "y": val_b}
        attr_matmul = {
            "transpose_x": trans_a,
            "transpose_y": trans_b,
            "alpha": alpha,
            "name": string(val_mm)
        }
        node.fluid_code.add_layer('matmul',
                                  inputs=matmul_inputs,
                                  output=val_mm,
                                  param_attr=attr_matmul)
C
channingss 已提交
906

C
update  
channingss 已提交
907 908 909 910 911 912 913 914 915
        if beta != 0:
            if beta == 1.:
                add_inputs = {"x": val_mm, "y": val_c}
                attr = {"name": string(node.layer_name)}
                node.fluid_code.add_layer("elementwise_add",
                                          inputs=add_inputs,
                                          output=node,
                                          param_attr=attr)
            else:
C
channingss 已提交
916 917 918 919 920 921 922 923 924 925 926 927 928
                var_beta = node.layer_name + '_beta'
                matmul_beta_inputs = {"x": val_c, "y": var_beta}
                node.fluid_code.add_layer("Constant",
                                          inputs=matmul_beta_inputs,
                                          output=var_beta,
                                          param_attr={'value': beta})

                add_inputs = {"x": val_mm, "y": var_beta}
                attr = {"name": string(node.layer_name)}
                node.fluid_code.add_layer("elementwise_add",
                                          inputs=add_inputs,
                                          output=node,
                                          param_attr=attr)
C
update  
channingss 已提交
929 930

    def Sum(self, node):
931
        val_inps = node.layer.input
932
        inputs = {
C
channingss 已提交
933 934
            "x": self.graph.get_input_node(node, idx=0, copy=True),
            "y": self.graph.get_input_node(node, idx=1, copy=True),
935 936
        }
        node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
937

C
channingss 已提交
938 939
        for idx, ipt in enumerate(val_inps[2:]):
            y = self.graph.get_input_node(node, idx=idx, copy=True)
940 941
            inputs = {
                "x": node.layer_name,
C
channingss 已提交
942
                "y": y,
943 944 945 946
            }
            node.fluid_code.add_layer("elementwise_add",
                                      inputs=inputs,
                                      output=node)
C
update  
channingss 已提交
947 948

    def MatMul(self, node):
C
channingss 已提交
949 950
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
951 952 953 954 955 956 957 958
        inputs = {"x": val_x, "y": val_y}
        attr = {"name": string(node.layer_name)}
        node.fluid_code.add_layer("matmul",
                                  inputs=inputs,
                                  output=node,
                                  param_attr=attr)

    def BatchNormalization(self, node):
C
channingss 已提交
959 960 961 962 963
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_scale = self.graph.get_input_node(node, idx=1, copy=True)
        val_b = self.graph.get_input_node(node, idx=2, copy=True)
        val_mean = self.graph.get_input_node(node, idx=3, copy=True)
        val_var = self.graph.get_input_node(node, idx=4, copy=True)
C
update  
channingss 已提交
964 965 966 967 968 969 970 971 972

        self.omit_nodes.append(val_scale.layer_name)
        self.omit_nodes.append(val_b.layer_name)
        self.omit_nodes.append(val_mean.layer_name)
        self.omit_nodes.append(val_var.layer_name)

        momentum = node.get_attr('momentum', .9)
        epsilon = node.get_attr('epsilon', 1e-5)

C
channingss 已提交
973 974
        # Attribute: spatial is used in BatchNormalization-1,6,7
        spatial = bool(node.get_attr('spatial'))
C
update  
channingss 已提交
975 976 977 978
        attr = {
            "momentum": momentum,
            "epsilon": epsilon,
            "data_layout": string('NCHW'),
C
channingss 已提交
979
            "is_test": True,
C
update  
channingss 已提交
980 981 982 983
            "param_attr": string(val_scale.layer_name),
            "bias_attr": string(val_b.layer_name),
            "moving_mean_name": string(val_mean.layer_name),
            "moving_variance_name": string(val_var.layer_name),
C
channingss 已提交
984
            "use_global_stats": spatial,
C
update  
channingss 已提交
985 986 987 988 989 990 991 992
            "name": string(node.layer_name)
        }
        node.fluid_code.add_layer("batch_norm",
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Transpose(self, node):
C
channingss 已提交
993
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
994 995 996 997 998 999 1000 1001
        perm = node.get_attr('perm')
        attr = {'perm': perm, "name": string(node.layer_name)}
        node.fluid_code.add_layer("transpose",
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Relu(self, node):
C
channingss 已提交
1002
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
1003 1004 1005 1006 1007 1008 1009
        attr = {"name": string(node.layer_name)}
        node.fluid_code.add_layer("relu",
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def PRelu(self, node):
C
channingss 已提交
1010 1011
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_slope = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
1012

C
channingss 已提交
1013 1014 1015 1016 1017 1018 1019 1020 1021 1022
        mode = 'channel'
        shape_slope = val_slope.out_shapes[0]
        if len(shape_slope) == 1:
            mode = 'all'
        elif len(shape_slope) > 2:
            mode = 'element'
        attr = {
            "param_attr": string(val_slope.layer_name),
            'mode': string(mode)
        }
C
update  
channingss 已提交
1023 1024 1025 1026 1027 1028
        node.fluid_code.add_layer("prelu",
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

    def Squeeze(self, node):
C
channingss 已提交
1029 1030 1031
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        axes = node.get_attr('axes')
        attr = {'axes': axes, "name": string(node.layer_name)}
C
update  
channingss 已提交
1032 1033 1034 1035
        node.fluid_code.add_layer("squeeze",
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)
C
channings 已提交
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080
        
    def Equal(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_y = self.graph.get_input_node(node, idx=1, copy=True)
        node.fluid_code.add_layer("equal",
                                  inputs={'x':val_x, 'y':val_y},
                                  output=node,
                                  param_attr=None)
    def Where(self, node):
        condition = self.graph.get_input_node(node, idx=0, copy=True)
        val_x = self.graph.get_input_node(node, idx=1, copy=True)
        val_y = self.graph.get_input_node(node, idx=2, copy=True)
        
        not_condition = condition.layer_name + '_not'
        node.fluid_code.add_layer("logical_not",
                                  inputs=condition,
                                  output=not_condition,
                                  param_attr=None)
        cast_not_condition = not_condition+'_cast' 
        node.fluid_code.add_layer("cast",
                                  inputs=not_condition,
                                  output=cast_not_condition,
                                  param_attr={'dtype':string(val_x.dtype)})
        cast_condition = condition.layer_name + '_cast'
        node.fluid_code.add_layer("cast",
                                  inputs=condition,
                                  output=cast_condition,
                                  param_attr={'dtype':string(val_x.dtype)})
        mul_val_x = val_x.layer_name + '_mul' 
        node.fluid_code.add_layer("elementwise_mul",
                                  inputs={'x':val_x,'y':cast_condition},
                                  output=mul_val_x,
                                  param_attr=None)
        
        mul_val_y = val_y.layer_name + '_mul'
        node.fluid_code.add_layer("elementwise_mul",
                                  inputs={'x':val_y,'y':cast_not_condition},
                                  output=mul_val_y,
                                  param_attr=None)
        
        node.fluid_code.add_layer("elementwise_add",
                                  inputs={'x':mul_val_x,'y':mul_val_y},
                                  output=node,
                                  param_attr=None)
        
C
update  
channingss 已提交
1081
    def Identity(self, node):
C
channingss 已提交
1082
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
1083
        node.fluid_code.add_layer("assign", inputs=val_x, output=node)
C
channings 已提交
1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
        
    def Tile(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
        repeats = _const_weight_or_none(val_repeats)
        assert repeats is not None, 'for OP:Tile, only const repeats supported'
        
        if isinstance(repeats, int):
            repeats = [repeats]
            
        attr = {
            'expand_times':repeats,
            "name": string(node.layer_name),
        }
        node.fluid_code.add_layer("expand", 
                                  inputs=val_x, 
                                  output=node,
                                  param_attr=attr)
        
C
update  
channingss 已提交
1103
    def MaxPool(self, node):
C
channingss 已提交
1104
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
1105

C
channingss 已提交
1106
        auto_pad = node.get_attr('auto_pad', 'NOTSET')
C
update  
channingss 已提交
1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
        assert node.get_attr(
            "dilations") is None, 'only dilations = 0 is supported'  # optional

        kernel_shape = node.get_attr("kernel_shape")
        poolnd = len(kernel_shape)
        strides = node.get_attr("strides")
        pad_mode = node.get_attr("pads")
        ceil_mode = bool(node.get_attr('ceil_mode', 0))  # optional
        pads = node.get_attr('pads', [0] * (poolnd * 2))  # optional
        fluid_op = 'pool{}d'.format(poolnd)
        assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
C
channingss 已提交
1118

C
channingss 已提交
1119 1120
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
1121
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
channingss 已提交
1122
            input_shape = val_x.out_shapes[0]
C
channingss 已提交
1123 1124 1125 1126 1127 1128
            pad_h = get_same_padding(input_shape[2], kernel_shape[0],
                                     strides[0])
            pad_w = get_same_padding(input_shape[3], kernel_shape[1],
                                     strides[1])
            attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}

C
update  
channingss 已提交
1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142
        attr = {
            "pool_size": kernel_shape,
            "pool_type": string("max"),
            "pool_stride": strides,
            "pool_padding": paddings,
            "ceil_mode": ceil_mode,
            "name": string(node.layer_name),
            "exclusive": False
        }
        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)

C
channings 已提交
1143
    def _global_pool(self, node):
C
channingss 已提交
1144
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
C
update  
channingss 已提交
1145
        val_y = self.graph.get_node(node.layer.output[0], copy=True)
C
channingss 已提交
1146 1147
        input_shape = val_x.out_shapes[0]
        output_shape = val_y.out_shapes[0]
C
update  
channingss 已提交
1148 1149 1150 1151 1152 1153 1154
        assert input_shape is not None or output_shape is not None, 'poolnd not inferred'  # N
        if input_shape:
            poolnd = len(input_shape) - 2  # NC...
        elif output_shape:
            poolnd = len(output_shape) - 2  # NC...
        assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
        fluid_op = 'pool{}d'.format(poolnd)
C
channings 已提交
1155 1156 1157 1158 1159 1160 1161
        
        pool_type = None
        if node.layer.op_type == 'GlobalMaxPool':
            pool_type = 'max'
        elif node.layer.op_type == 'GlobalAveragePool':
            pool_type = 'avg'

C
update  
channingss 已提交
1162
        attr = {
C
channings 已提交
1163
            "pool_type": string(pool_type),
C
update  
channingss 已提交
1164 1165 1166 1167 1168 1169 1170
            "global_pooling": True,
            "name": string(node.layer_name)
        }
        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)
C
channings 已提交
1171 1172 1173 1174 1175 1176 1177
        
    def GlobalMaxPool(self, node):
        self._global_pool(node)
        
    def GlobalAveragePool(self, node):
        self._global_pool(node)
        
C
update  
channingss 已提交
1178
    def Conv(self, node):
C
channingss 已提交
1179 1180
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
C
update  
channingss 已提交
1181 1182 1183 1184 1185 1186
        val_y = self.graph.get_node(node.layer.output[0], copy=True)

        self.omit_nodes.append(val_w.layer_name)

        has_bias = len(node.layer.input) == 3
        if has_bias:
C
channingss 已提交
1187
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
C
update  
channingss 已提交
1188 1189 1190
            self.omit_nodes.append(val_b.layer_name)
        auto_pad = node.get_attr('auto_pad', 'NOTSET')

C
channingss 已提交
1191
        kernel_shape = node.get_attr('kernel_shape')
C
update  
channingss 已提交
1192 1193
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
C
channingss 已提交
1194
        num_out_channels = val_w.out_shapes[0][0]  # OI...
C
update  
channingss 已提交
1195 1196 1197 1198 1199 1200 1201
        fluid_op = 'conv{}d'.format(convnd)

        num_groups = node.get_attr('group', 1)
        strides = node.get_attr('strides', [1] * convnd)  # optional
        dilations = node.get_attr('dilations', [1] * convnd)  # optional
        pads = node.get_attr('pads', [0] * (convnd * 2))  # optional

C
channingss 已提交
1202
        input_shape = val_x.out_shapes[0]
C
update  
channingss 已提交
1203 1204
        paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)

C
channingss 已提交
1205
        if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
C
update  
channingss 已提交
1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229
            pad_h = get_same_padding(input_shape[2], kernel_shape[0],
                                     strides[0])
            pad_w = get_same_padding(input_shape[3], kernel_shape[1],
                                     strides[1])
            attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}

        attr = {
            "num_filters": num_out_channels,
            "filter_size": kernel_shape,
            "stride": strides,
            "padding": paddings,
            "dilation": dilations,
            "groups": num_groups,
            'param_attr': string(val_w.layer_name),
            "name": string(node.layer_name)
        }
        if has_bias:
            attr["bias_attr"] = string(val_b.layer_name)
        else:
            attr["bias_attr"] = False
        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)
C
channingss 已提交
1230 1231

    def ConvTranspose(self, node):
C
channingss 已提交
1232 1233
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
C
channingss 已提交
1234 1235 1236 1237
        val_b = None
        if len(node.layer.input)>2:
            val_b = self.graph.get_input_node(node, idx=2, copy=True)
            self.omit_nodes.append(val_b.layer_name)
C
channingss 已提交
1238 1239 1240 1241 1242 1243
        self.omit_nodes.append(val_w.layer_name)

        val_y = self.graph.get_node(node.layer.output[0], copy=True)

        auto_pad = node.get_attr('auto_pad', 'NOTSET')
        out_padding = node.get_attr('output_padding', [0, 0])
C
channingss 已提交
1244
        kernel_shape = node.get_attr('kernel_shape')
C
channingss 已提交
1245 1246 1247
        assert kernel_shape, 'kernel_shape not inferred'
        convnd = len(kernel_shape)
        assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
C
channingss 已提交
1248
        num_out_channels = val_w.out_shapes[0][1]
C
channingss 已提交
1249 1250
        fluid_op = 'conv{}d_transpose'.format(convnd)

C
channingss 已提交
1251 1252 1253 1254 1255
        num_groups = node.get_attr('group', 1)
        strides = node.get_attr('strides', [1] * convnd)
        dilations = node.get_attr('dilations', [1] * convnd)
        output_size = node.get_attr('output_shape', [])
        pads = node.get_attr('pads', [0] * (convnd * 2))
C
channingss 已提交
1256 1257 1258 1259

        paddings, var_x = self._pad_if_asymmetric(node, pads, val_x)

        output_size = [0, 0]
C
channingss 已提交
1260

C
channingss 已提交
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275
        output_size[0] = (val_x.out_shapes[0][2] -
                          1) * strides[0] - 2 * paddings[0] + dilations[0] * (
                              kernel_shape[0] - 1) + 1 + out_padding[0]
        output_size[1] = (val_x.out_shapes[0][3] -
                          1) * strides[1] - 2 * paddings[1] + dilations[1] * (
                              kernel_shape[1] - 1) + 1 + out_padding[1]
        attr = {
            'num_filters': num_out_channels,
            'output_size': output_size or None,
            'filter_size': kernel_shape,
            'padding': paddings,
            'stride': strides,
            'dilation': dilations,
            'groups': num_groups,
            'param_attr': string(val_w.layer_name),
C
channingss 已提交
1276
            'bias_attr': None if val_b is None else string(val_b.layer_name),
C
channingss 已提交
1277 1278 1279 1280 1281 1282
            'name': string(node.layer_name),
        }
        node.fluid_code.add_layer(fluid_op,
                                  inputs=val_x,
                                  output=node,
                                  param_attr=attr)
C
channings 已提交
1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428

    def GRU(self, node):
        val_x = self.graph.get_input_node(node, idx=0, copy=True)
        val_w = self.graph.get_input_node(node, idx=1, copy=True)
        val_r = self.graph.get_input_node(node, idx=2, copy=True)
        
        val_b = None
        val_len = None
        val_xh = None
        miss_arg_num = 0
        num_ipt = len(node.layer.input)
        if num_ipt>3 and node.layer.input[3] != '':
            val_b = self.graph.get_input_node(node, idx=3, copy=True)
        else:
            miss_arg_num += 1
        if num_ipt>4 and node.layer.input[4] != '':
            val_len = self.graph.get_input_node(node, idx=4-miss_arg_num, copy=True)
        else:
            miss_arg_num += 1
        if num_ipt>5 and node.layer.input[5] != '':
            val_xh = self.graph.get_input_node(node, idx=5-miss_arg_num, copy=True)
            
        data, dtype, shape = self.get_dynamic_shape(val_x.layer_name)
        
        x_shape = val_x.out_shapes[0]
        
        assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
        assert node.get_attr('clip', None) is None, 'clipping not supported'

        hidden_size = node.get_attr('hidden_size', None)
        if hidden_size is None:
            r_shape = val_r.out_shapes[0]
            if r_shape:
                hidden_size = r_shape[-1]
        if hidden_size is None:
            w_shape = var_w.out_shapes[0]
            if w_shape:
                hidden_size = w_shape[-2] // 3
        if hidden_size is None and val_b:
            b_shape = val_b.out_shapes[0]
            if b_shape:
                hidden_size = b_shape[-1] // 6
        if hidden_size is None and val_xh:
            xh_shape = val_xh.out_shapes[0]
            if xh_shape:
                hidden_size = xh_shape[-1]
            
        direction = node.get_attr('direction', 'forward') 
        assert direction != 'bidirectional', 'direction = bidirectional not supported'
        
        activations = node.get_attr('activations', ['Sigmoid', 'Tanh'])
        assert len(activations) == 2, 'bidirectional operation not supported'
        
        assert node.get_attr(
        'linear_before_reset',
        0) == 0, 'only linear_before_reset = 0 supported'
        
        activations = [s.lower() for s in activations]
        gate_activation, candidate_activation = activations
        is_reverse = direction == 'reverse'
        
        var_x0 = node.layer_name + '_x0'
        node.fluid_code.add_layer('squeeze',
                                  inputs=val_x,
                                  output=var_x0,
                                  param_attr={'axes': [1],'name':string(var_x0)})
        
        var_w0 = node.layer_name + '_w0'
        node.fluid_code.add_layer('squeeze',
                                  inputs=val_w,
                                  output=var_w0,
                                  param_attr={'axes': [0],'name':string(var_w0)})
        
        var_fc = node.layer_name + '_fc'
        var_mm = (node.layer_name + '_mm') if val_b else var_fc
        node.fluid_code.add_layer('matmul',
                                  inputs={'x':var_x0, 'y':var_w0},
                                  output=var_mm,
                                  param_attr={'transpose_x': 0,'transpose_y': 1,'name':string(var_mm)})
        
        var_r0 = node.layer_name + '_r0'
        node.fluid_code.add_layer('squeeze',
                                  inputs=val_r,
                                  output=var_r0,
                                  param_attr={'axes': [0],'name':string(var_r0)})
        
        var_r0t = node.layer_name + '_r0t' 
        
        node.fluid_code.add_layer('transpose',
                                  inputs=var_r0,
                                  output=var_r0t,
                                  param_attr={'perm': [1, 0],'name':string(var_r0t)})
        if val_b:
            var_bi = node.layer_name + '_bi'
            var_bh = node.layer_name + '_bh'
            node.fluid_code.add_layer('split',
                                  inputs=val_b,
                                  output=var_bi+','+var_bh,
                                  param_attr={'axis': 1,
                                              'split': [hidden_size * 3, hidden_size * 3],
                                              'name':string(node.layer_name+'.b/split')})
            var_bi0 = node.layer_name + '_bi0'
            node.fluid_code.add_layer('squeeze',
                                  inputs=var_bi,
                                  output=var_bi0,
                                  param_attr={'axes': [0],'name':string(var_bi0)})
            
            node.fluid_code.add_layer('elmentwise_add',
                                  inputs=[var_mm, var_bi0],
                                  output=var_fc,
                                  param_attr={'axes': 1,'name':string(node.layer_name+'.i/bias')})

        if val_xh:
            var_xh0 = node.layer_name + '_xh0'
            node.fluid_code.add_layer('squeeze',
                                  inputs=val_xh,
                                  output=var_xh0,
                                  param_attr={'axes': [1],'name':string(var_xh0)})
        var_y00 = node.layer_name + '_y00'
        
        attr={
            'origin_mode':True,
            'h_0': var_xh0 if val_xh else None,
            'is_reverse':is_reverse,
            'gate_activation':string(gate_activation),
            'candidate_activation':string(candidate_activation),
            'param_attr':string(var_r0t),
            'bias_attr':string(var_bh) if val_b else False,
        }
        node.fluid_code.add_layer('dynamic_gru',
                                  inputs=var_fc +','+ str(hidden_size),
                                  output=var_y00,
                                  param_attr=attr)
        
        num_opt = len(node.layer.output)
        
        if num_opt>0 and node.layer.output[0] != '':
            node.fluid_code.add_layer('unsqueeze',
                                  inputs=var_y00,
                                  output=node.layer.output[0],
                                  param_attr={'axes': [1, 1],'name':string(node.layer.output[0])})
        if num_opt>1 and node.layer.output[1] != '':
            node.fluid_code.add_layer('unsqueeze',
                                  inputs=var_y00,
                                  output=node.layer.output[1],
                                  param_attr={'axes': [1, 1],'name':string(node.layer.output[1])})