caffe_op_mapper.py 47.4 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import sys
S
SunAhong1993 已提交
16 17 18 19 20
import numbers
import numpy as np
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
from x2paddle.core.program import PaddleGraph 
S
SunAhong1993 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
from x2paddle.decoder.caffe_decoder import CaffeGraphNode


def _adjust_parameters(node):
    data = node.data
    # When using the protobuf-backend, each parameter initially has four dimensions.
    # In certain cases (like FC layers), we want to eliminate the singleton dimensions.
    # This implementation takes care of the common cases. However, it does leave the
    # potential for future issues.
    # The Caffe-backend does not suffer from this problem.
    data = list(data)

    squeeze_indices = [1]  # Squeeze biases.
    if node.layer_type == 'InnerProduct':
        squeeze_indices.append(0)  # Squeeze FC.

    for idx in squeeze_indices:
        if idx >= len(data):
            continue

        d = data[idx]
        assert len(
            d.shape
        ) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
            str(d.shape))

        shape_old = d.shape
        sq_axis = None
        if idx == 0:
            sq_axis = (0, 1)
        elif idx == 1:
            sq_axis = (0, 1, 2)
        else:
            continue

        data[idx] = np.squeeze(d, axis=sq_axis)
        shape_new = data[idx].shape
    return data

def _get_kernel_parameters(kind, params):
    assert kind in ["Convolution", "Pooling", "Deconvolution", "ConvolutionDepthwise"]
    [k_h, k_w] = [1, 1]
    if isinstance(params.kernel_size, numbers.Number):
        [k_h, k_w] = [params.kernel_size] * 2
    elif len(params.kernel_size) > 0:
        k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[
            0]
        k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
            len(params.kernel_size) - 1]
    elif params.kernel_h > 0 or params.kernel_w > 0:
        k_h = params.kernel_h
        k_w = params.kernel_w
    [s_h, s_w] = [1, 1]
    if isinstance(params.stride, numbers.Number):
        [s_h, s_w] = [params.stride] * 2
    elif len(params.stride) > 0:
        s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
        s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
            params.stride) - 1]
    elif params.stride_h > 0 or params.stride_w > 0:
        s_h = params.stride_h
        s_w = params.stride_w
    [p_h, p_w] = [0, 0]
    if isinstance(params.pad, numbers.Number):
        [p_h, p_w] = [params.pad] * 2
    elif len(params.pad) > 0:
        p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
        p_w = params.pad_w if params.pad_w > 0 else params.pad[len(
            params.pad) - 1]
    elif params.pad_h > 0 or params.pad_w > 0:
        p_h = params.pad_h
        p_w = params.pad_w
    dila_h = dila_w = 1
    group = 1
    c_o = 1
    if kind in ["Convolution", "Deconvolution", "ConvolutionDepthwise"]:
        if kind in ["Convolution", "Deconvolution"]:
            c_o = params.num_output
        dila_len = len(params.dilation)
        if dila_len == 2:
            dila_h = params.dilation[0]
            dila_w = params.dilation[1]
        elif dila_len == 1:
            dila_h = dila_w = params.dilation[0]
        else:
            assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
                dila_len)
    if kind in ['Convolution', 'Deconvolution']:
        group = params.group
    kernel = [k_h, k_w]
    stride = [s_h, s_w]
    pad = [p_h, p_w]
    dilation = [dila_h, dila_w]
    return c_o, kernel, stride, pad, dilation, group
S
SunAhong1993 已提交
115 116 117 118


class CaffeOpMapper(OpMapper):
    directly_map_ops = {
S
SunAhong1993 已提交
119 120
        'Sigmoid': ['paddle.nn.layer.Sigmoid'],
        'TanH': ['paddle.nn.Tanh'],
S
SunAhong1993 已提交
121 122 123 124 125
    }

    def __init__(self, decoder):
        super(CaffeOpMapper, self).__init__()
        self.graph = decoder.caffe_graph
S
SunAhong1993 已提交
126 127
        if not self.op_checker():
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
128
        self.params = dict()
S
SunAhong1993 已提交
129 130
        self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="caffe")
        self.paddle_graph.outputs = self.graph.output_nodes
S
SunAhong1993 已提交
131 132 133
        self.input_index = 0 
        self.inputs_info = {}
        self.nn_name2id = {}
