caffe_op_mapper.py 46.5 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import sys
S
SunAhong1993 已提交
16 17 18 19 20
import numbers
import numpy as np
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
from x2paddle.core.program import PaddleGraph 
S
SunAhong1993 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
from x2paddle.decoder.caffe_decoder import CaffeGraphNode


def _adjust_parameters(node):
    data = node.data
    # When using the protobuf-backend, each parameter initially has four dimensions.
    # In certain cases (like FC layers), we want to eliminate the singleton dimensions.
    # This implementation takes care of the common cases. However, it does leave the
    # potential for future issues.
    # The Caffe-backend does not suffer from this problem.
    data = list(data)

    squeeze_indices = [1]  # Squeeze biases.
    if node.layer_type == 'InnerProduct':
        squeeze_indices.append(0)  # Squeeze FC.

    for idx in squeeze_indices:
        if idx >= len(data):
            continue

        d = data[idx]
        assert len(
            d.shape
        ) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
            str(d.shape))

        shape_old = d.shape
        sq_axis = None
        if idx == 0:
            sq_axis = (0, 1)
        elif idx == 1:
            sq_axis = (0, 1, 2)
        else:
            continue

        data[idx] = np.squeeze(d, axis=sq_axis)
        shape_new = data[idx].shape
    return data

def _get_kernel_parameters(kind, params):
    assert kind in ["Convolution", "Pooling", "Deconvolution", "ConvolutionDepthwise"]
    [k_h, k_w] = [1, 1]
    if isinstance(params.kernel_size, numbers.Number):
        [k_h, k_w] = [params.kernel_size] * 2
    elif len(params.kernel_size) > 0:
        k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[
            0]
        k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
            len(params.kernel_size) - 1]
    elif params.kernel_h > 0 or params.kernel_w > 0:
        k_h = params.kernel_h
        k_w = params.kernel_w
    [s_h, s_w] = [1, 1]
    if isinstance(params.stride, numbers.Number):
        [s_h, s_w] = [params.stride] * 2
    elif len(params.stride) > 0:
        s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
        s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
            params.stride) - 1]
    elif params.stride_h > 0 or params.stride_w > 0:
        s_h = params.stride_h
        s_w = params.stride_w
    [p_h, p_w] = [0, 0]
    if isinstance(params.pad, numbers.Number):
        [p_h, p_w] = [params.pad] * 2
    elif len(params.pad) > 0:
        p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
        p_w = params.pad_w if params.pad_w > 0 else params.pad[len(
            params.pad) - 1]
    elif params.pad_h > 0 or params.pad_w > 0:
        p_h = params.pad_h
        p_w = params.pad_w
    dila_h = dila_w = 1
    group = 1
    c_o = 1
    if kind in ["Convolution", "Deconvolution", "ConvolutionDepthwise"]:
        if kind in ["Convolution", "Deconvolution"]:
            c_o = params.num_output
        dila_len = len(params.dilation)
        if dila_len == 2:
            dila_h = params.dilation[0]
            dila_w = params.dilation[1]
        elif dila_len == 1:
            dila_h = dila_w = params.dilation[0]
        else:
            assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
                dila_len)
    if kind in ['Convolution', 'Deconvolution']:
        group = params.group
    kernel = [k_h, k_w]
    stride = [s_h, s_w]
    pad = [p_h, p_w]
    dilation = [dila_h, dila_w]
    return c_o, kernel, stride, pad, dilation, group
S
SunAhong1993 已提交
115 116 117 118


class CaffeOpMapper(OpMapper):
    directly_map_ops = {
S
SunAhong1993 已提交
119 120
        'Sigmoid': ['paddle.nn.layer.Sigmoid'],
        'TanH': ['paddle.nn.Tanh'],
S
SunAhong1993 已提交
121 122 123 124 125
    }

    def __init__(self, decoder):
        super(CaffeOpMapper, self).__init__()
        self.graph = decoder.caffe_graph
S
SunAhong1993 已提交
126 127
        if not self.op_checker():
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
128
        self.params = dict()
S
SunAhong1993 已提交
129 130
        self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="caffe")
        self.paddle_graph.outputs = self.graph.output_nodes
S
SunAhong1993 已提交
131 132 133
        self.input_index = 0 
        self.inputs_info = {}
        self.nn_name2id = {}
