提交 b6672e43 编写于 作者: X xiaohang

add listdataset

上级 e79773da
......@@ -11,11 +11,13 @@ from warpctc_pytorch import CTCLoss
import os
import utils
import dataset
import time
import models.crnn as crnn
parser = argparse.ArgumentParser()
parser.add_argument('--trainroot', required=True, help='path to dataset')
parser.add_argument('--trainroot', default="", help='path to dataset')
parser.add_argument('--trainlist', default="", help='path to train_list')
parser.add_argument('--valroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
......@@ -56,7 +58,13 @@ cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
train_dataset = dataset.lmdbDataset(root=opt.trainroot)
if opt.trainroot != "":
train_dataset = dataset.lmdbDataset(root=opt.trainroot)
elif opt.trainlist != "":
train_dataset = dataset.listDataset(list_file =opt.trainlist)
else:
print("no train data, exit")
exit(0)
assert train_dataset
if not opt.random_sample:
sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize)
......@@ -64,7 +72,7 @@ else:
sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=True, sampler=sampler,
shuffle=False, sampler=sampler,
num_workers=int(opt.workers),
collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio))
test_dataset = dataset.lmdbDataset(
......@@ -153,7 +161,7 @@ def val(net, dataset, criterion, max_iter=100):
loss_avg.add(cost)
_, preds = preds.max(2)
preds = preds.squeeze(2)
#preds = preds.squeeze(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
for pred, target in zip(sim_preds, cpu_texts):
......@@ -186,6 +194,7 @@ def trainBatch(net, criterion, optimizer):
return cost
t0 = time.time()
for epoch in range(opt.niter):
train_iter = iter(train_loader)
i = 0
......@@ -202,6 +211,9 @@ for epoch in range(opt.niter):
print('[%d/%d][%d/%d] Loss: %f' %
(epoch, opt.niter, i, len(train_loader), loss_avg.val()))
loss_avg.reset()
t1 = time.time()
print('time elapsed %d' % (t1-t0))
t0 = time.time()
if i % opt.valInterval == 0:
val(crnn, test_dataset, criterion)
......
......@@ -13,6 +13,39 @@ from PIL import Image
import numpy as np
class listDataset(Dataset):
def __init__(self, list_file=None, transform=None, target_transform=None):
with open(list_file) as fp:
self.lines = fp.readlines()
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
imgpath = self.lines[index].strip()
try:
img = Image.open(imgpath).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label = imgpath.split('/')[-1].split('_')[1].lower()
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class lmdbDataset(Dataset):
def __init__(self, root=None, transform=None, target_transform=None):
......
nohup python crnn_main.py --trainroot ../../PyTorch/crnn/tool/data/train_lmdb/ --valroot ../../PyTorch/crnn/tool/data/test_lmdb/ --cuda --adam --lr=0.001 > log_adam.txt &
#python main.py --trainroot ../PyTorch/crnn/tool/data/train_lmdb/ --valroot ../PyTorch/crnn/tool/data/test_lmdb/ --cuda --adam --lr=0.001
python main.py --trainlist train_list.txt --valroot ../PyTorch/crnn/tool/data/test_lmdb/ --cuda --adam --lr=0.001 # train_list could be annotation_train.txt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册