S
SunAhong1993 已提交
134 135 136 137 138 139 140 141
        print("Total nodes: {}".format(
            sum([
                isinstance(node, CaffeGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
S
SunAhong1993 已提交
142 143 144 145 146 147 148
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self, op):
                func = getattr(self, op)
                func(node)
            elif op in self.directly_map_ops:
                self.directly_map(node)
S
SunAhong1993 已提交
149
        print("\nNodes converted.")
S
SunAhong1993 已提交
150 151 152
        self.paddle_graph.set_name(self.graph.graph_name)
        self.paddle_graph.set_parameters(self.params)
        self.paddle_graph.set_inputs_info(self.inputs_info)
S
SunAhong1993 已提交
153 154 155 156 157 158
                
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
S
SunAhong1993 已提交
159
            if not hasattr(self, op) and op not in self.directly_map_ops:
S
SunAhong1993 已提交
160 161 162 163
                unsupported_ops.add(op)
        if len(unsupported_ops) == 0:
            return True
        else:
S
SunAhong1993 已提交
164 165 166
            if len(unsupported_ops) > 0:
                print("\n========= {} OPs are not supported yet ===========".format(
                    len(unsupported_ops)))
S
SunAhong1993 已提交
167
            for op in unsupported_ops:
S
SunAhong1993 已提交
168
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
169
            return False
S
SunAhong1993 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
        
    def directly_map(self, node):
        inputs = node.layer.input
        assert len(inputs) == 1, 'directly_map error with multi inputs'
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_input_node(node, 0)
        paddle_op = op_info[0]
        if paddle_op.startswith("paddle.nn"):
            op_name = paddle_op[10:].lower()
            op_name = name_generator(op_name, self.nn_name2id)
            output_name = node.name
            layer_outputs = [op_name, output_name]
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=layer_outputs)
S
SunAhong1993 已提交
186
        else:
S
SunAhong1993 已提交
187 188 189 190
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=[node.name])
S
SunAhong1993 已提交
191 192

    def Input(self, node):
S
SunAhong1993 已提交
193
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
194 195 196 197 198 199 200
            "paddle.to_tensor",
            inputs={},
            outputs=[node.layer_name],
            data="x{}".format(self.input_index))
        shape = list(node.layer.input_param.shape[0].dim)[1:]
        self.inputs_info["x{}".format(self.input_index)] = [[-1] + shape, "float32"]
        self.input_index += 1
S
SunAhong1993 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        
    def MemoryData(self, node):
        params = node.layer.memory_data_param
        transform_params = node.layer.transform_param
        self.paddle_graph.add_layer(
            "paddle.to_tensor",
            inputs={},
            outputs=[node.layer_name],
            data="x{}".format(self.input_index))
        shape = list()
        shape.append(params.batch_size)
        shape.append(params.channels)
        if hasattr(transform_params, "crop_size"):
            shape.append(transform_params.crop_size)
            shape.append(transform_params.crop_size)
        else:
            shape.append(params.width)
            shape.append(params.height)
        self.inputs_info["x{}".format(self.input_index)] = [shape, "float32"]
        self.input_index += 1
S
SunAhong1993 已提交
221 222

    def Convolution(self, node):
S
SunAhong1993 已提交
223
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
224 225 226 227
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
228
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
229 230 231 232 233 234 235
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
236
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
237 238 239
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
240
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
241 242 243 244 245
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Convolution node\'s input is not 1."
S
SunAhong1993 已提交
246
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
247
        layer_attrs = {
S
SunAhong1993 已提交
248
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
249 250 251 252 253 254 255 256 257
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
258
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
259
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
260
            inputs={"input": input.name},
S
SunAhong1993 已提交
261 262
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
263 264 265 266
        
    def DepthwiseConvolution(self, node):
        node.layer_type = "ConvolutionDepthwise"
        self.ConvolutionDepthwise(node)
S
SunAhong1993 已提交
267 268

    def Deconvolution(self, node):
S
SunAhong1993 已提交
269
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
270 271 272 273
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
274
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
275 276 277 278 279 280 281
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
282
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
283 284 285
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
286
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
287 288 289 290 291
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
292
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
293
        layer_attrs = {
S
SunAhong1993 已提交
294
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
295 296 297 298 299 300 301 302 303
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
304
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
305
            "paddle.nn.Conv2DTranspose",
S
SunAhong1993 已提交
306
            inputs={"input": input.name},
S
SunAhong1993 已提交
307 308 309 310
            outputs=layer_outputs,
            **layer_attrs)
        
    def ConvolutionDepthwise(self, node):
S
SunAhong1993 已提交
311
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
312 313 314 315
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
316
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
317
            node.layer_type, params)
S
SunAhong1993 已提交
318 319
        out_channel = params.num_output if params.num_output is not None else node.in_shapes[0][1]
        in_channel = node.in_shapes[0][1]
S
SunAhong1993 已提交
320 321 322 323 324 325 326 327
        group = int(in_channel / (in_channel / out_channel)) if in_channel > out_channel else int(in_channel /
                                                                (out_channel / in_channel))
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
328
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
329 330 331
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
332
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
333 334 335 336 337
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
338
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
339 340 341 342 343 344 345 346 347 348 349
        layer_attrs = {
            "in_channels": in_channel,
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
350
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
351
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
352
            inputs={"input": input.name},
S
SunAhong1993 已提交
353 354 355 356
            outputs=layer_outputs,
            **layer_attrs)

    def Pooling(self, node):
S
SunAhong1993 已提交
357
        pool2d_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
358 359 360
        output_name = node.layer_name
        layer_outputs = [pool2d_name, output_name]
        params = node.layer.pooling_param
S
SunAhong1993 已提交
361 362 363
        ceil_mode = getattr(params, "ceil_mode", True)
        if not hasattr(params, 'ceil_mode'):
            ceil_mode = True if getattr(params, "round_mode", 0) == 0 else False
S
SunAhong1993 已提交
364 365
        global_pool = getattr(params, "global_pooling", False)
        kernel_default = [1, 1]
S
SunAhong1993 已提交
366
        channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
367 368 369 370 371 372 373
            node.layer_type, params)
        if params.pool == 0:
            pool_type = "max"
        else:
            pool_type = "avg"
        assert len(
            node.inputs) == 1, "The count of Pooling node\'s input is not 1."
S
SunAhong1993 已提交
374
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
375 376 377 378
        if global_pool:
            if kernel[0] == 0:
                kernel = [1, 1]
            if params.pool == 0:
S
SunAhong1993 已提交
379
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
380
                    "paddle.nn.AdaptiveMaxPool2D",
S
SunAhong1993 已提交
381
                    inputs={"input": input.name},
S
SunAhong1993 已提交
382 383 384
                    outputs=layer_outputs,
                    output_size=kernel)
            else:
S
SunAhong1993 已提交
385
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
386
                    "paddle.nn.AdaptiveAvgPool2D",
S
SunAhong1993 已提交
387
                    inputs={"input": input.name},
S
SunAhong1993 已提交
388 389
                    outputs=layer_outputs,
                    output_size=kernel)
S
SunAhong1993 已提交
390
        else:
S
SunAhong1993 已提交
391
            layer_attrs = {
S
SunAhong1993 已提交
392 393 394
                'kernel_size': kernel,
                'stride': stride,
                'padding': pad,
S
SunAhong1993 已提交
395 396
                'ceil_mode': ceil_mode,
            }
S
SunAhong1993 已提交
397 398 399 400 401 402 403 404 405 406 407 408
            if params.pool == 0:
                self.paddle_graph.add_layer(
                    "paddle.nn.MaxPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
            else:
                self.paddle_graph.add_layer(
                    "paddle.nn.AvgPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
S
SunAhong1993 已提交
409 410

    def LRN(self, node):
S
SunAhong1993 已提交
411 412 413
        lrn_name = name_generator("lrn", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [lrn_name, output_name]
S
SunAhong1993 已提交
414
        assert len(node.inputs) == 1, "The count of LRN node\'s input is not 1."
S
SunAhong1993 已提交
415
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
416 417 418 419
        params = node.layer.lrn_param
        assert params.local_size % 2 == 1
        alpha = params.alpha / float(params.local_size)
        layer_attrs = {
S
fix  
SunAhong1993 已提交
420 421
            "n": params.local_size,
            "k": params.k,
S
SunAhong1993 已提交
422
            "alpha": alpha,
S
fix  
SunAhong1993 已提交
423
            "beta": params.beta,
S
SunAhong1993 已提交
424
        }
S
SunAhong1993 已提交
425
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
426
            "paddle.fluid.layers.lrn", 
S
SunAhong1993 已提交
427
            inputs={"input": input.name},
S
fix  
SunAhong1993 已提交
428
            outputs=[node.layer_name],
S
SunAhong1993 已提交
429 430
            **layer_attrs)

S
SunAhong1993 已提交
431

S
SunAhong1993 已提交
432
    def InnerProduct(self, node):
S
SunAhong1993 已提交
433
        linear_name = name_generator("linear", self.nn_name2id)
S
SunAhong1993 已提交
434 435 436
        output_name = node.layer_name
        layer_outputs = [linear_name, output_name]
        data = node.data
S
SunAhong1993 已提交
437
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
438 439 440 441 442 443 444
        params = node.layer.inner_product_param
        if data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0."
                .format(node.layer_name, node.layer_type))
            data = []
            data.append(
S
SunAhong1993 已提交
445
                np.zeros([node.in_shapes[0][1], params.num_output]).astype("float32").astype(
S
SunAhong1993 已提交
446 447 448 449
                    "float32"))
            data.append(
                np.zeros([params.num_output]).astype("float32").astype("float32"))
        else:
S
SunAhong1993 已提交
450
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
            # Reshape the parameters to Paddle's ordering
            transpose_order = (1, 0)
            w = data[0]
            fc_shape = w.shape
            output_channels = fc_shape[0]
            w = w.reshape((output_channels, -1))
            w = w.transpose(transpose_order)
            data[0] = w

        self.params[linear_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[linear_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of InnerProduct node\'s input is not 1."
        assert params.axis == 1
        assert params.bias_term == True
        layer_attrs = {
            "in_features": data[0].shape[0],
            "out_features": params.num_output           
        }
        if len(data) == 1:
            layer_attrs["bias"] = False
S
SunAhong1993 已提交
473
        if node.in_shapes[0][-1] != data[0].shape[0]:
S
SunAhong1993 已提交
474
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
475
                "paddle.reshape",
S
SunAhong1993 已提交
476
                inputs={"x": input.name},
S
SunAhong1993 已提交
477 478
                outputs=[output_name],
                shape=[-1, data[0].shape[0]])
S
SunAhong1993 已提交
479
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
480 481 482 483 484
                "paddle.nn.Linear",
                inputs={"input": output_name},
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
485
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
486
                "paddle.nn.Linear",
S
SunAhong1993 已提交
487
                inputs={"input": input.name},
S
SunAhong1993 已提交
488 489 490 491 492 493 494
                outputs=layer_outputs,
                **layer_attrs)
        
    def AbsVal(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of AbsVal node\'s input is not more than 1."
S
SunAhong1993 已提交
495
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
496
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
497
            "paddle.abs",
S
SunAhong1993 已提交
498
            inputs={"input": input.name},
S
SunAhong1993 已提交
499 500 501
            outputs=[node.layer_name])

    def Softmax(self, node):
S
SunAhong1993 已提交
502
        softmax_name = name_generator("softmax", self.nn_name2id)
S
SunAhong1993 已提交
503 504 505 506
        output_name = node.layer_name
        layer_outputs = [softmax_name, output_name]
        assert len(
            node.inputs) == 1, "The count of Softmax node\'s input is not 1."
S
SunAhong1993 已提交
507
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
508 509
        params = node.layer.softmax_param
        axis = params.axis
S
SunAhong1993 已提交
510
        shape = node.in_shapes[0]
S
SunAhong1993 已提交
511 512 513
        dims = len(shape)
        axis = axis + dims if axis < 0 else axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
514
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
515
            "paddle.nn.Softmax",
S
SunAhong1993 已提交
516
            inputs={"input": input.name},
S
SunAhong1993 已提交
517 518 519 520 521 522
            outputs=layer_outputs,
            **layer_attrs)

    def Slice(self, node):
        assert len(
            node.inputs) == 1, "The count of Slice node\'s input is not 1."
S
SunAhong1993 已提交
523
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
524 525 526 527 528 529
        top_len = len(node.layer.top)
        params = node.layer.slice_param
        axis = params.axis
        slice_dim = params.slice_dim
        if slice_dim != 1 and axis == 1:
            axis = slice_dim
S
SunAhong1993 已提交
530
        output_shape = node.out_shapes