S
SunAhong1993 已提交
134 135 136 137 138 139 140 141
        print("Total nodes: {}".format(
            sum([
                isinstance(node, CaffeGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
S
SunAhong1993 已提交
142 143 144 145 146 147 148
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self, op):
                func = getattr(self, op)
                func(node)
            elif op in self.directly_map_ops:
                self.directly_map(node)
S
SunAhong1993 已提交
149
        print("\nNodes converted.")
S
SunAhong1993 已提交
150 151 152
        self.paddle_graph.set_name(self.graph.graph_name)
        self.paddle_graph.set_parameters(self.params)
        self.paddle_graph.set_inputs_info(self.inputs_info)
S
SunAhong1993 已提交
153 154 155 156 157 158
                
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
S
SunAhong1993 已提交
159
            if not hasattr(self, op) and op not in self.directly_map_ops:
S
SunAhong1993 已提交
160 161 162 163
                unsupported_ops.add(op)
        if len(unsupported_ops) == 0:
            return True
        else:
S
SunAhong1993 已提交
164 165 166
            if len(unsupported_ops) > 0:
                print("\n========= {} OPs are not supported yet ===========".format(
                    len(unsupported_ops)))
S
SunAhong1993 已提交
167
            for op in unsupported_ops:
S
SunAhong1993 已提交
168
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
169
            return False
S
SunAhong1993 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
        
    def directly_map(self, node):
        inputs = node.layer.input
        assert len(inputs) == 1, 'directly_map error with multi inputs'
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_input_node(node, 0)
        paddle_op = op_info[0]
        if paddle_op.startswith("paddle.nn"):
            op_name = paddle_op[10:].lower()
            op_name = name_generator(op_name, self.nn_name2id)
            output_name = node.name
            layer_outputs = [op_name, output_name]
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=layer_outputs)
S
SunAhong1993 已提交
186
        else:
S
SunAhong1993 已提交
187 188 189 190
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=[node.name])
S
SunAhong1993 已提交
191 192

    def Input(self, node):
S
SunAhong1993 已提交
193
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
194 195 196 197 198 199 200
            "paddle.to_tensor",
            inputs={},
            outputs=[node.layer_name],
            data="x{}".format(self.input_index))
        shape = list(node.layer.input_param.shape[0].dim)[1:]
        self.inputs_info["x{}".format(self.input_index)] = [[-1] + shape, "float32"]
        self.input_index += 1
S
SunAhong1993 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        
    def MemoryData(self, node):
        params = node.layer.memory_data_param
        transform_params = node.layer.transform_param
        self.paddle_graph.add_layer(
            "paddle.to_tensor",
            inputs={},
            outputs=[node.layer_name],
            data="x{}".format(self.input_index))
        shape = list()
        shape.append(params.batch_size)
        shape.append(params.channels)
        if hasattr(transform_params, "crop_size"):
            shape.append(transform_params.crop_size)
            shape.append(transform_params.crop_size)
        else:
            shape.append(params.width)
            shape.append(params.height)
        self.inputs_info["x{}".format(self.input_index)] = [shape, "float32"]
        self.input_index += 1
S
SunAhong1993 已提交
221 222

    def Convolution(self, node):
S
SunAhong1993 已提交
223
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
224 225 226 227
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
228
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
229 230 231 232 233 234 235
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
236
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
237 238 239
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
240
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
241 242 243 244 245
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Convolution node\'s input is not 1."
S
SunAhong1993 已提交
246
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
247
        layer_attrs = {
S
SunAhong1993 已提交
248
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
249 250 251 252 253 254 255 256 257
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
258
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
259
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
260
            inputs={"input": input.name},
S
SunAhong1993 已提交
261 262
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
263 264 265 266
        
    def DepthwiseConvolution(self, node):
        node.layer_type = "ConvolutionDepthwise"
        self.ConvolutionDepthwise(node)
S
SunAhong1993 已提交
267 268

    def Deconvolution(self, node):
S
SunAhong1993 已提交
269
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
270 271 272 273
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
274
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
275 276 277 278 279 280 281
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
282
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
283 284 285
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
286
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
287 288 289 290 291
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
292
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
293
        layer_attrs = {
S
SunAhong1993 已提交
294
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
295 296 297 298 299 300 301 302 303
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
304
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
305
            "paddle.nn.Conv2DTranspose",
S
SunAhong1993 已提交
306
            inputs={"input": input.name},
S
SunAhong1993 已提交
307 308 309 310
            outputs=layer_outputs,
            **layer_attrs)
        
    def ConvolutionDepthwise(self, node):
S
SunAhong1993 已提交
311
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
312 313 314 315
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
316
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
317
            node.layer_type, params)
S
SunAhong1993 已提交
318 319
        out_channel = params.num_output if params.num_output is not None else node.in_shapes[0][1]
        in_channel = node.in_shapes[0][1]
S
SunAhong1993 已提交
320 321 322 323 324 325 326 327
        group = int(in_channel / (in_channel / out_channel)) if in_channel > out_channel else int(in_channel /
                                                                (out_channel / in_channel))
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
328
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
329 330 331
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
332
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
333 334 335 336 337
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
338
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
339 340 341 342 343 344 345 346 347 348 349
        layer_attrs = {
            "in_channels": in_channel,
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
350
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
351
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
352
            inputs={"input": input.name},
S
SunAhong1993 已提交
353 354 355 356
            outputs=layer_outputs,
            **layer_attrs)

    def Pooling(self, node):
S
SunAhong1993 已提交
357
        pool2d_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
358 359 360
        output_name = node.layer_name
        layer_outputs = [pool2d_name, output_name]
        params = node.layer.pooling_param
S
SunAhong1993 已提交
361 362 363
        ceil_mode = getattr(params, "ceil_mode", True)
        if not hasattr(params, 'ceil_mode'):
            ceil_mode = True if getattr(params, "round_mode", 0) == 0 else False
S
SunAhong1993 已提交
364 365
        global_pool = getattr(params, "global_pooling", False)
        kernel_default = [1, 1]
S
SunAhong1993 已提交
366
        channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
367 368 369 370 371 372 373
            node.layer_type, params)
        if params.pool == 0:
            pool_type = "max"
        else:
            pool_type = "avg"
        assert len(
            node.inputs) == 1, "The count of Pooling node\'s input is not 1."
S
SunAhong1993 已提交
374
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
375 376 377 378
        if global_pool:
            if kernel[0] == 0:
                kernel = [1, 1]
            if params.pool == 0:
S
SunAhong1993 已提交
379
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
380
                    "paddle.nn.AdaptiveMaxPool2D",
S
SunAhong1993 已提交
381
                    inputs={"input": input.name},
S
SunAhong1993 已提交
382 383 384
                    outputs=layer_outputs,
                    output_size=kernel)
            else:
S
SunAhong1993 已提交
385
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
386
                    "paddle.nn.AdaptiveAvgPool2D",
