resnet.py 5.7 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
C
update  
ceci3 已提交
25
from .utils import check_points
C
ceci3 已提交
26 27 28

__all__ = ["ResNetSpace"]

W
wanghaoshuang 已提交
29

C
ceci3 已提交
30 31
@SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase):
C
update  
ceci3 已提交
32
    def __init__(self, input_size, output_size, block_num, block_mask=None):
C
ceci3 已提交
33 34 35
        super(ResNetSpace, self).__init__(input_size, output_size, block_num,
                                          block_mask)
        # self.filter_num1 ~ self.filter_num4 means convolution channel
C
ceci3 已提交
36 37 38 39
        self.filter_num1 = np.array([48, 64, 96, 128, 160, 192, 224])  #7 
        self.filter_num2 = np.array([64, 96, 128, 160, 192, 256, 320])  #7
        self.filter_num3 = np.array([128, 160, 192, 256, 320, 384])  #6
        self.filter_num4 = np.array([192, 256, 384, 512, 640])  #5
C
ceci3 已提交
40
        # self.repeat1 ~ self.repeat4 means depth of network
C
ceci3 已提交
41 42 43 44
        self.repeat1 = [2, 3, 4, 5, 6]  #5
        self.repeat2 = [2, 3, 4, 5, 6, 7]  #6
        self.repeat3 = [2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24]  #13
        self.repeat4 = [2, 3, 4, 5, 6, 7]  #6
C
ceci3 已提交
45

C
ceci3 已提交
46
    def init_tokens(self):
C
ceci3 已提交
47 48 49
        """
        The initial token.
        """
C
ceci3 已提交
50
        init_token_base = [0, 0, 0, 0, 0, 0, 0, 0]
C
ceci3 已提交
51
        return init_token_base
C
ceci3 已提交
52 53

    def range_table(self):
C
ceci3 已提交
54 55 56
        """
        Get range table of current search space, constrains the range of tokens.
        """
C
ceci3 已提交
57 58 59 60 61
        range_table_base = [
            len(self.filter_num1), len(self.repeat1), len(self.filter_num2),
            len(self.repeat2), len(self.filter_num3), len(self.repeat3),
            len(self.filter_num4), len(self.repeat4)
        ]
C
ceci3 已提交
62
        return range_table_base
C
ceci3 已提交
63

C
ceci3 已提交
64
    def token2arch(self, tokens=None):
C
ceci3 已提交
65 66 67
        """
        return net_arch function
        """
C
ceci3 已提交
68
        if tokens is None:
C
ceci3 已提交
69
            tokens = self.init_tokens()
C
ceci3 已提交
70

C
ceci3 已提交
71 72
        depth = []
        num_filters = []
C
ceci3 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

        filter1 = self.filter_num1[tokens[0]]
        repeat1 = self.repeat1[tokens[1]]
        num_filters.append(filter1)
        depth.append(repeat1)
        filter2 = self.filter_num2[tokens[2]]
        repeat2 = self.repeat2[tokens[3]]
        num_filters.append(filter2)
        depth.append(repeat2)
        filter3 = self.filter_num3[tokens[4]]
        repeat3 = self.repeat3[tokens[5]]
        num_filters.append(filter3)
        depth.append(repeat3)
        filter4 = self.filter_num4[tokens[6]]
        repeat4 = self.repeat4[tokens[7]]
        num_filters.append(filter4)
        depth.append(repeat4)
C
ceci3 已提交
90

C
update  
ceci3 已提交
91 92 93
        def net_arch(input, return_block=None, end_points=None):
            decode_ends = dict()

C
ceci3 已提交
94 95 96 97 98 99 100
            conv = conv_bn_layer(
                input=input,
                filter_size=5,
                num_filters=filter1,
                stride=2,
                act='relu',
                name='resnet_conv0')
C
update  
ceci3 已提交
101
            layer_count = 1
C
ceci3 已提交
102 103
            for block in range(len(depth)):
                for i in range(depth[block]):
C
update  
ceci3 已提交
104 105 106 107 108 109 110 111 112
                    stride = 2 if i == 0 and block != 0 else 1
                    if stride == 2:
                        layer_count += 1
                    if check_points((layer_count - 1), return_block):
                        decode_ends[layer_count - 1] = conv

                    if check_points((layer_count - 1), end_points):
                        return conv, decode_ends

C
ceci3 已提交
113
                    conv = self._bottleneck_block(
C
ceci3 已提交
114 115
                        input=conv,
                        num_filters=num_filters[block],
C
update  
ceci3 已提交
116
                        stride=stride,
C
ceci3 已提交
117
                        name='resnet_depth{}_block{}'.format(i, block))
C
ceci3 已提交
118

C
update  
ceci3 已提交
119 120
            if check_points(layer_count, end_points):
                return conv, decode_ends
C
ceci3 已提交
121 122

            return conv
C
ceci3 已提交
123 124

        return net_arch
C
ceci3 已提交
125

C
ceci3 已提交
126
    def _shortcut(self, input, ch_out, stride, name=None):
C
ceci3 已提交
127 128
        ch_in = input.shape[1]
        if ch_in != ch_out or stride != 1:
C
ceci3 已提交
129 130 131 132 133 134
            return conv_bn_layer(
                input=input,
                filter_size=1,
                num_filters=ch_out,
                stride=stride,
                name=name + '_conv')
C
ceci3 已提交
135 136 137
        else:
            return input

C
ceci3 已提交
138
    def _bottleneck_block(self, input, num_filters, stride, name=None):
C
ceci3 已提交
139 140 141
        conv0 = conv_bn_layer(
            input=input,
            num_filters=num_filters,
C
ceci3 已提交
142
            filter_size=1,
C
ceci3 已提交
143
            act='relu',
C
ceci3 已提交
144
            name=name + '_bottleneck_conv0')
C
ceci3 已提交
145 146 147
        conv1 = conv_bn_layer(
            input=conv0,
            num_filters=num_filters,
C
ceci3 已提交
148 149 150 151 152 153 154 155
            filter_size=3,
            stride=stride,
            act='relu',
            name=name + '_bottleneck_conv1')
        conv2 = conv_bn_layer(
            input=conv1,
            num_filters=num_filters * 4,
            filter_size=1,
C
ceci3 已提交
156
            act=None,
C
ceci3 已提交
157 158
            name=name + '_bottleneck_conv2')

C
ceci3 已提交
159
        short = self._shortcut(
C
ceci3 已提交
160 161
            input, num_filters * 4, stride, name=name + '_shortcut')

C
ceci3 已提交
162
        return fluid.layers.elementwise_add(
C
ceci3 已提交
163
            x=short, y=conv2, act='relu', name=name + '_bottleneck_add')