S
SunAhong1993 已提交
531 532 533
        sections_list = list()
        outputs_list = list()
        for i, s in enumerate(output_shape):
S
SunAhong1993 已提交
534
            sections_list.append(s[axis])
S
SunAhong1993 已提交
535
            outputs_list.append("{}_p{}".format(node.layer_name, i))
S
SunAhong1993 已提交
536 537
        layer_attrs = {
            'num_or_sections': sections_list,
S
SunAhong1993 已提交
538
            'axis': axis,
S
SunAhong1993 已提交
539
        }
S
SunAhong1993 已提交
540
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
541
            "paddle.split",
S
SunAhong1993 已提交
542
            inputs={"x": input.name},
S
SunAhong1993 已提交
543
            outputs=outputs_list,
S
SunAhong1993 已提交
544 545 546 547 548 549
            **layer_attrs)

    def Concat(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of Concat node\'s input is not more than 1."
S
SunAhong1993 已提交
550
        inputs_list = list()
S
SunAhong1993 已提交
551
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
552 553
            input = self.graph.get_input_node(node, idx=i, copy=True)
            inputs_list.append(input.name)
S
SunAhong1993 已提交
554 555 556
        params = node.layer.concat_param
        axis = params.axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
557
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
558
            "paddle.concat",
S
SunAhong1993 已提交
559
            inputs={"x": inputs_list},
S
SunAhong1993 已提交
560 561 562 563
            outputs=[node.layer_name],
            **layer_attrs)

    def ReLU(self, node):
S
SunAhong1993 已提交
564
        relu_name = name_generator("relu", self.nn_name2id)
S
SunAhong1993 已提交
565 566 567 568
        output_name = node.layer_name
        layer_outputs = [relu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU node\'s input is not 1."
S
SunAhong1993 已提交
569
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
570 571 572 573
        params = node.layer.relu_param
        if params.HasField('negative_slope') and params.negative_slope != 0:
            negative_slope = float(params.negative_slope)

574
            layer_attrs = {'negative_slope': negative_slope}
S
SunAhong1993 已提交
575
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
576
                "paddle.nn.LeakyReLU",
S
SunAhong1993 已提交
577
                inputs={"input": input.name},
S
SunAhong1993 已提交
578 579 580
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
581
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
582
                "paddle.nn.ReLU",
S
SunAhong1993 已提交
583
                inputs={"input": input.name},
S
SunAhong1993 已提交
584 585 586
                outputs=layer_outputs)

    def PReLU(self, node):
S
SunAhong1993 已提交
587
        prelu_name = name_generator("prelu", self.nn_name2id)
S
SunAhong1993 已提交
588 589 590 591
        output_name = node.layer_name
        layer_outputs = [prelu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of PReLU node\'s input is not 1."
S
SunAhong1993 已提交
592
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
593 594
        params = node.layer.prelu_param
        mode_bool = params.channel_shared
S
SunAhong1993 已提交
595
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
596
        if mode_bool:
S
SunAhong1993 已提交
597
            num_parameters = 1
S
SunAhong1993 已提交
598
        else:
S
SunAhong1993 已提交
599
            num_parameters = output_shape[1]
S
SunAhong1993 已提交
600 601 602 603
        data = node.data
        self.params[prelu_name + '._weight'] = np.squeeze(data[0])
        assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format(
            node.layer_name, node.layer_type)
S
SunAhong1993 已提交
604
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
605
            "paddle.nn.PReLU",
S
SunAhong1993 已提交
606
            inputs={"input": input.name},
S
SunAhong1993 已提交
607
            outputs=layer_outputs,
S
SunAhong1993 已提交
608
            num_parameters=num_parameters)
S
SunAhong1993 已提交
609 610 611 612 613 614 615

    def Eltwise(self, node):
        assert len(
            node.inputs) == 2, "The count of Eltwise node\'s input is not 2."
        params = node.layer.eltwise_param
        mode = params.operation
        inputs = []
S
SunAhong1993 已提交
616 617 618 619
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input0_name = input0.name
        input1_name = input1.name
S
SunAhong1993 已提交
620 621 622 623
        if mode == 0:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
624
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
625 626 627 628 629 630
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        elif mode == 1:
            if hasattr(params, 'coeff') and len(params.coeff) == 2:
                coeff = params.coeff
S
SunAhong1993 已提交
631
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
632
                    "paddle.scale",
S
SunAhong1993 已提交
633 634
                    inputs={"x": input0_name},
                    outputs=[node.layer_name + '_mul0'],
S
SunAhong1993 已提交
635
                    scale=coeff[0])
S
SunAhong1993 已提交
636
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
637
                    "paddle.scale",
S
SunAhong1993 已提交
638 639
                    inputs={"x": input1_name},
                    outputs=[node.layer_name + '_mul1'],
S
SunAhong1993 已提交
640
                    scale=coeff[1])
S
SunAhong1993 已提交
641 642 643
                inputs_dict = {}
                inputs_dict['x'] = node.layer_name + '_mul0'
                inputs_dict['y'] = node.layer_name + '_mul1'
S
SunAhong1993 已提交
644
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
645 646 647 648 649 650 651
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
            else:
                inputs_dict = {}
                inputs_dict['x'] = input0_name
                inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
652
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
653 654 655 656 657 658 659
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
        else:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
660
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
661 662 663 664 665
                "paddle.max",
                inputs=inputs_dict,
                outputs=[node.layer_name])

    def BatchNorm(self, node):
S
SunAhong1993 已提交
666
        batchnorm_name = name_generator("batchnorm", self.nn_name2id)
S
SunAhong1993 已提交
667 668 669 670
        output_name = node.layer_name
        layer_outputs = [batchnorm_name, output_name]
        assert len(
            node.inputs) == 1, "The count of BatchNorm node\'s input is not 1."