S
SunAhong1993 已提交
387
                    inputs={"input": input.name},
S
SunAhong1993 已提交
388 389
                    outputs=layer_outputs,
                    output_size=kernel)
S
SunAhong1993 已提交
390
        else:
S
SunAhong1993 已提交
391
            layer_attrs = {
S
SunAhong1993 已提交
392 393 394
                'kernel_size': kernel,
                'stride': stride,
                'padding': pad,
S
SunAhong1993 已提交
395 396
                'ceil_mode': ceil_mode,
            }
S
SunAhong1993 已提交
397 398 399 400 401 402 403 404 405 406 407 408
            if params.pool == 0:
                self.paddle_graph.add_layer(
                    "paddle.nn.MaxPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
            else:
                self.paddle_graph.add_layer(
                    "paddle.nn.AvgPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
S
SunAhong1993 已提交
409 410

    def LRN(self, node):
S
SunAhong1993 已提交
411 412 413
        lrn_name = name_generator("lrn", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [lrn_name, output_name]
S
SunAhong1993 已提交
414
        assert len(node.inputs) == 1, "The count of LRN node\'s input is not 1."
S
SunAhong1993 已提交
415
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
416 417 418 419
        params = node.layer.lrn_param
        assert params.local_size % 2 == 1
        alpha = params.alpha / float(params.local_size)
        layer_attrs = {
S
fix  
SunAhong1993 已提交
420 421
            "n": params.local_size,
            "k": params.k,
S
SunAhong1993 已提交
422
            "alpha": alpha,
S
fix  
SunAhong1993 已提交
423
            "beta": params.beta,
S
SunAhong1993 已提交
424
        }
S
SunAhong1993 已提交
425
        self.paddle_graph.add_layer(
S
fix  
SunAhong1993 已提交
426
            "paddle.fluid.layers.lrn", 
S
SunAhong1993 已提交
427
            inputs={"input": input.name},
S
fix  
SunAhong1993 已提交
428
            outputs=[node.layer_name],
S
SunAhong1993 已提交
429 430
            **layer_attrs)

S
SunAhong1993 已提交
431

S
SunAhong1993 已提交
432
    def InnerProduct(self, node):
S
SunAhong1993 已提交
433
        linear_name = name_generator("linear", self.nn_name2id)
S
SunAhong1993 已提交
434 435 436
        output_name = node.layer_name
        layer_outputs = [linear_name, output_name]
        data = node.data
S
SunAhong1993 已提交
437
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
438 439 440 441 442 443 444
        params = node.layer.inner_product_param
        if data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0."
                .format(node.layer_name, node.layer_type))
            data = []
            data.append(
S
SunAhong1993 已提交
445
                np.zeros([node.in_shapes[0][1], params.num_output]).astype("float32").astype(
S
SunAhong1993 已提交
446 447 448 449
                    "float32"))
            data.append(
                np.zeros([params.num_output]).astype("float32").astype("float32"))
        else:
S
SunAhong1993 已提交
450
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
            # Reshape the parameters to Paddle's ordering
            transpose_order = (1, 0)
            w = data[0]
            fc_shape = w.shape
            output_channels = fc_shape[0]
            w = w.reshape((output_channels, -1))
            w = w.transpose(transpose_order)
            data[0] = w

        self.params[linear_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[linear_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of InnerProduct node\'s input is not 1."
        assert params.axis == 1
        assert params.bias_term == True
        layer_attrs = {
            "in_features": data[0].shape[0],
            "out_features": params.num_output           
        }
        if len(data) == 1:
            layer_attrs["bias"] = False
S
SunAhong1993 已提交
473
        if node.in_shapes[0][-1] != data[0].shape[0]:
S
SunAhong1993 已提交
474
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
475
                "paddle.reshape",
S
SunAhong1993 已提交
476
                inputs={"x": input.name},
S
SunAhong1993 已提交
477 478
                outputs=[output_name],
                shape=[-1, data[0].shape[0]])
S
SunAhong1993 已提交
479
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
480 481 482 483 484
                "paddle.nn.Linear",
                inputs={"input": output_name},
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
485
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
486
                "paddle.nn.Linear",
S
SunAhong1993 已提交
487
                inputs={"input": input.name},
S
SunAhong1993 已提交
488 489 490 491 492 493 494
                outputs=layer_outputs,
                **layer_attrs)
        
    def AbsVal(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of AbsVal node\'s input is not more than 1."
S
SunAhong1993 已提交
495
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
496
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
497
            "paddle.abs",
S
SunAhong1993 已提交
498
            inputs={"input": input.name},
S
SunAhong1993 已提交
499 500 501
            outputs=[node.layer_name])

    def Softmax(self, node):
S
SunAhong1993 已提交
502
        softmax_name = name_generator("softmax", self.nn_name2id)
S
SunAhong1993 已提交
503 504 505 506
        output_name = node.layer_name
        layer_outputs = [softmax_name, output_name]
        assert len(
            node.inputs) == 1, "The count of Softmax node\'s input is not 1."
S
SunAhong1993 已提交
507
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
508 509
        params = node.layer.softmax_param
        axis = params.axis
S
SunAhong1993 已提交
510
        shape = node.in_shapes[0]
S
SunAhong1993 已提交
511 512 513
        dims = len(shape)
        axis = axis + dims if axis < 0 else axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
514
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
515
            "paddle.nn.Softmax",
S
SunAhong1993 已提交
516
            inputs={"input": input.name},
S
SunAhong1993 已提交
517 518 519 520 521 522
            outputs=layer_outputs,
            **layer_attrs)

    def Slice(self, node):
        assert len(
            node.inputs) == 1, "The count of Slice node\'s input is not 1."
S
SunAhong1993 已提交
523
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
524 525 526 527 528 529
        top_len = len(node.layer.top)
        params = node.layer.slice_param
        axis = params.axis
        slice_dim = params.slice_dim
        if slice_dim != 1 and axis == 1:
            axis = slice_dim
S
SunAhong1993 已提交
530
        output_shape = node.out_shapes
S
SunAhong1993 已提交
531 532 533
        sections_list = list()
        outputs_list = list()
        for i, s in enumerate(output_shape):
S
SunAhong1993 已提交
534
            sections_list.append(s[axis])
S
SunAhong1993 已提交
535
            outputs_list.append("{}_p{}".format(node.layer_name, i))
S
SunAhong1993 已提交
536 537
        layer_attrs = {
            'num_or_sections': sections_list,
S
SunAhong1993 已提交
538
            'axis': axis,
S
SunAhong1993 已提交
539
        }
S
SunAhong1993 已提交
540
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
541
            "paddle.split",
S
SunAhong1993 已提交
542
            inputs={"x": input.name},
S
SunAhong1993 已提交
543
            outputs=outputs_list,
S
SunAhong1993 已提交
544 545 546 547 548 549
            **layer_attrs)

    def Concat(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of Concat node\'s input is not more than 1."
S
SunAhong1993 已提交
550
        inputs_list = list()
S
SunAhong1993 已提交
551
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
552 553
            input = self.graph.get_input_node(node, idx=i, copy=True)
            inputs_list.append(input.name)
S
SunAhong1993 已提交
554 555 556
        params = node.layer.concat_param
        axis = params.axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
557
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
558
            "paddle.concat",
S
SunAhong1993 已提交
559
            inputs={"x": inputs_list},
S
SunAhong1993 已提交
560 561 562 563
            outputs=[node.layer_name],
            **layer_attrs)

    def ReLU(self, node):
