caffe_op_mapper.py 45.4 KB
Newer Older
S
SunAhong1993 已提交
1
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
S
SunAhong1993 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

S
SunAhong1993 已提交
15
import sys
S
SunAhong1993 已提交
16 17 18 19 20
import numbers
import numpy as np
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
from x2paddle.core.program import PaddleGraph 
S
SunAhong1993 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
from x2paddle.decoder.caffe_decoder import CaffeGraphNode


def _adjust_parameters(node):
    data = node.data
    # When using the protobuf-backend, each parameter initially has four dimensions.
    # In certain cases (like FC layers), we want to eliminate the singleton dimensions.
    # This implementation takes care of the common cases. However, it does leave the
    # potential for future issues.
    # The Caffe-backend does not suffer from this problem.
    data = list(data)

    squeeze_indices = [1]  # Squeeze biases.
    if node.layer_type == 'InnerProduct':
        squeeze_indices.append(0)  # Squeeze FC.

    for idx in squeeze_indices:
        if idx >= len(data):
            continue

        d = data[idx]
        assert len(
            d.shape
        ) == 4, 'invalid shape[%s] from caffe when adjust_parameters' % (
            str(d.shape))

        shape_old = d.shape
        sq_axis = None
        if idx == 0:
            sq_axis = (0, 1)
        elif idx == 1:
            sq_axis = (0, 1, 2)
        else:
            continue

        data[idx] = np.squeeze(d, axis=sq_axis)
        shape_new = data[idx].shape
    return data

def _get_kernel_parameters(kind, params):
    assert kind in ["Convolution", "Pooling", "Deconvolution", "ConvolutionDepthwise"]
    [k_h, k_w] = [1, 1]
    if isinstance(params.kernel_size, numbers.Number):
        [k_h, k_w] = [params.kernel_size] * 2
    elif len(params.kernel_size) > 0:
        k_h = params.kernel_h if params.kernel_h > 0 else params.kernel_size[
            0]
        k_w = params.kernel_w if params.kernel_w > 0 else params.kernel_size[
            len(params.kernel_size) - 1]
    elif params.kernel_h > 0 or params.kernel_w > 0:
        k_h = params.kernel_h
        k_w = params.kernel_w
    [s_h, s_w] = [1, 1]
    if isinstance(params.stride, numbers.Number):
        [s_h, s_w] = [params.stride] * 2
    elif len(params.stride) > 0:
        s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
        s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
            params.stride) - 1]
    elif params.stride_h > 0 or params.stride_w > 0:
        s_h = params.stride_h
        s_w = params.stride_w
    [p_h, p_w] = [0, 0]
    if isinstance(params.pad, numbers.Number):
        [p_h, p_w] = [params.pad] * 2
    elif len(params.pad) > 0:
        p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
        p_w = params.pad_w if params.pad_w > 0 else params.pad[len(
            params.pad) - 1]
    elif params.pad_h > 0 or params.pad_w > 0:
        p_h = params.pad_h
        p_w = params.pad_w
    dila_h = dila_w = 1
    group = 1
    c_o = 1
    if kind in ["Convolution", "Deconvolution", "ConvolutionDepthwise"]:
        if kind in ["Convolution", "Deconvolution"]:
            c_o = params.num_output
        dila_len = len(params.dilation)
        if dila_len == 2:
            dila_h = params.dilation[0]
            dila_w = params.dilation[1]
        elif dila_len == 1:
            dila_h = dila_w = params.dilation[0]
        else:
            assert dila_len == 0, "invalid length[%s] of dilation in convolution" % (
                dila_len)
    if kind in ['Convolution', 'Deconvolution']:
        group = params.group
    kernel = [k_h, k_w]
    stride = [s_h, s_w]
    pad = [p_h, p_w]
    dilation = [dila_h, dila_w]
    return c_o, kernel, stride, pad, dilation, group
S
SunAhong1993 已提交
115 116 117 118


class CaffeOpMapper(OpMapper):
    directly_map_ops = {
S
SunAhong1993 已提交
119 120
        'Sigmoid': ['paddle.nn.layer.Sigmoid'],
        'TanH': ['paddle.nn.Tanh'],
S
SunAhong1993 已提交
121 122 123 124 125
    }

    def __init__(self, decoder):
        super(CaffeOpMapper, self).__init__()
        self.graph = decoder.caffe_graph
S
SunAhong1993 已提交
126 127
        if not self.op_checker():
            raise Exception("Model is not supported yet.")
S
SunAhong1993 已提交
128
        self.params = dict()
S
SunAhong1993 已提交
129 130
        self.paddle_graph = PaddleGraph(parent_layer=None, graph_type="dygraph", source_type="caffe")
        self.paddle_graph.outputs = self.graph.output_nodes
S
SunAhong1993 已提交
131 132 133
        self.input_index = 0 
        self.inputs_info = {}
        self.nn_name2id = {}
S
SunAhong1993 已提交
134 135 136 137 138 139 140 141
        print("Total nodes: {}".format(
            sum([
                isinstance(node, CaffeGraphNode)
                for name, node in self.graph.node_map.items()
            ])))
        print("Nodes converting ...")
        for i, node_name in enumerate(self.graph.topo_sort):
            sys.stderr.write("\rConverting node {} ...     ".format(i + 1))
S
SunAhong1993 已提交
142 143 144 145 146 147 148
            node = self.graph.get_node(node_name)
            op = node.layer_type
            if hasattr(self, op):
                func = getattr(self, op)
                func(node)
            elif op in self.directly_map_ops:
                self.directly_map(node)
S
SunAhong1993 已提交
149
        print("\nNodes converted.")
S
SunAhong1993 已提交
150 151 152
        self.paddle_graph.set_name(self.graph.graph_name)
        self.paddle_graph.set_parameters(self.params)
        self.paddle_graph.set_inputs_info(self.inputs_info)
S
SunAhong1993 已提交
153 154 155 156 157 158
                
    def op_checker(self):
        unsupported_ops = set()
        for node_name in self.graph.topo_sort:
            node = self.graph.get_node(node_name)
            op = node.layer_type
S
SunAhong1993 已提交
159
            if not hasattr(self, op) and op not in self.directly_map_ops:
S
SunAhong1993 已提交
160 161 162 163
                unsupported_ops.add(op)
        if len(unsupported_ops) == 0:
            return True
        else:
S
SunAhong1993 已提交
164 165 166
            if len(unsupported_ops) > 0:
                print("\n========= {} OPs are not supported yet ===========".format(
                    len(unsupported_ops)))
S
SunAhong1993 已提交
167
            for op in unsupported_ops:
S
SunAhong1993 已提交
168
                print("========== {} ============".format(op))
S
SunAhong1993 已提交
169
            return False
S
SunAhong1993 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
        
    def directly_map(self, node):
        inputs = node.layer.input
        assert len(inputs) == 1, 'directly_map error with multi inputs'
        op_info = self.directly_map_ops[node.layer_type]
        input = self.graph.get_input_node(node, 0)
        paddle_op = op_info[0]
        if paddle_op.startswith("paddle.nn"):
            op_name = paddle_op[10:].lower()
            op_name = name_generator(op_name, self.nn_name2id)
            output_name = node.name
            layer_outputs = [op_name, output_name]
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=layer_outputs)
S
SunAhong1993 已提交
186
        else:
S
SunAhong1993 已提交
187 188 189 190
            self.paddle_graph.add_layer(
                kernel=paddle_op,
                inputs={"x": input.name},
                outputs=[node.name])
S
SunAhong1993 已提交
191 192

    def Input(self, node):
S
SunAhong1993 已提交
193
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
194 195 196 197 198 199 200 201 202
            "paddle.to_tensor",
            inputs={},
            outputs=[node.layer_name],
            data="x{}".format(self.input_index))
        shape = list(node.layer.input_param.shape[0].dim)[1:]
        self.inputs_info["x{}".format(self.input_index)] = [[-1] + shape, "float32"]
        self.input_index += 1

    def Convolution(self, node):
S
SunAhong1993 已提交
203
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
204 205 206 207
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
208
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
209 210 211 212 213 214 215
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
216
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
217 218 219
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
220
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
221 222 223 224 225
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Convolution node\'s input is not 1."
S
SunAhong1993 已提交
226
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
227
        layer_attrs = {
S
SunAhong1993 已提交
228
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
229 230 231 232 233 234 235 236 237
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
238
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
239
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
240
            inputs={"input": input.name},
S
SunAhong1993 已提交
241 242
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
243 244 245 246
        
    def DepthwiseConvolution(self, node):
        node.layer_type = "ConvolutionDepthwise"
        self.ConvolutionDepthwise(node)
S
SunAhong1993 已提交
247 248

    def Deconvolution(self, node):
S
SunAhong1993 已提交
249
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
250 251 252 253
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
254
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
255 256 257 258 259 260 261
            node.layer_type, params)
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
262
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
263 264 265
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
266
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
267 268 269 270 271
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
272
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
273
        layer_attrs = {
S
SunAhong1993 已提交
274
            "in_channels": node.in_shapes[0][1],
S
SunAhong1993 已提交
275 276 277 278 279 280 281 282 283
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
284
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
285
            "paddle.nn.Conv2DTranspose",
S
SunAhong1993 已提交
286
            inputs={"input": input.name},
S
SunAhong1993 已提交
287 288 289 290
            outputs=layer_outputs,
            **layer_attrs)
        
    def ConvolutionDepthwise(self, node):
S
SunAhong1993 已提交
291
        conv2d_name = name_generator("conv", self.nn_name2id)
S
SunAhong1993 已提交
292 293 294 295
        output_name = node.layer_name
        layer_outputs = [conv2d_name, output_name]
        data = node.data
        params = node.layer.convolution_param
S
SunAhong1993 已提交
296
        out_channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
297
            node.layer_type, params)
S
SunAhong1993 已提交
298 299
        out_channel = params.num_output if params.num_output is not None else node.in_shapes[0][1]
        in_channel = node.in_shapes[0][1]
S
SunAhong1993 已提交
300 301 302 303 304 305 306 307
        group = int(in_channel / (in_channel / out_channel)) if in_channel > out_channel else int(in_channel /
                                                                (out_channel / in_channel))
        if data is None:
            data = []
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            data.append(
S
SunAhong1993 已提交
308
                np.zeros([out_channel, node.in_shapes[0][1], kernel[0], kernel[1]]).astype(
S
SunAhong1993 已提交
309 310 311
                    'float32'))
            data.append(np.zeros([out_channel, ]).astype('float32'))
        else:
S
SunAhong1993 已提交
312
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
313 314 315 316 317
        self.params[conv2d_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[conv2d_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of Deconvolution node\'s input is not 1."
S
SunAhong1993 已提交
318
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
319 320 321 322 323 324 325 326 327 328 329
        layer_attrs = {
            "in_channels": in_channel,
            "out_channels": out_channel,
            "kernel_size": kernel,
            "stride": stride,
            "padding": pad,
            "dilation": dilation,
            "groups": group
        }
        if len(data) == 1:
            layer_attrs["bias_attr"] = False
S
SunAhong1993 已提交
330
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
331
            "paddle.nn.Conv2D",
S
SunAhong1993 已提交
332
            inputs={"input": input.name},
S
SunAhong1993 已提交
333 334 335 336
            outputs=layer_outputs,
            **layer_attrs)

    def Pooling(self, node):
S
SunAhong1993 已提交
337
        pool2d_name = name_generator("pool", self.nn_name2id)
S
SunAhong1993 已提交
338 339 340 341 342 343
        output_name = node.layer_name
        layer_outputs = [pool2d_name, output_name]
        params = node.layer.pooling_param
        ceil_mode = getattr(params, "ceil_mod", True)
        global_pool = getattr(params, "global_pooling", False)
        kernel_default = [1, 1]
S
SunAhong1993 已提交
344
        channel, kernel, stride, pad, dilation, group = _get_kernel_parameters(
S
SunAhong1993 已提交
345 346 347 348 349 350 351
            node.layer_type, params)
        if params.pool == 0:
            pool_type = "max"
        else:
            pool_type = "avg"
        assert len(
            node.inputs) == 1, "The count of Pooling node\'s input is not 1."
S
SunAhong1993 已提交
352
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
353 354 355 356
        if global_pool:
            if kernel[0] == 0:
                kernel = [1, 1]
            if params.pool == 0:
S
SunAhong1993 已提交
357
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
358
                    "paddle.nn.AdaptiveMaxPool2D",
S
SunAhong1993 已提交
359
                    inputs={"input": input.name},
S
SunAhong1993 已提交
360 361 362
                    outputs=layer_outputs,
                    output_size=kernel)
            else:
