From 6baed993fbf1183ca8658705160d9270293c412a Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 12 May 2020 11:35:41 +0800 Subject: [PATCH] modify dataset num_samples --- paddlex/cv/datasets/dataset.py | 8 ++++++++ paddlex/cv/models/base.py | 13 +++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/paddlex/cv/datasets/dataset.py b/paddlex/cv/datasets/dataset.py index 09c0427..c3bec89 100644 --- a/paddlex/cv/datasets/dataset.py +++ b/paddlex/cv/datasets/dataset.py @@ -254,3 +254,11 @@ class Dataset: buffer_size=self.buffer_size, batch_size=batch_size, drop_last=drop_last) + + def set_num_samples(self, num_samples): + if num_samples > len(self.file_list): + logging.warning( + "You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}" + .format(num_samples, len(self.file_list), len(self.file_list))) + num_samples = len(self.file_list) + self.num_samples = num_samples diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index 3107d5f..112e3ef 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -417,7 +417,7 @@ class BaseAPI: earlystop = EarlyStop(early_stop_patience, thresh) best_accuracy_key = "" best_accuracy = -1.0 - best_model_epoch = 1 + best_model_epoch = -1 for i in range(num_epochs): records = list() step_start_time = time.time() @@ -490,7 +490,7 @@ class BaseAPI: current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) if not osp.isdir(current_save_dir): os.makedirs(current_save_dir) - if eval_dataset is not None: + if eval_dataset is not None and eval_dataset.num_samples > 0: self.eval_metrics, self.eval_details = self.evaluate( eval_dataset=eval_dataset, batch_size=eval_batch_size, @@ -522,10 +522,11 @@ class BaseAPI: self.save_model(save_dir=current_save_dir) time_eval_one_epoch = time.time() - eval_epoch_start_time eval_epoch_start_time = time.time() - logging.info( - 'Current evaluated best model in eval_dataset is epoch_{}, {}={}' - .format(best_model_epoch, best_accuracy_key, - best_accuracy)) + if best_model_epoch > 0: + logging.info( + 'Current evaluated best model in eval_dataset is epoch_{}, {}={}' + .format(best_model_epoch, best_accuracy_key, + best_accuracy)) if eval_dataset is not None and early_stop: if earlystop(current_accuracy): break -- GitLab