S
SunAhong1993 已提交
564
        relu_name = name_generator("relu", self.nn_name2id)
S
SunAhong1993 已提交
565 566 567 568
        output_name = node.layer_name
        layer_outputs = [relu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU node\'s input is not 1."
S
SunAhong1993 已提交
569
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
570 571 572 573 574
        params = node.layer.relu_param
        if params.HasField('negative_slope') and params.negative_slope != 0:
            negative_slope = float(params.negative_slope)

            layer_attrs = {'alpha': negative_slope}
S
SunAhong1993 已提交
575
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
576
                "paddle.nn.LeakyReLU",
S
SunAhong1993 已提交
577
                inputs={"input": input.name},
S
SunAhong1993 已提交
578 579 580
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
581
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
582
                "paddle.nn.ReLU",
S
SunAhong1993 已提交
583
                inputs={"input": input.name},
S
SunAhong1993 已提交
584 585 586
                outputs=layer_outputs)

    def PReLU(self, node):
S
SunAhong1993 已提交
587
        prelu_name = name_generator("prelu", self.nn_name2id)
S
SunAhong1993 已提交
588 589 590 591
        output_name = node.layer_name
        layer_outputs = [prelu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of PReLU node\'s input is not 1."
S
SunAhong1993 已提交
592
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
593 594
        params = node.layer.prelu_param
        mode_bool = params.channel_shared
S
SunAhong1993 已提交
595
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
596
        if mode_bool:
S
SunAhong1993 已提交
597
            num_parameters = 1
S
SunAhong1993 已提交
598
        else:
S
SunAhong1993 已提交
599
            num_parameters = output_shape[1]
S
SunAhong1993 已提交
600 601 602 603
        data = node.data
        self.params[prelu_name + '._weight'] = np.squeeze(data[0])
        assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format(
            node.layer_name, node.layer_type)
S
SunAhong1993 已提交
604
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
605
            "paddle.nn.PReLU",
S
SunAhong1993 已提交
606
            inputs={"input": input.name},
S
SunAhong1993 已提交
607
            outputs=layer_outputs,
S
SunAhong1993 已提交
608
            num_parameters=num_parameters)
S
SunAhong1993 已提交
609 610 611 612 613 614 615

    def Eltwise(self, node):
        assert len(
            node.inputs) == 2, "The count of Eltwise node\'s input is not 2."
        params = node.layer.eltwise_param
        mode = params.operation
        inputs = []
S
SunAhong1993 已提交
616 617 618 619
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input0_name = input0.name
        input1_name = input1.name
S
SunAhong1993 已提交
620 621 622 623
        if mode == 0:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
624
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
625 626 627 628 629 630
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        elif mode == 1:
            if hasattr(params, 'coeff') and len(params.coeff) == 2:
                coeff = params.coeff
S
SunAhong1993 已提交
631
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
632
                    "paddle.scale",
S
SunAhong1993 已提交
633 634
                    inputs={"x": input0_name},
                    outputs=[node.layer_name + '_mul0'],
S
SunAhong1993 已提交
635
                    scale=coeff[0])
S
SunAhong1993 已提交
636
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
637
                    "paddle.scale",
S
SunAhong1993 已提交
638 639
                    inputs={"x": input1_name},
                    outputs=[node.layer_name + '_mul1'],
S
SunAhong1993 已提交
640
                    scale=coeff[1])
S
SunAhong1993 已提交
641 642 643
                inputs_dict = {}
                inputs_dict['x'] = node.layer_name + '_mul0'
                inputs_dict['y'] = node.layer_name + '_mul1'