S
SunAhong1993 已提交
363
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
364
                    "paddle.nn.AdaptiveAvgPool2D",
S
SunAhong1993 已提交
365
                    inputs={"input": input.name},
S
SunAhong1993 已提交
366 367
                    outputs=layer_outputs,
                    output_size=kernel)
S
SunAhong1993 已提交
368
        else:
S
SunAhong1993 已提交
369
            layer_attrs = {
S
SunAhong1993 已提交
370 371 372
                'kernel_size': kernel,
                'stride': stride,
                'padding': pad,
S
SunAhong1993 已提交
373 374
                'ceil_mode': ceil_mode,
            }
S
SunAhong1993 已提交
375 376 377 378 379 380 381 382 383 384 385 386
            if params.pool == 0:
                self.paddle_graph.add_layer(
                    "paddle.nn.MaxPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
            else:
                self.paddle_graph.add_layer(
                    "paddle.nn.AvgPool2D",
                    inputs={"input": input.name},
                    outputs=layer_outputs,
                    **layer_attrs)
S
SunAhong1993 已提交
387 388

    def LRN(self, node):
S
SunAhong1993 已提交
389 390 391
        lrn_name = name_generator("lrn", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [lrn_name, output_name]
S
SunAhong1993 已提交
392
        assert len(node.inputs) == 1, "The count of LRN node\'s input is not 1."
S
SunAhong1993 已提交
393
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
394 395 396 397
        params = node.layer.lrn_param
        assert params.local_size % 2 == 1
        alpha = params.alpha / float(params.local_size)
        layer_attrs = {
S
SunAhong1993 已提交
398 399
            "size": params.local_size, 
            "k": params.k, 
S
SunAhong1993 已提交
400
            "alpha": alpha,
S
SunAhong1993 已提交
401
            "beta": params.beta
S
SunAhong1993 已提交
402
        }
S
SunAhong1993 已提交
403
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
404
            "paddle.nn.LocalResponseNorm",
S
SunAhong1993 已提交
405
            inputs={"input": input.name},
S
SunAhong1993 已提交
406
            outputs=layer_outputs,
S
SunAhong1993 已提交
407 408
            **layer_attrs)

S
SunAhong1993 已提交
409

S
SunAhong1993 已提交
410
    def InnerProduct(self, node):
S
SunAhong1993 已提交
411
        linear_name = name_generator("linear", self.nn_name2id)
S
SunAhong1993 已提交
412 413 414
        output_name = node.layer_name
        layer_outputs = [linear_name, output_name]
        data = node.data
S
SunAhong1993 已提交
415
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
416 417 418 419 420 421 422
        params = node.layer.inner_product_param
        if data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0."
                .format(node.layer_name, node.layer_type))
            data = []
            data.append(
S
SunAhong1993 已提交
423
                np.zeros([node.in_shapes[0][1], params.num_output]).astype("float32").astype(
S
SunAhong1993 已提交
424 425 426 427
                    "float32"))
            data.append(
                np.zeros([params.num_output]).astype("float32").astype("float32"))
        else:
S
SunAhong1993 已提交
428
            data = _adjust_parameters(node)
S
SunAhong1993 已提交
429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
            # Reshape the parameters to Paddle's ordering
            transpose_order = (1, 0)
            w = data[0]
            fc_shape = w.shape
            output_channels = fc_shape[0]
            w = w.reshape((output_channels, -1))
            w = w.transpose(transpose_order)
            data[0] = w

        self.params[linear_name + ".weight"] = data[0]
        if len(data) == 2:
            self.params[linear_name + ".bias"] = data[1]
        assert len(node.inputs
                   ) == 1, "The count of InnerProduct node\'s input is not 1."
        assert params.axis == 1
        assert params.bias_term == True
        layer_attrs = {
            "in_features": data[0].shape[0],
            "out_features": params.num_output           
        }
        if len(data) == 1:
            layer_attrs["bias"] = False
S
SunAhong1993 已提交
451
        if node.in_shapes[0][-1] != data[0].shape[0]:
S
SunAhong1993 已提交
452
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
453
                "paddle.reshape",
S
SunAhong1993 已提交
454
                inputs={"x": input.name},
S
SunAhong1993 已提交
455 456
                outputs=[output_name],
                shape=[-1, data[0].shape[0]])
S
SunAhong1993 已提交
457
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
458 459 460 461 462
                "paddle.nn.Linear",
                inputs={"input": output_name},
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
463
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
464
                "paddle.nn.Linear",
S
SunAhong1993 已提交
465
                inputs={"input": input.name},
S
SunAhong1993 已提交
466 467 468 469 470 471 472
                outputs=layer_outputs,
                **layer_attrs)
        
    def AbsVal(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of AbsVal node\'s input is not more than 1."
S
SunAhong1993 已提交
473
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
474
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
475
            "paddle.abs",
S
SunAhong1993 已提交
476
            inputs={"input": input.name},
S
SunAhong1993 已提交
477 478 479
            outputs=[node.layer_name])

    def Softmax(self, node):
S
SunAhong1993 已提交
480
        softmax_name = name_generator("softmax", self.nn_name2id)
S
SunAhong1993 已提交
481 482 483 484
        output_name = node.layer_name
        layer_outputs = [softmax_name, output_name]
        assert len(
            node.inputs) == 1, "The count of Softmax node\'s input is not 1."
S
SunAhong1993 已提交
485
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
486 487
        params = node.layer.softmax_param
        axis = params.axis
S
SunAhong1993 已提交
488
        shape = node.in_shapes[0]
S
SunAhong1993 已提交
489 490 491
        dims = len(shape)
        axis = axis + dims if axis < 0 else axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
