提交 9ff3159e 编写于 作者: W wanghaoshuang

Merge branch 'develop' of http://gitlab.baidu.com/PaddlePaddle/PaddleSlim into auto_prune

...@@ -24,6 +24,7 @@ __all__ = ["SAController"] ...@@ -24,6 +24,7 @@ __all__ = ["SAController"]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController): class SAController(EvolutionaryController):
"""Simulated annealing controller.""" """Simulated annealing controller."""
...@@ -45,7 +46,8 @@ class SAController(EvolutionaryController): ...@@ -45,7 +46,8 @@ class SAController(EvolutionaryController):
""" """
super(SAController, self).__init__() super(SAController, self).__init__()
self._range_table = range_table self._range_table = range_table
assert isinstance(self._range_table, tuple) and (len(self._range_table) == 2) assert isinstance(self._range_table, tuple) and (
len(self._range_table) == 2)
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_iter_number = max_iter_number self._max_iter_number = max_iter_number
...@@ -79,10 +81,9 @@ class SAController(EvolutionaryController): ...@@ -79,10 +81,9 @@ class SAController(EvolutionaryController):
if reward > self._max_reward: if reward > self._max_reward:
self._max_reward = reward self._max_reward = reward
self._best_tokens = tokens self._best_tokens = tokens
_logger.info("iter: {}; max_reward: {}; best_tokens: {}".format( _logger.info(
self._iter, self._max_reward, self._best_tokens)) "Controller - iter: {}; current_reward: {}; current tokens: {}".
_logger.info("current_reward: {}; current tokens: {}".format( format(self._iter, self._reward, self._tokens))
self._reward, self._tokens))
def next_tokens(self, control_token=None): def next_tokens(self, control_token=None):
""" """
...@@ -94,16 +95,19 @@ class SAController(EvolutionaryController): ...@@ -94,16 +95,19 @@ class SAController(EvolutionaryController):
tokens = self._tokens tokens = self._tokens
new_tokens = tokens[:] new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index], self._range_table[1][index]+1) new_tokens[index] = np.random.randint(self._range_table[0][index],
_logger.info("change index[{}] from {} to {}".format(index, tokens[ self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index])) index], new_tokens[index]))
if self._constrain_func is None: if self._constrain_func is None:
return new_tokens return new_tokens
for _ in range(self._max_iter_number): for _ in range(self._max_iter_number):
if not self._constrain_func(new_tokens): if not self._constrain_func(new_tokens):
index = int(len(self._range_table) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
new_tokens[index] = np.random.randint(self._range_table[index]) new_tokens[index] = np.random.randint(
self._range_table[0][index],
self._range_table[1][index] + 1)
else: else:
break break
return new_tokens return new_tokens
...@@ -14,4 +14,8 @@ ...@@ -14,4 +14,8 @@
from . import graph_wrapper from . import graph_wrapper
from .graph_wrapper import * from .graph_wrapper import *
from . import registry
from .registry import *
__all__ = graph_wrapper.__all__ __all__ = graph_wrapper.__all__
__all__ += registry.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
__all__ = ["Registry"]
class Registry(object): class Registry(object):
def __init__(self, name): def __init__(self, name):
self._name = name self._name = name
self._module_dict = dict() self._module_dict = dict()
def __repr__(self): def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(self._name, list(self._module_dict.keys())) format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str return format_str
@property @property
def name(self): def name(self):
return self._name return self._name
@property @property
def module_dict(self): def module_dict(self):
return self._module_dict return self._module_dict
...@@ -20,12 +40,14 @@ class Registry(object): ...@@ -20,12 +40,14 @@ class Registry(object):
def _register_module(self, module_class): def _register_module(self, module_class):
if not inspect.isclass(module_class): if not inspect.isclass(module_class):
raise TypeError('module must be a class, but receive {}.'.format(type(module_class))) raise TypeError('module must be a class, but receive {}.'.format(
type(module_class)))
module_name = module_class.__name__ module_name = module_class.__name__
if module_name in self._module_dict: if module_name in self._module_dict:
raise KeyError('{} is already registered in {}.'.format(module_name, self.name)) raise KeyError('{} is already registered in {}.'.format(
module_name, self.name))
self._module_dict[module_name] = module_class self._module_dict[module_name] = module_class
def register_module(self, cls): def register(self, cls):
self._register_module(cls) self._register_module(cls)
return cls return cls
...@@ -11,3 +11,9 @@ ...@@ -11,3 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import search_space
from search_space import *
__all__ = []
__all__ += search_space.__all__
...@@ -11,3 +11,18 @@ ...@@ -11,3 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import mobilenetv2
from .mobilenetv2 import *
import search_space_registry
from search_space_registry import *
import search_space_factory
from search_space_factory import *
import search_space_base
from search_space_base import *
__all__ = []
__all__ += mobilenetv2.__all__
__all__ += search_space_registry.__all__
__all__ += search_space_factory.__all__
__all__ += search_space_base.__all__
...@@ -16,7 +16,15 @@ import paddle.fluid as fluid ...@@ -16,7 +16,15 @@ import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1, act=None, name=None, use_cudnn=True): def conv_bn_layer(input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
act=None,
name=None,
use_cudnn=True):
"""Build convolution and batch normalization layers. """Build convolution and batch normalization layers.
Args: Args:
input(Variable): input. input(Variable): input.
...@@ -31,11 +39,24 @@ def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1 ...@@ -31,11 +39,24 @@ def conv_bn_layer(input, filter_size, num_filters, stride, padding, num_groups=1
Returns: Returns:
Variable, layers output. Variable, layers output.
""" """
conv = fluid.layers.conv2d(input, num_filters=num_filters, filter_size=filter_size, stride=stride, padding=padding, conv = fluid.layers.conv2d(
groups=num_groups, act=None, use_cudnn=use_cudnn, param_attr=ParamAttr(name=name+'_weights'), bias_attr=False) input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn' bn_name = name + '_bn'
bn = fluid.layers.batch_norm(input=conv, param_attr=ParamAttr(name=bn_name+'_scale'), bias_attr=ParamAttr(name=bn_name+'_offset'), bn = fluid.layers.batch_norm(
moving_mean_name=bn_name+'_mean', moving_variance_name=bn_name+'_variance') input=conv,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(name=bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if act == 'relu6': if act == 'relu6':
return fluid.layers.relu6(bn) return fluid.layers.relu6(bn)
else: else:
......
...@@ -19,26 +19,38 @@ from __future__ import print_function ...@@ -19,26 +19,38 @@ from __future__ import print_function
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from ..search_space_base import SearchSpaceBase from .search_space_base import SearchSpaceBase
from .layer import conv_bn_layer from .base_layer import conv_bn_layer
from .registry import SEARCHSPACE from .search_space_registry import SEARCHSPACE
@SEARCHSPACE.register_module __all__ = ["MobileNetV2Space"]
@SEARCHSPACE.register
class MobileNetV2Space(SearchSpaceBase): class MobileNetV2Space(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000): def __init__(self,
super(MobileNetV2Space, self).__init__(input_size, output_size, block_num) input_size,
self.head_num = np.array([3,4,8,12,16,24,32]) #7 output_size,
self.filter_num1 = np.array([3,4,8,12,16,24,32,48]) #8 block_num,
self.filter_num2 = np.array([8,12,16,24,32,48,64,80]) #8 scale=1.0,
self.filter_num3 = np.array([16,24,32,48,64,80,96,128]) #8 class_dim=1000):
self.filter_num4 = np.array([24,32,48,64,80,96,128,144,160,192]) #10 super(MobileNetV2Space, self).__init__(input_size, output_size,
self.filter_num5 = np.array([32,48,64,80,96,128,144,160,192,224]) #10 block_num)
self.filter_num6 = np.array([64,80,96,128,144,160,192,224,256,320,384,512]) #12 self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7
self.k_size = np.array([3,5]) #2 self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) #8
self.multiply = np.array([1,2,3,4,6]) #5 self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) #8
self.repeat = np.array([1,2,3,4,5,6]) #6 self.filter_num3 = np.array([16, 24, 32, 48, 64, 80, 96, 128]) #8
self.scale=scale self.filter_num4 = np.array(
self.class_dim=class_dim [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10
self.filter_num5 = np.array(
[32, 48, 64, 80, 96, 128, 144, 160, 192, 224]) #10
self.filter_num6 = np.array(
[64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512]) #12
self.k_size = np.array([3, 5]) #2
self.multiply = np.array([1, 2, 3, 4, 6]) #5
self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6
self.scale = scale
self.class_dim = class_dim
def init_tokens(self): def init_tokens(self):
""" """
...@@ -47,6 +59,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -47,6 +59,7 @@ class MobileNetV2Space(SearchSpaceBase):
each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size]
""" """
# original MobileNetV2 # original MobileNetV2
# yapf: disable
return [4, # 1, 16, 1 return [4, # 1, 16, 1
4, 5, 1, 0, # 6, 24, 1 4, 5, 1, 0, # 6, 24, 1
4, 5, 1, 0, # 6, 24, 2 4, 5, 1, 0, # 6, 24, 2
...@@ -55,13 +68,15 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -55,13 +68,15 @@ class MobileNetV2Space(SearchSpaceBase):
4, 5, 2, 0, # 6, 96, 3 4, 5, 2, 0, # 6, 96, 3
4, 7, 2, 0, # 6, 160, 3 4, 7, 2, 0, # 6, 160, 3
4, 9, 0, 0] # 6, 320, 1 4, 9, 0, 0] # 6, 320, 1
# yapf: enable
def range_table(self): def range_table(self):
""" """
get range table of current search space get range table of current search space
""" """
# head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
return [7, # yapf: disable
return [7,
5, 8, 6, 2, 5, 8, 6, 2,
5, 8, 6, 2, 5, 8, 6, 2,
5, 8, 6, 2, 5, 8, 6, 2,
...@@ -69,6 +84,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -69,6 +84,7 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2, 5, 10, 6, 2,
5, 10, 6, 2, 5, 10, 6, 2,
5, 12, 6, 2] 5, 12, 6, 2]
# yapf: enable
def token2arch(self, tokens=None): def token2arch(self, tokens=None):
""" """
...@@ -79,16 +95,24 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -79,16 +95,24 @@ class MobileNetV2Space(SearchSpaceBase):
base_bottleneck_params_list = [ base_bottleneck_params_list = [
(1, self.head_num[tokens[0]], 1, 1, 3), (1, self.head_num[tokens[0]], 1, 1, 3),
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], self.repeat[tokens[3]], 2, self.k_size[tokens[4]]), (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
(self.multiply[tokens[5]], self.filter_num1[tokens[6]], self.repeat[tokens[7]], 2, self.k_size[tokens[8]]), self.repeat[tokens[3]], 2, self.k_size[tokens[4]]),
(self.multiply[tokens[9]], self.filter_num2[tokens[10]], self.repeat[tokens[11]], 2, self.k_size[tokens[12]]), (self.multiply[tokens[5]], self.filter_num1[tokens[6]],
(self.multiply[tokens[13]], self.filter_num3[tokens[14]], self.repeat[tokens[15]], 2, self.k_size[tokens[16]]), self.repeat[tokens[7]], 2, self.k_size[tokens[8]]),
(self.multiply[tokens[17]], self.filter_num3[tokens[18]], self.repeat[tokens[19]], 1, self.k_size[tokens[20]]), (self.multiply[tokens[9]], self.filter_num2[tokens[10]],
(self.multiply[tokens[21]], self.filter_num5[tokens[22]], self.repeat[tokens[23]], 2, self.k_size[tokens[24]]), self.repeat[tokens[11]], 2, self.k_size[tokens[12]]),
(self.multiply[tokens[25]], self.filter_num6[tokens[26]], self.repeat[tokens[27]], 1, self.k_size[tokens[28]]), (self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]),
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]),
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]),
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]),
] ]
assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(self.block_num) assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(
self.block_num)
# the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1, # the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1,
# otherwise, add layers to params_list directly. # otherwise, add layers to params_list directly.
...@@ -134,10 +158,11 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -134,10 +158,11 @@ class MobileNetV2Space(SearchSpaceBase):
# if output_size is 1, add fc layer in the end # if output_size is 1, add fc layer in the end
if self.output_size == 1: if self.output_size == 1:
input = fluid.layers.fc(input=input, input = fluid.layers.fc(
size=self.class_dim, input=input,
param_attr=ParamAttr(name='fc10_weights'), size=self.class_dim,
bias_attr=ParamAttr(name='fc10_offset')) param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
else: else:
assert self.output_size == input.shape[2], \ assert self.output_size == input.shape[2], \
("output_size must EQUAL to input_size / (2^block_num)." ("output_size must EQUAL to input_size / (2^block_num)."
...@@ -148,7 +173,6 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -148,7 +173,6 @@ class MobileNetV2Space(SearchSpaceBase):
return net_arch return net_arch
def shortcut(self, input, data_residual): def shortcut(self, input, data_residual):
"""Build shortcut layer. """Build shortcut layer.
Args: Args:
...@@ -159,7 +183,6 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -159,7 +183,6 @@ class MobileNetV2Space(SearchSpaceBase):
""" """
return fluid.layers.elementwise_add(input, data_residual) return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self, def inverted_residual_unit(self,
input, input,
num_in_filter, num_in_filter,
...@@ -220,15 +243,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -220,15 +243,7 @@ class MobileNetV2Space(SearchSpaceBase):
out = self.shortcut(input=input, data_residual=out) out = self.shortcut(input=input, data_residual=out)
return out return out
def invresi_blocks(self, def invresi_blocks(self, input, in_c, t, c, n, s, k, name=None):
input,
in_c,
t,
c,
n,
s,
k,
name=None):
"""Build inverted residual blocks. """Build inverted residual blocks.
Args: Args:
input: Variable, input. input: Variable, input.
...@@ -266,5 +281,3 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -266,5 +281,3 @@ class MobileNetV2Space(SearchSpaceBase):
expansion_factor=t, expansion_factor=t,
name=name + '_' + str(i + 1)) name=name + '_' + str(i + 1))
return last_residual_block return last_residual_block
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
__all__ = ['SearchSpaceBase'] __all__ = ['SearchSpaceBase']
class SearchSpaceBase(object): class SearchSpaceBase(object):
"""Controller for Neural Architecture Search. """Controller for Neural Architecture Search.
""" """
...@@ -41,4 +42,3 @@ class SearchSpaceBase(object): ...@@ -41,4 +42,3 @@ class SearchSpaceBase(object):
list<layers> list<layers>
""" """
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
...@@ -12,7 +12,10 @@ ...@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from searchspace.registry import SEARCHSPACE from search_space_registry import SEARCHSPACE
__all__ = ["SearchSpaceFactory"]
class SearchSpaceFactory(object): class SearchSpaceFactory(object):
def __init__(self): def __init__(self):
...@@ -29,8 +32,7 @@ class SearchSpaceFactory(object): ...@@ -29,8 +32,7 @@ class SearchSpaceFactory(object):
model space(class) model space(class)
""" """
cls = SEARCHSPACE.get(key) cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'], config['block_num']) space = cls(config['input_size'], config['output_size'],
config['block_num'])
return space return space
...@@ -12,4 +12,8 @@ ...@@ -12,4 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .mobilenetv2_space import MobileNetV2Space from ...core import Registry
__all__ = ["SEARCHSPACE"]
SEARCHSPACE = Registry('searchspace')
from ..utils.registry import Registry
SEARCHSPACE = Registry('searchspace')
...@@ -12,19 +12,29 @@ ...@@ -12,19 +12,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import socket
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from .pruner import Pruner from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..search import ControllerServer from .controller_server import ControllerServer
from .controller_client import ControllerClient
__all__ = ["AutoPruner"] __all__ = ["AutoPruner"]
_logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object): class AutoPruner(object):
def __init__(self, def __init__(self,
program, program,
scope,
place,
params=[], params=[],
init_ratios=None, init_ratios=None,
pruned_flops=0.5, pruned_flops=0.5,
...@@ -32,13 +42,13 @@ class AutoPruner(object): ...@@ -32,13 +42,13 @@ class AutoPruner(object):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_iter_number=300, max_try_number=300,
max_client_num=10, max_client_num=10,
search_steps=300, search_steps=300,
max_ratios=[0.9], max_ratios=[0.9],
min_ratios=[0], min_ratios=[0],
key="auto_pruner" key="auto_pruner",
): is_server=True):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
...@@ -54,8 +64,10 @@ class AutoPruner(object): ...@@ -54,8 +64,10 @@ class AutoPruner(object):
search_strategy(str): The search strategy. Default: 'sa'. search_strategy(str): The search strategy. Default: 'sa'.
""" """
# step1: Create controller server. And start server if current host match server_ip. # step1: Create controller server. And start server if current host match server_ip.
self._program = program self._program = program
self._scope = scope
self._place = place
self._params = params self._params = params
self._init_ratios = init_ratios self._init_ratios = init_ratios
self._pruned_flops = pruned_flops self._pruned_flops = pruned_flops
...@@ -63,40 +75,53 @@ class AutoPruner(object): ...@@ -63,40 +75,53 @@ class AutoPruner(object):
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number self._max_try_number = max_try_number
self._is_server = is_server
assert isinstance(self._max_ratios, float) or isinstance(self._max_ratios)
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
self._pruner = Pruner() self._pruner = Pruner()
if self._pruned_flops: if self._pruned_flops:
self._base_flops = flops(program) self._base_flops = flops(program)
_logger.info("AutoPruner - base flops: {};".format(
self._base_flops))
if self._pruned_latency: if self._pruned_latency:
self._base_latency = latency(program) self._base_latency = latency(program)
if self._init_ratios is None: if self._init_ratios is None:
self._init_ratios = self._get_init_ratios( self._init_ratios = self._get_init_ratios(
self,_program, self._params, self._pruned_flops, self, _program, self._params, self._pruned_flops,
self._pruned_latency) self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
controller = SAController(self._range_table, controller = SAController(self._range_table, self._reduce_rate,
self._reduce_rate, self._init_temperature, self._max_try_number,
self._init_temperature, init_tokens, self._constrain_func)
self._max_try_number,
init_tokens, server_ip, server_port = server_addr
self._constrain_func) if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller_server = ControllerServer( self._controller_server = ControllerServer(
controller=controller, controller=controller,
addr=server_addr, address=(server_ip, server_port),
max_client_num, max_client_num=max_client_num,
search_steps, search_steps=search_steps,
key=key) key=key)
# create controller server
if self._is_server:
self._controller_server.start()
self._controller_client = ControllerClient(server_addr, key=key) self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
self._iter = 0 self._iter = 0
self._param_backup = {}
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _get_init_ratios(self, program, params, pruned_flops, pruned_latency): def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
pass pass
...@@ -104,34 +129,51 @@ class AutoPruner(object): ...@@ -104,34 +129,51 @@ class AutoPruner(object):
def _get_range_table(self, min_ratios, max_ratios): def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float) assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float) assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
min_ratios = min_ratios if isinstance(min_ratios, list) else [min_ratios] min_ratios = min_ratios if isinstance(min_ratios,
max_ratios = max_ratios if isinstance(max_ratios, list) else [max_ratios] list) else [min_ratios]
max_ratios = max_ratios if isinstance(max_ratios,
list) else [max_ratios]
min_tokens = self._ratios2tokens(min_ratios) min_tokens = self._ratios2tokens(min_ratios)
max_tokens = self._ratios2tokens(max_ratios) max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens) return (min_tokens, max_tokens)
def _constrain_func(self, tokens): def _constrain_func(self, tokens):
ratios = self._tokens2ratios(tokens) ratios = self._tokens2ratios(tokens)
pruned_program = self._pruner.prune( pruned_program = self._pruner.prune(
program, self._program,
scope, self._scope,
self._params, self._params,
self._current_ratios, ratios,
place=self._place,
only_graph=True) only_graph=True)
return flops(pruned_program) < self._base_flops return flops(pruned_program) < self._base_flops * (
1 - self._pruned_flops)
def prune(self, program, scope, place): def prune(self, program):
self._current_ratios = self._next_ratios() self._current_ratios = self._next_ratios()
pruned_program = self._pruner.prune(program, scope, self._params, pruned_program = self._pruner.prune(
self._current_ratios) program,
self._scope,
self._params,
self._current_ratios,
place=self._place,
param_backup=self._param_backup)
_logger.info("AutoPruner - pruned ratios: {}".format(
self._current_ratios))
return pruned_program return pruned_program
def reward(self, score): def reward(self, score):
tokens = self.ratios2tokens(self._current_ratios) self._restore(self._scope)
self._controller_client.reward(tokens, score) self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score)
self._iter += 1 self._iter += 1
def _restore(self, scope):
for param_name in self._param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(self._param_backup[param_name], self._place)
def _next_ratios(self): def _next_ratios(self):
tokens = self._controller_client.next_tokens() tokens = self._controller_client.next_tokens()
return self._tokens2ratios(tokens) return self._tokens2ratios(tokens)
...@@ -141,7 +183,7 @@ class AutoPruner(object): ...@@ -141,7 +183,7 @@ class AutoPruner(object):
""" """
return [int(ratio / 0.01) for ratio in ratios] return [int(ratio / 0.01) for ratio in ratios]
def _tokens2_ratios(self, tokens): def _tokens2ratios(self, tokens):
"""Convert tokens to pruned ratios. """Convert tokens to pruned ratios.
""" """
return [token * 0.01 for token in tokens] return [token * 0.01 for token in tokens]
...@@ -12,10 +12,14 @@ ...@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import socket import socket
from ..common import get_logger
__all__ = ['ControllerClient'] __all__ = ['ControllerClient']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object): class ControllerClient(object):
""" """
......
...@@ -12,11 +12,18 @@ ...@@ -12,11 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import logging
import socket import socket
from ..common import get_logger
from threading import Thread from threading import Thread
from .lock import lock, unlock
__all__ = ['ControllerServer'] __all__ = ['ControllerServer']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object): class ControllerServer(object):
""" """
The controller wrapper with a socket server to handle the request of search agent. The controller wrapper with a socket server to handle the request of search agent.
...@@ -44,14 +51,30 @@ class ControllerServer(object): ...@@ -44,14 +51,30 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._socket_file = "./controller_server.socket"
def start(self): def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address) self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num) self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1] self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0] self._ip = self._socket_server.getsockname()[0]
_logger.info("listen on: [{}:{}]".format(self._ip, self._port)) _logger.info("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run) thread = Thread(target=self.run)
thread.start() thread.start()
return str(thread) return str(thread)
...@@ -59,6 +82,8 @@ class ControllerServer(object): ...@@ -59,6 +82,8 @@ class ControllerServer(object):
def close(self): def close(self):
"""Close the server.""" """Close the server."""
self._closed = True self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self): def port(self):
"""Get the port.""" """Get the port."""
...@@ -70,30 +95,34 @@ class ControllerServer(object): ...@@ -70,30 +95,34 @@ class ControllerServer(object):
def run(self): def run(self):
_logger.info("Controller Server run...") _logger.info("Controller Server run...")
while ((self._search_steps is None) or try:
(self._controller._iter < while ((self._search_steps is None) or
(self._search_steps))) and not self._closed: (self._controller._iter <
conn, addr = self._socket_server.accept() (self._search_steps))) and not self._closed:
message = conn.recv(1024).decode() conn, addr = self._socket_server.accept()
if message.strip("\n") == "next_tokens": message = conn.recv(1024).decode()
tokens = self._controller.next_tokens() if message.strip("\n") == "next_tokens":
tokens = ",".join([str(token) for token in tokens]) tokens = self._controller.next_tokens()
conn.send(tokens.encode()) tokens = ",".join([str(token) for token in tokens])
else: conn.send(tokens.encode())
_logger.info("recv message from {}: [{}]".format(addr, message)) else:
messages = message.strip('\n').split("\t") _logger.debug("recv message from {}: [{}]".format(addr,
if (len(messages) < 3) or (messages[0] != self._key): message))
_logger.info("recv noise from {}: [{}]".format(addr, messages = message.strip('\n').split("\t")
message)) if (len(messages) < 3) or (messages[0] != self._key):
continue _logger.debug("recv noise from {}: [{}]".format(
tokens = messages[1] addr, message))
reward = messages[2] continue
tokens = [int(token) for token in tokens.split(",")] tokens = messages[1]
self._controller.update(tokens, float(reward)) reward = messages[2]
tokens = self._controller.next_tokens() tokens = [int(token) for token in tokens.split(",")]
tokens = ",".join([str(token) for token in tokens]) self._controller.update(tokens, float(reward))
conn.send(tokens.encode()) tokens = self._controller.next_tokens()
_logger.info("send message to {}: [{}]".format(addr, tokens)) tokens = ",".join([str(token) for token in tokens])
conn.close() conn.send(tokens.encode())
self._socket_server.close() _logger.debug("send message to {}: [{}]".format(addr,
_logger.info("server closed!") tokens))
conn.close()
finally:
self._socket_server.close()
self.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__All__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from core import VarWrapper, OpWrapper, GraphWrapper import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
__all__ = ["prune"] __all__ = ["Pruner"]
class Pruner(): class Pruner():
...@@ -64,10 +65,10 @@ class Pruner(): ...@@ -64,10 +65,10 @@ class Pruner():
params, params,
ratios, ratios,
place, place,
lazy=False, lazy=lazy,
only_graph=False, only_graph=only_graph,
param_backup=None, param_backup=param_backup,
param_shape_backup=None) param_shape_backup=param_shape_backup)
return graph.program return graph.program
def _prune_filters_by_ratio(self, def _prune_filters_by_ratio(self,
......
...@@ -39,6 +39,8 @@ packages = [ ...@@ -39,6 +39,8 @@ packages = [
'paddleslim.nas', 'paddleslim.nas',
'paddleslim.analysis', 'paddleslim.analysis',
'paddleslim.quant', 'paddleslim.quant',
'paddleslim.core',
'paddleslim.common',
] ]
setup( setup(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
def conv_bn_layer(input,
num_filters,
filter_size,
name,
stride=1,
groups=1,
act=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + "_out")
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '_output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import AutoPruner
from paddleslim.analysis import flops
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruned_flops = 0.5
pruner = AutoPruner(
main_program,
scope,
place,
params=["conv4_weights"],
init_ratios=[0.5],
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=2,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner")
base_flops = flops(main_program)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruner = Pruner()
main_program = pruner.prune(
main_program,
scope,
params=["conv4_weights"],
ratios=[0.5],
place=place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
shapes = {
"conv1_weights": (4L, 3L, 3L, 3L),
"conv2_weights": (4L, 4L, 3L, 3L),
"conv3_weights": (8L, 4L, 3L, 3L),
"conv4_weights": (4L, 8L, 3L, 3L),
"conv5_weights": (8L, 4L, 3L, 3L),
"conv6_weights": (8L, 8L, 3L, 3L)
}
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
self.assertTrue(param.shape == shapes[param.name])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SearchSpaceFactory
class TestSearchSpaceFactory(unittest.TestCase):
def test_factory(self):
# if output_size is 1, the model will add fc layer in the end.
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
my_space = space.get_search_space('MobileNetV2Space', config)
model_arch = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size = config['input_size']
model_input = fluid.layers.data(
name='model_in',
shape=[1, 3, input_size, input_size],
dtype='float32',
append_batch_size=False)
print('input shape', model_input.shape)
predict = model_arch(model_input)
print('output shape', predict.shape)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册