S
SunAhong1993 已提交
644
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
645 646 647 648 649 650 651
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
            else:
                inputs_dict = {}
                inputs_dict['x'] = input0_name
                inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
652
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
653 654 655 656 657 658 659
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
        else:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
660
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
661 662 663 664 665
                "paddle.max",
                inputs=inputs_dict,
                outputs=[node.layer_name])

    def BatchNorm(self, node):
S
SunAhong1993 已提交
666
        batchnorm_name = name_generator("batchnorm", self.nn_name2id)
S
SunAhong1993 已提交
667 668 669 670
        output_name = node.layer_name
        layer_outputs = [batchnorm_name, output_name]
        assert len(
            node.inputs) == 1, "The count of BatchNorm node\'s input is not 1."
S
SunAhong1993 已提交
671
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
672 673 674 675 676 677 678 679 680
        params = node.layer.batch_norm_param
        if hasattr(params, "eps"):
            eps = params.eps
        else:
            eps = 1e-5
        if node.data is None or len(node.data) != 3:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
681 682
            mean = np.zeros([node.in_shapes[0][1], ]).astype("float32")
            variance = np.zeros([node.in_shapes[0][1], ]).astype("float32")
S
SunAhong1993 已提交
683 684 685 686 687 688 689 690 691 692 693 694
            scale = 0
        else:

            node.data = [np.squeeze(i).astype("float32") for i in node.data]
            mean, variance, scale = node.data
        # Prescale the stats
        scaling_factor = 1.0 / scale if scale != 0 else 0
        mean *= scaling_factor
        variance *= scaling_factor
        self.params[batchnorm_name + "._mean"] = mean
        self.params[batchnorm_name + '._variance'] = variance
        layer_attrs = {
S
SunAhong1993 已提交
695
            "num_features": node.in_shapes[0][1],
S
SunAhong1993 已提交
696 697 698 699
            "epsilon": eps,
            "weight_attr": False,
            "bias_attr": False,
        }
S
SunAhong1993 已提交
700
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
701
            "paddle.nn.BatchNorm2D",
S
SunAhong1993 已提交
702
            inputs={"input": input.name},
S
SunAhong1993 已提交
703 704 705 706 707 708 709 710
            outputs=layer_outputs,
            **layer_attrs)
   
    def Scale(self, node):
        if node.data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
711
            self.params[node.layer_name + "_cparam1"] = np.zeros([
S
SunAhong1993 已提交
712
                node.in_shapes[0][1],
S
SunAhong1993 已提交
713
            ]).astype("float32")
S
SunAhong1993 已提交
714
            self.params[node.layer_name + "_cparam2"] = np.zeros([
S
SunAhong1993 已提交
715
                node.in_shapes[0][1],
S
SunAhong1993 已提交
716 717
            ]).astype("float32")
        else:
S
SunAhong1993 已提交
718
            self.params[node.layer_name + "_cparam1"] = np.squeeze(node.data[
S
SunAhong1993 已提交
719
                0]).astype("float32")
S
SunAhong1993 已提交
720
            self.params[node.layer_name + "_cparam2"] = np.squeeze(node.data[
S
SunAhong1993 已提交
721 722 723 724 725
                1]).astype("float32")
        params = node.layer.scale_param
        axis = params.axis
        inputs = []
        if len(node.inputs) == 2:
S
SunAhong1993 已提交
726 727 728 729
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input1 = self.graph.get_input_node(node, idx=1, copy=True)
            input0_name = input0.name
            input1_name = input1.name
S
SunAhong1993 已提交
730 731 732
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
733
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
734 735 736 737 738
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name + "_mul"],
                axis=1)
        else:
S
SunAhong1993 已提交
739 740
            self.paddle_graph.add_layer(
                "self.create_parameter",
S
SunAhong1993 已提交
741 742
                inputs={},
                outputs=[node.layer_name + "_cparam1"],
S
SunAhong1993 已提交
743 744
                shape=self.params[node.layer_name + "_cparam1"].shape,
                attr=string(node.layer_name + "_cparam1"))
S
SunAhong1993 已提交
745 746
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input0_name = input0.name
S
SunAhong1993 已提交
747 748 749
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = node.layer_name + "_cparam1"
S
SunAhong1993 已提交
750
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
751 752 753 754
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name + "_mul"],
                axis=axis)
S
SunAhong1993 已提交
755 756 757 758 759 760
        self.paddle_graph.add_layer(
            "self.create_parameter",
            inputs={},
            outputs=[node.layer_name + "_cparam2"],
            shape=self.params[node.layer_name + "_cparam2"].shape,
            attr=string(node.layer_name + "_cparam2"))
S
SunAhong1993 已提交
761 762 763
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = node.layer_name + "_cparam2"
S
SunAhong1993 已提交
764
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
765 766 767 768 769 770 771 772 773 774 775
        if axis == -1:
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        else:
            if axis < 0:
                axis = axis + len(output_shape)
            param2_shape = self.params[node.layer_name + "_cparam2"].shape
            param2_shape_len = len(param2_shape)
            diff_len = len(output_shape) - axis - param2_shape_len
S
SunAhong1993 已提交
776
            new_shape = list(param2_shape) + [1] * diff_len
S
SunAhong1993 已提交
777 778 779 780 781 782 783 784 785 786
            self.paddle_graph.add_layer(
                "paddle.reshape",
                inputs={"x": node.layer_name + "_cparam2"},
                outputs=[node.layer_name + "_cparam2"],
                shape=new_shape)
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
            
S
SunAhong1993 已提交
787
    def Reshape(self, node):
S
SunAhong1993 已提交
788 789
        input = self.graph.get_input_node(node, idx=0, copy=True)
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
790
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
791
            "paddle.reshape",
