未验证 提交 e165897c 编写于 作者: L littletomatodonkey 提交者: GitHub

fix drop last for training process (#713)

上级 9c0f0496
...@@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset): ...@@ -224,9 +224,11 @@ class MultiLabelDataset(Dataset):
labels = label_str.split(',') labels = label_str.split(',')
labels = [int(i) for i in labels] labels = [int(i) for i in labels]
return (transform(img, self.ops), np.array(labels).astype("float32")) return (transform(img, self.ops),
np.array(labels).astype("float32"))
except Exception as e: except Exception as e:
logger.error("data read failed: {}, exception info: {}".format(line, e)) logger.error("data read failed: {}, exception info: {}".format(
line, e))
return self.__getitem__(random.randint(0, len(self))) return self.__getitem__(random.randint(0, len(self)))
def __len__(self): def __len__(self):
...@@ -291,7 +293,7 @@ class Reader: ...@@ -291,7 +293,7 @@ class Reader:
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=self.shuffle and is_train, shuffle=self.shuffle and is_train,
drop_last=is_train) drop_last=False)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -72,6 +72,10 @@ def main(args, return_dict={}): ...@@ -72,6 +72,10 @@ def main(args, return_dict={}):
init_model(config, net, optimizer=None) init_model(config, net, optimizer=None)
valid_dataloader = Reader(config, 'valid', places=place)() valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!")
sys.exit(-1)
net.eval() net.eval()
with paddle.no_grad(): with paddle.no_grad():
if not multilabel: if not multilabel:
......
...@@ -88,9 +88,18 @@ def main(args): ...@@ -88,9 +88,18 @@ def main(args):
init_model(config, net, optimizer) init_model(config, net, optimizer)
train_dataloader = Reader(config, 'train', places=place)() train_dataloader = Reader(config, 'train', places=place)()
if len(train_dataloader) <= 0:
logger.error(
"train dataloader is empty, please check your data config again!")
sys.exit(-1)
if config.validate: if config.validate:
valid_dataloader = Reader(config, 'valid', places=place)() valid_dataloader = Reader(config, 'valid', places=place)()
if len(valid_dataloader) <= 0:
logger.error(
"valid dataloader is empty, please check your data config again!"
)
sys.exit(-1)
last_epoch_id = config.get("last_epoch", -1) last_epoch_id = config.get("last_epoch", -1)
best_top1_acc = 0.0 # best top1 acc record best_top1_acc = 0.0 # best top1 acc record
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册