492
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
493
            "paddle.nn.Softmax",
S
SunAhong1993 已提交
494
            inputs={"input": input.name},
S
SunAhong1993 已提交
495 496 497 498 499 500
            outputs=layer_outputs,
            **layer_attrs)

    def Slice(self, node):
        assert len(
            node.inputs) == 1, "The count of Slice node\'s input is not 1."
S
SunAhong1993 已提交
501
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
502 503 504 505 506 507
        top_len = len(node.layer.top)
        params = node.layer.slice_param
        axis = params.axis
        slice_dim = params.slice_dim
        if slice_dim != 1 and axis == 1:
            axis = slice_dim
S
SunAhong1993 已提交
508
        output_shape = node.out_shapes
S
SunAhong1993 已提交
509 510 511
        sections_list = list()
        outputs_list = list()
        for i, s in enumerate(output_shape):
S
SunAhong1993 已提交
512
            sections_list.append(s[axis])
S
SunAhong1993 已提交
513
            outputs_list.append("{}_p{}".format(node.layer_name, i))
S
SunAhong1993 已提交
514 515
        layer_attrs = {
            'num_or_sections': sections_list,
S
SunAhong1993 已提交
516
            'axis': axis,
S
SunAhong1993 已提交
517
        }
S
SunAhong1993 已提交
518
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
519
            "paddle.split",
S
SunAhong1993 已提交
520
            inputs={"x": input.name},
S
SunAhong1993 已提交
521
            outputs=outputs_list,
S
SunAhong1993 已提交
522 523 524 525 526 527
            **layer_attrs)

    def Concat(self, node):
        assert len(
            node.inputs
        ) >= 1, "The count of Concat node\'s input is not more than 1."
S
SunAhong1993 已提交
528
        inputs_list = list()
S
SunAhong1993 已提交
529
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
530 531
            input = self.graph.get_input_node(node, idx=i, copy=True)
            inputs_list.append(input.name)
S
SunAhong1993 已提交
532 533 534
        params = node.layer.concat_param
        axis = params.axis
        layer_attrs = {'axis': axis}
S
SunAhong1993 已提交
535
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
536
            "paddle.concat",
S
SunAhong1993 已提交
537
            inputs={"x": inputs_list},
S
SunAhong1993 已提交
538 539 540 541
            outputs=[node.layer_name],
            **layer_attrs)

    def ReLU(self, node):
S
SunAhong1993 已提交
542
        relu_name = name_generator("relu", self.nn_name2id)
S
SunAhong1993 已提交
543 544 545 546
        output_name = node.layer_name
        layer_outputs = [relu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU node\'s input is not 1."
S
SunAhong1993 已提交
547
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
548 549 550 551 552
        params = node.layer.relu_param
        if params.HasField('negative_slope') and params.negative_slope != 0:
            negative_slope = float(params.negative_slope)

            layer_attrs = {'alpha': negative_slope}
S
SunAhong1993 已提交
553
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
554
                "paddle.nn.LeakyReLU",
S
SunAhong1993 已提交
555
                inputs={"input": input.name},
S
SunAhong1993 已提交
556 557 558
                outputs=layer_outputs,
                **layer_attrs)
        else:
S
SunAhong1993 已提交
559
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
560
                "paddle.nn.ReLU",
S
SunAhong1993 已提交
561
                inputs={"input": input.name},
S
SunAhong1993 已提交
562 563 564
                outputs=layer_outputs)

    def PReLU(self, node):
S
SunAhong1993 已提交
565
        prelu_name = name_generator("prelu", self.nn_name2id)
S
SunAhong1993 已提交
566 567 568 569
        output_name = node.layer_name
        layer_outputs = [prelu_name, output_name]
        assert len(
            node.inputs) == 1, "The count of PReLU node\'s input is not 1."
S
SunAhong1993 已提交
570
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
571 572
        params = node.layer.prelu_param
        mode_bool = params.channel_shared
S
SunAhong1993 已提交
573
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
574
        if mode_bool:
S
SunAhong1993 已提交
575
            num_parameters = 1
S
SunAhong1993 已提交
576
        else:
S
SunAhong1993 已提交
577
            num_parameters = output_shape[1]
S
SunAhong1993 已提交
578 579 580 581
        data = node.data
        self.params[prelu_name + '._weight'] = np.squeeze(data[0])
        assert data is not None, "The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.".format(
            node.layer_name, node.layer_type)
S
SunAhong1993 已提交
582
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
583
            "paddle.nn.PReLU",
S
SunAhong1993 已提交
584
            inputs={"input": input.name},
S
SunAhong1993 已提交
585
            outputs=layer_outputs,
S
SunAhong1993 已提交
586
            num_parameters=num_parameters)
S
SunAhong1993 已提交
587 588 589 590 591 592 593

    def Eltwise(self, node):
        assert len(
            node.inputs) == 2, "The count of Eltwise node\'s input is not 2."
        params = node.layer.eltwise_param
        mode = params.operation
        inputs = []
S
SunAhong1993 已提交
594 595 596 597
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input0_name = input0.name
        input1_name = input1.name
S
SunAhong1993 已提交
598 599 600 601
        if mode == 0:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
602
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
603 604 605 606 607 608
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        elif mode == 1:
            if hasattr(params, 'coeff') and len(params.coeff) == 2:
                coeff = params.coeff
S
SunAhong1993 已提交
609
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
610
                    "paddle.scale",
S
SunAhong1993 已提交
611 612
                    inputs={"x": input0_name},
                    outputs=[node.layer_name + '_mul0'],
S
SunAhong1993 已提交
613
                    scale=coeff[0])
S
SunAhong1993 已提交
614
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
615
                    "paddle.scale",
S
SunAhong1993 已提交
616 617
                    inputs={"x": input1_name},
                    outputs=[node.layer_name + '_mul1'],
S
SunAhong1993 已提交
618
                    scale=coeff[2])
S
SunAhong1993 已提交
619 620 621
                inputs_dict = {}
                inputs_dict['x'] = node.layer_name + '_mul0'
                inputs_dict['y'] = node.layer_name + '_mul1'
