提交 68e50dc3 编写于 作者: C ceci3

update mobilenet_block

上级 0cd9aa5e
......@@ -38,8 +38,9 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
super(MobileNetV2BlockSpace, self).__init__(input_size, output_size,
block_num, block_mask)
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size, self.output_size)
if self.block_mask == None:
# use input_size and output_size to compute self.downsample_num
self.downsample_num = compute_downsample_num(self.input_size, self.output_size)
if self.block_num != None:
assert self.downsample_num <= self.block_num, 'downsample numeber must be LESS THAN OR EQUAL TO block_num, but NOW: downsample numeber is {}, block_num is {}'.format(self.downsample_num, self.block_num)
......@@ -111,12 +112,23 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
self.bottleneck_params_list.append(self.mutiply[tokens[j * 4], self.filter_num[tokens[j * 4 + 1]],
self.repeat[tokens[j * 4 + 2]], 1, self.k_size[tokens[j * 4 + 3]])
def net_arch(input, return_mid_layer=False):
def net_arch(input, return_mid_layer=False, return_block=[]):
assert isinstance(return_block, list), 'return_block must be a list.'
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
# bottleneck sequences
in_c = int(32 * self.scale)
mid_layer = dict()
layer_count = 0
depthwise_conv = None
for i, layer_setting in enumerate(self.bottleneck_params_list):
t, c, n, s, k = layer_setting
if s == 2:
layer_count += 1
if (layer_count - 1) in return_block:
mid_layer[layer_count] = depthwise_conv
input, depthwise_conv = self._invresi_blocks(
input=input,
in_c=in_c,
......@@ -129,7 +141,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase):
in_c = int(c * self.scale)
if return_mid_layer:
return input, depthwise_conv
return input, mid_layer
else:
return input,
......@@ -309,12 +321,19 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
self.bottleneck_params_list.append(self.filter_num[tokens[j * 3], self.filter_num[tokens[j * 3 + 1]],
1, self.k_size[tokens[j * 3 + 2]])
def net_arch(input, return_mid_layer=False):
def net_arch(input):
def net_arch(input, return_mid_layer=False, return_block=[]):
assert isinstance(return_block, list), 'return_block must be a list.'
mid_layer = dict()
layer_count = 0
for i, layer_setting in enumerate(self.bottleneck_params_list):
filter_num1, filter_num2, stride, kernel_size = layer_setting
if stride == 2:
layer_count += 1
if (layer_count - 1) in return_block:
mid_layer[layer_count] = input
input = self._depthwise_separable(
input=input,
num_filters1=filter_num1,
......@@ -324,7 +343,10 @@ class MobileNetV1BlockSpace(SearchSpaceBase):
scale=self.scale,
kernel_size=kernel_size,
name='mobilenetv1_{}'.format(str(i + 1)))
return input
if return_mid_layer:
return input, mid_layer
else:
return input
return net_arch
......
......@@ -31,8 +31,8 @@ class SearchSpaceBase(object):
"If block_mask is NOT None, we will use block_mask as major configs!"
)
self.block_num = None
if self.block_mask == None and self.block_num == None:
print("block_mask and block num can NOT be None at the same time!")
if self.block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None):
print("block_mask and (block num or input_size or output_size) can NOT be None at the same time!")
def init_tokens(self):
"""Get init tokens in search space.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册