提交 915ef9d5 编写于 作者: C ceci3

update save load

上级 d5081f7d
......@@ -13,6 +13,7 @@
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import os
import copy
import math
import logging
......@@ -127,12 +128,15 @@ class SAController(EvolutionaryController):
return new_tokens
def _save_checkpoint(self, output_file):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
file_path = os.path.join(output_dir, 'sanas.checkpoints')
scene = dict()
for key in self.__dict__():
if key in ['_checkpoints']:
continue
scene[key] = self.__dict__[key]
f = open(output_file, 'w')
f = open(file_path, 'w')
json.dump(scene)
f.close()
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import logging
import numpy as np
import json
import hashlib
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
......@@ -39,6 +41,8 @@ class SANAS(object):
reduce_rate=0.85,
search_steps=300,
key="sa_nas",
save_checkpoint=None,
load_checkpoint=None,
is_server=False):
"""
Search a group of ratios used to prune program.
......@@ -75,13 +79,35 @@ class SANAS(object):
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
if load_checkpoint != None:
assert os.path.exists(load_checkpoint) == True, 'load checkpoint file NOT EXIST!!! Please check the directory of checkpoint!!!'
checkpoint_path = os.path.join(load_checkpoint, 'sanas.checkpoints')
scene = json.load(checkpoint_path)
preinit_tokens = scene['_init_tokens']
prereward = scene['_reward']
premax_reward = scene['_max_reward']
prebest_tokens = scene['_best_tokens']
preiter = scene['_iter']
else:
preinit_tokens = None
prereward = -1
premax_reward = -1
prebest_tokens = init_tokens
preiter = 0
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
init_tokens=preinit_tokens,
reward = prereward,
max_reward = premax_reward,
iters = preiter,
best_tokens = prebest_tokens,
constrain_func=None,
checkpoints=save_checkpoint)
max_client_num = 100
self._controller_server = ControllerServer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册