From d95da67bad0a6ca7bb96bb6b908c5090838f662b Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 17 Jan 2020 12:57:01 +0800 Subject: [PATCH] Remove redundant code in paddleslim/prune (#44) --- paddleslim/common/__init__.py | 2 +- paddleslim/common/controller_server.py | 2 +- paddleslim/{prune => common}/lock.py | 0 paddleslim/common/lock_utils.py | 38 -------- paddleslim/prune/__init__.py | 6 -- paddleslim/prune/auto_pruner.py | 6 +- paddleslim/prune/controller_client.py | 66 ------------- paddleslim/prune/controller_server.py | 128 ------------------------- 8 files changed, 5 insertions(+), 243 deletions(-) rename paddleslim/{prune => common}/lock.py (100%) delete mode 100644 paddleslim/common/lock_utils.py delete mode 100644 paddleslim/prune/controller_client.py delete mode 100644 paddleslim/prune/controller_server.py diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index ccf9b76d..e146e004 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -16,7 +16,7 @@ from .sa_controller import SAController from .log_helper import get_logger from .controller_server import ControllerServer from .controller_client import ControllerClient -from .lock_utils import lock, unlock +from .lock import lock, unlock from .cached_reader import cached_reader __all__ = [ diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 3331d6ae..97328200 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -18,7 +18,7 @@ import socket import time from .log_helper import get_logger from threading import Thread -from .lock_utils import lock, unlock +from .lock import lock, unlock __all__ = ['ControllerServer'] diff --git a/paddleslim/prune/lock.py b/paddleslim/common/lock.py similarity index 100% rename from paddleslim/prune/lock.py rename to paddleslim/common/lock.py diff --git a/paddleslim/common/lock_utils.py b/paddleslim/common/lock_utils.py deleted file mode 100644 index 9daf4f3f..00000000 --- a/paddleslim/common/lock_utils.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -__all__ = ['lock', 'unlock'] - -if os.name == 'nt': - - def lock(file): - raise NotImplementedError('Windows is not supported.') - - def unlock(file): - raise NotImplementedError('Windows is not supported.') - -elif os.name == 'posix': - from fcntl import flock, LOCK_EX, LOCK_UN - - def lock(file): - """Lock the file in local file system.""" - flock(file.fileno(), LOCK_EX) - - def unlock(file): - """Unlock the file in local file system.""" - flock(file.fileno(), LOCK_UN) -else: - raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 1ef3cbc7..6ab266c4 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -15,10 +15,6 @@ from .pruner import * import pruner from .auto_pruner import * import auto_pruner -from .controller_server import * -import controller_server -from .controller_client import * -import controller_client from .sensitive_pruner import * import sensitive_pruner from .sensitive import * @@ -32,8 +28,6 @@ __all__ = [] __all__ += pruner.__all__ __all__ += auto_pruner.__all__ -__all__ += controller_server.__all__ -__all__ += controller_client.__all__ __all__ += sensitive_pruner.__all__ __all__ += sensitive.__all__ __all__ += prune_walker.__all__ diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index 575d93c5..672ce78d 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -157,7 +157,7 @@ class AutoPruner(object): def _constrain_func(self, tokens): ratios = self._tokens2ratios(tokens) - pruned_program = self._pruner.prune( + pruned_program, _, _ = self._pruner.prune( self._program, self._scope, self._params, @@ -183,7 +183,7 @@ class AutoPruner(object): Program: The pruned program. """ self._current_ratios = self._next_ratios() - pruned_program = self._pruner.prune( + pruned_program, _, _ = self._pruner.prune( program, self._scope, self._params, @@ -193,7 +193,7 @@ class AutoPruner(object): param_backup=self._param_backup) pruned_val_program = None if eval_program is not None: - pruned_val_program = self._pruner.prune( + pruned_val_program, _, _ = self._pruner.prune( program, self._scope, self._params, diff --git a/paddleslim/prune/controller_client.py b/paddleslim/prune/controller_client.py deleted file mode 100644 index f133e8b2..00000000 --- a/paddleslim/prune/controller_client.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import socket -from ..common import get_logger - -__all__ = ['ControllerClient'] - -_logger = get_logger(__name__, level=logging.INFO) - - -class ControllerClient(object): - """ - Controller client. - """ - - def __init__(self, server_ip=None, server_port=None, key=None): - """ - Args: - server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. - server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. - key(str): The key used to identify legal agent for controller server. Default: "light-nas" - """ - self.server_ip = server_ip - self.server_port = server_port - self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._key = key - - def update(self, tokens, reward): - """ - Update the controller according to latest tokens and reward. - Args: - tokens(list): The tokens generated in last step. - reward(float): The reward of tokens. - """ - socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_client.connect((self.server_ip, self.server_port)) - tokens = ",".join([str(token) for token in tokens]) - socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) - .encode()) - tokens = socket_client.recv(1024).decode() - tokens = [int(token) for token in tokens.strip("\n").split(",")] - return tokens - - def next_tokens(self): - """ - Get next tokens. - """ - socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - socket_client.connect((self.server_ip, self.server_port)) - socket_client.send("next_tokens".encode()) - tokens = socket_client.recv(1024).decode() - tokens = [int(token) for token in tokens.strip("\n").split(",")] - return tokens diff --git a/paddleslim/prune/controller_server.py b/paddleslim/prune/controller_server.py deleted file mode 100644 index 5fc97844..00000000 --- a/paddleslim/prune/controller_server.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import logging -import socket -from ..common import get_logger -from threading import Thread -from .lock import lock, unlock - -__all__ = ['ControllerServer'] - -_logger = get_logger(__name__, level=logging.INFO) - - -class ControllerServer(object): - """ - The controller wrapper with a socket server to handle the request of search agent. - """ - - def __init__(self, - controller=None, - address=('', 0), - max_client_num=100, - search_steps=None, - key=None): - """ - Args: - controller(slim.searcher.Controller): The controller used to generate tokens. - address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). - which means setting ip automatically - max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. - search_steps(int): The total steps of searching. None means never stopping. Default: None - """ - self._controller = controller - self._address = address - self._max_client_num = max_client_num - self._search_steps = search_steps - self._closed = False - self._port = address[1] - self._ip = address[0] - self._key = key - self._socket_file = "./controller_server.socket" - - def start(self): - open(self._socket_file, 'a').close() - socket_file = open(self._socket_file, 'r+') - lock(socket_file) - tid = socket_file.readline() - if tid == '': - _logger.info("start controller server...") - tid = self._start() - socket_file.write("tid: {}\nip: {}\nport: {}\n".format( - tid, self._ip, self._port)) - _logger.info("started controller server...") - unlock(socket_file) - socket_file.close() - - def _start(self): - self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket_server.bind(self._address) - self._socket_server.listen(self._max_client_num) - self._port = self._socket_server.getsockname()[1] - self._ip = self._socket_server.getsockname()[0] - _logger.info("ControllerServer - listen on: [{}:{}]".format( - self._ip, self._port)) - thread = Thread(target=self.run) - thread.start() - return str(thread) - - def close(self): - """Close the server.""" - self._closed = True - os.remove(self._socket_file) - _logger.info("server closed!") - - def port(self): - """Get the port.""" - return self._port - - def ip(self): - """Get the ip.""" - return self._ip - - def run(self): - _logger.info("Controller Server run...") - try: - while ((self._search_steps is None) or - (self._controller._iter < - (self._search_steps))) and not self._closed: - conn, addr = self._socket_server.accept() - message = conn.recv(1024).decode() - if message.strip("\n") == "next_tokens": - tokens = self._controller.next_tokens() - tokens = ",".join([str(token) for token in tokens]) - conn.send(tokens.encode()) - else: - _logger.debug("recv message from {}: [{}]".format(addr, - message)) - messages = message.strip('\n').split("\t") - if (len(messages) < 3) or (messages[0] != self._key): - _logger.debug("recv noise from {}: [{}]".format( - addr, message)) - continue - tokens = messages[1] - reward = messages[2] - tokens = [int(token) for token in tokens.split(",")] - self._controller.update(tokens, float(reward)) - tokens = self._controller.next_tokens() - tokens = ",".join([str(token) for token in tokens]) - conn.send(tokens.encode()) - _logger.debug("send message to {}: [{}]".format(addr, - tokens)) - conn.close() - finally: - self._socket_server.close() - self.close() -- GitLab