voc_augment.py 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: voc_augment.py

This file use SBD(Semantic Boundaries Dataset) <http://home.bharathh.info/pubs/codes/SBD/download.html>
C
chenguowei01 已提交
18
to augment the Pascal VOC.
19 20 21 22 23 24 25 26 27 28 29
"""

import os
import argparse
from multiprocessing import Pool, cpu_count

import cv2
import numpy as np
from scipy.io import loadmat
import tqdm

C
chenguowei01 已提交
30
from dygraph.utils.download import download_file_and_uncompress
31 32

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
C
chenguowei01 已提交
33
URL = 'http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz'
34 35 36 37 38 39 40 41 42 43 44 45


def parse_args():
    parser = argparse.ArgumentParser(
        description=
        'Convert SBD to Pascal Voc annotations to augment the train dataset of Pascal Voc'
    )
    parser.add_argument(
        '--voc_path',
        dest='voc_path',
        help='pascal voc path',
        type=str,
C
chenguowei01 已提交
46
        default=os.path.join(DATA_HOME, 'VOCdevkit'))
47 48 49 50 51

    parser.add_argument(
        '--num_workers',
        dest='num_workers',
        help='How many processes are used for data conversion',
C
chenguowei01 已提交
52
        type=int,
53 54 55 56
        default=cpu_count())
    return parser.parse_args()


C
chenguowei01 已提交
57
def mat_to_png(mat_file, sbd_cls_dir, save_dir):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    mat_path = os.path.join(sbd_cls_dir, mat_file)
    mat = loadmat(mat_path)
    mask = mat['GTcls'][0]['Segmentation'][0].astype(np.uint8)
    save_file = os.path.join(save_dir, mat_file.replace('mat', 'png'))
    cv2.imwrite(save_file, mask)


def main():
    args = parse_args()
    sbd_path = download_file_and_uncompress(
        url=URL,
        savepath=DATA_HOME,
        extrapath=DATA_HOME,
        extraname='benchmark_RELEASE')
    with open(os.path.join(sbd_path, 'dataset/train.txt'), 'r') as f:
        sbd_file_list = [line.strip() for line in f]
    with open(os.path.join(sbd_path, 'dataset/val.txt'), 'r') as f:
        sbd_file_list += [line.strip() for line in f]
    if not os.path.exists(args.voc_path):
        raise Exception(
C
chenguowei01 已提交
78
            'There is no voc_path: {}. Please ensure that the Pascal VOC dataset has been downloaded correctly'
79 80 81
        )
    with open(
            os.path.join(args.voc_path,
C
chenguowei01 已提交
82 83
                         'VOC2012/ImageSets/Segmentation/trainval.txt'),
            'r') as f:
84 85 86 87 88
        voc_file_list = [line.strip() for line in f]

    aug_file_list = list(set(sbd_file_list) - set(voc_file_list))
    with open(
            os.path.join(args.voc_path,
C
chenguowei01 已提交
89 90
                         'VOC2012/ImageSets/Segmentation/aug.txt'), 'w') as f:
        f.writelines(''.join([line, '\n']) for line in aug_file_list)
91 92

    sbd_cls_dir = os.path.join(sbd_path, 'dataset/cls')
C
chenguowei01 已提交
93 94 95
    save_dir = os.path.join(args.voc_path, 'VOC2012/SegmentationClassAug')
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
96 97 98
    mat_file_list = os.listdir(sbd_cls_dir)
    p = Pool(args.num_workers)
    for f in tqdm.tqdm(mat_file_list):
C
chenguowei01 已提交
99 100 101
        p.apply_async(mat_to_png, args=(f, sbd_cls_dir, save_dir))
    p.close()
    p.join()
102 103 104 105


if __name__ == '__main__':
    main()