提交 93f32775 编写于 作者: W wanghaoshuang

Merge branch 'develop' of http://gitlab.baidu.com/PaddlePaddle/PaddleSlim into sen

......@@ -12,8 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from ..core import GraphWrapper
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["sensitivity"]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import controller
from controller import *
import sa_controller
from sa_controller import *
import log_helper
from log_helper import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
import lock_utils
from lock_utils import *
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += lock_utils.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import numpy as np
__all__ = ['EvolutionaryController']
class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward.
Args:
tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens.
"""
raise NotImplementedError('Abstract method.')
def reset(self, range_table, constrain_func=None):
"""Reset the controller.
Args:
range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]).
constrain_func(function): It is used to check whether tokens meet the constraint.
None means there is no constraint. Default: None.
"""
raise NotImplementedError('Abstract method.')
def next_tokens(self):
"""Generate new tokens.
"""
raise NotImplementedError('Abstract method.')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from log_helper import get_logger
__all__ = ['ControllerClient']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object):
"""
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import socket
from .log_helper import get_logger
from threading import Thread
from .lock import lock, unlock
__all__ = ['ControllerServer']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
try:
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
finally:
self._socket_server.close()
self.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__all__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
__all__ = ['get_logger']
def get_logger(name, level, fmt=None):
"""
Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle
will disable basicConfig setting after import paddle.
Args:
name (str): The logger name.
level (logging.LEVEL): The base level of the logger
fmt (str): Format of logger output
Returns:
logging.Logger: logging logger with given setttings
Examples:
.. code-block:: python
logger = log_helper.get_logger(__name__, logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
"""
logger = logging.getLogger(name)
logger.setLevel(level)
handler = logging.StreamHandler()
if fmt:
formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import copy
import math
import logging
import numpy as np
from .controller import EvolutionaryController
from log_helper import get_logger
__all__ = ["SAController"]
_logger = get_logger(__name__, level=logging.INFO)
class SAController(EvolutionaryController):
"""Simulated annealing controller."""
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
init_tokens=None,
constrain_func=None):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
super(SAController, self).__init__()
self._range_table = range_table
assert isinstance(self._range_table, tuple) and (
len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
self._max_reward = -1
self._best_tokens = None
self._iter = 0
def __getstate__(self):
d = {}
for key in self.__dict__:
if key != "_constrain_func":
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
self._tokens = tokens
if reward > self._max_reward:
self._max_reward = reward
self._best_tokens = tokens
_logger.info(
"Controller - iter: {}; current_reward: {}; current tokens: {}".
format(self._iter, self._reward, self._tokens))
def next_tokens(self, control_token=None):
"""
Get next tokens.
"""
if control_token:
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
return new_tokens
for _ in range(self._max_iter_number):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(
self._range_table[0][index],
self._range_table[1][index] + 1)
else:
break
return new_tokens
......@@ -14,6 +14,9 @@
import search_space
from search_space import *
import sa_nas
from sa_nas import *
__all__ = []
__all__ += search_space.__all__
__all__ += sa_nas.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import logging
import numpy as np
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
from .search_space import SearchSpaceFactory
__all__ = ["SANAS"]
_logger = get_logger(__name__, level=logging.INFO)
class SANAS(object):
def __init__(self,
configs,
max_flops=None,
max_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300,
key="sa_nas",
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
self._iter = 0
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self):
"""
Get next network architectures.
Returns:
list<function>: A list of functions that define networks.
"""
self._current_tokens = self._controller_client.next_tokens()
archs = self._search_space.token2arch(self._current_tokens)
return archs
def reward(self, score):
"""
Return reward of current searched network.
Args:
score(float): The score of current searched network.
"""
self._controller_client.update(self._current_tokens, score)
self._iter += 1
......@@ -52,7 +52,6 @@ class MobileNetV2Space(SearchSpaceBase):
self.scale = scale
self.class_dim = class_dim
def init_tokens(self):
"""
The initial token send to controller.
......@@ -74,7 +73,8 @@ class MobileNetV2Space(SearchSpaceBase):
if self.block_num < 5:
self.token_len = 1 + (self.block_num - 1) * 4
else:
self.token_len = 1 + (self.block_num + 2 * (self.block_num - 5)) * 4
self.token_len = 1 + (self.block_num + 2 *
(self.block_num - 5)) * 4
return init_token_base[:self.token_len]
......@@ -92,6 +92,7 @@ class MobileNetV2Space(SearchSpaceBase):
5, 10, 6, 2,
5, 10, 6, 2,
5, 12, 6, 2]
range_table_base = list(np.array(range_table_base) - 1)
# yapf: enable
return range_table_base[:self.token_len]
......@@ -107,22 +108,34 @@ class MobileNetV2Space(SearchSpaceBase):
tokens = self.init_tokens()
bottleneck_params_list = []
if self.block_num >= 1: bottleneck_params_list.append((1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2: bottleneck_params_list.append((self.multiply[tokens[1]], self.filter_num1[tokens[2]],
if self.block_num >= 1:
bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2:
bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3: bottleneck_params_list.append((self.multiply[tokens[5]], self.filter_num1[tokens[6]],
if self.block_num >= 3:
bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4: bottleneck_params_list.append((self.multiply[tokens[9]], self.filter_num2[tokens[10]],
if self.block_num >= 4:
bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5:
bottleneck_params_list.append((self.multiply[tokens[13]], self.filter_num3[tokens[14]],
bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append((self.multiply[tokens[17]], self.filter_num3[tokens[18]],
bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num3[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6:
bottleneck_params_list.append((self.multiply[tokens[21]], self.filter_num5[tokens[22]],
bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append((self.multiply[tokens[25]], self.filter_num6[tokens[26]],
bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input):
......
......@@ -25,19 +25,25 @@ from .search_space_registry import SEARCHSPACE
__all__ = ["ResNetSpace"]
@SEARCHSPACE.register
class ResNetSpace(SearchSpaceBase):
def __init__(self, input_size, output_size, block_num, scale=1.0, class_dim=1000):
def __init__(self,
input_size,
output_size,
block_num,
scale=1.0,
class_dim=1000):
super(ResNetSpace, self).__init__(input_size, output_size, block_num)
pass
def init_tokens(self):
return [0,0,0,0,0,0]
return [0, 0, 0, 0, 0, 0]
def range_table(self):
return [3,3,3,3,3,3]
return [2, 2, 2, 2, 2, 2]
def token2arch(self,tokens=None):
def token2arch(self, tokens=None):
if tokens is None:
self.init_tokens()
......@@ -54,5 +60,3 @@ class ResNetSpace(SearchSpaceBase):
return input
return net_arch
......@@ -14,9 +14,9 @@
import numpy as np
import paddle.fluid as fluid
from core import VarWrapper, OpWrapper, GraphWrapper
from ..core import VarWrapper, OpWrapper, GraphWrapper
__all__ = ["prune"]
__all__ = ["Pruner"]
class Pruner():
......
......@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .quanter import quant_aware, quant_post, convert
from .quant_embedding import quant_embedding
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max']
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max']
VALID_DTYPES = ['int8']
_quant_config_default = {
# weight quantize type, default is 'abs_max'
'weight_quantize_type': 'abs_max',
# activation quantize type, default is 'abs_max'
'activation_quantize_type': 'abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
# if set quant_weight_only True, then only quantize parameters of layers which need to be quantized,
# and activations will not be quantized.
'quant_weight_only': False
}
def _parse_configs(user_config):
"""
check user configs is valid, and set default value if user not config.
Args:
user_config(dict):the config of user.
Return:
configs(dict): final configs will be used.
"""
configs = copy.deepcopy(_quant_config_default)
configs.update(user_config)
# check configs is valid
assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES)
assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES)
assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value."
assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value."
assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \
"not_quant_pattern must be a list"
assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \
"dtype must be a str."
assert (configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \
"window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."
assert isinstance(configs['moving_rate'], float), \
"moving_rate must be float value, The decay coefficient of moving average, default is 0.9."
assert isinstance(configs['quant_weight_only'], bool), \
"quant_weight_only must be bool value, if set quant_weight_only True, " \
"then only quantize parameters of layers which need to be quantized, " \
" and activations will not be quantized."
return configs
def quant_aware(program, place, config, scope=None, for_test=False):
"""
add trainable quantization ops in program.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: if program is test program, for_test should be set True, else False.
Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy.
"""
scope = fluid.global_scope() if not scope else scope
assert isinstance(config, dict), "config must be dict"
assert 'weight_quantize_type' in config.keys(
), 'weight_quantize_type must be configured'
assert 'activation_quantize_type' in config.keys(
), 'activation_quantize_type must be configured'
config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph)
if for_test:
quant_program = main_graph.to_program()
else:
quant_program = fluid.CompiledProgram(main_graph.graph)
return quant_program
def quant_post(program, place, config, scope=None):
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program.
Return:
fluid.Program: the quantization program is not trainable.
"""
pass
def convert(program, scope, place, config, save_int8=False):
"""
add quantization ops in program. the program returned is not trainable.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
save_int8: is export int8 freezed program.
Return:
fluid.Program: freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None.
"""
test_graph = IrGraph(core.Graph(program.desc), for_test=True)
# Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type=config['weight_quantize_type'])
freeze_pass.apply(test_graph)
freezed_program = test_graph.to_program()
if save_int8:
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.global_scope(), place=place)
convert_int8_pass.apply(test_graph)
freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8
else:
return freezed_program
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SANAS
from paddleslim.nas import SearchSpaceFactory
from paddleslim.analysis import flops
class TestSANAS(unittest.TestCase):
def test_nas(self):
factory = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)]
space = factory.get_search_space([('MobileNetV2Space', config0)])
origin_arch = space.token2arch()[0]
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[None, 3, 224, 224], dtype="float32")
origin_arch(input)
base_flops = flops(main_program)
search_steps = 3
sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(search_steps):
archs = sa_nas.next_archs()
main_program = fluid.Program()
s_program = fluid.Program()
with fluid.program_guard(main_program, s_program):
input = fluid.data(
name="input", shape=[None, 3, 224, 224], dtype="float32")
archs[0](input)
sa_nas.reward(1)
self.assertTrue(flops(main_program) < base_flops)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册