S
SunAhong1993 已提交
622
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
623 624 625 626 627 628 629
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
            else:
                inputs_dict = {}
                inputs_dict['x'] = input0_name
                inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
630
                self.paddle_graph.add_layer(
S
SunAhong1993 已提交
631 632 633 634 635 636 637
                    "paddle.add",
                    inputs=inputs_dict,
                    outputs=[node.layer_name])
        else:
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
638
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
639 640 641 642 643
                "paddle.max",
                inputs=inputs_dict,
                outputs=[node.layer_name])

    def BatchNorm(self, node):
S
SunAhong1993 已提交
644
        batchnorm_name = name_generator("batchnorm", self.nn_name2id)
S
SunAhong1993 已提交
645 646 647 648
        output_name = node.layer_name
        layer_outputs = [batchnorm_name, output_name]
        assert len(
            node.inputs) == 1, "The count of BatchNorm node\'s input is not 1."
S
SunAhong1993 已提交
649
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
650 651 652 653 654 655 656 657 658
        params = node.layer.batch_norm_param
        if hasattr(params, "eps"):
            eps = params.eps
        else:
            eps = 1e-5
        if node.data is None or len(node.data) != 3:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
659 660
            mean = np.zeros([node.in_shapes[0][1], ]).astype("float32")
            variance = np.zeros([node.in_shapes[0][1], ]).astype("float32")
S
SunAhong1993 已提交
661 662 663 664 665 666 667 668 669 670 671 672
            scale = 0
        else:

            node.data = [np.squeeze(i).astype("float32") for i in node.data]
            mean, variance, scale = node.data
        # Prescale the stats
        scaling_factor = 1.0 / scale if scale != 0 else 0
        mean *= scaling_factor
        variance *= scaling_factor
        self.params[batchnorm_name + "._mean"] = mean
        self.params[batchnorm_name + '._variance'] = variance
        layer_attrs = {
S
SunAhong1993 已提交
673
            "num_features": node.in_shapes[0][1],
S
SunAhong1993 已提交
674 675 676 677
            "epsilon": eps,
            "weight_attr": False,
            "bias_attr": False,
        }
S
SunAhong1993 已提交
678
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
679
            "paddle.nn.BatchNorm2D",
S
SunAhong1993 已提交
680
            inputs={"input": input.name},
S
SunAhong1993 已提交
681 682 683 684 685 686 687 688
            outputs=layer_outputs,
            **layer_attrs)
   
    def Scale(self, node):
        if node.data is None:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
S
SunAhong1993 已提交
689
            self.params[node.layer_name + "_cparam1"] = np.zeros([
S
SunAhong1993 已提交
690
                node.in_shapes[0][1],
S
SunAhong1993 已提交
691
            ]).astype("float32")
S
SunAhong1993 已提交
692
            self.params[node.layer_name + "_cparam2"] = np.zeros([
S
SunAhong1993 已提交
693
                node.in_shapes[0][1],
S
SunAhong1993 已提交
694 695
            ]).astype("float32")
        else:
S
SunAhong1993 已提交
696
            self.params[node.layer_name + "_cparam1"] = np.squeeze(node.data[
S
SunAhong1993 已提交
697
                0]).astype("float32")
S
SunAhong1993 已提交
698
            self.params[node.layer_name + "_cparam2"] = np.squeeze(node.data[
S
SunAhong1993 已提交
699 700 701 702 703
                1]).astype("float32")
        params = node.layer.scale_param
        axis = params.axis
        inputs = []
        if len(node.inputs) == 2:
S
SunAhong1993 已提交
704 705 706 707
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input1 = self.graph.get_input_node(node, idx=1, copy=True)
            input0_name = input0.name
            input1_name = input1.name
S
SunAhong1993 已提交
708 709 710
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = input1_name
S
SunAhong1993 已提交
711
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
712 713 714 715 716
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name + "_mul"],
                axis=1)
        else:
S
SunAhong1993 已提交
717 718
            self.paddle_graph.add_layer(
                "self.create_parameter",
S
SunAhong1993 已提交
719 720
                inputs={},
                outputs=[node.layer_name + "_cparam1"],
S
SunAhong1993 已提交
721 722
                shape=self.params[node.layer_name + "_cparam1"].shape,
                attr=string(node.layer_name + "_cparam1"))
S
SunAhong1993 已提交
723 724
            input0 = self.graph.get_input_node(node, idx=0, copy=True)
            input0_name = input0.name
S
SunAhong1993 已提交
725 726 727
            inputs_dict = {}
            inputs_dict['x'] = input0_name
            inputs_dict['y'] = node.layer_name + "_cparam1"
S
SunAhong1993 已提交
728
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
729 730 731 732
                "paddle.multiply",
                inputs=inputs_dict,
                outputs=[node.layer_name + "_mul"],
                axis=axis)
S
SunAhong1993 已提交
733 734 735 736 737 738
        self.paddle_graph.add_layer(
            "self.create_parameter",
            inputs={},
            outputs=[node.layer_name + "_cparam2"],
            shape=self.params[node.layer_name + "_cparam2"].shape,
            attr=string(node.layer_name + "_cparam2"))
S
SunAhong1993 已提交
739 740 741
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = node.layer_name + "_cparam2"
S
SunAhong1993 已提交
742
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
743 744 745 746 747 748 749 750 751 752 753
        if axis == -1:
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
        else:
            if axis < 0:
                axis = axis + len(output_shape)
            param2_shape = self.params[node.layer_name + "_cparam2"].shape
            param2_shape_len = len(param2_shape)
            diff_len = len(output_shape) - axis - param2_shape_len
S
SunAhong1993 已提交
754
            new_shape = list(param2_shape) + [1] * diff_len
S
SunAhong1993 已提交
755 756 757 758 759 760 761 762 763 764
            self.paddle_graph.add_layer(
                "paddle.reshape",
                inputs={"x": node.layer_name + "_cparam2"},
                outputs=[node.layer_name + "_cparam2"],
                shape=new_shape)
            self.paddle_graph.add_layer(
                "paddle.add",
                inputs=inputs_dict,
                outputs=[node.layer_name])
            