S
SunAhong1993 已提交
792
            inputs={"x": input.name},
S
SunAhong1993 已提交
793
            outputs=[node.layer_name],
S
SunAhong1993 已提交
794
            shape=output_shape)
S
SunAhong1993 已提交
795 796 797 798 799 800


    def ArgMax(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of ArgMax node\'s input and output is not 1."
S
SunAhong1993 已提交
801 802
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
803 804 805 806
        params = node.layer.argmax_param
        out_max_val = params.out_max_val if hasattr(params,
                                                    out_max_val) else False
        top_k = params.top_k if hasattr(params, top_k) else 1
S
SunAhong1993 已提交
807
        axis = params.axis if hasattr(params, axis) else -1
S
SunAhong1993 已提交
808 809 810
        if axis < 0:
            axis += len(input_shape)
        if out_max_val is True:
S
SunAhong1993 已提交
811
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
812
                "paddle.topk",
S
SunAhong1993 已提交
813
                inputs={"x": input.name},
S
SunAhong1993 已提交
814 815
                outputs=[node.layer_name + "_topk_var", node.layer_name + "_index_var"],
                k=top_k)
S
SunAhong1993 已提交
816
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
817 818 819 820
                "paddle.cast",
                inputs={"x": node.layer_name + "_index_var"},
                outputs=[node.layer_name + "_index_var"],
                dtype="{}_topk_var.dtype".format(node.layer_name))
S
SunAhong1993 已提交
821
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
822
                "paddle.concat",
S
SunAhong1993 已提交
823
                inputs={"x": [node.layer_name + "_topk_var", node.layer_name + "_index_var"]},
S
SunAhong1993 已提交
824 825 826
                outputs=[node.layer_name],
                axis=axis)
        else:
S
SunAhong1993 已提交
827
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
828
                "paddle.topk",
S
SunAhong1993 已提交
829
                inputs={"x": input.name},
S
SunAhong1993 已提交
830 831 832 833 834 835 836
                outputs=["_", node.layer_name],
                k=top_k)
            
    def Axpy(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of Axpy node\'s input and output is not 1."
S
SunAhong1993 已提交
837
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
838
        params = node.layer.axpy_param
S
SunAhong1993 已提交
839 840 841 842 843 844
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input2 = self.graph.get_input_node(node, idx=2, copy=True)
        input0_name = input0.name
        input1_name = input1.name
        input2_name = input2.name
S
SunAhong1993 已提交
845 846 847
        inputs_dict = {}
        inputs_dict['x'] = input1_name
        inputs_dict['y'] = input0_name
S
SunAhong1993 已提交
848
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
849 850 851 852 853 854 855
            "paddle.multiply",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"],
            axis=0)
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = input2_name
S
SunAhong1993 已提交
856
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
857 858 859 860 861 862 863 864
            "paddle.add",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"])
        

    def Crop(self, node):
        assert len(
            node.inputs) == 2, "The count of Crop node\'s input is not 2."
S
SunAhong1993 已提交
865 866
        input = self.graph.get_input_node(node, idx=0, copy=True)
        example = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
867 868
        params = node.layer.crop_param
        axis = params.axis
S
SunAhong1993 已提交
869
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
870 871 872 873 874 875 876 877 878
        if axis < 0:
            axis += len(input_shape)
        offset_real = [0] * len(input_shape)
        if hasattr(params, "offset") and len(params.offset) > 0:
            offset = list(params.offset)
            assert (len(input_shape) - axis
                    ) == len(offset), "invalid offset[%s] in crop layer" % (
                        str(offset))
            offset_real = [0] * axis + offset
S
SunAhong1993 已提交
879
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
880
                "paddle.crop",
S
SunAhong1993 已提交
881
                inputs={"x": input.name},
S
SunAhong1993 已提交
882
                outputs=[node.layer_name],
S
SunAhong1993 已提交
883
                shape=node.in_shapes[1],
S
SunAhong1993 已提交
884 885 886 887 888 889
                offsets=list(offset_real))

    def Flatten(self, node):
        assert len(
            node.
            inputs) == 1, "The count of DetectionOutput node\'s input is not 1."
S
SunAhong1993 已提交
890
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
891
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
892
            "paddle.reshape",
S
SunAhong1993 已提交
893
            inputs={"x": input.name},
S
SunAhong1993 已提交
894
            outputs=[node.layer_name],
S
SunAhong1993 已提交
895
            shape=node.out_shapes[0])
S
SunAhong1993 已提交
896 897 898 899

    def Power(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
900
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
901 902 903 904 905 906
        params = node.layer.power_param
        layer_attrs = {
            'scale': params.scale,
            'bias': params.shift,
            'bias_after_scale': True
        }
S
SunAhong1993 已提交
907
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
908
            "paddle.scale",
S
SunAhong1993 已提交
909
            inputs={"x": input.name},
S
SunAhong1993 已提交
910 911
            outputs=[node.layer_name],
            **layer_attrs)
S
SunAhong1993 已提交
912
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
913 914 915 916 917 918 919 920
            "paddle.pow",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            exponent=params.power)

    def Reduction(self, node):
        assert len(
            node.inputs) == 1, "The count of Reduction node\'s input is not 1."
S
SunAhong1993 已提交
921
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
922 923 924 925 926 927
        params = node.layer.reduction_param
        operation = params.operation
        axis = params.axis
        coeff = params.coeff
        assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
            operation)
S
SunAhong1993 已提交
928
        input_len = len(node.in_shapes[0])
