提交 ab06f63e 编写于 作者: S Steffy-zxf 提交者: wuzewu

Add inference model (#327)

* add-inference-model
上级 0197f9a2
...@@ -74,10 +74,6 @@ def save_checkpoint(checkpoint_dir, ...@@ -74,10 +74,6 @@ def save_checkpoint(checkpoint_dir,
ckpt = checkpoint_pb2.CheckPoint() ckpt = checkpoint_pb2.CheckPoint()
model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step) model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step)
logger.info("Saving model checkpoint to {}".format(model_saved_dir))
fluid.io.save_persistables(
exe, dirname=model_saved_dir, main_program=main_program)
ckpt.current_epoch = current_epoch ckpt.current_epoch = current_epoch
ckpt.global_step = global_step ckpt.global_step = global_step
ckpt.latest_model_dir = model_saved_dir ckpt.latest_model_dir = model_saved_dir
......
...@@ -662,11 +662,7 @@ class BaseTask(object): ...@@ -662,11 +662,7 @@ class BaseTask(object):
"best_model") "best_model")
logger.eval("best model saved to %s [best %s=%.5f]" % logger.eval("best model saved to %s [best %s=%.5f]" %
(model_saved_dir, main_metric, main_value)) (model_saved_dir, main_metric, main_value))
self.save_inference_model(dirname=model_saved_dir)
save_result = fluid.io.save_persistables(
executor=self.exe,
dirname=model_saved_dir,
main_program=self.main_program)
def _default_log_interval_event(self, run_states): def _default_log_interval_event(self, run_states):
scores, avg_loss, run_speed = self._calculate_metrics(run_states) scores, avg_loss, run_speed = self._calculate_metrics(run_states)
...@@ -717,6 +713,10 @@ class BaseTask(object): ...@@ -717,6 +713,10 @@ class BaseTask(object):
# NOTE: current saved checkpoint machanism is not completed, # NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status # it can't restore dataset training status
def save_checkpoint(self): def save_checkpoint(self):
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"step_%d" % self.current_step)
logger.info("Saving model checkpoint to {}".format(model_saved_dir))
self.save_inference_model(dirname=model_saved_dir)
save_checkpoint( save_checkpoint(
checkpoint_dir=self.config.checkpoint_dir, checkpoint_dir=self.config.checkpoint_dir,
current_epoch=self.current_epoch, current_epoch=self.current_epoch,
......
...@@ -317,7 +317,7 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -317,7 +317,7 @@ class MultiLabelClassifierTask(ClassifierTask):
def fetch_list(self): def fetch_list(self):
if self.is_train_phase or self.is_test_phase: if self.is_train_phase or self.is_test_phase:
return [metric.name for metric in self.metrics] + [self.loss.name] return [metric.name for metric in self.metrics] + [self.loss.name]
return self.outputs return [output.name for output in self.outputs]
def _postprocessing(self, run_states): def _postprocessing(self, run_states):
results = [] results = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册