提交 dfb48029 编写于 作者: X xiaohang

also support vallist

上级 b6672e43
......@@ -18,7 +18,8 @@ import models.crnn as crnn
parser = argparse.ArgumentParser()
parser.add_argument('--trainroot', default="", help='path to dataset')
parser.add_argument('--trainlist', default="", help='path to train_list')
parser.add_argument('--valroot', required=True, help='path to dataset')
parser.add_argument('--valroot', default="", help='path to dataset')
parser.add_argument('--vallist', default="", help='path to val_list')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
......@@ -65,6 +66,7 @@ elif opt.trainlist != "":
else:
print("no train data, exit")
exit(0)
assert train_dataset
if not opt.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
......@@ -75,8 +77,15 @@ train_loader = torch.utils.data.DataLoader(
shuffle=False, sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
if opt.valroot != "":
test_dataset = dataset.lmdbDataset(
root=opt.valroot, transform=dataset.resizeNormalize((100, 32)))
elif opt.vallist != "":
test_dataset = dataset.listDataset(list_file =opt.vallist, transform=dataset.resizeNormalize((100, 32)))
else:
print("no val data, exit")
exit(0)
nclass = len(opt.alphabet) + 1
nc = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册