S
SunAhong1993 已提交
765
    def Reshape(self, node):
S
SunAhong1993 已提交
766 767
        input = self.graph.get_input_node(node, idx=0, copy=True)
        output_shape = node.out_shapes[0]
S
SunAhong1993 已提交
768
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
769
            "paddle.reshape",
S
SunAhong1993 已提交
770
            inputs={"x": input.name},
S
SunAhong1993 已提交
771
            outputs=[node.layer_name],
S
SunAhong1993 已提交
772
            shape=output_shape)
S
SunAhong1993 已提交
773 774 775 776 777 778


    def ArgMax(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of ArgMax node\'s input and output is not 1."
S
SunAhong1993 已提交
779 780
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
781 782 783 784 785 786 787 788
        params = node.layer.argmax_param
        out_max_val = params.out_max_val if hasattr(params,
                                                    out_max_val) else False
        top_k = params.top_k if hasattr(params, top_k) else 1
        axis = parmas.axis if hasattr(params, axis) else -1
        if axis < 0:
            axis += len(input_shape)
        if out_max_val is True:
S
SunAhong1993 已提交
789
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
790
                "paddle.topk",
S
SunAhong1993 已提交
791
                inputs={"x": input.name},
S
SunAhong1993 已提交
792 793
                outputs=[node.layer_name + "_topk_var", node.layer_name + "_index_var"],
                k=top_k)
S
SunAhong1993 已提交
794
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
795 796 797 798
                "paddle.cast",
                inputs={"x": node.layer_name + "_index_var"},
                outputs=[node.layer_name + "_index_var"],
                dtype="{}_topk_var.dtype".format(node.layer_name))
S
SunAhong1993 已提交
799
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
800
                "paddle.concat",
S
SunAhong1993 已提交
801
                inputs={"x": [node.layer_name + "_topk_var", node.layer_name + "_index_var"]},
S
SunAhong1993 已提交
802 803 804
                outputs=[node.layer_name],
                axis=axis)
        else:
S
SunAhong1993 已提交
805
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
806
                "paddle.topk",
S
SunAhong1993 已提交
807
                inputs={"x": input.name},
S
SunAhong1993 已提交
808 809 810 811 812 813 814
                outputs=["_", node.layer_name],
                k=top_k)
            
    def Axpy(self, node):
        assert len(node.inputs) == 1 and len(
            node.outputs
        ) == 1, "The count of Axpy node\'s input and output is not 1."
S
SunAhong1993 已提交
815
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
816
        params = node.layer.axpy_param
S
SunAhong1993 已提交
817 818 819 820 821 822
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
        input2 = self.graph.get_input_node(node, idx=2, copy=True)
        input0_name = input0.name
        input1_name = input1.name
        input2_name = input2.name
S
SunAhong1993 已提交
823 824 825
        inputs_dict = {}
        inputs_dict['x'] = input1_name
        inputs_dict['y'] = input0_name
S
SunAhong1993 已提交
826
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
827 828 829 830 831 832 833
            "paddle.multiply",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"],
            axis=0)
        inputs_dict = {}
        inputs_dict['x'] = node.layer_name + "_mul"
        inputs_dict['y'] = input2_name
S
SunAhong1993 已提交
834
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
835 836 837 838 839 840 841 842
            "paddle.add",
            inputs=inputs_dict,
            outputs=[node.layer_name + "_mul"])
        

    def Crop(self, node):
        assert len(
            node.inputs) == 2, "The count of Crop node\'s input is not 2."
S
SunAhong1993 已提交
843 844
        input = self.graph.get_input_node(node, idx=0, copy=True)
        example = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
845 846
        params = node.layer.crop_param
        axis = params.axis
S
SunAhong1993 已提交
847
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
848 849 850 851 852 853 854 855 856
        if axis < 0:
            axis += len(input_shape)
        offset_real = [0] * len(input_shape)
        if hasattr(params, "offset") and len(params.offset) > 0:
            offset = list(params.offset)
            assert (len(input_shape) - axis
                    ) == len(offset), "invalid offset[%s] in crop layer" % (
                        str(offset))
            offset_real = [0] * axis + offset
S
SunAhong1993 已提交
857
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
858
                "paddle.crop",
S
SunAhong1993 已提交
859
                inputs={"x": input.name},
S
SunAhong1993 已提交
860
                outputs=[node.layer_name],
S
SunAhong1993 已提交
861
                shape=node.in_shapes[1],
S
SunAhong1993 已提交
862 863 864 865 866 867
                offsets=list(offset_real))

    def Flatten(self, node):
        assert len(
            node.
            inputs) == 1, "The count of DetectionOutput node\'s input is not 1."
S
SunAhong1993 已提交
868
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
869
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
870
            "paddle.reshape",
S
SunAhong1993 已提交
871
            inputs={"x": input.name},
S
SunAhong1993 已提交
872
            outputs=[node.layer_name],
S
SunAhong1993 已提交
873
            shape=node.out_shapes[0])
S
SunAhong1993 已提交
874 875 876 877

    def Power(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
878
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
879 880 881 882 883 884
        params = node.layer.power_param
        layer_attrs = {
            'scale': params.scale,
            'bias': params.shift,
            'bias_after_scale': True
        }
S
SunAhong1993 已提交
885
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
886
            "paddle.scale",
S
SunAhong1993 已提交
887
            inputs={"x": input.name},
S
SunAhong1993 已提交
888 889
            outputs=[node.layer_name],
            **layer_attrs)
S
SunAhong1993 已提交
890
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
891 892 893 894 895 896 897 898
            "paddle.pow",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            exponent=params.power)

    def Reduction(self, node):
        assert len(
            node.inputs) == 1, "The count of Reduction node\'s input is not 1."
S
SunAhong1993 已提交
899
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
900 901 902 903 904 905
        params = node.layer.reduction_param
        operation = params.operation
        axis = params.axis
        coeff = params.coeff
        assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % (
            operation)
S
SunAhong1993 已提交
906
        input_len = len(node.in_shapes[0])
S
SunAhong1993 已提交
907 908 909 910 911 912 913 914 915
        if axis < 0:
            axis += input_len + 1
        dim = list(range(input_len))
        # operation = SUM
        if operation == 1:  
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
916
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
917
                "paddle.sum",
