prepara_data.py 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
import os
import os.path as osp
import numpy as np
import cv2
import shutil
from PIL import Image
import paddlex as pdx

# 定义训练集切分时的滑动窗口大小和步长,格式为(W, H)
train_tile_size = (1024, 1024)
train_stride = (512, 512)
# 定义验证集切分时的滑动窗口大小和步长,格式(W, H)
val_tile_size = (769, 769)
val_stride = (769, 769)

# 下载并解压2015 CCF大数据比赛提供的高清遥感影像
ccf_remote_dataset = 'https://bj.bcebos.com/paddlex/examples/remote_sensing/datasets/ccf_remote_dataset.tar.gz'
pdx.utils.download_and_decompress(ccf_remote_dataset, path='./')

if not osp.exists('./dataset/JPEGImages'):
    os.makedirs('./dataset/JPEGImages')
if not osp.exists('./dataset/Annotations'):
    os.makedirs('./dataset/Annotations')

# 将前4张图片划分入训练集,并切分成小块之后加入到训练集中
# 并生成train_list.txt
for train_id in range(1, 5):
    shutil.copyfile("ccf_remote_dataset/{}.png".format(train_id),
                    "./dataset/JPEGImages/{}.png".format(train_id))
    shutil.copyfile("ccf_remote_dataset/{}_class.png".format(train_id),
                    "./dataset/Annotations/{}_class.png".format(train_id))
    mode = 'w' if train_id == 1 else 'a'
    with open('./dataset/train_list.txt', mode) as f:
        f.write("JPEGImages/{}.png Annotations/{}_class.png\n".format(
            train_id, train_id))

for train_id in range(1, 5):
    image = cv2.imread('ccf_remote_dataset/{}.png'.format(train_id))
    label = Image.open('ccf_remote_dataset/{}_class.png'.format(train_id))
    H, W, C = image.shape
    train_tile_id = 1
    for h in range(0, H, train_stride[1]):
        for w in range(0, W, train_stride[0]):
            left = w
            upper = h
            right = min(w + train_tile_size[0] * 2, W)
            lower = min(h + train_tile_size[1] * 2, H)
            tile_image = image[upper:lower, left:right, :]
            cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
                train_id, train_tile_id), tile_image)
            cut_label = label.crop((left, upper, right, lower))
            cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
                train_id, train_tile_id))
            with open('./dataset/train_list.txt', 'a') as f:
                f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
                        format(train_id, train_tile_id, train_id,
                               train_tile_id))
            train_tile_id += 1

# 将第5张图片切分成小块之后加入到验证集中
val_id = 5
val_tile_id = 1
shutil.copyfile("ccf_remote_dataset/{}.png".format(val_id),
                "./dataset/JPEGImages/{}.png".format(val_id))
shutil.copyfile("ccf_remote_dataset/{}_class.png".format(val_id),
                "./dataset/Annotations/{}_class.png".format(val_id))
image = cv2.imread('ccf_remote_dataset/{}.png'.format(val_id))
label = Image.open('ccf_remote_dataset/{}_class.png'.format(val_id))
H, W, C = image.shape
for h in range(0, H, val_stride[1]):
    for w in range(0, W, val_stride[0]):
        left = w
        upper = h
        right = min(w + val_tile_size[0], W)
        lower = min(h + val_tile_size[1], H)
        cut_image = image[upper:lower, left:right, :]
        cv2.imwrite("./dataset/JPEGImages/{}_{}.png".format(
            val_id, val_tile_id), cut_image)
        cut_label = label.crop((left, upper, right, lower))
        cut_label.save("./dataset/Annotations/{}_class_{}.png".format(
            val_id, val_tile_id))
        mode = 'w' if val_tile_id == 1 else 'a'
        with open('./dataset/val_list.txt', mode) as f:
            f.write("JPEGImages/{}_{}.png Annotations/{}_class_{}.png\n".
                    format(val_id, val_tile_id, val_id, val_tile_id))
        val_tile_id += 1

# 生成labels.txt
label_list = ['background', 'vegetation', 'building', 'water', 'road']
for i, label in enumerate(label_list):
    mode = 'w' if i == 0 else 'a'
    with open('./dataset/labels.txt', 'a') as f:
        name = "{}\n".format(label) if i < len(
            label_list) - 1 else "{}".format(label)
        f.write(name)