S
SunAhong1993 已提交
929 930 931 932 933 934 935 936 937
        if axis < 0:
            axis += input_len + 1
        dim = list(range(input_len))
        # operation = SUM
        if operation == 1:  
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
938
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
939
                "paddle.sum",
S
SunAhong1993 已提交
940
                inputs={"input": input.name},
S
SunAhong1993 已提交
941 942 943 944
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = ASUM
        elif operation == 2:  
S
SunAhong1993 已提交
945
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
946
                "paddle.abs",
S
SunAhong1993 已提交
947
                inputs={"x": input.name},
S
SunAhong1993 已提交
948 949 950 951 952
                outputs=[node.layer_name])
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
953
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
954 955 956 957 958 959
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = SUMSQ
        elif operation == 3: 
S
SunAhong1993 已提交
960
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
961
                "paddle.pow",
S
SunAhong1993 已提交
962
                inputs={"x": input.name},
S
SunAhong1993 已提交
963 964 965 966 967 968
                outputs=[node.layer_name],
                exponent=2.0)
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
969
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
970 971 972 973 974 975 976
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = MEAN
        else: 
            layer_attrs = {
S
SunAhong1993 已提交
977 978
                "axis": dim[axis:],
                "keepdim": False,
S
SunAhong1993 已提交
979
            }
S
SunAhong1993 已提交
980
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
981
                "paddle.mean",
S
SunAhong1993 已提交
982
                inputs={"x": input.name},
S
SunAhong1993 已提交
983 984
                outputs=[node.layer_name],
                **layer_attrs)
S
SunAhong1993 已提交
985
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
986 987 988 989 990 991
            "paddle.scale",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            scale=coeff)
        
    def DetectionOutput(self, node):
