gen_ocr_train_val.py 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
# coding:utf8
import os
import shutil
import random
import argparse

# 删除划分的训练集和验证集文件夹,重新创建一个空的文件夹
def isCreateOrDeleteFolder(path, flag):
    flagPath = os.path.join(path, flag)
    if os.path.exists(flagPath):
        shutil.rmtree(flagPath)
    os.makedirs(flagPath)
    flagAbsPath = os.path.abspath(flagPath)
    return flagAbsPath


def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, trainTxt, valTxt, flag):
    # 按照指定的比例划分训练集和验证集
    labelPath = os.path.join(root, dir)
    labelAbsPath = os.path.abspath(labelPath)
    if flag == "det":
        labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName)
    elif flag == "rec":
        labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName)
    labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
    labelFileContent = labelFileRead.readlines()
    random.shuffle(labelFileContent)
    labelRecordLen = len(labelFileContent)
    for index, labelRecordInfo in enumerate(labelFileContent):
        imageRelativePath = labelRecordInfo.split('\t')[0]
        imageLabel = labelRecordInfo.split('\t')[1]
        imageName = os.path.basename(imageRelativePath)
        if flag == "det":
            imagePath = os.path.join(labelAbsPath, imageName)
        elif flag == "rec":
            imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
        # 小于划分比例trainValRatio时,数据集划分到训练集,否则测试集
        if index / labelRecordLen < args.trainValRatio:
            imageCopyPath = os.path.join(absTrainRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
        else:
            imageCopyPath = os.path.join(absValRootPath, imageName)
            shutil.copy(imagePath, imageCopyPath)
            valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))


def genDetRecTrainVal(args):
    detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
    detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
    recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
    recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
    os.remove(os.path.join(args.detRootPath, "train.txt"))
    os.remove(os.path.join(args.detRootPath, "val.txt"))
    os.remove(os.path.join(args.recRootPath, "train.txt"))
    os.remove(os.path.join(args.recRootPath, "val.txt"))
    detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
    detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
    recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
    recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
    for root, dirs, files in os.walk(args.labelRootPath):
        for dir in dirs:
            splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detTrainTxt, detValTxt, "det")
            splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recTrainTxt, recValTxt, "rec")
        break


if __name__ == "__main__":
    # 功能描述:分别划分检测和识别的训练集和验证集
    # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
    # 如此会有多个标注好的图像文件夹汇总并划分训练集和验证集的需求
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--trainValRatio",
        type=float,
        default=0.8,
        help="ratio of training set to validation set")
    parser.add_argument(
        "--labelRootPath",
        type=str,
        default="./train_data/label",
        help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
    )
    parser.add_argument(
        "--detRootPath",
        type=str,
        default="./train_data/det/demPanel",
        help="the path where the divided detection dataset is placed")
    parser.add_argument(
        "--recRootPath",
        type=str,
        default="./train_data/rec/demPanel",
        help="the path where the divided recognition dataset is placed"
    )
    parser.add_argument(
        "--detLabelFileName",
        type=str,
        default="Label.txt",
        help="the name of the detection annotation file")
    parser.add_argument(
        "--recLabelFileName",
        type=str,
        default="rec_gt.txt",
        help="the name of the recognition annotation file"
    )
    parser.add_argument(
        "--recImageDirName",
        type=str,
        default="crop_img",
        help="the name of the folder where the cropped recognition dataset is located"
    )
    args = parser.parse_args()
    genDetRecTrainVal(args)