提交 595ffae0 编写于 作者: C ceci3

fix bug

上级 33653657
...@@ -99,7 +99,6 @@ class SAController(EvolutionaryController): ...@@ -99,7 +99,6 @@ class SAController(EvolutionaryController):
if self._checkpoints != None: if self._checkpoints != None:
self._save_checkpoint(self._checkpoints) self._save_checkpoint(self._checkpoints)
def next_tokens(self, control_token=None): def next_tokens(self, control_token=None):
""" """
Get next tokens. Get next tokens.
...@@ -121,22 +120,19 @@ class SAController(EvolutionaryController): ...@@ -121,22 +120,19 @@ class SAController(EvolutionaryController):
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
new_tokens[index] = np.random.randint( new_tokens[index] = np.random.randint(
self._range_table[0][index], self._range_table[0][index], self._range_table[1][index])
self._range_table[1][index])
else: else:
break break
return new_tokens return new_tokens
def _save_checkpoint(self, output_file): def _save_checkpoint(self, output_dir):
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
file_path = os.path.join(output_dir, 'sanas.checkpoints') file_path = os.path.join(output_dir, 'sanas.checkpoints')
scene = dict() scene = dict()
for key in self.__dict__(): for key in self.__dict__:
if key in ['_checkpoints']: if key in ['_checkpoints']:
continue continue
scene[key] = self.__dict__[key] scene[key] = self.__dict__[key]
f = open(file_path, 'w') with open(file_path, 'w') as f:
json.dump(scene) json.dump(scene, f)
f.close()
...@@ -81,10 +81,14 @@ class SANAS(object): ...@@ -81,10 +81,14 @@ class SANAS(object):
_logger.info("range table: {}".format(range_table)) _logger.info("range table: {}".format(range_table))
if load_checkpoint != None: if load_checkpoint != None:
assert os.path.exists(load_checkpoint) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!' assert os.path.exists(
checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints') load_checkpoint
scene = json.load(checkpoint_path) ) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
preinit_tokens = scene['_init_tokens'] checkpoint_path = os.path.join(load_checkpoint,
'sanas.checkpoints')
with open(checkpoint_path, 'r') as f:
scene = json.load(f)
preinit_tokens = scene['_tokens']
prereward = scene['_reward'] prereward = scene['_reward']
premax_reward = scene['_max_reward'] premax_reward = scene['_max_reward']
prebest_tokens = scene['_best_tokens'] prebest_tokens = scene['_best_tokens']
...@@ -102,10 +106,10 @@ class SANAS(object): ...@@ -102,10 +106,10 @@ class SANAS(object):
self._init_temperature, self._init_temperature,
max_try_times=None, max_try_times=None,
init_tokens=preinit_tokens, init_tokens=preinit_tokens,
reward = prereward, reward=prereward,
max_reward = premax_reward, max_reward=premax_reward,
iters = preiter, iters=preiter,
best_tokens = prebest_tokens, best_tokens=prebest_tokens,
constrain_func=None, constrain_func=None,
checkpoints=save_checkpoint) checkpoints=save_checkpoint)
...@@ -123,8 +127,6 @@ class SANAS(object): ...@@ -123,8 +127,6 @@ class SANAS(object):
server_ip, server_port, key=self._key) server_ip, server_port, key=self._key)
if is_server and load_checkpoint != None: if is_server and load_checkpoint != None:
checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints')
scene = json.load(checkpoint_path)
self._iter = scene['_iter'] self._iter = scene['_iter']
else: else:
self._iter = 0 self._iter = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册