diff --git a/paddlehub/finetune/task.py b/paddlehub/finetune/task.py index 06d284ab5800330afb0da75b67509bae396f399e..083b2a88cf640ec1780eccc6cbb546337594aecf 100644 --- a/paddlehub/finetune/task.py +++ b/paddlehub/finetune/task.py @@ -389,16 +389,21 @@ class BasicTask(object): def _build_env_end_event(self): pass + def _calculate_metrics(self, run_states): + raise NotImplementedError + def _eval_start_event(self): logger.info("Evaluation on {} dataset start".format(self.phase)) - def _eval_end_event(self, run_state): + def _eval_end_event(self, run_states): + run_speed = self._calculate_metrics(run_states) logger.info("[%s dataset evaluation result] [step/sec: %.2f]" % - (self.phase, run_state.run_speed)) + (self.phase, run_speed)) - def _log_interval_event(self, run_state): - logger.info("step %d: [step/sec: %.2f]" % (self.current_step, - run_state.run_speed)) + def _log_interval_event(self, run_states): + run_speed = self._calculate_metrics(run_states) + logger.info( + "step %d: [step/sec: %.2f]" % (self.current_step, run_speed)) def _save_ckpt_interval_event(self): self.save_checkpoint(self.current_epoch, self.current_step) @@ -446,6 +451,9 @@ class BasicTask(object): startup_program=self._base_startup_program, load_best_model=load_best_model) + if self.is_predict_phase or self.is_test_phase: + self.env.current_step = 0 + def finetune_and_eval(self): self.finetune(do_eval=True) @@ -516,8 +524,8 @@ class BasicTask(object): step_run_state.run_examples += num_batch_examples step_run_state.update() period_run_states += [step_run_state] + self.env.current_step += 1 if self.is_train_phase: - self.env.current_step += 1 if self.current_step % self.config.log_interval == 0: self._log_interval_event(period_run_states) global_run_states += period_run_states @@ -552,8 +560,8 @@ class BasicTask(object): step_run_state.run_examples += num_batch_examples step_run_state.update() period_run_states += [step_run_state] + self.env.current_step += 1 if self.is_train_phase: - self.env.current_step += 1 if self.current_step % self.config.log_interval == 0: self._log_interval_event(period_run_states) global_run_states += period_run_states @@ -629,10 +637,11 @@ class ClassifierTask(BasicTask): def _build_env_end_event(self): with self.log_writer.mode(self.phase) as logw: - self.env.loss_scalar = logw.scalar( - tag="Loss [{}]".format(self.phase)) - self.env.acc_scalar = logw.scalar( - tag="Accuracy [{}]".format(self.phase)) + if not self.is_predict_phase: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + self.env.acc_scalar = logw.scalar( + tag="Accuracy [{}]".format(self.phase)) def _calculate_metrics(self, run_states): loss_sum = acc_sum = run_examples = 0 @@ -664,9 +673,9 @@ class ClassifierTask(BasicTask): logger.info( "[%s dataset evaluation result] loss=%.5f acc=%.5f [step/sec: %.2f]" % (self.phase, eval_loss, eval_acc, run_speed)) + self.env.loss_scalar.add_record(self.current_step, eval_loss) + self.env.acc_scalar.add_record(self.current_step, eval_acc) if self.phase in ["dev", "val"] and eval_acc > self.best_accuracy: - self.env.loss_scalar.add_record(self.current_step, eval_loss) - self.env.acc_scalar.add_record(self.current_step, eval_acc) self.best_accuracy = eval_acc model_saved_dir = os.path.join(self.config.checkpoint_dir, "best_model") @@ -796,13 +805,19 @@ class SequenceLabelTask(BasicTask): def _build_env_end_event(self): with self.log_writer.mode(self.phase) as logw: - self.env.loss_scalar = logw.scalar( - tag="Loss [{}]".format(self.phase)) - self.env.f1_scalar = logw.scalar(tag="F1 [{}]".format(self.phase)) - self.env.precision_scalar = logw.scalar( - tag="Precision [{}]".format(self.phase)) - self.env.recall_scalar = logw.scalar( - tag="Recall [{}]".format(self.phase)) + if self.is_train_phase: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + + if self.phase in ["dev", "val"]: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + self.env.f1_scalar = logw.scalar( + tag="F1 [{}]".format(self.phase)) + self.env.precision_scalar = logw.scalar( + tag="Precision [{}]".format(self.phase)) + self.env.recall_scalar = logw.scalar( + tag="Recall [{}]".format(self.phase)) def _calculate_metrics(self, run_states): total_infer = total_label = total_correct = loss_sum = 0 @@ -838,6 +853,7 @@ class SequenceLabelTask(BasicTask): def _eval_end_event(self, run_states): precision, recall, f1, avg_loss, run_speed = self._calculate_metrics( run_states) + self.env.loss_scalar.add_record(self.current_step, avg_loss) self.env.f1_scalar.add_record(self.current_step, f1) self.env.precision_scalar.add_record(self.current_step, precision) self.env.recall_scalar.add_record(self.current_step, recall) @@ -951,14 +967,16 @@ class MultiLabelClassifierTask(ClassifierTask): def _build_env_end_event(self): with self.log_writer.mode(self.phase) as logw: - self.env.loss_scalar = logw.scalar( - tag="Loss [{}]".format(self.phase)) - self.env.auc_scalar_list = [] - for i in range(self.num_classes): - self.env.auc_scalar_list.append( - logw.scalar(tag="AUC_{} [{}]".format(i, "train"))) - self.env.avg_auc_scalar = logw.scalar( - tag="Average auc [{}]".format(self.phase)) + if not self.is_predict_phase: + self.env.loss_scalar = logw.scalar( + tag="Loss [{}]".format(self.phase)) + if self.is_train_phase: + self.env.auc_scalar_list = [] + for i in range(self.num_classes): + self.env.auc_scalar_list.append( + logw.scalar(tag="AUC_{} [{}]".format(i, "train"))) + self.env.avg_auc_scalar = logw.scalar( + tag="Average auc [{}]".format(self.phase)) def _calculate_metrics(self, run_states): loss_sum = acc_sum = run_examples = 0 @@ -978,33 +996,27 @@ class MultiLabelClassifierTask(ClassifierTask): def _log_interval_event(self, run_states): avg_loss, auc_list, run_speed = self._calculate_metrics(run_states) - if self.is_train_phase: - for index, auc_scalar in enumerate(self.env.auc_scalar_list): - auc_scalar.add_record(self.current_step, auc_list[index]) + self.env.loss_scalar.add_record(self.current_step, avg_loss) avg_auc = np.mean(auc_list) self.env.avg_auc_scalar.add_record(self.current_step, avg_auc) logger.info("step %d: loss=%.5f avg_auc=%.5f [step/sec: %.2f]" % (self.current_step, avg_loss, avg_auc, run_speed)) - for index, auc in enumerate(auc_list): + for index, auc_scalar in enumerate(self.env.auc_scalar_list): + auc_scalar.add_record(self.current_step, auc_list[index][0]) logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0])) def _eval_end_event(self, run_states): eval_loss, auc_list, run_speed = self._calculate_metrics(run_states) - if self.is_train_phase: - for index, auc_scalar in enumerate(self.env.auc_scalar_list): - auc_scalar.add_record(self.current_step, auc_list[index]) avg_auc = np.mean(auc_list) logger.info( "[%s dataset evaluation result] loss=%.5f avg_auc=%.5f [step/sec: %.2f]" % (self.phase, eval_loss, avg_auc, run_speed)) for index, auc in enumerate(auc_list): logger.info("label_%d_auc = %.5f" % (index, auc_list[index][0])) + self.env.loss_scalar.add_record(self.current_step, eval_loss) + self.env.avg_auc_scalar.add_record(self.current_step, avg_auc) if self.phase in ["dev", "val"] and avg_auc > self.best_avg_auc: - self.env.loss_scalar.add_record(self.current_step, eval_loss) - for index, auc_scalar in enumerate(self.env.auc_scalar_list): - auc_scalar.add_record(self.current_step, auc_list[index]) - self.env.avg_auc_scalar.add_record(self.current_step, avg_auc) self.best_avg_auc = avg_auc model_saved_dir = os.path.join(self.config.checkpoint_dir, "best_model")