提交 338dd135 编写于 作者: W wanghaoshuang

Add voc2012 dataset for image segment

上级 91e9a25e
......@@ -24,8 +24,11 @@ import conll05
import uci_housing
import sentiment
import wmt14
import mq2007
import flowers
import voc_seg
__all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14'
'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc_seg'
]
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.voc_seg
import unittest
class TestVOC(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
for l in reader():
self.assertEqual(l[0].size, l[1].size)
sum += 1
return sum
def test_train(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.train())
self.assertEqual(count, 2913)
def test_test(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.test())
self.assertEqual(count, 1464)
def test_val(self):
count = self.check_reader(paddle.v2.dataset.voc_seg.val())
self.assertEqual(count, 1449)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image dataset for segmentation.
The 2012 dataset contains images from 2008-2011 for which additional segmentations have been prepared. As in previous years the assignment to training/test sets has been maintained. The total number of images with segmentation has been increased from 7,062 to 9,993.
"""
import tarfile
import numpy as np
from common import download
from paddle.v2.image import *
__all__ = ['train', 'test', 'val']
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar'
VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'
def reader_creator(filename, sub_name):
tarobject = tarfile.open(filename)
name2mem = {}
for ele in tarobject.getmembers():
name2mem[ele.name] = ele
def reader():
set_file = SET_FILE.format(sub_name)
sets = tarobject.extractfile(name2mem[set_file])
for line in sets:
line = line.strip()
data_file = DATA_FILE.format(line)
label_file = LABEL_FILE.format(line)
data = tarobject.extractfile(name2mem[data_file]).read()
label = tarobject.extractfile(name2mem[label_file]).read()
data = load_image_bytes(data)
label = load_image_bytes(label)
yield data, label
return reader
def train():
"""
Create a train dataset reader containing 2913 images.
"""
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'trainval')
def test():
"""
Create a test dataset reader containing 1464 images.
"""
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'train')
def val():
"""
Create a val dataset reader containing 1449 images.
"""
return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'val')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册