提交 51c23929 编写于 作者: C ceci3

fix

上级 e598cb9c
...@@ -20,6 +20,10 @@ import mobilenetv1 ...@@ -20,6 +20,10 @@ import mobilenetv1
from .mobilenetv1 import * from .mobilenetv1 import *
import resnet import resnet
from .resnet import * from .resnet import *
import resnet_block
from .resnet_block import *
import inception_block
from .inception_block import *
import search_space_registry import search_space_registry
from search_space_registry import * from search_space_registry import *
import search_space_factory import search_space_factory
......
...@@ -29,6 +29,7 @@ __all__ = ["CombineSearchSpace"] ...@@ -29,6 +29,7 @@ __all__ = ["CombineSearchSpace"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
class CombineSearchSpace(object): class CombineSearchSpace(object):
""" """
Combine Search Space. Combine Search Space.
...@@ -42,11 +43,13 @@ class CombineSearchSpace(object): ...@@ -42,11 +43,13 @@ class CombineSearchSpace(object):
for config_list in config_lists: for config_list in config_lists:
if isinstance(config_list, tuple): if isinstance(config_list, tuple):
key, config = config_list key, config = config_list
if isinstance(config_list, str): elif isinstance(config_list, str):
key = config_list key = config_list
config = None config = None
else: else:
raise NotImplementedError('the type of config is Error!!! Please check the config information. Receive the type of config is {}'.format(type(config_list))) raise NotImplementedError(
'the type of config is Error!!! Please check the config information. Receive the type of config is {}'.
format(type(config_list)))
self.spaces.append(self._get_single_search_space(key, config)) self.spaces.append(self._get_single_search_space(key, config))
self.init_tokens() self.init_tokens()
...@@ -61,6 +64,8 @@ class CombineSearchSpace(object): ...@@ -61,6 +64,8 @@ class CombineSearchSpace(object):
model space(class) model space(class)
""" """
cls = SEARCHSPACE.get(key) cls = SEARCHSPACE.get(key)
assert cls != None, '{} is NOT a correct space, the space we support is {}'.format(
key, SEARCHSPACE)
if config is None: if config is None:
block_mask = None block_mask = None
...@@ -69,21 +74,27 @@ class CombineSearchSpace(object): ...@@ -69,21 +74,27 @@ class CombineSearchSpace(object):
block_num = None block_num = None
else: else:
if 'Block' not in cls.__name__: if 'Block' not in cls.__name__:
_logger.warn('if space is not a Block space, config is useless, current space is {}'.format(cls.__name__)) _logger.warn(
'if space is not a Block space, config is useless, current space is {}'.
block_mask = config['block_mask'] if 'block_mask' in config else None format(cls.__name__))
input_size = config['input_size'] if 'input_size' in config else None
output_size = config['output_size'] if 'output_size' in config else None block_mask = config[
'block_mask'] if 'block_mask' in config else None
input_size = config[
'input_size'] if 'input_size' in config else None
output_size = config[
'output_size'] if 'output_size' in config else None
block_num = config['block_num'] if 'block_num' in config else None block_num = config['block_num'] if 'block_num' in config else None
if 'Block' in cls.__name__: if 'Block' in cls.__name__:
if block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None): if block_mask == None and (self.block_num == None or
raise NotImplementedError("block_mask or (block num and input_size and output_size) can NOT be None at the same time in Block SPACE!") self.input_size == None or
self.output_size == None):
space = cls(input_size, raise NotImplementedError(
output_size, "block_mask or (block num and input_size and output_size) can NOT be None at the same time in Block SPACE!"
block_num, )
block_mask=block_mask)
space = cls(input_size, output_size, block_num, block_mask=block_mask)
return space return space
def init_tokens(self): def init_tokens(self):
......
...@@ -81,7 +81,7 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -81,7 +81,7 @@ class InceptionABlockSpace(SearchSpaceBase):
range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.filter_num))
range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.filter_num))
range_table_base.append(len(self.k_size)) range_table_base.append(len(self.k_size))
range_table_base.append(len(self.pooltype)) range_table_base.append(len(self.pool_type))
return range_table_base return range_table_base
...@@ -97,51 +97,69 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -97,51 +97,69 @@ class InceptionABlockSpace(SearchSpaceBase):
if self.block_mask != None: if self.block_mask != None:
for i in range(len(self.block_mask)): for i in range(len(self.block_mask)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[i * 9], self.filter_num[i * 9 + 1], (self.filter_num[tokens[i * 9]],
self.filter_num[i * 9 + 2], self.filter_num[i * 9 + 3], self.filter_num[tokens[i * 9 + 1]],
self.filter_num[i * 9 + 4], self.filter_num[i * 9 + 5], self.filter_num[tokens[i * 9 + 2]],
self.filter_num[i * 9 + 6], self.k_size[i * 9 + 7], 2 if self.filter_num[tokens[i * 9 + 3]],
self.block_mask == 1 else 1, self.pool_type[i * 9 + 8])) self.filter_num[tokens[i * 9 + 4]],
self.filter_num[tokens[i * 9 + 5]],
self.filter_num[tokens[i * 9 + 6]],
self.k_size[tokens[i * 9 + 7]], 2 if self.block_mask == 1
else 1, self.pool_type[tokens[i * 9 + 8]]))
else: else:
repeat_num = self.block_num / self.downsample_num repeat_num = self.block_num / self.downsample_num
num_minus = self.block_num % self.downsample_num num_minus = self.block_num % self.downsample_num
### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers ### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers
for i in range(self.downsample_num): for i in range(self.downsample_num):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[i * 9], self.filter_num[i * 9 + 1], (self.filter_num[tokens[i * 9]],
self.filter_num[i * 9 + 2], self.filter_num[i * 9 + 3], self.filter_num[tokens[i * 9 + 1]],
self.filter_num[i * 9 + 4], self.filter_num[i * 9 + 5], self.filter_num[tokens[i * 9 + 2]],
self.filter_num[i * 9 + 6], self.k_size[i * 9 + 7], 2, self.filter_num[tokens[i * 9 + 3]],
self.pool_type[i * 9 + 8])) self.filter_num[tokens[i * 9 + 4]],
self.filter_num[tokens[i * 9 + 5]],
self.filter_num[tokens[i * 9 + 6]],
self.k_size[tokens[i * 9 + 7]], 2,
self.pool_type[tokens[i * 9 + 8]]))
### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block ### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block
for k in range(repeat_num - 1): for k in range(repeat_num - 1):
kk = k * self.downsample_num + i kk = k * self.downsample_num + i
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[kk * 9], self.filter_num[kk * 9 + 1], (self.filter_num[tokens[kk * 9]],
self.filter_num[kk * 9 + 2], self.filter_num[tokens[kk * 9 + 1]],
self.filter_num[kk * 9 + 3], self.filter_num[tokens[kk * 9 + 2]],
self.filter_num[kk * 9 + 4], self.filter_num[tokens[kk * 9 + 3]],
self.filter_num[kk * 9 + 5], self.filter_num[tokens[kk * 9 + 4]],
self.filter_num[kk * 9 + 6], self.k_size[kk * 9 + 7], self.filter_num[tokens[kk * 9 + 5]],
1, self.pool_type[kk * 9 + 8])) self.filter_num[tokens[kk * 9 + 6]],
self.k_size[tokens[kk * 9 + 7]], 1,
self.pool_type[tokens[kk * 9 + 8]]))
if self.downsample_num - i <= num_minus: if self.downsample_num - i <= num_minus:
j = self.downsample_num * repeat_num + i j = self.downsample_num * repeat_num + i
self.bottleneck_params_list.append(( self.bottleneck_params_list.append(
self.filter_num[j * 9], self.filter_num[j * 9 + 1], (self.filter_num[tokens[j * 9]],
self.filter_num[j * 9 + 2], self.filter_num[j * 9 + 3], self.filter_num[tokens[j * 9 + 1]],
self.filter_num[j * 9 + 4], self.filter_num[j * 9 + 5], self.filter_num[tokens[j * 9 + 2]],
self.filter_num[j * 9 + 6], self.k_size[j * 9 + 7], 1, self.filter_num[tokens[j * 9 + 3]],
self.pool_type[j * 9 + 8])) self.filter_num[tokens[j * 9 + 4]],
self.filter_num[tokens[j * 9 + 5]],
self.filter_num[tokens[j * 9 + 6]],
self.k_size[tokens[j * 9 + 7]], 1,
self.pool_type[tokens[j * 9 + 8]]))
if self.downsample_num == 0 and self.block_num != 0: if self.downsample_num == 0 and self.block_num != 0:
for i in range(len(self.block_num)): for i in range(len(self.block_num)):
self.bottleneck_params_list.append(( self.bottleneck_params_list.append(
self.filter_num[i * 9], self.filter_num[i * 9 + 1], (self.filter_num[tokens[i * 9]],
self.filter_num[i * 9 + 2], self.filter_num[i * 9 + 3], self.filter_num[tokens[i * 9 + 1]],
self.filter_num[i * 9 + 4], self.filter_num[i * 9 + 5], self.filter_num[tokens[i * 9 + 2]],
self.filter_num[i * 9 + 6], self.k_size[i * 9 + 7], 1, self.filter_num[tokens[i * 9 + 3]],
self.pool_type[i * 9 + 8])) self.filter_num[tokens[i * 9 + 4]],
self.filter_num[tokens[i * 9 + 5]],
self.filter_num[tokens[i * 9 + 6]],
self.k_size[tokens[i * 9 + 7]], 1,
self.pool_type[tokens[i * 9 + 8]]))
def net_arch(input, return_mid_layer=False, return_block=[]): def net_arch(input, return_mid_layer=False, return_block=[]):
assert isinstance(return_block, assert isinstance(return_block,
...@@ -169,7 +187,7 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -169,7 +187,7 @@ class InceptionABlockSpace(SearchSpaceBase):
if return_mid_layer: if return_mid_layer:
return input, mid_layer return input, mid_layer
else: else:
return input return input,
return net_arch return net_arch
...@@ -247,7 +265,7 @@ class InceptionABlockSpace(SearchSpaceBase): ...@@ -247,7 +265,7 @@ class InceptionABlockSpace(SearchSpaceBase):
@SEARCHSPACE.register @SEARCHSPACE.register
class InceptionCBlockSpace(SearchSpaceBase): class InceptionCBlockSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, block_mask): def __init__(self, input_size, output_size, block_num, block_mask):
super(InceptionABlockSpace, self).__init__(input_size, output_size, super(InceptionCBlockSpace, self).__init__(input_size, output_size,
block_num, block_mask) block_num, block_mask)
if self.block_mask == None: if self.block_mask == None:
# use input_size and output_size to compute self.downsample_num # use input_size and output_size to compute self.downsample_num
...@@ -274,9 +292,9 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -274,9 +292,9 @@ class InceptionCBlockSpace(SearchSpaceBase):
The initial token. The initial token.
""" """
if self.block_mask != None: if self.block_mask != None:
return [0] * (len(self.block_mask) * 9) return [0] * (len(self.block_mask) * 11)
else: else:
return [0] * (self.block_num * 9) return [0] * (self.block_num * 11)
def range_table(self): def range_table(self):
""" """
...@@ -297,7 +315,7 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -297,7 +315,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.filter_num))
range_table_base.append(len(self.filter_num)) range_table_base.append(len(self.filter_num))
range_table_base.append(len(self.k_size)) range_table_base.append(len(self.k_size))
range_table_base.append(len(self.pooltype)) range_table_base.append(len(self.pool_type))
return range_table_base return range_table_base
...@@ -313,63 +331,79 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -313,63 +331,79 @@ class InceptionCBlockSpace(SearchSpaceBase):
if self.block_mask != None: if self.block_mask != None:
for i in range(len(self.block_mask)): for i in range(len(self.block_mask)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[i * 11], self.filter_num[i * 11 + 1], (self.filter_num[tokens[i * 11]],
self.filter_num[i * 11 + 2], self.filter_num[i * 11 + 3], self.filter_num[tokens[i * 11 + 1]],
self.filter_num[i * 11 + 4], self.filter_num[i * 11 + 5], self.filter_num[tokens[i * 11 + 2]],
self.filter_num[i * 11 + 6], self.filter_num[i * 11 + 7], self.filter_num[tokens[i * 11 + 3]],
self.filter_num[i * 11 + 8], self.k_size[i * 11 + 9], 2 if self.filter_num[tokens[i * 11 + 4]],
self.block_mask == 1 else 1, self.pool_type[i * 11 + 10])) self.filter_num[tokens[i * 11 + 5]],
self.filter_num[tokens[i * 11 + 6]],
self.filter_num[tokens[i * 11 + 7]],
self.filter_num[tokens[i * 11 + 8]],
self.k_size[tokens[i * 11 + 9]], 2 if self.block_mask == 1
else 1, self.pool_type[tokens[i * 11 + 10]]))
else: else:
repeat_num = self.block_num / self.downsample_num repeat_num = self.block_num / self.downsample_num
num_minus = self.block_num % self.downsample_num num_minus = self.block_num % self.downsample_num
### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers ### if block_num > downsample_num, add stride=1 block at last (block_num-downsample_num) layers
for i in range(self.downsample_num): for i in range(self.downsample_num):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[i * 11], self.filter_num[i * 11 + 1], (self.filter_num[tokens[i * 11]],
self.filter_num[i * 11 + 2], self.filter_num[i * 11 + 3], self.filter_num[tokens[i * 11 + 1]],
self.filter_num[i * 11 + 4], self.filter_num[i * 11 + 5], self.filter_num[tokens[i * 11 + 2]],
self.filter_num[i * 11 + 6], self.filter_num[i * 11 + 7], self.filter_num[tokens[i * 11 + 3]],
self.filter_num[i * 11 + 8], self.k_size[i * 11 + 9], 2, self.filter_num[tokens[i * 11 + 4]],
self.pool_type[i * 11 + 10])) self.filter_num[tokens[i * 11 + 5]],
self.filter_num[tokens[i * 11 + 6]],
self.filter_num[tokens[i * 11 + 7]],
self.filter_num[tokens[i * 11 + 8]],
self.k_size[tokens[i * 11 + 9]], 2,
self.pool_type[tokens[i * 11 + 10]]))
### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block ### if block_num / downsample_num > 1, add (block_num / downsample_num) times stride=1 block
for k in range(repeat_num - 1): for k in range(repeat_num - 1):
kk = k * self.downsample_num + i kk = k * self.downsample_num + i
self.bottleneck_params_list.append(( self.bottleneck_params_list.append(
self.filter_num[kk * 11], self.filter_num[kk * 11 + 1], (self.filter_num[tokens[kk * 11]],
self.filter_num[kk * 11 + 2], self.filter_num[tokens[kk * 11 + 1]],
self.filter_num[kk * 11 + 3], self.filter_num[tokens[kk * 11 + 2]],
self.filter_num[kk * 11 + 4], self.filter_num[tokens[kk * 11 + 3]],
self.filter_num[kk * 11 + 5], self.filter_num[tokens[kk * 11 + 4]],
self.filter_num[kk * 11 + 6], self.filter_num[tokens[kk * 11 + 5]],
self.filter_num[kk * 11 + 7], self.filter_num[tokens[kk * 11 + 6]],
self.filter_num[kk * 11 + 8], self.k_size[kk * 11 + 9], self.filter_num[tokens[kk * 11 + 7]],
1, self.pool_type[kk * 11 + 10])) self.filter_num[tokens[kk * 11 + 8]],
self.k_size[tokens[kk * 11 + 9]], 1,
self.pool_type[tokens[kk * 11 + 10]]))
if self.downsample_num - i <= num_minus: if self.downsample_num - i <= num_minus:
j = self.downsample_num * repeat_num + i j = self.downsample_num * repeat_num + i
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[j * 11], self.filter_num[j * 11 + 1], (self.filter_num[tokens[j * 11]],
self.filter_num[j * 11 + 2], self.filter_num[tokens[j * 11 + 1]],
self.filter_num[j * 11 + 3], self.filter_num[tokens[j * 11 + 2]],
self.filter_num[j * 11 + 4], self.filter_num[tokens[j * 11 + 3]],
self.filter_num[j * 11 + 5], self.filter_num[tokens[j * 11 + 4]],
self.filter_num[j * 11 + 6], self.filter_num[tokens[j * 11 + 5]],
self.filter_num[j * 11 + 7], self.filter_num[tokens[j * 11 + 6]],
self.filter_num[j * 11 + 8], self.k_size[j * 11 + 9], self.filter_num[tokens[j * 11 + 7]],
1, self.pool_type[j * 11 + 10])) self.filter_num[tokens[j * 11 + 8]],
self.k_size[tokens[j * 11 + 9]], 1,
self.pool_type[tokens[j * 11 + 10]]))
if self.downsample_num == 0 and self.block_num != 0: if self.downsample_num == 0 and self.block_num != 0:
for i in range(len(self.block_num)): for i in range(len(self.block_num)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.filter_num[i * 11], self.filter_num[i * 11 + 1], (self.filter_num[tokens[i * 11]],
self.filter_num[i * 11 + 2], self.filter_num[tokens[i * 11 + 1]],
self.filter_num[i * 11 + 3], self.filter_num[tokens[i * 11 + 2]],
self.filter_num[i * 11 + 4], self.filter_num[tokens[i * 11 + 3]],
self.filter_num[i * 11 + 5], self.filter_num[tokens[i * 11 + 4]],
self.filter_num[i * 11 + 6], self.filter_num[tokens[i * 11 + 5]],
self.filter_num[i * 11 + 7], self.filter_num[tokens[i * 11 + 6]],
self.filter_num[i * 11 + 8], self.k_size[i * 11 + 9], self.filter_num[tokens[i * 11 + 7]],
1, self.pool_type[i * 11 + 10])) self.filter_num[tokens[i * 11 + 8]],
self.k_size[tokens[i * 11 + 9]], 1,
self.pool_type[tokens[i * 11 + 10]]))
def net_arch(input, return_mid_layer=False, return_block=[]): def net_arch(input, return_mid_layer=False, return_block=[]):
assert isinstance(return_block, assert isinstance(return_block,
...@@ -397,7 +431,7 @@ class InceptionCBlockSpace(SearchSpaceBase): ...@@ -397,7 +431,7 @@ class InceptionCBlockSpace(SearchSpaceBase):
if return_mid_layer: if return_mid_layer:
return input, mid_layer return input, mid_layer
else: else:
return input return input,
return net_arch return net_arch
......
...@@ -334,7 +334,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase): ...@@ -334,7 +334,7 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
if tokens == None: if tokens == None:
tokens = self.init_tokens() tokens = self.init_tokens()
self.bottleneck_param_list = [] self.bottleneck_params_list = []
if self.block_mask != None: if self.block_mask != None:
for i in range(len(self.block_mask)): for i in range(len(self.block_mask)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
...@@ -391,7 +391,6 @@ class MobileNetV1BlockSpace(SearchSpaceBase): ...@@ -391,7 +391,6 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
input=input, input=input,
num_filters1=filter_num1, num_filters1=filter_num1,
num_filters2=filter_num2, num_filters2=filter_num2,
num_groups=filter_num1,
stride=stride, stride=stride,
scale=self.scale, scale=self.scale,
kernel_size=kernel_size, kernel_size=kernel_size,
...@@ -408,17 +407,17 @@ class MobileNetV1BlockSpace(SearchSpaceBase): ...@@ -408,17 +407,17 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
input, input,
num_filters1, num_filters1,
num_filters2, num_filters2,
num_groups,
stride, stride,
scale, scale,
kernel_size, kernel_size,
name=None): name=None):
num_groups = input.shape[1]
depthwise_conv = conv_bn_layer( depthwise_conv = conv_bn_layer(
input=input, input=input,
filter_size=kernel_size, filter_size=kernel_size,
num_filters=int(num_filters1 * scale), num_filters=int(num_filters1 * scale),
stride=stride, stride=stride,
num_groups=int(num_groups * scale), num_groups=num_groups,
use_cudnn=False, use_cudnn=False,
name=name + '_dw') name=name + '_dw')
pointwise_conv = conv_bn_layer( pointwise_conv = conv_bn_layer(
......
...@@ -29,9 +29,9 @@ __all__ = ["ResNetBlockSpace"] ...@@ -29,9 +29,9 @@ __all__ = ["ResNetBlockSpace"]
@SEARCHSPACE.register @SEARCHSPACE.register
class ResNetBlockSpace(SearchSpaceBase): class ResNetBlockSpace(SearchSpaceBase):
def __init__(input_size, output_size, block_num, block_mask): def __init__(self, input_size, output_size, block_num, block_mask=None):
super(ResNetSpace, self).__init__(input_size, output_size, block_num, super(ResNetBlockSpace, self).__init__(input_size, output_size,
block_mask) block_num, block_mask)
# use input_size and output_size to compute self.downsample_num # use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size, self.downsample_num = compute_downsample_num(self.input_size,
self.output_size) self.output_size)
...@@ -52,6 +52,14 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -52,6 +52,14 @@ class ResNetBlockSpace(SearchSpaceBase):
def range_table(self): def range_table(self):
range_table_base = [] range_table_base = []
if self.block_mask != None:
range_table_length = len(self.block_mask)
else:
range_table_length = self.block_mum
for i in range(range_table_length):
range_table_base.append(len(self.filter_num))
range_table_base.append(len(self.k_size))
return range_table_base return range_table_base
...@@ -63,39 +71,39 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -63,39 +71,39 @@ class ResNetBlockSpace(SearchSpaceBase):
if self.block_mask != None: if self.block_mask != None:
for i in range(len(self.block_mask)): for i in range(len(self.block_mask)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.num_filters[tokens[i * 2]], (self.filter_num[tokens[i * 2]],
self.kernel_size[tokens[i * 2 + 1]], 2 self.k_size[tokens[i * 2 + 1]], 2
if self.block_mask[i] == 1 else 1)) if self.block_mask[i] == 1 else 1))
else: else:
repeat_num = self.block_num / self.downsample_num repeat_num = self.block_num / self.downsample_num
num_minus = self.block_num % self.downsample_num num_minus = self.block_num % self.downsample_num
for i in range(self.downsample_num): for i in range(self.downsample_num):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
self.num_filters[tokens[i * 2]], self.filter_num[tokens[i * 2]],
self.kernel_size[tokens[i * 2 + 1]], 2) self.k_size[tokens[i * 2 + 1]], 2)
for k in range(repeat_num - 1): for k in range(repeat_num - 1):
kk = k * self.downsample_num + i kk = k * self.downsample_num + i
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
self.num_filters[tokens[kk * 2]], self.filter_num[tokens[kk * 2]],
self.kernel_size[tokens[kk * 2 + 1]], 1) self.k_size[tokens[kk * 2 + 1]], 1)
if self.downsample_num - i <= num_minus: if self.downsample_num - i <= num_minus:
j = self.downsample_num * repeat_num + i j = self.downsample_num * repeat_num + i
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
self.num_filters[tokens[j * 2]], self.filter_num[tokens[j * 2]],
self.kernel_size[tokens[j * 2 + 1]], 1) self.k_size[tokens[j * 2 + 1]], 1)
if self.downsample_num == 0 and self.block_num != 0: if self.downsample_num == 0 and self.block_num != 0:
for i in range(len(self.block_num)): for i in range(len(self.block_num)):
self.bottleneck_params_list.append( self.bottleneck_params_list.append(
self.num_filters[tokens[i * 2]], self.filter_num[tokens[i * 2]],
self.kernel_size[tokens[i * 2 + 1]], 1) self.k_size[tokens[i * 2 + 1]], 1)
def net_arch(input, return_mid_layer=False, return_block=[]): def net_arch(input, return_mid_layer=False, return_block=[]):
assert isinstance(return_block, assert isinstance(return_block,
list), 'return_block must be a list.' list), 'return_block must be a list.'
layer_count = 0 layer_count = 0
mid_layer = dict() mid_layer = dict()
for layer_setting in self.bottleneck_params_list: for i, layer_setting in enumerate(self.bottleneck_params_list):
filter_num, k_size, stride = layer_setting filter_num, k_size, stride = layer_setting
if stride == 2: if stride == 2:
layer_count += 1 layer_count += 1
...@@ -138,6 +146,7 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -138,6 +146,7 @@ class ResNetBlockSpace(SearchSpaceBase):
input=input, input=input,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
stride=1,
act='relu', act='relu',
name=name + '_bottleneck_conv0') name=name + '_bottleneck_conv0')
conv1 = conv_bn_layer( conv1 = conv_bn_layer(
...@@ -151,6 +160,7 @@ class ResNetBlockSpace(SearchSpaceBase): ...@@ -151,6 +160,7 @@ class ResNetBlockSpace(SearchSpaceBase):
input=conv1, input=conv1,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
stride=1,
act=None, act=None,
name=name + '_bottleneck_conv2') name=name + '_bottleneck_conv2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册