提交 5234589f 编写于 作者: T tangwei12

add save stage

上级 405ec7ed
...@@ -31,17 +31,13 @@ logger = logging.getLogger("fluid") ...@@ -31,17 +31,13 @@ logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def need_save(epoch_id, epoch_interval, is_last=False):
if is_last:
return True
return epoch_id % epoch_interval == 0
class SingleTrainer(Trainer): class SingleTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
self.inference_models = []
self.increment_models = []
self.exe = fluid.Executor(fluid.CPUPlace()) self.exe = fluid.Executor(fluid.CPUPlace())
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
...@@ -116,44 +112,48 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -116,44 +112,48 @@ class SingleTrainerWithDataset(SingleTrainer):
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
return dataset return dataset
def save(self, epoch_id): def save(self, epoch_id, namespace):
def need_save(epoch_id, epoch_interval, is_last=False):
if is_last:
return True
if epoch_id == -1:
return False
return epoch_id % epoch_interval == 0
def save_inference_model(): def save_inference_model():
is_save_inference = envs.get_global_env("save.inference", False) save_interval = envs.get_global_env("save.inference.epoch_interval", -1, namespace)
if not is_save_inference:
return
save_interval = envs.get_global_env("save.inference.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None) print("save inference model is not supported now.")
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None) return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None) dirname = envs.get_global_env("save.inference.dirname", None, namespace)
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
self.inference_models.append((epoch_id, dirname))
def save_persistables(): def save_persistables():
is_save_increment = envs.get_global_env("save.increment", False) save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace)
if not is_save_increment:
return
save_interval = envs.get_global_env("save.increment.epoch_interval", 1)
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
dirname = envs.get_global_env("save.inference.dirname", None) dirname = envs.get_global_env("save.increment.dirname", None, namespace)
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
fluid.io.save_persistables(self.exe, dirname) fluid.io.save_persistables(self.exe, dirname)
self.increment_models.append((epoch_id, dirname))
is_save = envs.get_global_env("save", False)
if not is_save:
return
save_persistables() save_persistables()
save_inference_model() save_inference_model()
...@@ -169,8 +169,14 @@ class SingleTrainerWithDataset(SingleTrainer): ...@@ -169,8 +169,14 @@ class SingleTrainerWithDataset(SingleTrainer):
fetch_list=self.metric_extras[0], fetch_list=self.metric_extras[0],
fetch_info=self.metric_extras[1], fetch_info=self.metric_extras[1],
print_period=self.metric_extras[2]) print_period=self.metric_extras[2])
self.save(i, "train")
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
def infer(self, context): def infer(self, context):
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
def terminal(self, context):
for model in self.increment_models:
print("epoch :{}, dir: {}".format(model[0], model[1]))
context['is_exit'] = True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册