未验证 提交 2ab4725c 编写于 作者: W whs 提交者: GitHub

Make ocr model support for variable height. (#3114)

Make support for variable height.
1. Fix readme in ocr model.
2. Fix data reader to support variable shape of input.
上级 3a33a0bb
......@@ -75,9 +75,9 @@
在训练时,我们通过选项`--train_images``--train_list` 分别设置准备好的`train_images``train_list`
`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整训练数据的高度。用户可以根据自己准备的训练数据,设置合适的`DATA_SHAPE`。如果使用默认的示例数据,则使用默认的`DATA_SHAPE`即可。
>**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, reader.py会自动下载使用[示例数据](http://paddle-ocr-data.bj.bcebos.com/data.tar.gz),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。
>**注:** 如果`--train_images` 和 `--train_list`都未设置或设置为None, data_reader.py会自动下载使用[示例数据](http://paddle-ocr-data.bj.bcebos.com/data.tar.gz),并将其缓存到`$HOME/.cache/paddle/dataset/ctc_data/data/` 路径下。
**B. 测试集和评估集**
......@@ -85,6 +85,8 @@
在训练阶段,测试集的路径通过train.py的选项`--test_images``--test_list` 来设置。
在评估时,评估集的路径通过eval.py的选项`--input_images_dir``--input_images_list` 来设置。
`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整测试图像的高度,所以测试图像可以有不同高度。但是,`DATA_SHAPE`需要和训练模型时保持严格一致。
**C. 待预测数据集**
预测支持三种形式的输入:
......@@ -108,6 +110,8 @@ data/test_images/00003.jpg
第三种:从stdin读入一张图片的path,然后进行一次inference.
`data_reader.py`中,会按照用户设置的`DATA_SHAPE`调整预测图像的高度,所以预测图像可以有不同高度。但是,`DATA_SHAPE`需要和训练模型时保持严格一致。
## 模型训练与预测
### 训练
......
......@@ -136,10 +136,10 @@ class DataGenerator(object):
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L') #zhuanhuidu
2])).convert('L')
if j == 0:
sz = img.size
img = img.resize((sz[0], sz[1]))
img = img.resize((sz[0], DATA_SHAPE[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
if self.model == "crnn_ctc":
......@@ -171,6 +171,8 @@ class DataGenerator(object):
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[2])).convert(
'L')
img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
if self.model == "crnn_ctc":
......@@ -207,6 +209,7 @@ class DataGenerator(object):
else:
img_path = line.strip("\t\n\r")
img = Image.open(img_path).convert('L')
img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
label = [int(c) for c in line.split(' ')[3].split(',')]
......@@ -225,6 +228,7 @@ class DataGenerator(object):
while True:
img_path = input("Please input the path of image: ")
img = Image.open(img_path).convert('L')
img = img.resize((img.size[0], DATA_SHAPE[1])) # resize height
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
yield img, [[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册