diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index 29ed2d1a23919d57e5e2daa6531846180bccdee9..4a5b39b5b8d11cce276e64e47ca209d6e2459c89 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -99,7 +99,6 @@ class SAController(EvolutionaryController): if self._checkpoints != None: self._save_checkpoint(self._checkpoints) - def next_tokens(self, control_token=None): """ Get next tokens. @@ -121,22 +120,19 @@ class SAController(EvolutionaryController): index = int(len(self._range_table[0]) * np.random.random()) new_tokens = tokens[:] new_tokens[index] = np.random.randint( - self._range_table[0][index], - self._range_table[1][index]) + self._range_table[0][index], self._range_table[1][index]) else: break return new_tokens - def _save_checkpoint(self, output_file): + def _save_checkpoint(self, output_dir): if not os.path.exists(output_dir): os.makedirs(output_dir) file_path = os.path.join(output_dir, 'sanas.checkpoints') scene = dict() - for key in self.__dict__(): + for key in self.__dict__: if key in ['_checkpoints']: continue scene[key] = self.__dict__[key] - f = open(file_path, 'w') - json.dump(scene) - f.close() - + with open(file_path, 'w') as f: + json.dump(scene, f) diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 6a3fb80fe0c95298840d9909cbc67ef3f425fe66..b151195420bd812b3cf36e4bf2e28de652045d8c 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -81,10 +81,14 @@ class SANAS(object): _logger.info("range table: {}".format(range_table)) if load_checkpoint != None: - assert os.path.exists(load_checkpoint) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!' - checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints') - scene = json.load(checkpoint_path) - preinit_tokens = scene['_init_tokens'] + assert os.path.exists( + load_checkpoint + ) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!' + checkpoint_path = os.path.join(load_checkpoint, + 'sanas.checkpoints') + with open(checkpoint_path, 'r') as f: + scene = json.load(f) + preinit_tokens = scene['_tokens'] prereward = scene['_reward'] premax_reward = scene['_max_reward'] prebest_tokens = scene['_best_tokens'] @@ -95,17 +99,17 @@ class SANAS(object): premax_reward = -1 prebest_tokens = None preiter = 0 - + controller = SAController( range_table, self._reduce_rate, self._init_temperature, max_try_times=None, init_tokens=preinit_tokens, - reward = prereward, - max_reward = premax_reward, - iters = preiter, - best_tokens = prebest_tokens, + reward=prereward, + max_reward=premax_reward, + iters=preiter, + best_tokens=prebest_tokens, constrain_func=None, checkpoints=save_checkpoint) @@ -123,8 +127,6 @@ class SANAS(object): server_ip, server_port, key=self._key) if is_server and load_checkpoint != None: - checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints') - scene = json.load(checkpoint_path) self._iter = scene['_iter'] else: self._iter = 0