diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 80ff6295c34e853d8f69b9e78719af23a56d1fbb..cdd85cce37188fe51f8f2de0b9d5a22365ecfe95 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -24,8 +24,11 @@ import conll05 import uci_housing import sentiment import wmt14 +import mq2007 +import flowers +import voc_seg __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14' + 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc_seg' ] diff --git a/python/paddle/v2/dataset/tests/vocseg_test.py b/python/paddle/v2/dataset/tests/vocseg_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a773fa18b2a7ced385ad881c9cec186904ef350 --- /dev/null +++ b/python/paddle/v2/dataset/tests/vocseg_test.py @@ -0,0 +1,42 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.voc_seg +import unittest + + +class TestVOC(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + for l in reader(): + self.assertEqual(l[0].size, l[1].size) + sum += 1 + return sum + + def test_train(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.train()) + self.assertEqual(count, 2913) + + def test_test(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.test()) + self.assertEqual(count, 1464) + + def test_val(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.val()) + self.assertEqual(count, 1449) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/dataset/voc_seg.py b/python/paddle/v2/dataset/voc_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..9b79f726d235b722f9fabf88386553116d0945c6 --- /dev/null +++ b/python/paddle/v2/dataset/voc_seg.py @@ -0,0 +1,74 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image dataset for segmentation. +The 2012 dataset contains images from 2008-2011 for which additional segmentations have been prepared. As in previous years the assignment to training/test sets has been maintained. The total number of images with segmentation has been increased from 7,062 to 9,993. +""" + +import tarfile +import numpy as np +from common import download +from paddle.v2.image import * + +__all__ = ['train', 'test', 'val'] + +VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' +VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' +SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt' +DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg' +LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png' + + +def reader_creator(filename, sub_name): + + tarobject = tarfile.open(filename) + name2mem = {} + for ele in tarobject.getmembers(): + name2mem[ele.name] = ele + + def reader(): + set_file = SET_FILE.format(sub_name) + sets = tarobject.extractfile(name2mem[set_file]) + for line in sets: + line = line.strip() + data_file = DATA_FILE.format(line) + label_file = LABEL_FILE.format(line) + data = tarobject.extractfile(name2mem[data_file]).read() + label = tarobject.extractfile(name2mem[label_file]).read() + data = load_image_bytes(data) + label = load_image_bytes(label) + yield data, label + + return reader + + +def train(): + """ + Create a train dataset reader containing 2913 images. + """ + return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'trainval') + + +def test(): + """ + Create a test dataset reader containing 1464 images. + """ + return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'train') + + +def val(): + """ + Create a val dataset reader containing 1449 images. + """ + return reader_creator(download(VOC_URL, 'voc_seg', VOC_MD5), 'val')