diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py index c7b8e96156ce4a8b00b2f558c40c59fa51a6bdae..36231912715a29808d55158881ab3e918260f8b5 100644 --- a/paddleslim/nas/search_space/mobilenetv2.py +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -114,38 +114,57 @@ class MobileNetV2Space(SearchSpaceBase): if tokens is None: tokens = self.init_tokens() - bottleneck_params_list = [] + self.bottleneck_params_list = [] if self.block_num >= 1: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (1, self.head_num[tokens[0]], 1, 1, 3)) if self.block_num >= 2: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[1]], self.filter_num1[tokens[2]], self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) if self.block_num >= 3: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[5]], self.filter_num1[tokens[6]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) if self.block_num >= 4: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[9]], self.filter_num2[tokens[10]], self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) if self.block_num >= 5: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[17]], self.filter_num4[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) if self.block_num >= 6: - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[21]], self.filter_num5[tokens[22]], self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) - bottleneck_params_list.append( + self.bottleneck_params_list.append( (self.multiply[tokens[25]], self.filter_num6[tokens[26]], self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) - def net_arch(input, end_points=None, decode_points=None): + def _modify_bottle_params(output_stride=None): + if output_stride is not None and output_stride % 2 != 0: + raise Exception("output stride must to be even number") + if output_stride is None: + return + else: + stride = 2 + for i, layer_setting in enumerate(self.bottleneck_params_list): + t, c, n, s, ks = layer_setting + stride = stride * s + if stride > output_stride: + s = 1 + self.bottleneck_params_list[i] = (t, c, n, s, ks) + + def net_arch(input, + end_points=None, + decode_points=None, + output_stride=None): + _modify_bottle_params(output_stride) + decode_ends = dict() def check_points(count, points): @@ -177,10 +196,11 @@ class MobileNetV2Space(SearchSpaceBase): # bottleneck sequences i = 1 in_c = int(32 * self.scale) - for layer_setting in bottleneck_params_list: + for layer_setting in self.bottleneck_params_list: t, c, n, s, k = layer_setting i += 1 - input = self._invresi_blocks( + #print(input) + input, depthwise_output = self._invresi_blocks( input=input, in_c=in_c, t=t, @@ -190,8 +210,9 @@ class MobileNetV2Space(SearchSpaceBase): k=k, name='mobilenetv2_conv' + str(i)) in_c = int(c * self.scale) - layer_count += n + layer_count += 1 + ### decode_points and end_points means block num if check_points(layer_count, decode_points): decode_ends[layer_count] = depthwise_output @@ -320,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase): Returns: Variable, layers output. """ - first_block = self._inverted_residual_unit( + first_block, depthwise_output = self._inverted_residual_unit( input=input, num_in_filter=in_c, num_filters=c,