S
SunAhong1993 已提交
671
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
672 673 674 675 676 677 678 679 680
        params = node.layer.batch_norm_param
        if hasattr(params, "eps"):
            eps = params.eps
        else:
            eps = 1e-5
        if node.data is None or len(node.data) != 3:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
681 682
            mean = np.zeros([node.in_shapes[0][1], ]).astype("float32")
            variance = np.zeros([node.in_shapes[0][1], ]).astype("float32")
S
SunAhong1993 已提交
683 684 685 686 687 688 689 690 691 692 693 694
            scale = 0
        else:

            node.data = [np.squeeze(i).astype("float32") for i in node.data]
            mean, variance, scale = node.data
        # Prescale the stats
        scaling_factor = 1.0 / scale if scale != 0 else 0
        mean *= scaling_factor
        variance *= scaling_factor
        self.params[batchnorm_name + "._mean"] = mean
        self.params[batchnorm_name + '._variance'] = variance
        layer_attrs = {
S
SunAhong1993 已提交
695
            "num_features": node.in_shapes[0][1],
S
SunAhong1993 已提交
696 697 698 699
            "epsilon": eps,
            "weight_attr": False,
            "bias_attr": False,
        }
S
SunAhong1993 已提交
700 701 702 703 704 705
        if len(node.in_shapes[0]) == 2:
            self.paddle_graph.add_layer(
                "paddle.unsqueeze",
                inputs={"x": input.name},
                outputs=[input.name],
                axis=[2,3])
S
SunAhong1993 已提交
706
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
707
            "paddle.nn.BatchNorm2D",
S
SunAhong1993 已提交
708
            inputs={"input": input.name},
S
SunAhong1993 已提交
709 710
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
711 712 713 714 715 716
        if len(node.in_shapes[0]) == 2:
            self.paddle_graph.add_layer(
                "paddle.squeeze",
                inputs={"x": node.layer_name},
                outputs=[node.layer_name],
                axis=[2,3])
S
SunAhong1993 已提交
717 718 719 720 721 722
   
    def Scale(self, node):
        if node.data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
723
            self.params[node.layer_name + "_cparam1"] = np.zeros([
S
SunAhong1993 已提交
724
                node.in_shapes[0][1],
S
SunAhong1993 已提交
725
            ]).astype("float32")
S
SunAhong1993 已提交
726
            self.params[node.layer_name + "_cparam2"] = np.zeros([
S
SunAhong1993 已提交
727
                node.in_shapes[0][1],
S
SunAhong1993 已提交
728 729
            ]).astype("float32")
        else:
S
SunAhong1993 已提交
730
            self.params[node.layer_name + "_cparam1"] = np.squeeze(node.data[
S
SunAhong1993 已提交
731
                0]).astype("float32")
S
SunAhong1993 已提交
732 733 734 735 736 737 738
            if not node.layer.scale_param.bias_term:
                self.params[node.layer_name + "_cparam2"] = np.zeros([
                    node.in_shapes[0][1],
                ]).astype("float32")
            else:
                self.params[node.layer_name + "_cparam2"] = np.squeeze(node.data[
                    1]).astype("float32")
S
SunAhong1993 已提交
739 740 741 742
        params = node.layer.scale_param
        axis = params.axis
        inputs = []
        if len(node.inputs) == 2:
S
SunAhong1993 已提交
743 744 745 746
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input1 = self.graph.get_input_node(node, idx=1, copy=True)
            input0_name = input0.name
            input1_name = input1.name
S
SunAhong1993 已提交
747 748 749
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
750
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
751 752 753 754 755
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name + "_mul"],
                axis=1)
        else:
S
SunAhong1993 已提交
756 757
            self.paddle_graph.add_layer(
                "self.create_parameter",
S
SunAhong1993 已提交
758 759
                inputs={},
                outputs=[node.layer_name + "_cparam1"],
S
SunAhong1993 已提交
760 761
                shape=self.params[node.layer_name + "_cparam1"].shape,
                attr=string(node.layer_name + "_cparam1"))
S
SunAhong1993 已提交
762 763
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input0_name = input0.name
S
SunAhong1993 已提交
764 765 766
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = node.layer_name + "_cparam1"
S
SunAhong1993 已提交
767 768 769 770 771 772 773 774 775 776 777
            if len(node.in_shapes[0]) == 2:
                self.paddle_graph.add_layer(
                    "paddle.multiply",
                    inputs=inputs_dict,
                    outputs=[node.layer_name + "_mul"])
            else:
                self.paddle_graph.add_layer(
                    "paddle.multiply",
                    inputs=inputs_dict,
                    outputs=[node.layer_name + "_mul"],
                    axis=axis)
S
SunAhong1993 已提交
778 779 780 781 782 783
        self.paddle_graph.add_layer(
            "self.create_parameter",
            inputs={},
            outputs=[node.layer_name + "_cparam2"],
            shape=self.params[node.layer_name + "_cparam2"].shape,
            attr=string(node.layer_name + "_cparam2"))
S
SunAhong1993 已提交
784 785 786
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = node.layer_name + "_cparam2"
S
SunAhong1993 已提交
787
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
788 789 790 791 792 793 794 795 796 797 798
        if axis == -1:
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        else:
            if axis < 0:
                axis = axis + len(output_shape)
            param2_shape = self.params[node.layer_name + "_cparam2"].shape
            param2_shape_len = len(param2_shape)
            diff_len = len(output_shape) - axis - param2_shape_len
S
SunAhong1993 已提交
799
            new_shape = list(param2_shape) + [1] * diff_len