S
SunAhong1993 已提交
992 993 994
        detection_output_name = name_generator("detection_output", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [detection_output_name, output_name]
S
SunAhong1993 已提交
995 996
        assert len(
            node.inputs) == 3, "The count of DetectionOutput node\'s input is not 3."
S
SunAhong1993 已提交
997
        inputs_dict = dict()
S
SunAhong1993 已提交
998
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
999
            input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1000
            if i == 1:
S
SunAhong1993 已提交
1001
                input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
1002 1003 1004
                while input is not None \
                      and input.layer_type != 'Softmax' \
                      and input.layer_type != 'Sigmoid':
S
SunAhong1993 已提交
1005
                    input = self.graph.get_input_node(input, idx=0, copy=True)
S
SunAhong1993 已提交
1006
                assert input is not None, 'This kind of DetectionOutput is not supported!'
S
SunAhong1993 已提交
1007 1008
                input = self.graph.get_input_node(input, idx=0, copy=True)
            inputs_dict["x{}".format(i)] = input.name
S
SunAhong1993 已提交
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028
        params = node.layer.detection_output_param
        nms_param = params.nms_param
        nms_param_dict = dict()
        nms_param_dict["nms_threshold"] = nms_param.nms_threshold
        nms_param_dict["top_k"] = nms_param.top_k
        nms_param_dict["eta"] = nms_param.eta
        if nms_param is None:
            nms_param_dict = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        fields = ["eta", "top_k", "nms_threshold"]
        for f in default.keys():
            if f not in nms_param_dict:
                nms_param_dict[f] = default[f]
        layer_attrs = {
            "background_label": params.background_label_id,
            "nms_threshold": nms_param_dict["nms_threshold"],
            "nms_top_k": nms_param_dict["top_k"],
            "keep_top_k": params.keep_top_k,
            "score_threshold": params.confidence_threshold,
            "nms_eta": nms_param_dict["eta"]}
S
SunAhong1993 已提交
1029 1030
        self.paddle_graph.add_layer(
            kernel="custom_layer:DetectionOutput",
S
SunAhong1993 已提交
1031
            inputs=inputs_dict,
S
SunAhong1993 已提交
1032
            outputs=layer_outputs,
S
SunAhong1993 已提交
1033 1034 1035
            **layer_attrs)
                    
    def Normalize(self, node):
S
SunAhong1993 已提交
1036 1037 1038
        normalize_name = name_generator("normalize", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [normalize_name, output_name]
S
SunAhong1993 已提交
1039 1040
        assert len(
            node.inputs) == 1, "The count of Normalize node\'s input is not 1."
S
SunAhong1993 已提交
1041
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1042
        params = node.layer.norm_param
S
SunAhong1993 已提交
1043
        param_name = node.layer_name + "_scale"
S
SunAhong1993 已提交
1044 1045 1046 1047
        if node.data is None or len(node.data) != 1:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
1048 1049
            self.params[param_name] = \
                np.zeros([1] if params.channel_shared else [node.in_shapes[0][1]]).astype("float32")
S
SunAhong1993 已提交
1050
        else:
S
SunAhong1993 已提交
1051
            self.params[param_name] = _adjust_parameters(node)[0]
S
SunAhong1993 已提交
1052
        
S
SunAhong1993 已提交
1053 1054 1055 1056 1057 1058 1059 1060
        
        self.paddle_graph.add_layer(
            "self.create_parameter",
            inputs={},
            outputs=[param_name],
            shape=self.params[param_name].shape,
            attr=string(param_name))
        inputs_dict = {}
S
SunAhong1993 已提交
1061
        layer_attrs = {
S
SunAhong1993 已提交
1062 1063
            "axis": -1 if params.channel_shared else 1}
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1064
            "custom_layer:Normalize",
S
SunAhong1993 已提交
1065 1066
            inputs={"x": input.name,
                    "param": param_name},
S
SunAhong1993 已提交
1067 1068
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
1069 1070 1071 1072
        
    def Permute(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
1073
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1074 1075
        params = node.layer.permute_param
        order = list(params.order)    
S
SunAhong1993 已提交
1076
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1077
            "paddle.transpose",
S
SunAhong1993 已提交
1078
            inputs={"x": input.name},
S
SunAhong1993 已提交
1079 1080 1081 1082
            outputs=[node.layer_name],
            perm=order)
        
    def PriorBox(self, node):
S
SunAhong1993 已提交
1083 1084 1085
        priorbox_name = name_generator("priorbox", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [priorbox_name, output_name]
S
SunAhong1993 已提交
1086 1087
        assert len(
            node.inputs) == 2, "The count of PriorBox node\'s input is not 2."
S
SunAhong1993 已提交
1088 1089
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1090
        inputs_dict = {}
S
SunAhong1993 已提交
1091 1092
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
        params = node.layer.prior_box_param
        steps = tuple(params.step) if type(params.step) \
                is list or type(params.step) is tuple \
                else (params.step, params.step)
        layer_attrs = {
            "min_sizes": params.min_size,
            "max_sizes": params.max_size,
            "aspect_ratios": params.aspect_ratio,
            "variance": params.variance,
            "flip": params.flip,
            "clip": params.clip,
            "steps": steps,
            "offset": params.offset,
            "min_max_aspect_ratios_order": True}
S
SunAhong1993 已提交
1107 1108
        self.paddle_graph.add_layer(
            "custom_layer:PriorBox",
S
SunAhong1993 已提交
1109
            inputs=inputs_dict,
S
SunAhong1993 已提交
1110
            outputs=layer_outputs,
S
SunAhong1993 已提交
1111
            **layer_attrs)
S
SunAhong1993 已提交
1112
        
S
SunAhong1993 已提交
1113
    def ReLU6(self, node):
S
SunAhong1993 已提交
1114
        relu6_name = name_generator("relu6", self.nn_name2id)
S
SunAhong1993 已提交
1115 1116 1117 1118
        output_name = node.layer_name
        layer_outputs = [relu6_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU6 node\'s input is not 1."
S
SunAhong1993 已提交
1119
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1120
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1121
            "paddle.nn.ReLU6",
S
SunAhong1993 已提交
1122
            inputs={"input": input.name},
S
SunAhong1993 已提交
1123 1124 1125
            outputs=layer_outputs)
        
    def ROIPooling(self, node):
S
SunAhong1993 已提交
1126 1127 1128
        roipooling_name = name_generator("roipooling", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [roipooling_name, output_name]
S
SunAhong1993 已提交
1129 1130
        assert len(
            node.inputs) == 2, "The count of ROIPooling node\'s input is not 2."
S
SunAhong1993 已提交
1131 1132
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1133
        inputs_dict = {}
S
SunAhong1993 已提交
1134 1135
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1136 1137 1138 1139 1140
        params = node.layer.roi_pooling_param
        layer_attrs = {
            "pooled_height": params.pooled_h,
            "pooled_width": params.pooled_w,
            "spatial_scale": params.spatial_scale}
S
SunAhong1993 已提交
1141
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1142
            "custom_layer:ROIPooling",
S
SunAhong1993 已提交
1143
            inputs=inputs_dict,
S
SunAhong1993 已提交
1144
            outputs=layer_outputs,
S
SunAhong1993 已提交
1145 1146 1147 1148 1149
            **layer_attrs)
        
    def ShuffleChannel(self, node):
        assert len(
            node.inputs) == 1, "The count of ShuffleChannel node\'s input is not 1."
S
SunAhong1993 已提交
1150
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1151
        params = node.layer.shuffle_channel_param
S
SunAhong1993 已提交
1152
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1153
            "paddle.fluid.layers.shuffle_channel",
S
SunAhong1993 已提交
1154
            inputs={"x": input.name},
S
SunAhong1993 已提交
1155 1156 1157 1158 1159 1160
            outputs=[node.layer_name],
            group=params.group)
        
    def Upsample(self, node):
        assert len(
            node.inputs) == 1, "The count of Upsample node\'s input is not 1."
S
SunAhong1993 已提交
1161
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1162 1163 1164 1165 1166
        params = node.layer.upsample_param
        layer_attrs = {
            "align_corners": False,
            "scale_factor": params.scale,
            "mode": "nearest"}
S
SunAhong1993 已提交
1167
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1168
            "paddle.nn.functioanl.interpolate",
S
SunAhong1993 已提交
1169
            inputs={"input": input.name},
S
SunAhong1993 已提交
1170 1171 1172 1173
            outputs=[node.layer_name],
            **layer_attrs)
    
    def Select(self, node):
S
SunAhong1993 已提交
1174 1175 1176
        select_name = name_generator("select", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [select_name, output_name]
S
SunAhong1993 已提交
1177 1178
        assert len(
            node.inputs) == 1, "The count of Select node\'s input is not 1."
S
SunAhong1993 已提交
1179 1180
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
1181 1182 1183 1184 1185
        params = node.layer.select_param
        layer_attrs = {
            "input_shape": input_shape,
            "point": params.slice_point,
            "axis": params.axis}
S
SunAhong1993 已提交
1186 1187
        self.paddle_graph.add_layer(
            "custom_layer:Select",
S
SunAhong1993 已提交
1188
            inputs={"x": input.name},
S
SunAhong1993 已提交
1189
            outputs=layer_outputs,
S
SunAhong1993 已提交
1190 1191 1192
            **layer_attrs)
        

S
SunAhong1993 已提交
1193
    
S
SunAhong1993 已提交
1194