提交 907c7473 编写于 作者: C ceci3

update mobilenetv2

上级 b4314f6d
......@@ -70,7 +70,13 @@ class MobileNetV2Space(SearchSpaceBase):
4, 7, 2, 0, # 6, 160, 3
4, 9, 0, 0] # 6, 320, 1
# yapf: enable
return init_token_base#[:self.tokens_lens]
if self.block_num < 5:
self.token_len = 1 + (self.block_num - 1) * 4
else:
self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4
return init_token_base[:self.token_len]
def range_table(self):
"""
......@@ -87,50 +93,37 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2,
5, 12, 6, 2]
# yapf: enable
return range_table_base#[:self.tokens_lens]
return range_table_base[:self.token_len]
def token2arch(self, tokens=None):
"""
return net_arch function
"""
if tokens is None:
tokens = self.init_tokens()
base_bottleneck_params_list = [
(1, self.head_num[tokens[0]], 1, 1, 3),
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]),
(self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]),
(self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]),
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]),
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]),
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]),
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]),
]
assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(
self.block_num)
# the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1,
# otherwise, add layers to params_list directly.
bottleneck_params_list = []
for param_list in base_bottleneck_params_list:
if param_list[3] == 1:
bottleneck_params_list.append(param_list)
else:
if self.block_num > 1:
bottleneck_params_list.append(param_list)
self.block_num -= 1
else:
break
if tokens is None:
tokens = self.init_tokens()
self.tokens_lens = 1 + (len(bottleneck_params_list) - 1) * 4
bottleneck_params_list = []
if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5:
bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input):
#conv1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册