voc2012.py 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image dataset for segmentation.
16 17 18 19
The 2012 dataset contains images from 2008-2011 for which additional
segmentations have been prepared. As in previous years the assignment
to training/test sets has been maintained. The total number of images
with segmentation has been increased from 7,062 to 9,993.
20 21
"""

22 23
from __future__ import print_function

24
import tarfile
25
import io
26
import numpy as np
27 28
from paddle.dataset.common import download
from paddle.dataset.image import *
29
import paddle.utils.deprecated as deprecated
30
from PIL import Image
31 32 33

__all__ = ['train', 'test', 'val']

34
VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\
35
VOCtrainval_11-May-2012.tar'
36

37 38 39 40 41
VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd'
SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt'
DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg'
LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png'

42 43
CACHE_DIR = 'voc2012'

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60

def reader_creator(filename, sub_name):

    tarobject = tarfile.open(filename)
    name2mem = {}
    for ele in tarobject.getmembers():
        name2mem[ele.name] = ele

    def reader():
        set_file = SET_FILE.format(sub_name)
        sets = tarobject.extractfile(name2mem[set_file])
        for line in sets:
            line = line.strip()
            data_file = DATA_FILE.format(line)
            label_file = LABEL_FILE.format(line)
            data = tarobject.extractfile(name2mem[data_file]).read()
            label = tarobject.extractfile(name2mem[label_file]).read()
61 62 63 64
            data = Image.open(io.BytesIO(data))
            label = Image.open(io.BytesIO(label))
            data = np.array(data)
            label = np.array(label)
65 66 67 68 69
            yield data, label

    return reader


70 71 72 73
@deprecated(
    since="2.0.0",
    update_to="paddle.vision.datasets.VOC2012",
    reason="Please use new dataset API which supports paddle.io.DataLoader")
74 75
def train():
    """
76
    Create a train dataset reader containing 2913 images in HWC order.
77
    """
78
    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval')
79 80


81 82 83 84
@deprecated(
    since="2.0.0",
    update_to="paddle.vision.datasets.VOC2012",
    reason="Please use new dataset API which supports paddle.io.DataLoader")
85 86
def test():
    """
87
    Create a test dataset reader containing 1464 images in HWC order.
88
    """
89
    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train')
90 91


92 93 94 95
@deprecated(
    since="2.0.0",
    update_to="paddle.vision.datasets.VOC2012",
    reason="Please use new dataset API which supports paddle.io.DataLoader")
96 97
def val():
    """
98
    Create a val dataset reader containing 1449 images in HWC order.
99
    """
100
    return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val')