提交 00ae4353 编写于 作者: L lvmengsi

Merge branch 'fix_nas' into 'develop'

update mobilenetv2 space

See merge request !50
...@@ -145,7 +145,18 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -145,7 +145,18 @@ class MobileNetV2Space(SearchSpaceBase):
(self.multiply[tokens[25]], self.filter_num6[tokens[26]], (self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input): def net_arch(input, end_points=None, decode_points=None):
decode_ends = dict()
def check_points(count, points):
if points is None:
return False
else:
if isinstance(points, list):
return (True if count in points else False)
else:
return (True if count == points else False)
#conv1 #conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer( input = conv_bn_layer(
...@@ -156,6 +167,12 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -156,6 +167,12 @@ class MobileNetV2Space(SearchSpaceBase):
padding='SAME', padding='SAME',
act='relu6', act='relu6',
name='mobilenetv2_conv1_1') name='mobilenetv2_conv1_1')
layer_count = 1
if check_points(layer_count, decode_points):
decode_ends[layer_count] = input
if check_points(layer_count, end_points):
return input, decode_ends
# bottleneck sequences # bottleneck sequences
i = 1 i = 1
...@@ -173,6 +190,13 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -173,6 +190,13 @@ class MobileNetV2Space(SearchSpaceBase):
k=k, k=k,
name='mobilenetv2_conv' + str(i)) name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale) in_c = int(c * self.scale)
layer_count += n
if check_points(layer_count, decode_points):
decode_ends[layer_count] = depthwise_output
if check_points(layer_count, end_points):
return input, decode_ends
# last conv # last conv
input = conv_bn_layer( input = conv_bn_layer(
...@@ -266,6 +290,8 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -266,6 +290,8 @@ class MobileNetV2Space(SearchSpaceBase):
name=name + '_dwise', name=name + '_dwise',
use_cudnn=False) use_cudnn=False)
depthwise_output = bottleneck_conv
linear_out = conv_bn_layer( linear_out = conv_bn_layer(
input=bottleneck_conv, input=bottleneck_conv,
num_filters=num_filters, num_filters=num_filters,
...@@ -278,7 +304,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -278,7 +304,7 @@ class MobileNetV2Space(SearchSpaceBase):
out = linear_out out = linear_out
if ifshortcut: if ifshortcut:
out = self._shortcut(input=input, data_residual=out) out = self._shortcut(input=input, data_residual=out)
return out return out, depthwise_output
def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None):
"""Build inverted residual blocks. """Build inverted residual blocks.
...@@ -308,7 +334,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -308,7 +334,7 @@ class MobileNetV2Space(SearchSpaceBase):
last_c = c last_c = c
for i in range(1, n): for i in range(1, n):
last_residual_block = self._inverted_residual_unit( last_residual_block, depthwise_output = self._inverted_residual_unit(
input=last_residual_block, input=last_residual_block,
num_in_filter=last_c, num_in_filter=last_c,
num_filters=c, num_filters=c,
...@@ -317,4 +343,4 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -317,4 +343,4 @@ class MobileNetV2Space(SearchSpaceBase):
filter_size=k, filter_size=k,
expansion_factor=t, expansion_factor=t,
name=name + '_' + str(i + 1)) name=name + '_' + str(i + 1))
return last_residual_block return last_residual_block, depthwise_output
...@@ -19,16 +19,19 @@ class SearchSpaceBase(object): ...@@ -19,16 +19,19 @@ class SearchSpaceBase(object):
"""Controller for Neural Architecture Search. """Controller for Neural Architecture Search.
""" """
def __init__(self, def __init__(self, input_size, output_size, block_num, block_mask, *args):
input_size, """init model config
output_size, """
block_num,
block_mask=None,
*argss):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.block_num = block_num self.block_num = block_num
self.block_mask = block_mask self.block_mask = block_mask
if self.block_mask is not None:
assert isinstance(self.block_mask,
list), 'Block_mask must be a list.'
print(
"If block_mask is NOT None, we will use block_mask as major configs!"
)
def init_tokens(self): def init_tokens(self):
"""Get init tokens in search space. """Get init tokens in search space.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册