S
SunAhong1993 已提交
918
                inputs={"input": input.name},
S
SunAhong1993 已提交
919 920 921 922
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = ASUM
        elif operation == 2:  
S
SunAhong1993 已提交
923
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
924
                "paddle.abs",
S
SunAhong1993 已提交
925
                inputs={"x": input.name},
S
SunAhong1993 已提交
926 927 928 929 930
                outputs=[node.layer_name])
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
931
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
932 933 934 935 936 937
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = SUMSQ
        elif operation == 3: 
S
SunAhong1993 已提交
938
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
939
                "paddle.pow",
S
SunAhong1993 已提交
940
                inputs={"x": input.name},
S
SunAhong1993 已提交
941 942 943 944 945 946
                outputs=[node.layer_name],
                exponent=2.0)
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
947
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
948 949 950 951 952 953 954 955 956 957
                "paddle.sum",
                inputs={"input": node.layer_name},
                outputs=[node.layer_name],
                **layer_attrs)
        # operation = MEAN
        else: 
            layer_attrs = {
                "dim": dim[axis:],
                "keep_dim": False,
            }
S
SunAhong1993 已提交
958
            self.paddle_graph.add_layer(
S
SunAhong1993 已提交
959
                "paddle.mean",
S
SunAhong1993 已提交
960
                inputs={"input": input.name},
S
SunAhong1993 已提交
961 962
                outputs=[node.layer_name],
                **layer_attrs)
S
SunAhong1993 已提交
963
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
964 965 966 967 968 969
            "paddle.scale",
            inputs={"x": node.layer_name},
            outputs=[node.layer_name],
            scale=coeff)
        
    def DetectionOutput(self, node):
S
SunAhong1993 已提交
970 971 972
        detection_output_name = name_generator("detection_output", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [detection_output_name, output_name]
S
SunAhong1993 已提交
973 974
        assert len(
            node.inputs) == 3, "The count of DetectionOutput node\'s input is not 3."
S
SunAhong1993 已提交
975
        inputs_dict = dict()
S
SunAhong1993 已提交
976
        for i in range(len(node.inputs)):
S
SunAhong1993 已提交
977
            input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
978
            if i == 1:
S
SunAhong1993 已提交
979
                input = self.graph.get_input_node(node, idx=i, copy=True)
S
SunAhong1993 已提交
980 981 982
                while input is not None \
                      and input.layer_type != 'Softmax' \
                      and input.layer_type != 'Sigmoid':
S
SunAhong1993 已提交
983
                    input = self.graph.get_input_node(input, idx=0, copy=True)
S
SunAhong1993 已提交
984
                assert input is not None, 'This kind of DetectionOutput is not supported!'
S
SunAhong1993 已提交
985 986
                input = self.graph.get_input_node(input, idx=0, copy=True)
            inputs_dict["x{}".format(i)] = input.name
S
SunAhong1993 已提交
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006
        params = node.layer.detection_output_param
        nms_param = params.nms_param
        nms_param_dict = dict()
        nms_param_dict["nms_threshold"] = nms_param.nms_threshold
        nms_param_dict["top_k"] = nms_param.top_k
        nms_param_dict["eta"] = nms_param.eta
        if nms_param is None:
            nms_param_dict = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
        fields = ["eta", "top_k", "nms_threshold"]
        for f in default.keys():
            if f not in nms_param_dict:
                nms_param_dict[f] = default[f]
        layer_attrs = {
            "background_label": params.background_label_id,
            "nms_threshold": nms_param_dict["nms_threshold"],
            "nms_top_k": nms_param_dict["top_k"],
            "keep_top_k": params.keep_top_k,
            "score_threshold": params.confidence_threshold,
            "nms_eta": nms_param_dict["eta"]}
S
SunAhong1993 已提交
1007 1008
        self.paddle_graph.add_layer(
            kernel="custom_layer:DetectionOutput",
S
SunAhong1993 已提交
1009
            inputs=inputs_dict,
S
SunAhong1993 已提交
1010
            outputs=layer_outputs,
S
SunAhong1993 已提交
1011 1012 1013
            **layer_attrs)
                    
    def Normalize(self, node):
S
SunAhong1993 已提交
1014 1015 1016
        normalize_name = name_generator("normalize", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [normalize_name, output_name]
S
SunAhong1993 已提交
1017 1018
        assert len(
            node.inputs) == 1, "The count of Normalize node\'s input is not 1."
S
SunAhong1993 已提交
1019
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1020 1021 1022 1023 1024 1025
        params = node.layer.norm_param
        if node.data is None or len(node.data) != 1:
            print(
                "The parameter of {} (type is {}) is not set. So we set the parameters as 0"
                .format(node.layer_name, node.layer_type))
            self.parmas[node.layer_name + ".scale"] = \
S
SunAhong1993 已提交
1026
                np.zeros([1] if params.channel_shared else [1, 1, 1, node.in_shapes[0][1]]).astype("float32")
S
SunAhong1993 已提交
1027
        else:
S
SunAhong1993 已提交
1028
            self.parmas[node.layer_name + ".scale"] = _adjust_parameters(node)[0]
S
SunAhong1993 已提交
1029 1030 1031 1032 1033 1034 1035
        
        layer_attrs = {
            "axis": -1 if params.channel_shared else 1,
            "param_name": node.layer_name + ".scale",
            "param_shape": self.parmas[node.layer_name + ".scale"].shape}
        self.pd_pdgraph.add_layer(
            "custom_layer:Normalize",
S
SunAhong1993 已提交
1036
            inputs={"x": input.name},
S
SunAhong1993 已提交
1037 1038
            outputs=layer_outputs,
            **layer_attrs)
S
SunAhong1993 已提交
1039 1040 1041 1042
        
    def Permute(self, node):
        assert len(
            node.inputs) == 1, "The count of Permute node\'s input is not 1."
S
SunAhong1993 已提交
1043
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1044 1045
        params = node.layer.permute_param
        order = list(params.order)    
S
SunAhong1993 已提交
1046
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1047
            "paddle.transpose",
S
SunAhong1993 已提交
1048
            inputs={"x": input.name},
S
SunAhong1993 已提交
1049 1050 1051 1052
            outputs=[node.layer_name],
            perm=order)
        
    def PriorBox(self, node):
S
SunAhong1993 已提交
1053 1054 1055
        priorbox_name = name_generator("priorbox", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [priorbox_name, output_name]
S
SunAhong1993 已提交
1056 1057
        assert len(
            node.inputs) == 2, "The count of PriorBox node\'s input is not 2."
S
SunAhong1993 已提交
1058 1059
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1060
        inputs_dict = {}
S
SunAhong1993 已提交
1061 1062
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076
        params = node.layer.prior_box_param
        steps = tuple(params.step) if type(params.step) \
                is list or type(params.step) is tuple \
                else (params.step, params.step)
        layer_attrs = {
            "min_sizes": params.min_size,
            "max_sizes": params.max_size,
            "aspect_ratios": params.aspect_ratio,
            "variance": params.variance,
            "flip": params.flip,
            "clip": params.clip,
            "steps": steps,
            "offset": params.offset,
            "min_max_aspect_ratios_order": True}
S
SunAhong1993 已提交
1077 1078
        self.paddle_graph.add_layer(
            "custom_layer:PriorBox",
S
SunAhong1993 已提交
1079
            inputs=inputs_dict,
S
SunAhong1993 已提交
1080
            outputs=layer_outputs,
S
SunAhong1993 已提交
1081
            **layer_attrs)
S
SunAhong1993 已提交
1082
        
S
SunAhong1993 已提交
1083
    def ReLU6(self, node):
S
SunAhong1993 已提交
1084
        relu6_name = name_generator("relu6", self.nn_name2id)
S
SunAhong1993 已提交
1085 1086 1087 1088
        output_name = node.layer_name
        layer_outputs = [relu6_name, output_name]
        assert len(
            node.inputs) == 1, "The count of RelU6 node\'s input is not 1."
S
SunAhong1993 已提交
1089
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1090
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1091
            "paddle.nn.ReLU6",
S
SunAhong1993 已提交
1092
            inputs={"input": input.name},
S
SunAhong1993 已提交
1093 1094 1095
            outputs=layer_outputs)
        
    def ROIPooling(self, node):
S
SunAhong1993 已提交
1096 1097 1098
        roipooling_name = name_generator("roipooling", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [roipooling_name, output_name]
S
SunAhong1993 已提交
1099 1100
        assert len(
            node.inputs) == 2, "The count of ROIPooling node\'s input is not 2."
S
SunAhong1993 已提交
1101 1102
        input0 = self.graph.get_input_node(node, idx=0, copy=True)
        input1 = self.graph.get_input_node(node, idx=1, copy=True)
S
SunAhong1993 已提交
1103
        inputs_dict = {}
S
SunAhong1993 已提交
1104 1105
        inputs_dict["x0"] = input0.name
        inputs_dict["x1"] = input1.name
S
SunAhong1993 已提交
1106 1107 1108 1109 1110
        params = node.layer.roi_pooling_param
        layer_attrs = {
            "pooled_height": params.pooled_h,
            "pooled_width": params.pooled_w,
            "spatial_scale": params.spatial_scale}
S
SunAhong1993 已提交
1111
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1112
            "custom_layer:ROIPooling",
S
SunAhong1993 已提交
1113
            inputs=inputs_dict,
S
SunAhong1993 已提交
1114
            outputs=layer_outputs,
S
SunAhong1993 已提交
1115 1116 1117 1118 1119
            **layer_attrs)
        
    def ShuffleChannel(self, node):
        assert len(
            node.inputs) == 1, "The count of ShuffleChannel node\'s input is not 1."
S
SunAhong1993 已提交
1120
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1121
        params = node.layer.shuffle_channel_param
S
SunAhong1993 已提交
1122
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1123
            "paddle.fluid.layers.shuffle_channel",
S
SunAhong1993 已提交
1124
            inputs={"x": input.name},
S
SunAhong1993 已提交
1125 1126 1127 1128 1129 1130
            outputs=[node.layer_name],
            group=params.group)
        
    def Upsample(self, node):
        assert len(
            node.inputs) == 1, "The count of Upsample node\'s input is not 1."