S
SunAhong1993 已提交
800 801 802 803 804 805 806 807 808 809
            self.paddle_graph.add_layer(
                "paddle.reshape",
                inputs={"x": node.layer_name + "_cparam2"},
                outputs=[node.layer_name + "_cparam2"],
                shape=new_shape)
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
            
S
SunAhong1993 已提交
810
    def Reshape(self, node):
S
SunAhong1993 已提交
811 812
        input = self.graph.get_input_node(node, idx=0, copy=True)
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
813
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
814
            "paddle.reshape",
S
SunAhong1993 已提交
815
            inputs={"x": input.name},
S
SunAhong1993 已提交
816
            outputs=[node.layer_name],
S
SunAhong1993 已提交
817
            shape=output_shape)
S
SunAhong1993 已提交
818 819 820 821 822 823


    def ArgMax(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of ArgMax node\'s input and output is not 1."
S
SunAhong1993 已提交
824 825
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
826 827 828 829
        params = node.layer.argmax_param
        out_max_val = params.out_max_val if hasattr(params,
                                                    out_max_val) else False
        top_k = params.top_k if hasattr(params, top_k) else 1
S
SunAhong1993 已提交
830
        axis = params.axis if hasattr(params, axis) else -1
S
SunAhong1993 已提交
831 832 833
        if axis < 0:
            axis += len(input_shape)
        if out_max_val is True:
S
SunAhong1993 已提交
834
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
835
                "paddle.topk",
S
SunAhong1993 已提交
836
                inputs={"x": input.name},
S
SunAhong1993 已提交
837 838
                outputs=[node.layer_name + "_topk_var", node.layer_name + "_index_var"],
                k=top_k)
S
SunAhong1993 已提交
839
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
840 841 842 843
                "paddle.cast",
                inputs={"x": node.layer_name + "_index_var"},
                outputs=[node.layer_name + "_index_var"],
                dtype="{}_topk_var.dtype".format(node.layer_name))
S
SunAhong1993 已提交
844
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
845
                "paddle.concat",
S
SunAhong1993 已提交
846
                inputs={"x": [node.layer_name + "_topk_var", node.layer_name + "_index_var"]},
S
SunAhong1993 已提交
847 848 849
                outputs=[node.layer_name],
                axis=axis)
        else:
S
SunAhong1993 已提交
850
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
851
                "paddle.topk",
S
SunAhong1993 已提交
852
                inputs={"x": input.name},
S
SunAhong1993 已提交
853 854 855 856 857 858 859
                outputs=["_", node.layer_name],
                k=top_k)
            
    def Axpy(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of Axpy node\'s input and output is not 1."
S
SunAhong1993 已提交
860
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
861
        params = node.layer.axpy_param
S
SunAhong1993 已提交
862 863 864 865 866 867
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input2 = self.graph.get_input_node(node, idx=2, copy=True)
        input0_name = input0.name
        input1_name = input1.name
        input2_name = input2.name
S
SunAhong1993 已提交
868 869 870
        inputs_dict = {}
        inputs_dict['x'] = input1_name
        inputs_dict['y'] = input0_name
S
SunAhong1993 已提交
871
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
872 873 874 875 876 877 878
            "paddle.multiply",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"],
            axis=0)
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = input2_name
S
SunAhong1993 已提交
879
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
880 881 882 883 884 885 886 887
            "paddle.add",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"])
        

    def Crop(self, node):
        assert len(
            node.inputs) == 2, "The count of Crop node\'s input is not 2."
S
SunAhong1993 已提交
888 889
        input = self.graph.get_input_node(node, idx=0, copy=True)
        example = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
890 891
        params = node.layer.crop_param
        axis = params.axis
S
SunAhong1993 已提交
892
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
893 894 895 896 897 898 899 900 901
        if axis < 0:
            axis += len(input_shape)
        offset_real = [0] * len(input_shape)
        if hasattr(params, "offset") and len(params.offset) > 0:
            offset = list(params.offset)
            assert (len(input_shape) - axis
                    ) == len(offset), "invalid offset[%s] in crop layer" % (
                        str(offset))
            offset_real = [0] * axis + offset
S
SunAhong1993 已提交
902
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
903
                "paddle.crop",
S
SunAhong1993 已提交
904
                inputs={"x": input.name},
S
SunAhong1993 已提交
905
                outputs=[node.layer_name],
S
SunAhong1993 已提交
906
                shape=node.in_shapes[1],
S
SunAhong1993 已提交
907 908 909 910 911 912
                offsets=list(offset_real))

    def Flatten(self, node):
        assert len(
            node.
            inputs) == 1, "The count of DetectionOutput node\'s input is not 1."
S
SunAhong1993 已提交
913
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
914
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
915
            "paddle.reshape",
S
SunAhong1993 已提交
916
            inputs={"x": input.name},
S
SunAhong1993 已提交
917
            outputs=[node.layer_name],
S
SunAhong1993 已提交
918
            shape=node.out_shapes[0])
S
SunAhong1993 已提交
919 920 921 922

    def Power(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
923
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
924 925 926 927 928 929
        params = node.layer.power_param
        layer_attrs = {
            'scale': params.scale,
            'bias': params.shift,
            'bias_after_scale': True
        }
S
SunAhong1993 已提交
930
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
931
            "paddle.scale",
S
SunAhong1993 已提交
932
            inputs={"x": input.name},
S
SunAhong1993 已提交
933 934
            outputs=[node.layer_name],
            **layer_attrs)
S
SunAhong1993 已提交
935
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
936 937 938 939 940 941 942 943
            "paddle.pow",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            exponent=params.power)

    def Reduction(self, node):
        assert len(
            node.inputs) == 1, "The count of Reduction node\'s input is not 1."
S
SunAhong1993 已提交
944
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
945 946 947 948 949 950
        params = node.layer.reduction_param
        operation = params.operation
        axis = params.axis
        coeff = params.coeff
        assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
            operation)
