提交 3baa25ed 编写于 作者: X xiaohang

works for mnt data

上级 c70bc0f6
......@@ -36,3 +36,9 @@ Train a new model
Stable commits
--------------
dbe73da0dd7efb8bd76dbd7f0ac3856e742b98d4: support image list with label and alphabet
Train for VGG text data
--------------
1. create a link to mnt folder
2. python data/create_mnt_list.py
3. python main.py --trainlist data/train_list.txt --vallist data/test_list.txt --cuda --adam --lr=0.001
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_train.txt') as fp:
lines = fp.readlines()
train_fp = open('data/train_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> train_fp, output
train_fp.close()
with open('data/mnt/ramdisk/max/90kDICT32px/annotation_test.txt') as fp:
lines = fp.readlines()
test_fp = open('data/test_list.txt', 'w')
for line in lines:
imgpath = line.strip().split(' ')[0]
label = imgpath.split('/')[-1].split('_')[1].lower()
label = ':'.join(label)
imgpath = 'data/mnt/ramdisk/max/90kDICT32px/%s' % imgpath
output = ' '.join([imgpath, label])
print >> test_fp, output
test_fp.close()
......@@ -46,59 +46,6 @@ class listDataset(Dataset):
return (img, label)
class lmdbDataset(Dataset):
def __init__(self, root=None, transform=None, target_transform=None):
self.env = lmdb.open(
root,
max_readers=1,
readonly=True,
lock=False,
readahead=False,
meminit=False)
if not self.env:
print('cannot creat lmdb from %s' % (root))
sys.exit(0)
with self.env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'))
self.nSamples = nSamples
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'image-%09d' % index
imgbuf = txn.get(img_key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('L')
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
if self.transform is not None:
img = self.transform(img)
label_key = 'label-%09d' % index
label = str(txn.get(label_key))
if self.target_transform is not None:
label = self.target_transform(label)
return (img, label)
class resizeNormalize(object):
def __init__(self, size, interpolation=Image.BILINEAR):
......
......@@ -2,14 +2,4 @@
# imagepath label
# /ab/cd/image.jpg a:b:c:d
#python main.py --trainlist train_list2.txt --vallist test_list2.txt --cuda --adam --lr=0.001
#python main_for_music.py --trainlist data/train_list2.txt --vallist data/test_list2.txt --cuda --adam --lr=0.001
# loss ~ 2
nohup python main_for_music.py --trainlist data/train_list3.txt --vallist data/test_list3.txt --cuda --adam --lr=0.001 > log3.txt &
# loss ~ 2
python main.py --trainlist data/train_list.txt --vallist data/test_list.txt --cuda --adam --lr=0.001
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册