提交 6baed993 编写于 作者: J jiangjiajun

modify dataset num_samples

上级 aebd5798
......@@ -254,3 +254,11 @@ class Dataset:
buffer_size=self.buffer_size,
batch_size=batch_size,
drop_last=drop_last)
def set_num_samples(self, num_samples):
if num_samples > len(self.file_list):
logging.warning(
"You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
.format(num_samples, len(self.file_list), len(self.file_list)))
num_samples = len(self.file_list)
self.num_samples = num_samples
......@@ -417,7 +417,7 @@ class BaseAPI:
earlystop = EarlyStop(early_stop_patience, thresh)
best_accuracy_key = ""
best_accuracy = -1.0
best_model_epoch = 1
best_model_epoch = -1
for i in range(num_epochs):
records = list()
step_start_time = time.time()
......@@ -490,7 +490,7 @@ class BaseAPI:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir)
if eval_dataset is not None:
if eval_dataset is not None and eval_dataset.num_samples > 0:
self.eval_metrics, self.eval_details = self.evaluate(
eval_dataset=eval_dataset,
batch_size=eval_batch_size,
......@@ -522,10 +522,11 @@ class BaseAPI:
self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
.format(best_model_epoch, best_accuracy_key,
best_accuracy))
if best_model_epoch > 0:
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
.format(best_model_epoch, best_accuracy_key,
best_accuracy))
if eval_dataset is not None and early_stop:
if earlystop(current_accuracy):
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册