S
SunAhong1993 已提交
951
        input_len = len(node.in_shapes[0])
S
SunAhong1993 已提交
952 953 954 955 956 957 958 959 960
        if axis < 0:
            axis += input_len + 1
        dim = list(range(input_len))
        # operation = SUM
        if operation == 1:  
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
961
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
962
                "paddle.sum",
S
SunAhong1993 已提交
963
                inputs={"input": input.name},
S
SunAhong1993 已提交
964 965 966 967
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = ASUM
        elif operation == 2:  
S
SunAhong1993 已提交
968
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
969
                "paddle.abs",
S
SunAhong1993 已提交
970
                inputs={"x": input.name},
S
SunAhong1993 已提交
971 972 973 974 975
                outputs=[node.layer_name])
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
976
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
977 978 979 980 981 982
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = SUMSQ
        elif operation == 3: 
S
SunAhong1993 已提交
983
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
984
                "paddle.pow",
S
SunAhong1993 已提交
985
                inputs={"x": input.name},
S
SunAhong1993 已提交
986 987 988 989 990 991
                outputs=[node.layer_name],
                exponent=2.0)
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
992
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
993 994 995 996 997 998 999
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = MEAN
        else: 
            layer_attrs = {
S
SunAhong1993 已提交
1000 1001
                "axis": dim[axis:],
                "keepdim": False,
S
SunAhong1993 已提交
1002
            }
S
SunAhong1993 已提交
1003
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1004
                "paddle.mean",
S
SunAhong1993 已提交
1005
                inputs={"x": input.name},
S
SunAhong1993 已提交
1006 1007
                outputs=[node.layer_name],
                **layer_attrs)
S
SunAhong1993 已提交
1008
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1009 1010 1011 1012 1013 1014
            "paddle.scale",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            scale=coeff)
        
    def DetectionOutput(self, node):
