提交 5c922020 编写于 作者: C ceci3

update mobilenetv2 for seg

上级 661d48af
...@@ -114,38 +114,57 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -114,38 +114,57 @@ class MobileNetV2Space(SearchSpaceBase):
if tokens is None: if tokens is None:
tokens = self.init_tokens() tokens = self.init_tokens()
bottleneck_params_list = [] self.bottleneck_params_list = []
if self.block_num >= 1: if self.block_num >= 1:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3)) (1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2: if self.block_num >= 2:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3: if self.block_num >= 3:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]], (self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4: if self.block_num >= 4:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]], (self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5: if self.block_num >= 5:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]], (self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num4[tokens[18]], (self.multiply[tokens[17]], self.filter_num4[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6: if self.block_num >= 6:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]], (self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]], (self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input, end_points=None, decode_points=None): def _modify_bottle_params(output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, layer_setting in enumerate(self.bottleneck_params_list):
t, c, n, s, ks = layer_setting
stride = stride * s
if stride > output_stride:
s = 1
self.bottleneck_params_list[i] = (t, c, n, s, ks)
def net_arch(input,
end_points=None,
decode_points=None,
output_stride=None):
_modify_bottle_params(output_stride)
decode_ends = dict() decode_ends = dict()
def check_points(count, points): def check_points(count, points):
...@@ -177,10 +196,11 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -177,10 +196,11 @@ class MobileNetV2Space(SearchSpaceBase):
# bottleneck sequences # bottleneck sequences
i = 1 i = 1
in_c = int(32 * self.scale) in_c = int(32 * self.scale)
for layer_setting in bottleneck_params_list: for layer_setting in self.bottleneck_params_list:
t, c, n, s, k = layer_setting t, c, n, s, k = layer_setting
i += 1 i += 1
input = self._invresi_blocks( #print(input)
input, depthwise_output = self._invresi_blocks(
input=input, input=input,
in_c=in_c, in_c=in_c,
t=t, t=t,
...@@ -190,8 +210,9 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -190,8 +210,9 @@ class MobileNetV2Space(SearchSpaceBase):
k=k, k=k,
name='mobilenetv2_conv' + str(i)) name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale) in_c = int(c * self.scale)
layer_count += n layer_count += 1
### decode_points and end_points means block num
if check_points(layer_count, decode_points): if check_points(layer_count, decode_points):
decode_ends[layer_count] = depthwise_output decode_ends[layer_count] = depthwise_output
...@@ -320,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -320,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase):
Returns: Returns:
Variable, layers output. Variable, layers output.
""" """
first_block = self._inverted_residual_unit( first_block, depthwise_output = self._inverted_residual_unit(
input=input, input=input,
num_in_filter=in_c, num_in_filter=in_c,
num_filters=c, num_filters=c,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册