提交 fc6c613b 编写于 作者: S Steffy-zxf 提交者: wuzewu

Fix visualdl add record bug (#53)

* Fix visualdl add record bug
上级 5a44a789
...@@ -389,16 +389,21 @@ class BasicTask(object): ...@@ -389,16 +389,21 @@ class BasicTask(object):
def _build_env_end_event(self): def _build_env_end_event(self):
pass pass
def _calculate_metrics(self, run_states):
raise NotImplementedError
def _eval_start_event(self): def _eval_start_event(self):
logger.info("Evaluation on {} dataset start".format(self.phase)) logger.info("Evaluation on {} dataset start".format(self.phase))
def _eval_end_event(self, run_state): def _eval_end_event(self, run_states):
run_speed = self._calculate_metrics(run_states)
logger.info("[%s dataset evaluation result] [step/sec: %.2f]" % logger.info("[%s dataset evaluation result] [step/sec: %.2f]" %
(self.phase, run_state.run_speed)) (self.phase, run_speed))
def _log_interval_event(self, run_state): def _log_interval_event(self, run_states):
logger.info("step %d: [step/sec: %.2f]" % (self.current_step, run_speed = self._calculate_metrics(run_states)
run_state.run_speed)) logger.info(
"step %d: [step/sec: %.2f]" % (self.current_step, run_speed))
def _save_ckpt_interval_event(self): def _save_ckpt_interval_event(self):
self.save_checkpoint(self.current_epoch, self.current_step) self.save_checkpoint(self.current_epoch, self.current_step)
...@@ -446,6 +451,9 @@ class BasicTask(object): ...@@ -446,6 +451,9 @@ class BasicTask(object):
startup_program=self._base_startup_program, startup_program=self._base_startup_program,
load_best_model=load_best_model) load_best_model=load_best_model)
if self.is_predict_phase or self.is_test_phase:
self.env.current_step = 0
def finetune_and_eval(self): def finetune_and_eval(self):
self.finetune(do_eval=True) self.finetune(do_eval=True)
...@@ -516,8 +524,8 @@ class BasicTask(object): ...@@ -516,8 +524,8 @@ class BasicTask(object):
step_run_state.run_examples += num_batch_examples step_run_state.run_examples += num_batch_examples
step_run_state.update() step_run_state.update()
period_run_states += [step_run_state] period_run_states += [step_run_state]
if self.is_train_phase:
self.env.current_step += 1 self.env.current_step += 1
if self.is_train_phase:
if self.current_step % self.config.log_interval == 0: if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states) self._log_interval_event(period_run_states)
global_run_states += period_run_states global_run_states += period_run_states
...@@ -552,8 +560,8 @@ class BasicTask(object): ...@@ -552,8 +560,8 @@ class BasicTask(object):
step_run_state.run_examples += num_batch_examples step_run_state.run_examples += num_batch_examples
step_run_state.update() step_run_state.update()
period_run_states += [step_run_state] period_run_states += [step_run_state]
if self.is_train_phase:
self.env.current_step += 1 self.env.current_step += 1
if self.is_train_phase:
if self.current_step % self.config.log_interval == 0: if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states) self._log_interval_event(period_run_states)
global_run_states += period_run_states global_run_states += period_run_states
...@@ -629,6 +637,7 @@ class ClassifierTask(BasicTask): ...@@ -629,6 +637,7 @@ class ClassifierTask(BasicTask):
def _build_env_end_event(self): def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw: with self.log_writer.mode(self.phase) as logw:
if not self.is_predict_phase:
self.env.loss_scalar = logw.scalar( self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase)) tag="Loss [{}]".format(self.phase))
self.env.acc_scalar = logw.scalar( self.env.acc_scalar = logw.scalar(
...@@ -664,9 +673,9 @@ class ClassifierTask(BasicTask): ...@@ -664,9 +673,9 @@ class ClassifierTask(BasicTask):
logger.info( logger.info(
"[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]" "[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]"
% (self.phase, eval_loss, eval_acc, run_speed)) % (self.phase, eval_loss, eval_acc, run_speed))
if self.phase in ["dev", "val"] and eval_acc > self.best_accuracy:
self.env.loss_scalar.add_record(self.current_step, eval_loss) self.env.loss_scalar.add_record(self.current_step, eval_loss)
self.env.acc_scalar.add_record(self.current_step, eval_acc) self.env.acc_scalar.add_record(self.current_step, eval_acc)
if self.phase in ["dev", "val"] and eval_acc > self.best_accuracy:
self.best_accuracy = eval_acc self.best_accuracy = eval_acc
model_saved_dir = os.path.join(self.config.checkpoint_dir, model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model") "best_model")
...@@ -796,9 +805,15 @@ class SequenceLabelTask(BasicTask): ...@@ -796,9 +805,15 @@ class SequenceLabelTask(BasicTask):
def _build_env_end_event(self): def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw: with self.log_writer.mode(self.phase) as logw:
if self.is_train_phase:
self.env.loss_scalar = logw.scalar( self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase)) tag="Loss [{}]".format(self.phase))
self.env.f1_scalar = logw.scalar(tag="F1 [{}]".format(self.phase))
if self.phase in ["dev", "val"]:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.f1_scalar = logw.scalar(
tag="F1 [{}]".format(self.phase))
self.env.precision_scalar = logw.scalar( self.env.precision_scalar = logw.scalar(
tag="Precision [{}]".format(self.phase)) tag="Precision [{}]".format(self.phase))
self.env.recall_scalar = logw.scalar( self.env.recall_scalar = logw.scalar(
...@@ -838,6 +853,7 @@ class SequenceLabelTask(BasicTask): ...@@ -838,6 +853,7 @@ class SequenceLabelTask(BasicTask):
def _eval_end_event(self, run_states): def _eval_end_event(self, run_states):
precision, recall, f1, avg_loss, run_speed = self._calculate_metrics( precision, recall, f1, avg_loss, run_speed = self._calculate_metrics(
run_states) run_states)
self.env.loss_scalar.add_record(self.current_step, avg_loss)
self.env.f1_scalar.add_record(self.current_step, f1) self.env.f1_scalar.add_record(self.current_step, f1)
self.env.precision_scalar.add_record(self.current_step, precision) self.env.precision_scalar.add_record(self.current_step, precision)
self.env.recall_scalar.add_record(self.current_step, recall) self.env.recall_scalar.add_record(self.current_step, recall)
...@@ -951,8 +967,10 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -951,8 +967,10 @@ class MultiLabelClassifierTask(ClassifierTask):
def _build_env_end_event(self): def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw: with self.log_writer.mode(self.phase) as logw:
if not self.is_predict_phase:
self.env.loss_scalar = logw.scalar( self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase)) tag="Loss [{}]".format(self.phase))
if self.is_train_phase:
self.env.auc_scalar_list = [] self.env.auc_scalar_list = []
for i in range(self.num_classes): for i in range(self.num_classes):
self.env.auc_scalar_list.append( self.env.auc_scalar_list.append(
...@@ -978,33 +996,27 @@ class MultiLabelClassifierTask(ClassifierTask): ...@@ -978,33 +996,27 @@ class MultiLabelClassifierTask(ClassifierTask):
def _log_interval_event(self, run_states): def _log_interval_event(self, run_states):
avg_loss, auc_list, run_speed = self._calculate_metrics(run_states) avg_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.loss_scalar.add_record(self.current_step, avg_loss) self.env.loss_scalar.add_record(self.current_step, avg_loss)
avg_auc = np.mean(auc_list) avg_auc = np.mean(auc_list)
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc) self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
logger.info("step %d: loss=%.5f avg_auc=%.5f [step/sec: %.2f]" % logger.info("step %d: loss=%.5f avg_auc=%.5f [step/sec: %.2f]" %
(self.current_step, avg_loss, avg_auc, run_speed)) (self.current_step, avg_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list): for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index][0])
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0])) logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
def _eval_end_event(self, run_states): def _eval_end_event(self, run_states):
eval_loss, auc_list, run_speed = self._calculate_metrics(run_states) eval_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
avg_auc = np.mean(auc_list) avg_auc = np.mean(auc_list)
logger.info( logger.info(
"[%s dataset evaluation result] loss=%.5f avg_auc=%.5f [step/sec: %.2f]" "[%s dataset evaluation result] loss=%.5f avg_auc=%.5f [step/sec: %.2f]"
% (self.phase, eval_loss, avg_auc, run_speed)) % (self.phase, eval_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list): for index, auc in enumerate(auc_list):
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0])) logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
if self.phase in ["dev", "val"] and avg_auc > self.best_avg_auc:
self.env.loss_scalar.add_record(self.current_step, eval_loss) self.env.loss_scalar.add_record(self.current_step, eval_loss)
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc) self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
if self.phase in ["dev", "val"] and avg_auc > self.best_avg_auc:
self.best_avg_auc = avg_auc self.best_avg_auc = avg_auc
model_saved_dir = os.path.join(self.config.checkpoint_dir, model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model") "best_model")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册