diff --git a/PPOCRLabel/gen_ocr_train_val_test.py b/PPOCRLabel/gen_ocr_train_val_test.py index 64cba612ae267835dd47aedc2b0356c9df462038..03ae566c6ec64d7ade229fb9571b0cd89ec189d4 100644 --- a/PPOCRLabel/gen_ocr_train_val_test.py +++ b/PPOCRLabel/gen_ocr_train_val_test.py @@ -17,15 +17,14 @@ def isCreateOrDeleteFolder(path, flag): return flagAbsPath -def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): +def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): # 按照指定的比例划分训练集、验证集、测试集 - labelPath = os.path.join(root, dir) - labelAbsPath = os.path.abspath(labelPath) + dataAbsPath = os.path.abspath(root) if flag == "det": - labelFilePath = os.path.join(labelAbsPath, args.detLabelFileName) + labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName) elif flag == "rec": - labelFilePath = os.path.join(labelAbsPath, args.recLabelFileName) + labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName) labelFileRead = open(labelFilePath, "r", encoding="UTF-8") labelFileContent = labelFileRead.readlines() @@ -38,9 +37,9 @@ def splitTrainVal(root, dir, absTrainRootPath, absValRootPath, absTestRootPath, imageName = os.path.basename(imageRelativePath) if flag == "det": - imagePath = os.path.join(labelAbsPath, imageName) + imagePath = os.path.join(dataAbsPath, imageName) elif flag == "rec": - imagePath = os.path.join(labelAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) + imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName)) # 按预设的比例划分训练集、验证集、测试集 trainValTestRatio = args.trainValTestRatio.split(":") @@ -90,15 +89,20 @@ def genDetRecTrainVal(args): recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8") recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8") - for root, dirs, files in os.walk(args.labelRootPath): + splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, + detTestTxt, "det") + + for root, dirs, files in os.walk(args.datasetRootPath): for dir in dirs: - splitTrainVal(root, dir, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, - detTestTxt, "det") - splitTrainVal(root, dir, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, - recTestTxt, "rec") + if dir == 'crop_img': + splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, + recTestTxt, "rec") + else: + continue break + if __name__ == "__main__": # 功能描述:分别划分检测和识别的训练集、验证集、测试集 # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注, @@ -110,9 +114,9 @@ if __name__ == "__main__": default="6:2:2", help="ratio of trainset:valset:testset") parser.add_argument( - "--labelRootPath", + "--datasetRootPath", type=str, - default="../train_data/label", + default="../train_data/", help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..." ) parser.add_argument(