diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index 4f1771dfc89564ec0c3f1253c761af863d09c840..9596a0552a5f268846d5def1c7d9c2b09491d7a6 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -13,10 +13,12 @@ # limitations under the License. """The controller used to search hyperparameters or neural architecture""" +import os import copy import math import logging import numpy as np +import json from .controller import EvolutionaryController from log_helper import get_logger @@ -34,15 +36,25 @@ class SAController(EvolutionaryController): init_temperature=1024, max_try_times=None, init_tokens=None, - constrain_func=None): + reward=-1, + max_reward=-1, + iters=0, + best_tokens=None, + constrain_func=None, + checkpoints=None): """Initialize. Args: range_table(list): Range table. reduce_rate(float): The decay rate of temperature. init_temperature(float): Init temperature. max_try_times(int): max try times before get legal tokens. - init_tokens(list): The initial tokens. + init_tokens(list): The initial tokens. Default: None. + reward(float): The reward of current tokens. Default: -1. + max_reward(float): The max reward in the search of sanas, in general, best tokens get max reward. Default: -1. + iters(int): The iteration of sa controller. Default: 0. + best_tokens(list): The best tokens in the search of sanas, in general, best tokens get max reward. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file. """ super(SAController, self).__init__() self._range_table = range_table @@ -51,12 +63,13 @@ class SAController(EvolutionaryController): self._reduce_rate = reduce_rate self._init_temperature = init_temperature self._max_try_times = max_try_times - self._reward = -1 + self._reward = reward self._tokens = init_tokens self._constrain_func = constrain_func - self._max_reward = -1 - self._best_tokens = None - self._iter = 0 + self._max_reward = max_reward + self._best_tokens = best_tokens + self._iter = iters + self._checkpoints = checkpoints def __getstate__(self): d = {} @@ -84,8 +97,11 @@ class SAController(EvolutionaryController): self._max_reward = reward self._best_tokens = tokens _logger.info( - "Controller - iter: {}; current_reward: {}; current tokens: {}". - format(self._iter, self._reward, self._tokens)) + "Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}". + format(self._iter, self._reward, self._tokens, reward, tokens)) + + if self._checkpoints != None: + self._save_checkpoint(self._checkpoints) def next_tokens(self, control_token=None): """ @@ -108,8 +124,19 @@ class SAController(EvolutionaryController): index = int(len(self._range_table[0]) * np.random.random()) new_tokens = tokens[:] new_tokens[index] = np.random.randint( - self._range_table[0][index], - self._range_table[1][index]) + self._range_table[0][index], self._range_table[1][index]) else: break return new_tokens + + def _save_checkpoint(self, output_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + file_path = os.path.join(output_dir, 'sanas.checkpoints') + scene = dict() + for key in self.__dict__: + if key in ['_checkpoints']: + continue + scene[key] = self.__dict__[key] + with open(file_path, 'w') as f: + json.dump(scene, f) diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 1563e141b098fb59e6b2f5abf6c4b5f8cd92d275..90e65d42aae47d0fd7d3c15955c3d2f7bd0693a1 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import socket import logging import numpy as np +import json import hashlib import paddle.fluid as fluid from ..core import VarWrapper, OpWrapper, GraphWrapper @@ -39,6 +41,8 @@ class SANAS(object): reduce_rate=0.85, search_steps=300, key="sa_nas", + save_checkpoint='nas_checkpoint', + load_checkpoint=None, is_server=False): """ Search a group of ratios used to prune program. @@ -51,6 +55,8 @@ class SANAS(object): reduce_rate(float): The decay rate used in simulated annealing search strategy. search_steps(int): The steps of searching. key(str): Identity used in communication between controller server and clients. + save_checkpoint(string|None): The directory of checkpoint to save, if set to None, not save checkpoint. Default: 'nas_checkpoint'. + load_checkpoint(string|None): The directory of checkpoint to load, if set to None, not load checkpoint. Default: None. is_server(bool): Whether current host is controller server. Default: True. """ if not is_server: @@ -75,13 +81,39 @@ class SANAS(object): range_table = self._search_space.range_table() range_table = (len(range_table) * [0], range_table) _logger.info("range table: {}".format(range_table)) + + if load_checkpoint != None: + assert os.path.exists( + load_checkpoint + ) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!' + checkpoint_path = os.path.join(load_checkpoint, + 'sanas.checkpoints') + with open(checkpoint_path, 'r') as f: + scene = json.load(f) + preinit_tokens = scene['_tokens'] + prereward = scene['_reward'] + premax_reward = scene['_max_reward'] + prebest_tokens = scene['_best_tokens'] + preiter = scene['_iter'] + else: + preinit_tokens = init_tokens + prereward = -1 + premax_reward = -1 + prebest_tokens = None + preiter = 0 + controller = SAController( range_table, self._reduce_rate, self._init_temperature, max_try_times=None, - init_tokens=init_tokens, - constrain_func=None) + init_tokens=preinit_tokens, + reward=prereward, + max_reward=premax_reward, + iters=preiter, + best_tokens=prebest_tokens, + constrain_func=None, + checkpoints=save_checkpoint) max_client_num = 100 self._controller_server = ControllerServer( @@ -96,7 +128,10 @@ class SANAS(object): self._controller_client = ControllerClient( server_ip, server_port, key=self._key) - self._iter = 0 + if is_server and load_checkpoint != None: + self._iter = scene['_iter'] + else: + self._iter = 0 def _get_host_ip(self): return socket.gethostbyname(socket.gethostname())