From e63fbe49be01d0f8e82e03747d1a9460c3b2a1a7 Mon Sep 17 00:00:00 2001 From: Alchemist_W <2443176192@qq.com> Date: Thu, 10 Feb 2022 19:29:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=AE=8C=E6=88=90:=E5=88=92?= =?UTF-8?q?=E5=88=86det=E4=B8=8Erec=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PPOCRLabel/gen_ocr_train_val_test.py | 32 ++++++++++++++++------------ 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/PPOCRLabel/gen_ocr_train_val_test.py b/PPOCRLabel/gen_ocr_train_val_test.py index 64cba612..03ae566c 100644 --- a/PPOCRLabel/gen_ocr_train_val_test.py +++ b/PPOCRLabel/gen_ocr_train_val_test.py @@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag): return flagAbsPath -def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): +def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): # 按照指定的比例划分训练集、验证集、测试集 - labelPath = os.path.join(root, dir) - labelAbsPath = os.path.abspath(labelPath) + dataAbsPath = os.path.abspath(root) if flag == "det": - labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName) + labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName) elif flag == "rec": - labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName) + labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName) labelFileRead = open(labelFilePath, "r", encoding="UTF-8") labelFileContent = labelFileRead.readlines() @@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, imageName = os.path.basename(imageRelativePath) if flag == "det": - imagePath = os.path.join(labelAbsPath, imageName) + imagePath = os.path.join(dataAbsPath, imageName) elif flag == "rec": - imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) + imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) # 按预设的比例划分训练集、验证集、测试集 trainValTestRatio = args.trainValTestRatio.split(":") @@ -90,15 +89,20 @@ def genDetRecTrainVal(args): recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8") recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8") - for root, dirs, files in os.walk(args.labelRootPath): + splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, + detTestTxt, "det") + + for root, dirs, files in os.walk(args.datasetRootPath): for dir in dirs: - splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, - detTestTxt, "det") - splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, - recTestTxt, "rec") + if dir == 'crop_img': + splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, + recTestTxt, "rec") + else: + continue break + if __name__ == "__main__": # 功能描述:分别划分检测和识别的训练集、验证集、测试集 # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注, @@ -110,9 +114,9 @@ if __name__ == "__main__": default="6:2:2", help="ratio of trainset:valset:testset") parser.add_argument( - "--labelRootPath", + "--datasetRootPath", type=str, - default="../train_data/label", + default="../train_data/", help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..." ) parser.add_argument( -- GitLab