提交 0073f078 编写于 作者: L lvmengsi

Merge branch 'fix_nas' into 'develop'

update mobilenetv2 for seg

See merge request !51
......@@ -114,38 +114,57 @@ class MobileNetV2Space(SearchSpaceBase):
if tokens is None:
tokens = self.init_tokens()
bottleneck_params_list = []
self.bottleneck_params_list = []
if self.block_num >= 1:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num4[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append(
self.bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input, end_points=None, decode_points=None):
def _modify_bottle_params(output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, layer_setting in enumerate(self.bottleneck_params_list):
t, c, n, s, ks = layer_setting
stride = stride * s
if stride > output_stride:
s = 1
self.bottleneck_params_list[i] = (t, c, n, s, ks)
def net_arch(input,
end_points=None,
decode_points=None,
output_stride=None):
_modify_bottle_params(output_stride)
decode_ends = dict()
def check_points(count, points):
......@@ -177,10 +196,11 @@ class MobileNetV2Space(SearchSpaceBase):
# bottleneck sequences
i = 1
in_c = int(32 * self.scale)
for layer_setting in bottleneck_params_list:
for layer_setting in self.bottleneck_params_list:
t, c, n, s, k = layer_setting
i += 1
input = self._invresi_blocks(
#print(input)
input, depthwise_output = self._invresi_blocks(
input=input,
in_c=in_c,
t=t,
......@@ -190,8 +210,9 @@ class MobileNetV2Space(SearchSpaceBase):
k=k,
name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale)
layer_count += n
layer_count += 1
### decode_points and end_points means block num
if check_points(layer_count, decode_points):
decode_ends[layer_count] = depthwise_output
......@@ -320,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase):
Returns:
Variable, layers output.
"""
first_block = self._inverted_residual_unit(
first_block, depthwise_output = self._inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册