提交 47e063cb 编写于 作者: W wanghaoshuang

Merge branch 'develop' of http://gitlab.baidu.com/PaddlePaddle/PaddleSlim into auto_prune

......@@ -17,7 +17,16 @@ import sa_controller
from sa_controller import *
import log_helper
from log_helper import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
import lock_utils
from lock_utils import *
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += lock_utils.__all__
......@@ -19,6 +19,7 @@ import numpy as np
__all__ = ['EvolutionaryController']
class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
......@@ -48,6 +49,3 @@ class EvolutionaryController(object):
"""Generate new tokens.
"""
raise NotImplementedError('Abstract method.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from log_helper import get_logger
__all__ = ['ControllerClient']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object):
"""
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import socket
from .log_helper import get_logger
from threading import Thread
from .lock import lock, unlock
__all__ = ['ControllerServer']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
try:
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
finally:
self._socket_server.close()
self.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__all__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
......@@ -14,6 +14,9 @@
import search_space
from search_space import *
import sa_nas
from sa_nas import *
__all__ = []
__all__ += search_space.__all__
__all__ += sa_nas.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import logging
import numpy as np
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
from .search_space import SearchSpaceFactory
__all__ = ["SANAS"]
_logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
def __init__(self,
configs,
max_flops=None,
max_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300,
key="sa_nas",
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
self._iter = 0
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self):
"""
Get next network architectures.
Returns:
list<function>: A list of functions that define networks.
"""
self._current_tokens = self._controller_client.next_tokens()
archs = self._search_space.token2arch(self._current_tokens)
return archs
def reward(self, score):
"""
Return reward of current searched network.
Args:
score(float): The score of current searched network.
"""
self._controller_client.update(self._current_tokens, score)
self._iter += 1
......@@ -14,6 +14,8 @@
import mobilenetv2
from .mobilenetv2 import *
import resnet
from .resnet import *
import search_space_registry
from search_space_registry import *
import search_space_factory
......@@ -26,3 +28,4 @@ __all__ += mobilenetv2.__all__
__all__ += search_space_registry.__all__
__all__ += search_space_factory.__all__
__all__ += search_space_base.__all__
......@@ -59,5 +59,7 @@ def conv_bn_layer(input,
moving_variance_name=bn_name + '_variance')
if act == 'relu6':
return fluid.layers.relu6(bn)
elif act == 'sigmoid':
return fluid.layers.sigmoid(bn)
else:
return bn
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .search_space_registry import SEARCHSPACE
from .base_layer import conv_bn_layer
__all__ = ["CombineSearchSpace"]
class CombineSearchSpace(object):
"""
Combine Search Space.
Args:
configs(list<tuple>): multi config.
"""
def __init__(self, config_lists):
self.lens = len(config_lists)
self.spaces = []
for config_list in config_lists:
key, config = config_list
self.spaces.append(self._get_single_search_space(key, config))
def _get_single_search_space(self, key, config):
"""
get specific model space based on key and config.
Args:
key(str): model space name.
config(dict): basic config information.
return:
model space(class)
"""
cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'],
config['block_num'])
return space
def init_tokens(self):
"""
Combine init tokens.
"""
tokens = []
self.single_token_num = []
for space in self.spaces:
tokens.extend(space.init_tokens())
self.single_token_num.append(len(space.init_tokens()))
return tokens
def range_table(self):
"""
Combine range table.
"""
range_tables = []
for space in self.spaces:
range_tables.extend(space.range_table())
return range_tables
def token2arch(self, tokens=None):
"""
Combine model arch
"""
if tokens is None:
tokens = self.init_tokens()
token_list = []
start_idx = 0
end_idx = 0
for i in range(len(self.single_token_num)):
end_idx += self.single_token_num[i]
token_list.append(tokens[start_idx:end_idx])
start_idx = end_idx
model_archs = []
for space, token in zip(self.spaces, token_list):
model_archs.append(space.token2arch(token))
return model_archs
......@@ -60,7 +60,7 @@ class MobileNetV2Space(SearchSpaceBase):
"""
# original MobileNetV2
# yapf: disable
return [4, # 1, 16, 1
init_token_base = [4, # 1, 16, 1
4, 5, 1, 0, # 6, 24, 1
4, 5, 1, 0, # 6, 24, 2
4, 4, 2, 0, # 6, 32, 3
......@@ -70,13 +70,21 @@ class MobileNetV2Space(SearchSpaceBase):
4, 9, 0, 0] # 6, 320, 1
# yapf: enable
if self.block_num < 5:
self.token_len = 1 + (self.block_num - 1) * 4
else:
self.token_len = 1 + (self.block_num + 2 *
(self.block_num - 5)) * 4
return init_token_base[:self.token_len]
def range_table(self):
"""
get range table of current search space
"""
# head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size]
# yapf: disable
return [7,
range_table_base = [7,
5, 8, 6, 2,
5, 8, 6, 2,
5, 8, 6, 2,
......@@ -84,48 +92,51 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2,
5, 10, 6, 2,
5, 12, 6, 2]
range_table_base = list(np.array(range_table_base) - 1)
# yapf: enable
return range_table_base[:self.token_len]
def token2arch(self, tokens=None):
"""
return net_arch function
"""
assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(
self.block_num)
if tokens is None:
tokens = self.init_tokens()
base_bottleneck_params_list = [
(1, self.head_num[tokens[0]], 1, 1, 3),
bottleneck_params_list = []
if self.block_num >= 1:
bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2:
bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]),
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3:
bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]),
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4:
bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]),
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5:
bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]),
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]),
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]),
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]),
]
assert self.block_num < 7, 'block number must less than 7, but receive block number is {}'.format(
self.block_num)
# the stride = 2 means downsample feature map in the convolution, so only when stride=2, block_num minus 1,
# otherwise, add layers to params_list directly.
bottleneck_params_list = []
for param_list in base_bottleneck_params_list:
if param_list[3] == 1:
bottleneck_params_list.append(param_list)
else:
if self.block_num > 1:
bottleneck_params_list.append(param_list)
self.block_num -= 1
else:
break
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input):
#conv1
......@@ -137,7 +148,7 @@ class MobileNetV2Space(SearchSpaceBase):
stride=2,
padding='SAME',
act='relu6',
name='conv1_1')
name='mobilenetv2_conv1_1')
# bottleneck sequences
i = 1
......@@ -145,7 +156,7 @@ class MobileNetV2Space(SearchSpaceBase):
for layer_setting in bottleneck_params_list:
t, c, n, s, k = layer_setting
i += 1
input = self.invresi_blocks(
input = self._invresi_blocks(
input=input,
in_c=in_c,
t=t,
......@@ -153,7 +164,7 @@ class MobileNetV2Space(SearchSpaceBase):
n=n,
s=s,
k=k,
name='conv' + str(i))
name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale)
# if output_size is 1, add fc layer in the end
......@@ -161,8 +172,8 @@ class MobileNetV2Space(SearchSpaceBase):
input = fluid.layers.fc(
input=input,
size=self.class_dim,
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
param_attr=ParamAttr(name='mobilenetv2_fc_weights'),
bias_attr=ParamAttr(name='mobilenetv2_fc_offset'))
else:
assert self.output_size == input.shape[2], \
("output_size must EQUAL to input_size / (2^block_num)."
......@@ -173,7 +184,7 @@ class MobileNetV2Space(SearchSpaceBase):
return net_arch
def shortcut(self, input, data_residual):
def _shortcut(self, input, data_residual):
"""Build shortcut layer.
Args:
input(Variable): input.
......@@ -183,7 +194,7 @@ class MobileNetV2Space(SearchSpaceBase):
"""
return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self,
def _inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
......@@ -240,10 +251,10 @@ class MobileNetV2Space(SearchSpaceBase):
name=name + '_linear')
out = linear_out
if ifshortcut:
out = self.shortcut(input=input, data_residual=out)
out = self._shortcut(input=input, data_residual=out)
return out
def invresi_blocks(self, input, in_c, t, c, n, s, k, name=None):
def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None):
"""Build inverted residual blocks.
Args:
input: Variable, input.
......@@ -257,7 +268,7 @@ class MobileNetV2Space(SearchSpaceBase):
Returns:
Variable, layers output.
"""
first_block = self.inverted_residual_unit(
first_block = self._inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
......@@ -271,7 +282,7 @@ class MobileNetV2Space(SearchSpaceBase):
last_c = c
for i in range(1, n):
last_residual_block = self.inverted_residual_unit(
last_residual_block = self._inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .search_space_base import SearchSpaceBase
from .base_layer import conv_bn_layer
from .search_space_registry import SEARCHSPACE
__all__ = ["ResNetSpace"]
@SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase):
def __init__(self,
input_size,
output_size,
block_num,
scale=1.0,
class_dim=1000):
super(ResNetSpace, self).__init__(input_size, output_size, block_num)
pass
def init_tokens(self):
return [0, 0, 0, 0, 0, 0]
def range_table(self):
return [2, 2, 2, 2, 2, 2]
def token2arch(self, tokens=None):
if tokens is None:
self.init_tokens()
def net_arch(input):
input = conv_bn_layer(
input,
num_filters=32,
filter_size=3,
stride=2,
padding='SAME',
act='sigmoid',
name='resnet_conv1_1')
return input
return net_arch
......@@ -39,6 +39,6 @@ class SearchSpaceBase(object):
Args:
tokens(list<int>): The tokens which represent a network.
Return:
list<layers>
model arch
"""
raise NotImplementedError('Abstract method.')
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from search_space_registry import SEARCHSPACE
from .combine_search_space import CombineSearchSpace
__all__ = ["SearchSpaceFactory"]
......@@ -21,18 +21,11 @@ class SearchSpaceFactory(object):
def __init__(self):
pass
def get_search_space(self, key, config):
def get_search_space(self, config_lists):
"""
get specific model space based on key and config.
get model spaces based on list(key, config).
Args:
key(str): model space name.
config(dict): basic config information.
return:
model space(class)
"""
cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'],
config['block_num'])
assert isinstance(config_lists, list), "configs must be a list"
return space
return CombineSearchSpace(config_lists)
......@@ -22,8 +22,8 @@ from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from .controller_server import ControllerServer
from .controller_client import ControllerClient
from ..common import ControllerServer
from ..common import ControllerClient
__all__ = ["AutoPruner"]
......
......@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .quanter import quant_aware, quant_post, convert
from .quant_embedding import quant_embedding
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max']
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max']
VALID_DTYPES = ['int8']
_quant_config_default = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
'activation_quantize_type': 'abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized,
# and activations will not be quantized.
'quant_weight_only': False
}
def _parse_configs(user_config):
"""
check user configs is valid, and set default value if user not config.
Args:
user_config(dict):the config of user.
Return:
configs(dict): final configs will be used.
"""
configs = copy.deepcopy(_quant_config_default)
configs.update(user_config)
# check configs is valid
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES)
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value."
assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value."
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \
"not_quant_pattern must be a list"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
assert (configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \
"window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs
def quant_aware(program, place, config, scope=None, for_test=False):
"""
add trainable quantization ops in program.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: if program is test program, for_test should be set True, else False.
Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy.
"""
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quantize_type' in config.keys(
), 'weight_quantize_type must be configured'
assert 'activation_quantize_type' in config.keys(
), 'activation_quantize_type must be configured'
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph)
if for_test:
quant_program = main_graph.to_program()
else:
quant_program = fluid.CompiledProgram(main_graph.graph)
return quant_program
def quant_post(program, place, config, scope=None):
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program.
Return:
fluid.Program: the quantization program is not trainable.
"""
pass
def convert(program, scope, place, config, save_int8=False):
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
save_int8: is export int8 freezed program.
Return:
fluid.Program: freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None.
"""
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
if save_int8:
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(), place=place)
convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8
else:
return freezed_program
......@@ -25,7 +25,7 @@ class TestSearchSpace(unittest.TestCase):
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
my_space = space.get_search_space('MobileNetV2Space', config)
my_space = space.get_search_space([('MobileNetV2Space', config)])
model_arch = my_space.token2arch()
train_prog = fluid.Program()
......@@ -37,12 +37,33 @@ class TestSearchSpace(unittest.TestCase):
shape=[1, 3, input_size, input_size],
dtype='float32',
append_batch_size=False)
predict = model_arch(model_input)
predict = model_arch[0](model_input)
self.assertTrue(predict.shape[2] == config['output_size'])
#for op in train_prog.global_block().ops:
# print(op.type)
class TestMultiSearchSpace(unittest.TestCase):
space = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
my_space = space.get_search_space(
[('MobileNetV2Space', config0), ('ResNetSpace', config1)])
model_archs = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size = config0['input_size']
model_input = fluid.layers.data(
name='model_in',
shape=[1, 3, input_size, input_size],
dtype='float32',
append_batch_size=False)
for model_arch in model_archs:
predict = model_arch(model_input)
model_input = predict
print(predict)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SANAS
from paddleslim.nas import SearchSpaceFactory
from paddleslim.analysis import flops
class TestSANAS(unittest.TestCase):
def test_nas(self):
factory = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]
space = factory.get_search_space([('MobileNetV2Space', config0)])
origin_arch = space.token2arch()[0]
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[None, 3, 224, 224], dtype="float32")
origin_arch(input)
base_flops = flops(main_program)
search_steps = 3
sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(search_steps):
archs = sa_nas.next_archs()
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[None, 3, 224, 224], dtype="float32")
archs[0](input)
sa_nas.reward(1)
self.assertTrue(flops(main_program) < base_flops)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册