提交 1dca49fb 编写于 作者: W wanghaoshuang

Merge branch 'auto_prune' into 'develop'

Add auto pruner based sa search strategy.

See merge request !15
......@@ -11,4 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pruner import Pruner
import pruner
from pruner import *
import auto_pruner
from auto_pruner import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
__all__ = []
__all__ += pruner.__all__
__all__ += auto_pruner.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import logging
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
from ..common import ControllerServer
from ..common import ControllerClient
__all__ = ["AutoPruner"]
_logger = get_logger(__name__, level=logging.INFO)
class AutoPruner(object):
def __init__(self,
program,
scope,
place,
params=[],
init_ratios=None,
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner",
is_server=True):
"""
Search a group of ratios used to prune program.
Args:
program(Program): The program to be pruned.
scope(Scope): The scope to be pruned.
place(fluid.Place): The device place of parameters.
params(list<str>): The names of parameters to be pruned.
init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
List means ratios used for pruning each parameter in `params`.
The length of `init_ratios` should be equal to length of params when `init_ratios` is a list.
If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
pruned_latency(float): The percent of latency to be pruned. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`.
The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
If it is a scalar, it will used for all the parameters in `params`.
key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True.
"""
self._program = program
self._scope = scope
self._place = place
self._params = params
self._init_ratios = init_ratios
self._pruned_flops = pruned_flops
self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios)
self._pruner = Pruner()
if self._pruned_flops:
self._base_flops = flops(program)
_logger.info("AutoPruner - base flops: {};".format(
self._base_flops))
if self._pruned_latency:
self._base_latency = latency(program)
if self._init_ratios is None:
self._init_ratios = self._get_init_ratios(
self, _program, self._params, self._pruned_flops,
self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios)
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
self._iter = 0
self._param_backup = {}
def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname())
def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
pass
def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
min_ratios = min_ratios if isinstance(min_ratios,
list) else [min_ratios]
max_ratios = max_ratios if isinstance(max_ratios,
list) else [max_ratios]
min_tokens = self._ratios2tokens(min_ratios)
max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens)
def _constrain_func(self, tokens):
ratios = self._tokens2ratios(tokens)
pruned_program = self._pruner.prune(
self._program,
self._scope,
self._params,
ratios,
place=self._place,
only_graph=True)
return flops(pruned_program) < self._base_flops * (
1 - self._pruned_flops)
def prune(self, program):
"""
Prune program with latest tokens generated by controller.
Args:
program(fluid.Program): The program to be pruned.
Returns:
Program: The pruned program.
"""
self._current_ratios = self._next_ratios()
pruned_program = self._pruner.prune(
program,
self._scope,
self._params,
self._current_ratios,
place=self._place,
param_backup=self._param_backup)
_logger.info("AutoPruner - pruned ratios: {}".format(
self._current_ratios))
return pruned_program
def reward(self, score):
"""
Return reward of current pruned program.
Args:
score(float): The score of pruned program.
"""
self._restore(self._scope)
self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score)
self._iter += 1
def _restore(self, scope):
for param_name in self._param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(self._param_backup[param_name], self._place)
def _next_ratios(self):
tokens = self._controller_client.next_tokens()
return self._tokens2ratios(tokens)
def _ratios2tokens(self, ratios):
"""Convert pruned ratios to tokens.
"""
return [int(ratio / 0.01) for ratio in ratios]
def _tokens2ratios(self, tokens):
"""Convert tokens to pruned ratios.
"""
return [token * 0.01 for token in tokens]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from ..common import get_logger
__all__ = ['ControllerClient']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerClient(object):
"""
Controller client.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import socket
from ..common import get_logger
from threading import Thread
from .lock import lock, unlock
__all__ = ['ControllerServer']
_logger = get_logger(__name__, level=logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agent.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("ControllerServer - listen on: [{}:{}]".format(
self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
try:
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
finally:
self._socket_server.close()
self.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__All__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
......@@ -14,9 +14,10 @@
import numpy as np
import paddle.fluid as fluid
from core import VarWrapper, OpWrapper, GraphWrapper
import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper
__all__ = ["prune"]
__all__ = ["Pruner"]
class Pruner():
......@@ -64,10 +65,10 @@ class Pruner():
params,
ratios,
place,
lazy=False,
only_graph=False,
param_backup=None,
param_shape_backup=None)
lazy=lazy,
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
return graph.program
def _prune_filters_by_ratio(self,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append('..')
import unittest
import paddle.fluid as fluid
from nas.search_space_factory import SearchSpaceFactory
class TestSearchSpace(unittest.TestCase):
def test_searchspace(self):
# if output_size is 1, the model will add fc layer in the end.
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
my_space = space.get_search_space('MobileNetV2Space', config)
model_arch = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size= config['input_size']
model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False)
predict = model_arch(model_input)
self.assertTrue(predict.shape[2] == config['output_size'])
#for op in train_prog.global_block().ops:
# print(op.type)
if __name__ == '__main__':
unittest.main()
......@@ -39,6 +39,8 @@ packages = [
'paddleslim.nas',
'paddleslim.analysis',
'paddleslim.quant',
'paddleslim.core',
'paddleslim.common',
]
setup(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune import AutoPruner
from paddleslim.analysis import flops
from layers import conv_bn_layer
class TestPrune(unittest.TestCase):
def test_prune(self):
main_program = fluid.Program()
startup_program = fluid.Program()
# X X O X O
# conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6
# | ^ | ^
# |____________| |____________________|
#
# X: prune output channels
# O: prune input channels
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 3, 16, 16])
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
shapes = {}
for param in main_program.global_block().all_parameters():
shapes[param.name] = param.shape
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
pruned_flops = 0.5
pruner = AutoPruner(
main_program,
scope,
place,
params=["conv4_weights"],
init_ratios=[0.5],
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=2,
max_ratios=[0.9],
min_ratios=[0],
key="auto_pruner")
base_flops = flops(main_program)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
program = pruner.prune(main_program)
self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops))
pruner.reward(1)
if __name__ == '__main__':
unittest.main()
......@@ -11,15 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
sys.path.append('..')
import unittest
import paddle.fluid as fluid
from paddleslim.nas import SearchSpaceFactory
from nas.search_space_factory import SearchSpaceFactory
class TestSearchSpaceFactory(unittest.TestCase):
def test_factory(self):
class TestSearchSpace(unittest.TestCase):
def test_searchspace(self):
# if output_size is 1, the model will add fc layer in the end.
config = {'input_size': 224, 'output_size': 7, 'block_num': 5}
space = SearchSpaceFactory()
......@@ -39,23 +40,30 @@ class TestSearchSpaceFactory(unittest.TestCase):
predict = model_arch[0](model_input)
self.assertTrue(predict.shape[2] == config['output_size'])
class TestMultiSearchSpace(unittest.TestCase):
space = SearchSpaceFactory()
config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5}
config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2}
my_space = space.get_search_space([('MobileNetV2Space', config0), ('ResNetSpace', config1)])
my_space = space.get_search_space(
[('MobileNetV2Space', config0), ('ResNetSpace', config1)])
model_archs = my_space.token2arch()
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
input_size= config0['input_size']
model_input = fluid.layers.data(name='model_in', shape=[1, 3, input_size, input_size], dtype='float32', append_batch_size=False)
input_size = config0['input_size']
model_input = fluid.layers.data(
name='model_in',
shape=[1, 3, input_size, input_size],
dtype='float32',
append_batch_size=False)
for model_arch in model_archs:
predict = model_arch(model_input)
model_input = predict
print(predict)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册