S
SunAhong1993 已提交
1015 1016 1017
        detection_output_name = name_generator("detection_output", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [detection_output_name, output_name]
S
SunAhong1993 已提交
1018 1019
        assert len(
            node.inputs) == 3, "The count of DetectionOutput node\'s input is not 3."
S
SunAhong1993 已提交
1020
        inputs_dict = dict()
S
SunAhong1993 已提交
1021
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
1022
            input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1023
            if i == 1:
S
SunAhong1993 已提交
1024
                input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1025 1026 1027
                while input is not None \
                      and input.layer_type != 'Softmax' \
                      and input.layer_type != 'Sigmoid':
S
SunAhong1993 已提交
1028
                    input = self.graph.get_input_node(input, idx=0, copy=True)
S
SunAhong1993 已提交
1029
                assert input is not None, 'This kind of DetectionOutput is not supported!'
S
SunAhong1993 已提交
1030 1031
                input = self.graph.get_input_node(input, idx=0, copy=True)
            inputs_dict["x{}".format(i)] = input.name
S
SunAhong1993 已提交
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051
        params = node.layer.detection_output_param
        nms_param = params.nms_param
        nms_param_dict = dict()
        nms_param_dict["nms_threshold"] = nms_param.nms_threshold
        nms_param_dict["top_k"] = nms_param.top_k
        nms_param_dict["eta"] = nms_param.eta
        if nms_param is None:
            nms_param_dict = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        fields = ["eta", "top_k", "nms_threshold"]
        for f in default.keys():
            if f not in nms_param_dict:
                nms_param_dict[f] = default[f]
        layer_attrs = {
            "background_label": params.background_label_id,
            "nms_threshold": nms_param_dict["nms_threshold"],
            "nms_top_k": nms_param_dict["top_k"],
            "keep_top_k": params.keep_top_k,
            "score_threshold": params.confidence_threshold,
            "nms_eta": nms_param_dict["eta"]}
S
SunAhong1993 已提交
1052 1053
        self.paddle_graph.add_layer(
            kernel="custom_layer:DetectionOutput",
S
SunAhong1993 已提交
1054
            inputs=inputs_dict,
S
SunAhong1993 已提交
1055
            outputs=layer_outputs,
S
SunAhong1993 已提交
1056 1057 1058
            **layer_attrs)
                    
    def Normalize(self, node):
S
SunAhong1993 已提交
1059 1060 1061
        normalize_name = name_generator("normalize", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [normalize_name, output_name]
S
SunAhong1993 已提交
1062 1063
        assert len(
            node.inputs) == 1, "The count of Normalize node\'s input is not 1."
S
SunAhong1993 已提交
1064
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1065
        params = node.layer.norm_param
S
SunAhong1993 已提交
1066
        param_name = node.layer_name + "_scale"
S
SunAhong1993 已提交
1067 1068 1069 1070
        if node.data is None or len(node.data) != 1:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
1071 1072
            self.params[param_name] = \
                np.zeros([1] if params.channel_shared else [node.in_shapes[0][1]]).astype("float32")
S
SunAhong1993 已提交
1073
        else:
S
SunAhong1993 已提交
1074
            self.params[param_name] = _adjust_parameters(node)[0]
S
SunAhong1993 已提交
1075
        
S
SunAhong1993 已提交
1076 1077 1078 1079 1080 1081 1082 1083
        
        self.paddle_graph.add_layer(
            "self.create_parameter",
            inputs={},
            outputs=[param_name],
            shape=self.params[param_name].shape,
            attr=string(param_name))
        inputs_dict = {}
S
SunAhong1993 已提交
1084
        layer_attrs = {
S
SunAhong1993 已提交
1085 1086
            "axis": -1 if params.channel_shared else 1}
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1087
            "custom_layer:Normalize",
S
SunAhong1993 已提交
1088 1089
            inputs={"x": input.name,
                    "param": param_name},
S
SunAhong1993 已提交
1090 1091
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
1092 1093 1094 1095
        
    def Permute(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
1096
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1097 1098
        params = node.layer.permute_param
        order = list(params.order)    
S
SunAhong1993 已提交
1099
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1100
            "paddle.transpose",
S
SunAhong1993 已提交
1101
            inputs={"x": input.name},
S
SunAhong1993 已提交
1102 1103 1104 1105
            outputs=[node.layer_name],
            perm=order)
        
    def PriorBox(self, node):
S
SunAhong1993 已提交
1106 1107 1108
        priorbox_name = name_generator("priorbox", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [priorbox_name, output_name]
S
SunAhong1993 已提交
1109 1110
        assert len(
            node.inputs) == 2, "The count of PriorBox node\'s input is not 2."
S
SunAhong1993 已提交
1111 1112
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1113
        inputs_dict = {}
S
SunAhong1993 已提交
1114 1115
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
        params = node.layer.prior_box_param
        steps = tuple(params.step) if type(params.step) \
                is list or type(params.step) is tuple \
                else (params.step, params.step)
        layer_attrs = {
            "min_sizes": params.min_size,
            "max_sizes": params.max_size,
            "aspect_ratios": params.aspect_ratio,
            "variance": params.variance,
            "flip": params.flip,
            "clip": params.clip,
            "steps": steps,
            "offset": params.offset,
            "min_max_aspect_ratios_order": True}
S
SunAhong1993 已提交
1130 1131
        self.paddle_graph.add_layer(
            "custom_layer:PriorBox",
S
SunAhong1993 已提交
1132
            inputs=inputs_dict,
S
SunAhong1993 已提交
1133
            outputs=layer_outputs,
S
SunAhong1993 已提交
1134
            **layer_attrs)
S
SunAhong1993 已提交
1135
        
S
SunAhong1993 已提交
1136
    def ReLU6(self, node):
S
SunAhong1993 已提交
1137
        relu6_name = name_generator("relu6", self.nn_name2id)
S
SunAhong1993 已提交
1138 1139 1140 1141
        output_name = node.layer_name
        layer_outputs = [relu6_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU6 node\'s input is not 1."
S
SunAhong1993 已提交
1142
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1143
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1144
            "paddle.nn.ReLU6",
S
SunAhong1993 已提交
1145
            inputs={"input": input.name},
S
SunAhong1993 已提交
1146 1147 1148
            outputs=layer_outputs)
        
    def ROIPooling(self, node):
S
SunAhong1993 已提交
1149 1150 1151
        roipooling_name = name_generator("roipooling", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [roipooling_name, output_name]
S
SunAhong1993 已提交
1152 1153
        assert len(
            node.inputs) == 2, "The count of ROIPooling node\'s input is not 2."
S
SunAhong1993 已提交
1154 1155
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1156
        inputs_dict = {}
S
SunAhong1993 已提交
1157 1158
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1159 1160 1161 1162 1163
        params = node.layer.roi_pooling_param
        layer_attrs = {
            "pooled_height": params.pooled_h,
            "pooled_width": params.pooled_w,
            "spatial_scale": params.spatial_scale}
S
SunAhong1993 已提交
1164
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1165
            "custom_layer:ROIPooling",
S
SunAhong1993 已提交
1166
            inputs=inputs_dict,
S
SunAhong1993 已提交
1167
            outputs=layer_outputs,
S
SunAhong1993 已提交
1168 1169 1170 1171 1172
            **layer_attrs)
        
    def ShuffleChannel(self, node):
        assert len(
            node.inputs) == 1, "The count of ShuffleChannel node\'s input is not 1."
S
SunAhong1993 已提交
1173
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1174
        params = node.layer.shuffle_channel_param
S
SunAhong1993 已提交
1175
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1176
            "paddle.fluid.layers.shuffle_channel",
S
SunAhong1993 已提交
1177
            inputs={"x": input.name},
S
SunAhong1993 已提交
1178 1179 1180 1181 1182 1183
            outputs=[node.layer_name],
            group=params.group)
        
    def Upsample(self, node):
        assert len(
            node.inputs) == 1, "The count of Upsample node\'s input is not 1."
S
SunAhong1993 已提交
1184
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1185 1186 1187 1188 1189
        params = node.layer.upsample_param
        layer_attrs = {
            "align_corners": False,
            "scale_factor": params.scale,
            "mode": "nearest"}
S
SunAhong1993 已提交
1190
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1191
            "paddle.nn.functioanl.interpolate",
S
SunAhong1993 已提交
1192
            inputs={"input": input.name},
S
SunAhong1993 已提交
1193 1194 1195 1196
            outputs=[node.layer_name],
            **layer_attrs)
    
    def Select(self, node):
S
SunAhong1993 已提交
1197 1198 1199
        select_name = name_generator("select", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [select_name, output_name]
S
SunAhong1993 已提交
1200 1201
        assert len(
            node.inputs) == 1, "The count of Select node\'s input is not 1."
S
SunAhong1993 已提交
1202 1203
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
1204 1205 1206 1207 1208
        params = node.layer.select_param
        layer_attrs = {
            "input_shape": input_shape,
            "point": params.slice_point,
            "axis": params.axis}
S
SunAhong1993 已提交
1209 1210
        self.paddle_graph.add_layer(
            "custom_layer:Select",
S
SunAhong1993 已提交
1211
            inputs={"x": input.name},
S
SunAhong1993 已提交
1212
            outputs=layer_outputs,
S
SunAhong1993 已提交
1213 1214 1215
            **layer_attrs)
        

S
SunAhong1993 已提交
1216
    
S
SunAhong1993 已提交
1217