提交 d5081f7d 编写于 作者: C ceci3

add save_load for sanas

上级 477d503b
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import math import math
import logging import logging
import numpy as np import numpy as np
import json
from .controller import EvolutionaryController from .controller import EvolutionaryController
from log_helper import get_logger from log_helper import get_logger
...@@ -34,7 +35,12 @@ class SAController(EvolutionaryController): ...@@ -34,7 +35,12 @@ class SAController(EvolutionaryController):
init_temperature=1024, init_temperature=1024,
max_try_times=None, max_try_times=None,
init_tokens=None, init_tokens=None,
constrain_func=None): reward=-1,
max_reward=-1,
iters=0,
best_tokens=None,
constrain_func=None,
checkpoints=None):
"""Initialize. """Initialize.
Args: Args:
range_table(list<int>): Range table. range_table(list<int>): Range table.
...@@ -43,6 +49,7 @@ class SAController(EvolutionaryController): ...@@ -43,6 +49,7 @@ class SAController(EvolutionaryController):
max_try_times(int): max try times before get legal tokens. max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens. init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
checkpoints(str): if checkpoint is None, donnot save checkpoints, else save scene to checkpoints file.
""" """
super(SAController, self).__init__() super(SAController, self).__init__()
self._range_table = range_table self._range_table = range_table
...@@ -51,12 +58,13 @@ class SAController(EvolutionaryController): ...@@ -51,12 +58,13 @@ class SAController(EvolutionaryController):
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_times = max_try_times self._max_try_times = max_try_times
self._reward = -1 self._reward = reward
self._tokens = init_tokens self._tokens = init_tokens
self._constrain_func = constrain_func self._constrain_func = constrain_func
self._max_reward = -1 self._max_reward = max_reward
self._best_tokens = None self._best_tokens = best_tokens
self._iter = 0 self._iter = iters
self._checkpoints = checkpoints
def __getstate__(self): def __getstate__(self):
d = {} d = {}
...@@ -84,8 +92,12 @@ class SAController(EvolutionaryController): ...@@ -84,8 +92,12 @@ class SAController(EvolutionaryController):
self._max_reward = reward self._max_reward = reward
self._best_tokens = tokens self._best_tokens = tokens
_logger.info( _logger.info(
"Controller - iter: {}; current_reward: {}; current tokens: {}". "Controller - iter: {}; best_reward: {}, best tokens: {}, current_reward: {}; current tokens: {}".
format(self._iter, self._reward, self._tokens)) format(self._iter, self._reward, self._tokens, reward, tokens))
if self._checkpoints != None:
self._save_checkpoint(self._checkpoints)
def next_tokens(self, control_token=None): def next_tokens(self, control_token=None):
""" """
...@@ -113,3 +125,14 @@ class SAController(EvolutionaryController): ...@@ -113,3 +125,14 @@ class SAController(EvolutionaryController):
else: else:
break break
return new_tokens return new_tokens
def _save_checkpoint(self, output_file):
scene = dict()
for key in self.__dict__():
if key in ['_checkpoints']:
continue
scene[key] = self.__dict__[key]
f = open(output_file, 'w')
json.dump(scene)
f.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册