提交 595ffae0 编写于 作者: C ceci3

fix bug

上级 33653657
......@@ -99,7 +99,6 @@ class SAController(EvolutionaryController):
if self._checkpoints != None:
self._save_checkpoint(self._checkpoints)
def next_tokens(self, control_token=None):
"""
Get next tokens.
......@@ -121,22 +120,19 @@ class SAController(EvolutionaryController):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(
self._range_table[0][index],
self._range_table[1][index])
self._range_table[0][index], self._range_table[1][index])
else:
break
return new_tokens
def _save_checkpoint(self, output_file):
def _save_checkpoint(self, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_path = os.path.join(output_dir, 'sanas.checkpoints')
scene = dict()
for key in self.__dict__():
for key in self.__dict__:
if key in ['_checkpoints']:
continue
scene[key] = self.__dict__[key]
f = open(file_path, 'w')
json.dump(scene)
f.close()
with open(file_path, 'w') as f:
json.dump(scene, f)
......@@ -81,10 +81,14 @@ class SANAS(object):
_logger.info("range table: {}".format(range_table))
if load_checkpoint != None:
assert os.path.exists(load_checkpoint) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints')
scene = json.load(checkpoint_path)
preinit_tokens = scene['_init_tokens']
assert os.path.exists(
load_checkpoint
) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
checkpoint_path = os.path.join(load_checkpoint,
'sanas.checkpoints')
with open(checkpoint_path, 'r') as f:
scene = json.load(f)
preinit_tokens = scene['_tokens']
prereward = scene['_reward']
premax_reward = scene['_max_reward']
prebest_tokens = scene['_best_tokens']
......@@ -95,17 +99,17 @@ class SANAS(object):
premax_reward = -1
prebest_tokens = None
preiter = 0
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=preinit_tokens,
reward = prereward,
max_reward = premax_reward,
iters = preiter,
best_tokens = prebest_tokens,
reward=prereward,
max_reward=premax_reward,
iters=preiter,
best_tokens=prebest_tokens,
constrain_func=None,
checkpoints=save_checkpoint)
......@@ -123,8 +127,6 @@ class SANAS(object):
server_ip, server_port, key=self._key)
if is_server and load_checkpoint != None:
checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints')
scene = json.load(checkpoint_path)
self._iter = scene['_iter']
else:
self._iter = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册