未验证 提交 446243af 编写于 作者: H Hongyu Liu 提交者: GitHub

Add orc readme (#2341)

* add pbt lm; test=develop

* add dynamic ocr recognition; test=develop

* add unility.py; test=develop

* add readme; test=develop

* fix bug; test=develop

* fix bug; test=develop

* fix bug; test=develop
上级 63deaed1
DyGraph模式下ocr recognition实现
========
简介
--------
ocr任务是识别图片单行的字母信息,在动态图下使用了带attention的seq2seq结构,静态图实现可以参考([ocr recognition](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/ocr_recognition)
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。
## 代码结构
```
└── train.py # 训练脚本。
└── data_reader.py # 数据读取。
└── utility # 基础的函数。
```
## 使用的数据
教程中使用`ocr attention`数据集作为训练数据,该数据集通过`paddle.dataset`模块自动下载到本地。
## 训练测试ocr recognition
在GPU单卡上训练ocr recognition:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
## 效果
在test测试集合上,最好的效果为82.0%
......@@ -17,7 +17,6 @@ DATA_SHAPE = [1, 48, 512]
DATA_MD5 = "7256b1d5420d8c3e74815196e58cdad5"
DATA_URL = "http://paddle-ocr-data.bj.bcebos.com/data.tar.gz"
CACHE_DIR_NAME = "ctc_data"
SAVED_FILE_NAME = "data.tar.gz"
DATA_DIR_NAME = "data"
TRAIN_DATA_DIR_NAME = "train_images"
......@@ -27,15 +26,14 @@ TEST_LIST_FILE_NAME = "test.list"
class DataGenerator(object):
def __init__(self, model="crnn_ctc"):
self.model = model
def __init__(self):
pass
def train_reader(self,
img_root_dir,
img_label_list,
batchsize,
cycle,
max_length,
shuffle=True):
'''
Reader interface for training.
......@@ -89,11 +87,6 @@ class DataGenerator(object):
label = [int(c) for c in items[-1].split(',')]
max_len = max(max_len, len(label))
#print( "max len", max_len, i)
max_length = max_len
#mask = np.zeros( (batchsize, max_length)).astype('float32')
for j in range(batchsize):
line = img_label_lines[i * batchsize + j]
items = line.split(' ')
......@@ -102,11 +95,11 @@ class DataGenerator(object):
mask = np.zeros((max_len)).astype('float32')
mask[:len(label) + 1] = 1.0
#mask[ j, :len(label) + 1] = 1.0
if max_length > len(label) + 1:
extend_label = [EOS] * (max_length - len(label) - 1)
if max_len > len(label) + 1:
extend_label = [EOS] * (max_len - len(label) - 1)
label.extend(extend_label)
else:
label = label[0:max_length - 1]
label = label[0:max_len - 1]
img = Image.open(os.path.join(img_root_dir, items[
2])).convert('L')
if j == 0:
......@@ -121,85 +114,6 @@ class DataGenerator(object):
return reader
def test_reader(self, img_root_dir, img_label_list):
'''
Reader interface for inference.
:param img_root_dir: The root path of the images for training.
:type img_root_dir: str
:param img_label_list: The path of the <image_name, label> file for testing.
:type img_label_list: str
'''
def reader():
for line in open(img_label_list):
# h, w, img_name, labels
items = line.split(' ')
label = [int(c) for c in items[-1].split(',')]
img = Image.open(os.path.join(img_root_dir, items[2])).convert(
'L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
if self.model == "crnn_ctc":
yield img, label
else:
yield img, [SOS] + label, label + [EOS]
return reader
def infer_reader(self, img_root_dir=None, img_label_list=None, cycle=False):
'''A reader interface for inference.
:param img_root_dir: The root path of the images for training.
:type img_root_dir: str
:param img_label_list: The path of the <image_name, label> file for
inference. It should be the path of <image_path> file if img_root_dir
was None. If img_label_list was set to None, it will read image path
from stdin.
:type img_root_dir: str
:param cycle: If number of iterations is greater than dataset_size /
batch_size it reiterates dataset over as many times as necessary.
:type cycle: bool
'''
def reader():
def yield_img_and_label(lines):
for line in lines:
if img_root_dir is not None:
# h, w, img_name, labels
img_name = line.split(' ')[2]
img_path = os.path.join(img_root_dir, img_name)
else:
img_path = line.strip("\t\n\r")
img = Image.open(img_path).convert('L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
label = [int(c) for c in line.split(' ')[3].split(',')]
yield img, label
if img_label_list is not None:
lines = []
with open(img_label_list) as f:
lines = f.readlines()
for img, label in yield_img_and_label(lines):
yield img, label
while cycle:
for img, label in yield_img_and_label(lines):
yield img, label
else:
while True:
img_path = input("Please input the path of image: ")
img = Image.open(img_path).convert('L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
yield img, [[0]]
return reader
def num_classes():
'''Get classes number of this dataset.
......@@ -213,51 +127,31 @@ def data_shape():
return DATA_SHAPE
def train(batch_size,
max_length,
train_images_dir=None,
train_list_file=None,
def data_reader(batch_size,
images_dir=None,
list_file=None,
cycle=False,
shuffle=False,
model="crnn_ctc"):
generator = DataGenerator(model)
if train_images_dir is None:
data_dir = download_data()
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if train_list_file is None:
train_list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
return generator.train_reader(
train_images_dir,
train_list_file,
batch_size,
cycle,
max_length,
shuffle=shuffle)
data_type="train"):
generator = DataGenerator()
def test(batch_size=1,
test_images_dir=None,
test_list_file=None,
model="crnn_ctc"):
generator = DataGenerator(model)
if test_images_dir is None:
if data_type == "train":
if images_dir is None:
data_dir = download_data()
test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
if test_list_file is None:
test_list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
return paddle.batch(
generator.test_reader(test_images_dir, test_list_file), batch_size)
def inference(batch_size=1,
infer_images_dir=None,
infer_list_file=None,
cycle=False,
model="crnn_ctc"):
generator = DataGenerator(model)
return paddle.batch(
generator.infer_reader(infer_images_dir, infer_list_file, cycle),
batch_size)
images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
if list_file is None:
list_file = path.join(data_dir, TRAIN_LIST_FILE_NAME)
elif data_type == "test":
if images_dir is None:
data_dir = download_data()
images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
if list_file is None:
list_file = path.join(data_dir, TEST_LIST_FILE_NAME)
else:
print("data type only support train | test")
raise Exception("data type only support train | test")
return generator.train_reader(
images_dir, list_file, batch_size, cycle, shuffle=shuffle)
def download_data():
......
......@@ -42,7 +42,6 @@ add_arg('train_images', str, None, "The directory of images to be u
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('model', str, "attention", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
......@@ -78,10 +77,6 @@ class Config(object):
# special label for start and end
SOS = 0
EOS = 1
# settings for ctc data, not use in unittest
DATA_DIR_NAME = "./dataset/ctc_data/data"
TRAIN_DATA_DIR_NAME = "train_images"
TRAIN_LIST_FILE_NAME = "train.list"
# data shape for input image
DATA_SHAPE = [1, 48, 512]
......@@ -478,24 +473,18 @@ def train(args):
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0 )
train_reader = data_reader.train(
train_reader = data_reader.data_reader(
Config.batch_size,
max_length=Config.max_length,
train_images_dir=args.train_images,
train_list_file=args.train_list,
cycle=args.total_step > 0,
shuffle=True,
model=args.model)
data_type='train')
infer_image= './data/data/test_images/'
infer_files = './data/data/test.list'
test_reader = data_reader.train(
test_reader = data_reader.data_reader(
Config.batch_size,
1000,
train_images_dir= infer_image,
train_list_file= infer_files,
cycle=False,
model=args.model)
data_type="test")
def eval():
ocr_attention.eval()
total_loss = 0.0
......@@ -578,10 +567,6 @@ def train(args):
total_loss = 0.0
if total_step > 0 and total_step % 2000 == 0:
model_value = ocr_attention.state_dict()
np.savez( "model/" + str(total_step), **model_value )
ocr_attention.eval()
eval()
ocr_attention.train()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册