From 68e50dc37277a5a3e41b57866a0f64d2f93818c4 Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Mon, 2 Dec 2019 07:03:29 +0000 Subject: [PATCH] update mobilenet_block --- .../nas/search_space/mobilenet_block.py | 38 +++++++++++++++---- .../nas/search_space/search_space_base.py | 4 +- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/paddleslim/nas/search_space/mobilenet_block.py b/paddleslim/nas/search_space/mobilenet_block.py index 39ce4050..93bdfb29 100644 --- a/paddleslim/nas/search_space/mobilenet_block.py +++ b/paddleslim/nas/search_space/mobilenet_block.py @@ -38,8 +38,9 @@ class MobileNetV2BlockSpace(SearchSpaceBase): super(MobileNetV2BlockSpace, self).__init__(input_size, output_size, block_num, block_mask) - # use input_size and output_size to compute self.downsample_num - self.downsample_num = compute_downsample_num(self.input_size, self.output_size) + if self.block_mask == None: + # use input_size and output_size to compute self.downsample_num + self.downsample_num = compute_downsample_num(self.input_size, self.output_size) if self.block_num != None: assert self.downsample_num <= self.block_num, 'downsample numeber must be LESS THAN OR EQUAL TO block_num, but NOW: downsample numeber is {}, block_num is {}'.format(self.downsample_num, self.block_num) @@ -111,12 +112,23 @@ class MobileNetV2BlockSpace(SearchSpaceBase): self.bottleneck_params_list.append(self.mutiply[tokens[j * 4], self.filter_num[tokens[j * 4 + 1]], self.repeat[tokens[j * 4 + 2]], 1, self.k_size[tokens[j * 4 + 3]]) - def net_arch(input, return_mid_layer=False): + def net_arch(input, return_mid_layer=False, return_block=[]): + assert isinstance(return_block, list), 'return_block must be a list.' # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # bottleneck sequences in_c = int(32 * self.scale) + mid_layer = dict() + layer_count = 0 + depthwise_conv = None + for i, layer_setting in enumerate(self.bottleneck_params_list): t, c, n, s, k = layer_setting + + if s == 2: + layer_count += 1 + if (layer_count - 1) in return_block: + mid_layer[layer_count] = depthwise_conv + input, depthwise_conv = self._invresi_blocks( input=input, in_c=in_c, @@ -129,7 +141,7 @@ class MobileNetV2BlockSpace(SearchSpaceBase): in_c = int(c * self.scale) if return_mid_layer: - return input, depthwise_conv + return input, mid_layer else: return input, @@ -309,12 +321,19 @@ class MobileNetV1BlockSpace(SearchSpaceBase): self.bottleneck_params_list.append(self.filter_num[tokens[j * 3], self.filter_num[tokens[j * 3 + 1]], 1, self.k_size[tokens[j * 3 + 2]]) - def net_arch(input, return_mid_layer=False): - - def net_arch(input): + def net_arch(input, return_mid_layer=False, return_block=[]): + assert isinstance(return_block, list), 'return_block must be a list.' + mid_layer = dict() + layer_count = 0 + for i, layer_setting in enumerate(self.bottleneck_params_list): filter_num1, filter_num2, stride, kernel_size = layer_setting + if stride == 2: + layer_count += 1 + if (layer_count - 1) in return_block: + mid_layer[layer_count] = input + input = self._depthwise_separable( input=input, num_filters1=filter_num1, @@ -324,7 +343,10 @@ class MobileNetV1BlockSpace(SearchSpaceBase): scale=self.scale, kernel_size=kernel_size, name='mobilenetv1_{}'.format(str(i + 1))) - return input + if return_mid_layer: + return input, mid_layer + else: + return input return net_arch diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index e75bc035..4f6e89bb 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -31,8 +31,8 @@ class SearchSpaceBase(object): "If block_mask is NOT None, we will use block_mask as major configs!" ) self.block_num = None - if self.block_mask == None and self.block_num == None: - print("block_mask and block num can NOT be None at the same time!") + if self.block_mask == None and (self.block_num == None or self.input_size == None or self.output_size == None): + print("block_mask and (block num or input_size or output_size) can NOT be None at the same time!") def init_tokens(self): """Get init tokens in search space. -- GitLab