提交 fc6c613b 编写于 作者: S Steffy-zxf 提交者: wuzewu

Fix visualdl add record bug (#53)

* Fix visualdl add record bug
上级 5a44a789
......@@ -389,16 +389,21 @@ class BasicTask(object):
def _build_env_end_event(self):
pass
def _calculate_metrics(self, run_states):
raise NotImplementedError
def _eval_start_event(self):
logger.info("Evaluation on {} dataset start".format(self.phase))
def _eval_end_event(self, run_state):
def _eval_end_event(self, run_states):
run_speed = self._calculate_metrics(run_states)
logger.info("[%s dataset evaluation result] [step/sec: %.2f]" %
(self.phase, run_state.run_speed))
(self.phase, run_speed))
def _log_interval_event(self, run_state):
logger.info("step %d: [step/sec: %.2f]" % (self.current_step,
run_state.run_speed))
def _log_interval_event(self, run_states):
run_speed = self._calculate_metrics(run_states)
logger.info(
"step %d: [step/sec: %.2f]" % (self.current_step, run_speed))
def _save_ckpt_interval_event(self):
self.save_checkpoint(self.current_epoch, self.current_step)
......@@ -446,6 +451,9 @@ class BasicTask(object):
startup_program=self._base_startup_program,
load_best_model=load_best_model)
if self.is_predict_phase or self.is_test_phase:
self.env.current_step = 0
def finetune_and_eval(self):
self.finetune(do_eval=True)
......@@ -516,8 +524,8 @@ class BasicTask(object):
step_run_state.run_examples += num_batch_examples
step_run_state.update()
period_run_states += [step_run_state]
self.env.current_step += 1
if self.is_train_phase:
self.env.current_step += 1
if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states)
global_run_states += period_run_states
......@@ -552,8 +560,8 @@ class BasicTask(object):
step_run_state.run_examples += num_batch_examples
step_run_state.update()
period_run_states += [step_run_state]
self.env.current_step += 1
if self.is_train_phase:
self.env.current_step += 1
if self.current_step % self.config.log_interval == 0:
self._log_interval_event(period_run_states)
global_run_states += period_run_states
......@@ -629,10 +637,11 @@ class ClassifierTask(BasicTask):
def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.acc_scalar = logw.scalar(
tag="Accuracy [{}]".format(self.phase))
if not self.is_predict_phase:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.acc_scalar = logw.scalar(
tag="Accuracy [{}]".format(self.phase))
def _calculate_metrics(self, run_states):
loss_sum = acc_sum = run_examples = 0
......@@ -664,9 +673,9 @@ class ClassifierTask(BasicTask):
logger.info(
"[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]"
% (self.phase, eval_loss, eval_acc, run_speed))
self.env.loss_scalar.add_record(self.current_step, eval_loss)
self.env.acc_scalar.add_record(self.current_step, eval_acc)
if self.phase in ["dev", "val"] and eval_acc > self.best_accuracy:
self.env.loss_scalar.add_record(self.current_step, eval_loss)
self.env.acc_scalar.add_record(self.current_step, eval_acc)
self.best_accuracy = eval_acc
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model")
......@@ -796,13 +805,19 @@ class SequenceLabelTask(BasicTask):
def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.f1_scalar = logw.scalar(tag="F1 [{}]".format(self.phase))
self.env.precision_scalar = logw.scalar(
tag="Precision [{}]".format(self.phase))
self.env.recall_scalar = logw.scalar(
tag="Recall [{}]".format(self.phase))
if self.is_train_phase:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
if self.phase in ["dev", "val"]:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.f1_scalar = logw.scalar(
tag="F1 [{}]".format(self.phase))
self.env.precision_scalar = logw.scalar(
tag="Precision [{}]".format(self.phase))
self.env.recall_scalar = logw.scalar(
tag="Recall [{}]".format(self.phase))
def _calculate_metrics(self, run_states):
total_infer = total_label = total_correct = loss_sum = 0
......@@ -838,6 +853,7 @@ class SequenceLabelTask(BasicTask):
def _eval_end_event(self, run_states):
precision, recall, f1, avg_loss, run_speed = self._calculate_metrics(
run_states)
self.env.loss_scalar.add_record(self.current_step, avg_loss)
self.env.f1_scalar.add_record(self.current_step, f1)
self.env.precision_scalar.add_record(self.current_step, precision)
self.env.recall_scalar.add_record(self.current_step, recall)
......@@ -951,14 +967,16 @@ class MultiLabelClassifierTask(ClassifierTask):
def _build_env_end_event(self):
with self.log_writer.mode(self.phase) as logw:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
self.env.auc_scalar_list = []
for i in range(self.num_classes):
self.env.auc_scalar_list.append(
logw.scalar(tag="AUC_{} [{}]".format(i, "train")))
self.env.avg_auc_scalar = logw.scalar(
tag="Average auc [{}]".format(self.phase))
if not self.is_predict_phase:
self.env.loss_scalar = logw.scalar(
tag="Loss [{}]".format(self.phase))
if self.is_train_phase:
self.env.auc_scalar_list = []
for i in range(self.num_classes):
self.env.auc_scalar_list.append(
logw.scalar(tag="AUC_{} [{}]".format(i, "train")))
self.env.avg_auc_scalar = logw.scalar(
tag="Average auc [{}]".format(self.phase))
def _calculate_metrics(self, run_states):
loss_sum = acc_sum = run_examples = 0
......@@ -978,33 +996,27 @@ class MultiLabelClassifierTask(ClassifierTask):
def _log_interval_event(self, run_states):
avg_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.loss_scalar.add_record(self.current_step, avg_loss)
avg_auc = np.mean(auc_list)
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
logger.info("step %d: loss=%.5f avg_auc=%.5f [step/sec: %.2f]" %
(self.current_step, avg_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list):
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index][0])
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
def _eval_end_event(self, run_states):
eval_loss, auc_list, run_speed = self._calculate_metrics(run_states)
if self.is_train_phase:
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
avg_auc = np.mean(auc_list)
logger.info(
"[%s dataset evaluation result] loss=%.5f avg_auc=%.5f [step/sec: %.2f]"
% (self.phase, eval_loss, avg_auc, run_speed))
for index, auc in enumerate(auc_list):
logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0]))
self.env.loss_scalar.add_record(self.current_step, eval_loss)
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
if self.phase in ["dev", "val"] and avg_auc > self.best_avg_auc:
self.env.loss_scalar.add_record(self.current_step, eval_loss)
for index, auc_scalar in enumerate(self.env.auc_scalar_list):
auc_scalar.add_record(self.current_step, auc_list[index])
self.env.avg_auc_scalar.add_record(self.current_step, avg_auc)
self.best_avg_auc = avg_auc
model_saved_dir = os.path.join(self.config.checkpoint_dir,
"best_model")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册