提交 cd815b7e 编写于 作者: B BinLong

Merge branch 'develop' of github.com:PaddlePaddle/PaddleHub into develop

......@@ -417,8 +417,17 @@ class BasicTask(object):
def _build_env_end_event(self):
pass
def _calculate_metrics(self, run_states):
raise NotImplementedError
def _finetune_start_event(self):
logger.info("PaddleHub finetune start")
def _finetune_end_event(self, run_states):
logger.info("PaddleHub finetune finished.")
def _predict_start_event(self):
logger.info("PaddleHub predict start")
def _predict_end_event(self, run_states):
logger.info("PaddleHub predict finished.")
def _eval_start_event(self):
logger.info("Evaluation on {} dataset start".format(self.phase))
......@@ -434,7 +443,7 @@ class BasicTask(object):
"step %d: [step/sec: %.2f]" % (self.current_step, run_speed))
def _save_ckpt_interval_event(self):
self.save_checkpoint(self.current_epoch, self.current_step)
self.save_checkpoint()
def _eval_interval_event(self):
self.eval(phase="dev")
......@@ -443,12 +452,6 @@ class BasicTask(object):
if self.is_predict_phase:
yield run_state.run_results
def _finetune_start_event(self):
logger.info("PaddleHub finetune start")
def _finetune_end_event(self, run_state):
logger.info("PaddleHub finetune finished.")
def _build_net(self):
raise NotImplementedError
......@@ -461,9 +464,12 @@ class BasicTask(object):
def _add_metrics(self):
raise NotImplementedError
def _calculate_metrics(self, run_states):
raise NotImplementedError
# NOTE: current saved checkpoint machanism is not completed,
# it can't restore dataset training status
def save_checkpoint(self, epoch, step):
def save_checkpoint(self):
save_checkpoint(
checkpoint_dir=self.config.checkpoint_dir,
current_epoch=self.current_epoch,
......@@ -506,7 +512,7 @@ class BasicTask(object):
self.env.current_epoch += 1
# Save checkpoint after finetune
self.save_checkpoint(self.current_epoch + 1, self.current_step)
self.save_checkpoint()
# Final evaluation
self.eval(phase="dev")
......@@ -529,7 +535,9 @@ class BasicTask(object):
"best_model")
self.load_parameters(best_model_path)
self._predict_data = data
self._predict_start_event()
run_states = self._run()
self._predict_end_event(run_states)
self._predict_data = None
return [run_state.run_results for run_state in run_states]
......@@ -946,10 +954,10 @@ class SequenceLabelTask(BasicTask):
class MultiLabelClassifierTask(ClassifierTask):
def __init__(self,
data_reader,
feature,
num_classes,
feed_list,
data_reader,
startup_program=None,
config=None,
hidden_units=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册