S
SunAhong1993 已提交
1131
        input = self.graph.get_input_node(node, idx=0, copy=True)
S
SunAhong1993 已提交
1132 1133 1134 1135 1136
        params = node.layer.upsample_param
        layer_attrs = {
            "align_corners": False,
            "scale_factor": params.scale,
            "mode": "nearest"}
S
SunAhong1993 已提交
1137
        self.paddle_graph.add_layer(
S
SunAhong1993 已提交
1138
            "paddle.nn.functioanl.interpolate",
S
SunAhong1993 已提交
1139
            inputs={"input": input.name},
S
SunAhong1993 已提交
1140 1141 1142 1143
            outputs=[node.layer_name],
            **layer_attrs)
    
    def Select(self, node):
S
SunAhong1993 已提交
1144 1145 1146
        select_name = name_generator("select", self.nn_name2id)
        output_name = node.layer_name
        layer_outputs = [select_name, output_name]
S
SunAhong1993 已提交
1147 1148
        assert len(
            node.inputs) == 1, "The count of Select node\'s input is not 1."
S
SunAhong1993 已提交
1149 1150
        input = self.graph.get_input_node(node, idx=0, copy=True)
        input_shape = node.in_shapes[0]
S
SunAhong1993 已提交
1151 1152 1153 1154 1155
        params = node.layer.select_param
        layer_attrs = {
            "input_shape": input_shape,
            "point": params.slice_point,
            "axis": params.axis}
S
SunAhong1993 已提交
1156 1157
        self.paddle_graph.add_layer(
            "custom_layer:Select",
S
SunAhong1993 已提交
1158
            inputs={"x": input.name},
S
SunAhong1993 已提交
1159
            outputs=layer_outputs,
S
SunAhong1993 已提交
1160 1161 1162
            **layer_attrs)
        

S
SunAhong1993 已提交
1163
    
S
SunAhong1993 已提交
1164