diff --git a/PaddleCV/ocr_recognition/README.md b/PaddleCV/ocr_recognition/README.md index 63d34f42c2cc410631d435809ccef562a8b8c344..dcacefa0956d83201f628a2bfbaa4c0bf45a3e1c 100644 --- a/PaddleCV/ocr_recognition/README.md +++ b/PaddleCV/ocr_recognition/README.md @@ -75,9 +75,9 @@ 在训练时,我们通过选项`--train_images` 和 `--train_list` 分别设置准备好的`train_images` 和`train_list`。 +在`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整训练数据的高度。用户可以根据自己准备的训练数据,设置合适的`DATA_SHAPE`。如果使用默认的示例数据,则使用默认的`DATA_SHAPE`即可。 ->**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, reader.py会自动下载使用[示例数据](http://paddle-ocr-data.bj.bcebos.com/data.tar.gz),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。 - +>**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, data_reader.py会自动下载使用[示例数据](http://paddle-ocr-data.bj.bcebos.com/data.tar.gz),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。 **B. 测试集和评估集** @@ -85,6 +85,8 @@ 在训练阶段,测试集的路径通过train.py的选项`--test_images` 和 `--test_list` 来设置。 在评估时,评估集的路径通过eval.py的选项`--input_images_dir` 和`--input_images_list` 来设置。 +在`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整测试图像的高度,所以测试图像可以有不同高度。但是,`DATA_SHAPE`需要和训练模型时保持严格一致。 + **C. 待预测数据集** 预测支持三种形式的输入: @@ -108,6 +110,8 @@ data/test_images/00003.jpg 第三种:从stdin读入一张图片的path,然后进行一次inference. +在`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整预测图像的高度,所以预测图像可以有不同高度。但是,`DATA_SHAPE`需要和训练模型时保持严格一致。 + ## 模型训练与预测 ### 训练 diff --git a/PaddleCV/ocr_recognition/data_reader.py b/PaddleCV/ocr_recognition/data_reader.py index 113412ffbddc9e53b1665a106fb4c8dbc44d86cd..f1b529391d9fb2ba8f2c43ce2257b29ac971374b 100644 --- a/PaddleCV/ocr_recognition/data_reader.py +++ b/PaddleCV/ocr_recognition/data_reader.py @@ -136,10 +136,10 @@ class DataGenerator(object): label = [int(c) for c in items[-1].split(',')] img = Image.open(os.path.join(img_root_dir, items[ - 2])).convert('L') #zhuanhuidu + 2])).convert('L') if j == 0: sz = img.size - img = img.resize((sz[0], sz[1])) + img = img.resize((sz[0], DATA_SHAPE[1])) img = np.array(img) - 127.5 img = img[np.newaxis, ...] if self.model == "crnn_ctc": @@ -171,6 +171,8 @@ class DataGenerator(object): label = [int(c) for c in items[-1].split(',')] img = Image.open(os.path.join(img_root_dir, items[2])).convert( 'L') + + img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height img = np.array(img) - 127.5 img = img[np.newaxis, ...] if self.model == "crnn_ctc": @@ -207,6 +209,7 @@ class DataGenerator(object): else: img_path = line.strip("\t\n\r") img = Image.open(img_path).convert('L') + img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height img = np.array(img) - 127.5 img = img[np.newaxis, ...] label = [int(c) for c in line.split(' ')[3].split(',')] @@ -225,6 +228,7 @@ class DataGenerator(object): while True: img_path = input("Please input the path of image: ") img = Image.open(img_path).convert('L') + img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height img = np.array(img) - 127.5 img = img[np.newaxis, ...] yield img, [[0]]