prepare_data.py 3.6 KB
Newer Older
W
wangxinxin08 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import argparse
from convert import load_dota_infos, data_to_coco
from slicebase import SliceBase

wordname_15 = [
    'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
    'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
    'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
    'harbor', 'swimming-pool', 'helicopter'
]

wordname_16 = wordname_15 + ['container-crane']

wordname_18 = wordname_16 + ['airport', 'helipad']

DATA_CLASSES = {
    'dota10': wordname_15,
    'dota15': wordname_16,
    'dota20': wordname_18
}


def parse_args():
    parser = argparse.ArgumentParser('prepare data for training')

    parser.add_argument(
        '--input_dirs',
        nargs='+',
        type=str,
        default=None,
        help='input dirs which contain image and labelTxt dir')

    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='output dirs which contain image and labelTxt dir and coco style json file'
    )

    parser.add_argument(
        '--coco_json_file',
        type=str,
        default='',
        help='coco json annotation files')

    parser.add_argument('--subsize', type=int, default=1024, help='patch size')

    parser.add_argument('--gap', type=int, default=200, help='step size')

    parser.add_argument(
        '--data_type', type=str, default='dota10', help='data type')

    parser.add_argument(
        '--rates',
        nargs='+',
        type=float,
        default=[1.],
        help='scales for multi-sclace training')

    parser.add_argument(
        '--nproc', type=int, default=8, help='the processor number')

    parser.add_argument(
        '--iof_thr',
        type=float,
        default=0.5,
        help='the minimal iof between a object and a window')

    parser.add_argument(
        '--image_only',
        action='store_true',
        default=False,
        help='only processing image')

    args = parser.parse_args()
    return args


def load_dataset(input_dir, nproc, data_type):
    if 'dota' in data_type.lower():
        infos = load_dota_infos(input_dir, nproc)
    else:
        raise ValueError('only dota dataset is supported now')

    return infos


def main():
    args = parse_args()
    infos = []
    for input_dir in args.input_dirs:
        infos += load_dataset(input_dir, args.nproc, args.data_type)

    slicer = SliceBase(
        args.gap,
        args.subsize,
        args.iof_thr,
        num_process=args.nproc,
        image_only=args.image_only)
    slicer.slice_data(infos, args.rates, args.output_dir)
    if args.coco_json_file:
        infos = load_dota_infos(args.output_dir, args.nproc)
        coco_json_file = os.path.join(args.output_dir, args.coco_json_file)
        class_names = DATA_CLASSES[args.data_type]
        data_to_coco(infos, coco_json_file, class_names, args.nproc)


if __name__ == '__main__':
    main()