提交 6baed993 编写于 作者: J jiangjiajun

modify dataset num_samples

上级 aebd5798
...@@ -254,3 +254,11 @@ class Dataset: ...@@ -254,3 +254,11 @@ class Dataset:
buffer_size=self.buffer_size, buffer_size=self.buffer_size,
batch_size=batch_size, batch_size=batch_size,
drop_last=drop_last) drop_last=drop_last)
def set_num_samples(self, num_samples):
if num_samples > len(self.file_list):
logging.warning(
"You want set num_samples to {}, but your dataset only has {} samples, so we will keep your dataset num_samples as {}"
.format(num_samples, len(self.file_list), len(self.file_list)))
num_samples = len(self.file_list)
self.num_samples = num_samples
...@@ -417,7 +417,7 @@ class BaseAPI: ...@@ -417,7 +417,7 @@ class BaseAPI:
earlystop = EarlyStop(early_stop_patience, thresh) earlystop = EarlyStop(early_stop_patience, thresh)
best_accuracy_key = "" best_accuracy_key = ""
best_accuracy = -1.0 best_accuracy = -1.0
best_model_epoch = 1 best_model_epoch = -1
for i in range(num_epochs): for i in range(num_epochs):
records = list() records = list()
step_start_time = time.time() step_start_time = time.time()
...@@ -490,7 +490,7 @@ class BaseAPI: ...@@ -490,7 +490,7 @@ class BaseAPI:
current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1)) current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
if not osp.isdir(current_save_dir): if not osp.isdir(current_save_dir):
os.makedirs(current_save_dir) os.makedirs(current_save_dir)
if eval_dataset is not None: if eval_dataset is not None and eval_dataset.num_samples > 0:
self.eval_metrics, self.eval_details = self.evaluate( self.eval_metrics, self.eval_details = self.evaluate(
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
batch_size=eval_batch_size, batch_size=eval_batch_size,
...@@ -522,10 +522,11 @@ class BaseAPI: ...@@ -522,10 +522,11 @@ class BaseAPI:
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time() eval_epoch_start_time = time.time()
logging.info( if best_model_epoch > 0:
'Current evaluated best model in eval_dataset is epoch_{}, {}={}' logging.info(
.format(best_model_epoch, best_accuracy_key, 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
best_accuracy)) .format(best_model_epoch, best_accuracy_key,
best_accuracy))
if eval_dataset is not None and early_stop: if eval_dataset is not None and early_stop:
if earlystop(current_accuracy): if earlystop(current_accuracy):
break break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册