未验证 提交 c034bac2 编写于 作者: W whs 提交者: GitHub

Add attention training model for ocr. (#1034)

* Add attention training model for ocr.

* Add beam search for infer.

* Fix data reader.

* Fix inference.

* Prune result of inference.

* Fix README

* update README

* Rsize figure.

* Resize image and fix format.
上级 6274cc99
export ce_mode=1
python ctc_train.py --batch_size=32 --total_step=1 --eval_period=1 --log_period=1 --use_gpu=True 1> ./tmp.log
python train.py --batch_size=32 --total_step=1 --eval_period=1 --log_period=1 --use_gpu=True 1> ./tmp.log
cat tmp.log | python _ce.py
rm tmp.log
......@@ -5,8 +5,9 @@
## 代码结构
```
├── ctc_reader.py # 下载、读取、处理数据。
├── crnn_ctc_model.py # 定义了训练网络、预测网络和evaluate网络。
├── ctc_train.py # 用于模型的训练。
├── crnn_ctc_model.py # 定义了OCR CTC model的网络结构。
├── attention_model.py # 定义了OCR attention model的网络结构。
├── train.py # 用于模型的训练。
├── infer.py # 加载训练好的模型文件,对新数据进行预测。
├── eval.py # 评估模型在指定数据集上的效果。
└── utils.py # 定义通用的函数。
......@@ -15,9 +16,16 @@
## 简介
本章的任务是识别含有单行汉语字符图片,首先采用卷积将图片转为特征图, 然后使用`im2sequence op`将特征图转为序列,通过`双向GRU`学习到序列特征。训练过程选用的损失函数为CTC(Connectionist Temporal Classification) loss,最终的评估指标为样本级别的错误率
本章的任务是识别图片中单行英文字符,这里我们分别使用CTC model和attention model两种不同的模型来完成该任务
这两种模型的有相同的编码部分,首先采用卷积将图片转为特征图, 然后使用`im2sequence op`将特征图转为序列,通过`双向GRU`学习到序列特征。
两种模型的解码部分和使用的损失函数区别如下:
- CTC model: 训练过程选用的损失函数为CTC(Connectionist Temporal Classification) loss, 预测阶段采用的是贪婪策略和CTC解码策略。
- Attention model: 训练过程选用的是带注意力机制的解码策略和交叉信息熵损失函数,预测阶段采用的是柱搜索策略。
训练以上两种模型的评估指标为样本级别的错误率。
## 数据
......@@ -124,15 +132,23 @@ env OMP_NUM_THREADS=<num_of_physical_cores> python ctc_train.py --use_gpu False
env CUDA_VISIABLE_DEVICES=0,1,2,3 python ctc_train.py --parallel=True
```
默认使用的是`CTC model`, 可以通过选项`--model="attention"`切换为`attention model`
执行`python ctc_train.py --help`可查看更多使用方式和参数详细说明。
图2为使用默认参数和默认数据集训练的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。在60轮迭代训练中,测试集上最低错误率为第32轮的22.0%.
图2为使用默认参数在默认数据集上训练`CTC model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为22.0%.
<p align="center">
<img src="images/train.jpg" width="620" hspace='10'/> <br/>
<img src="images/train.jpg" width="400" hspace='10'/> <br/>
<strong>图 2</strong>
</p>
图3为使用默认参数在默认数据集上训练`attention model`的收敛曲线,其中横坐标轴为训练迭代次数,纵轴为样本级错误率。其中,蓝线为训练集上的样本错误率,红线为测试集上的样本错误率。测试集上最低错误率为16.25%.
<p align="center">
<img src="images/train_attention.jpg" width="400" hspace='10'/> <br/>
<strong>图 3</strong>
</p>
## 测试
......
import paddle.fluid as fluid
decoder_size = 128
word_vector_dim = 128
max_length = 100
sos = 0
eos = 1
gradient_clip = 10
LR = 1.0
beam_size = 2
learning_rate_decay = None
def conv_bn_pool(input,
group,
out_ch,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
tmp = input
for i in xrange(group):
filter_size = 3
conv_std = (2.0 / (filter_size**2 * tmp.shape[1]))**0.5
conv_param = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std))
tmp = fluid.layers.conv2d(
input=tmp,
num_filters=out_ch[i],
filter_size=3,
padding=1,
bias_attr=False,
param_attr=conv_param,
act=None, # LinearActivation
use_cudnn=use_cudnn)
tmp = fluid.layers.batch_norm(input=tmp, act=act, is_test=is_test)
if pool == True:
tmp = fluid.layers.pool2d(
input=tmp,
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
return tmp
def ocr_convs(input, is_test=False, use_cudnn=True):
tmp = input
tmp = conv_bn_pool(tmp, 2, [16, 16], is_test=is_test, use_cudnn=use_cudnn)
tmp = conv_bn_pool(tmp, 2, [32, 32], is_test=is_test, use_cudnn=use_cudnn)
tmp = conv_bn_pool(tmp, 2, [64, 64], is_test=is_test, use_cudnn=use_cudnn)
tmp = conv_bn_pool(
tmp, 2, [128, 128], is_test=is_test, pool=False, use_cudnn=use_cudnn)
return tmp
def encoder_net(images, rnn_hidden_size=200, is_test=False, use_cudnn=True):
conv_features = ocr_convs(images, is_test=is_test, use_cudnn=use_cudnn)
sliced_feature = fluid.layers.im2sequence(
input=conv_features,
stride=[1, 1],
filter_size=[conv_features.shape[2], 1])
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0, 0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
fc_1 = fluid.layers.fc(input=sliced_feature,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
fc_2 = fluid.layers.fc(input=sliced_feature,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
gru_forward = fluid.layers.dynamic_gru(
input=fc_1,
size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
gru_backward = fluid.layers.dynamic_gru(
input=fc_2,
size=rnn_hidden_size,
is_reverse=True,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=1)
encoded_proj = fluid.layers.fc(input=encoded_vector,
size=decoder_size,
bias_attr=False)
return gru_backward, encoded_vector, encoded_proj
def gru_decoder_with_attention(target_embedding, encoder_vec, encoder_proj,
decoder_boot, decoder_size, num_classes):
def simple_attention(encoder_vec, encoder_proj, decoder_state):
decoder_state_proj = fluid.layers.fc(input=decoder_state,
size=decoder_size,
bias_attr=False)
decoder_state_expand = fluid.layers.sequence_expand(
x=decoder_state_proj, y=encoder_proj)
concated = encoder_proj + decoder_state_expand
concated = fluid.layers.tanh(x=concated)
attention_weights = fluid.layers.fc(input=concated,
size=1,
act=None,
bias_attr=False)
attention_weights = fluid.layers.sequence_softmax(
input=attention_weights)
weigths_reshape = fluid.layers.reshape(x=attention_weights, shape=[-1])
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weigths_reshape, axis=0)
context = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return context
rnn = fluid.layers.DynamicRNN()
with rnn.block():
current_word = rnn.step_input(target_embedding)
encoder_vec = rnn.static_input(encoder_vec)
encoder_proj = rnn.static_input(encoder_proj)
hidden_mem = rnn.memory(init=decoder_boot, need_reorder=True)
context = simple_attention(encoder_vec, encoder_proj, hidden_mem)
fc_1 = fluid.layers.fc(input=context,
size=decoder_size * 3,
bias_attr=False)
fc_2 = fluid.layers.fc(input=current_word,
size=decoder_size * 3,
bias_attr=False)
decoder_inputs = fc_1 + fc_2
h, _, _ = fluid.layers.gru_unit(
input=decoder_inputs, hidden=hidden_mem, size=decoder_size * 3)
rnn.update_memory(hidden_mem, h)
out = fluid.layers.fc(input=h,
size=num_classes + 2,
bias_attr=True,
act='softmax')
rnn.output(out)
return rnn()
def attention_train_net(args, data_shape, num_classes):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label_in = fluid.layers.data(
name='label_in', shape=[1], dtype='int32', lod_level=1)
label_out = fluid.layers.data(
name='label_out', shape=[1], dtype='int32', lod_level=1)
gru_backward, encoded_vector, encoded_proj = encoder_net(images)
backward_first = fluid.layers.sequence_pool(
input=gru_backward, pool_type='first')
decoder_boot = fluid.layers.fc(input=backward_first,
size=decoder_size,
bias_attr=False,
act="relu")
label_in = fluid.layers.cast(x=label_in, dtype='int64')
trg_embedding = fluid.layers.embedding(
input=label_in,
size=[num_classes + 2, word_vector_dim],
dtype='float32')
prediction = gru_decoder_with_attention(trg_embedding, encoded_vector,
encoded_proj, decoder_boot,
decoder_size, num_classes)
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip))
label_out = fluid.layers.cast(x=label_out, dtype='int64')
_, maxid = fluid.layers.topk(input=prediction, k=1)
error_evaluator = fluid.evaluator.EditDistance(
input=maxid, label=label_out, ignored_tokens=[sos, eos])
inference_program = fluid.default_main_program().clone(for_test=True)
cost = fluid.layers.cross_entropy(input=prediction, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
if learning_rate_decay == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([50000], [LR, LR * 0.01])
else:
learning_rate = LR
optimizer = fluid.optimizer.Adadelta(
learning_rate=learning_rate, epsilon=1.0e-6, rho=0.9)
optimizer.minimize(sum_cost)
model_average = None
if args.average_window > 0:
model_average = fluid.optimizer.ModelAverage(
args.average_window,
min_average_window=args.min_average_window,
max_average_window=args.max_average_window)
return sum_cost, error_evaluator, inference_program, model_average
def simple_attention(encoder_vec, encoder_proj, decoder_state, decoder_size):
decoder_state_proj = fluid.layers.fc(input=decoder_state,
size=decoder_size,
bias_attr=False)
decoder_state_expand = fluid.layers.sequence_expand(
x=decoder_state_proj, y=encoder_proj)
concated = fluid.layers.elementwise_add(encoder_proj, decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weights = fluid.layers.fc(input=concated,
size=1,
act=None,
bias_attr=False)
attention_weights = fluid.layers.sequence_softmax(input=attention_weights)
weigths_reshape = fluid.layers.reshape(x=attention_weights, shape=[-1])
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weigths_reshape, axis=0)
context = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return context
def attention_infer(images, num_classes, use_cudnn=True):
max_length = 20
gru_backward, encoded_vector, encoded_proj = encoder_net(
images, is_test=True, use_cudnn=use_cudnn)
backward_first = fluid.layers.sequence_pool(
input=gru_backward, pool_type='first')
decoder_boot = fluid.layers.fc(input=backward_first,
size=decoder_size,
bias_attr=False,
act="relu")
init_state = decoder_boot
array_len = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=max_length)
counter = fluid.layers.zeros(shape=[1], dtype='int64', force_cpu=True)
# fill the first element with init_state
state_array = fluid.layers.create_array('float32')
fluid.layers.array_write(init_state, array=state_array, i=counter)
# ids, scores as memory
ids_array = fluid.layers.create_array('int64')
scores_array = fluid.layers.create_array('float32')
init_ids = fluid.layers.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = fluid.layers.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
fluid.layers.array_write(init_ids, array=ids_array, i=counter)
fluid.layers.array_write(init_scores, array=scores_array, i=counter)
cond = fluid.layers.less_than(x=counter, y=array_len)
while_op = fluid.layers.While(cond=cond)
with while_op.block():
pre_ids = fluid.layers.array_read(array=ids_array, i=counter)
pre_state = fluid.layers.array_read(array=state_array, i=counter)
pre_score = fluid.layers.array_read(array=scores_array, i=counter)
pre_ids_emb = fluid.layers.embedding(
input=pre_ids,
size=[num_classes + 2, word_vector_dim],
dtype='float32')
context = simple_attention(encoded_vector, encoded_proj, pre_state,
decoder_size)
# expand the recursive_sequence_lengths of pre_state to be the same with pre_score
pre_state_expanded = fluid.layers.sequence_expand(pre_state, pre_score)
context_expanded = fluid.layers.sequence_expand(context, pre_score)
fc_1 = fluid.layers.fc(input=context_expanded,
size=decoder_size * 3,
bias_attr=False)
fc_2 = fluid.layers.fc(input=pre_ids_emb,
size=decoder_size * 3,
bias_attr=False)
decoder_inputs = fc_1 + fc_2
current_state, _, _ = fluid.layers.gru_unit(
input=decoder_inputs,
hidden=pre_state_expanded,
size=decoder_size * 3)
current_state_with_lod = fluid.layers.lod_reset(
x=current_state, y=pre_score)
# use score to do beam search
current_score = fluid.layers.fc(input=current_state_with_lod,
size=num_classes + 2,
bias_attr=True,
act='softmax')
topk_scores, topk_indices = fluid.layers.topk(
current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = fluid.layers.elementwise_add(
x=fluid.layers.log(topk_scores),
y=fluid.layers.reshape(
pre_score, shape=[-1]),
axis=0)
selected_ids, selected_scores = fluid.layers.beam_search(
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
1, # end_id
#level=0
)
fluid.layers.increment(x=counter, value=1, in_place=True)
# update the memories
fluid.layers.array_write(current_state, array=state_array, i=counter)
fluid.layers.array_write(selected_ids, array=ids_array, i=counter)
fluid.layers.array_write(selected_scores, array=scores_array, i=counter)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = fluid.layers.less_than(x=counter, y=array_len)
finish_cond = fluid.layers.logical_not(
fluid.layers.is_empty(x=selected_ids))
fluid.layers.logical_and(x=length_cond, y=finish_cond, out=cond)
ids, scores = fluid.layers.beam_search_decode(ids_array, scores_array,
beam_size, eos)
return ids
def attention_eval(data_shape, num_classes):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label_in = fluid.layers.data(
name='label_in', shape=[1], dtype='int32', lod_level=1)
label_out = fluid.layers.data(
name='label_out', shape=[1], dtype='int32', lod_level=1)
label_out = fluid.layers.cast(x=label_out, dtype='int64')
label_in = fluid.layers.cast(x=label_in, dtype='int64')
gru_backward, encoded_vector, encoded_proj = encoder_net(
images, is_test=True)
backward_first = fluid.layers.sequence_pool(
input=gru_backward, pool_type='first')
decoder_boot = fluid.layers.fc(input=backward_first,
size=decoder_size,
bias_attr=False,
act="relu")
trg_embedding = fluid.layers.embedding(
input=label_in,
size=[num_classes + 2, word_vector_dim],
dtype='float32')
prediction = gru_decoder_with_attention(trg_embedding, encoded_vector,
encoded_proj, decoder_boot,
decoder_size, num_classes)
_, maxid = fluid.layers.topk(input=prediction, k=1)
error_evaluator = fluid.evaluator.EditDistance(
input=maxid, label=label_out, ignored_tokens=[sos, eos])
cost = fluid.layers.cross_entropy(input=prediction, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
return error_evaluator, sum_cost
......@@ -166,13 +166,16 @@ def encoder_net(images,
return fc_out
def ctc_train_net(images, label, args, num_classes):
def ctc_train_net(args, data_shape, num_classes):
L2_RATE = 0.0004
LR = 1.0e-3
MOMENTUM = 0.9
learning_rate_decay = None
regularizer = fluid.regularizer.L2Decay(L2_RATE)
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
fc_out = encoder_net(
images,
num_classes,
......@@ -211,7 +214,10 @@ def ctc_infer(images, num_classes, use_cudnn):
return fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
def ctc_eval(images, label, num_classes, use_cudnn):
def ctc_eval(data_shape, num_classes, use_cudnn):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
fc_out = encoder_net(images, num_classes, is_test=True, use_cudnn=use_cudnn)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
......
......@@ -7,6 +7,8 @@ from os import path
from paddle.dataset.image import load_image
import paddle
SOS = 0
EOS = 1
NUM_CLASSES = 95
DATA_SHAPE = [1, 48, 512]
......@@ -22,8 +24,8 @@ TEST_LIST_FILE_NAME = "test.list"
class DataGenerator(object):
def __init__(self):
pass
def __init__(self, model="crnn_ctc"):
self.model = model
def train_reader(self,
img_root_dir,
......@@ -89,7 +91,10 @@ class DataGenerator(object):
img = img.resize((sz[0], sz[1]))
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
result.append([img, label])
if self.model == "crnn_ctc":
result.append([img, label])
else:
result.append([img, [SOS] + label, label + [EOS]])
yield result
if not cycle:
break
......@@ -117,7 +122,10 @@ class DataGenerator(object):
'L')
img = np.array(img) - 127.5
img = img[np.newaxis, ...]
yield img, label
if self.model == "crnn_ctc":
yield img, label
else:
yield img, [SOS] + label, label + [EOS]
return reader
......@@ -185,8 +193,12 @@ def data_shape():
return DATA_SHAPE
def train(batch_size, train_images_dir=None, train_list_file=None, cycle=False):
generator = DataGenerator()
def train(batch_size,
train_images_dir=None,
train_list_file=None,
cycle=False,
model="crnn_ctc"):
generator = DataGenerator(model)
if train_images_dir is None:
data_dir = download_data()
train_images_dir = path.join(data_dir, TRAIN_DATA_DIR_NAME)
......@@ -199,8 +211,11 @@ def train(batch_size, train_images_dir=None, train_list_file=None, cycle=False):
train_images_dir, train_list_file, batch_size, cycle, shuffle=shuffle)
def test(batch_size=1, test_images_dir=None, test_list_file=None):
generator = DataGenerator()
def test(batch_size=1,
test_images_dir=None,
test_list_file=None,
model="crnn_ctc"):
generator = DataGenerator(model)
if test_images_dir is None:
data_dir = download_data()
test_images_dir = path.join(data_dir, TEST_DATA_DIR_NAME)
......@@ -213,8 +228,9 @@ def test(batch_size=1, test_images_dir=None, test_list_file=None):
def inference(batch_size=1,
infer_images_dir=None,
infer_list_file=None,
cycle=False):
generator = DataGenerator()
cycle=False,
model="crnn_ctc"):
generator = DataGenerator(model)
return paddle.batch(
generator.infer_reader(infer_images_dir, infer_list_file, cycle),
batch_size)
......
import paddle.v2 as paddle
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data
from attention_model import attention_eval
from crnn_ctc_model import ctc_eval
import ctc_reader
import data_reader
import argparse
import functools
import os
......@@ -11,27 +11,34 @@ import os
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model_path', str, None, "The model path to be used for inference.")
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('model_path', str, "", "The model path to be used for inference.")
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
add_arg('use_gpu', bool, True, "Whether use GPU to eval.")
# yapf: enable
def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
def evaluate(args):
"""OCR inference"""
if args.model == "crnn_ctc":
eval = ctc_eval
get_feeder_data = get_ctc_feeder_data
else:
eval = attention_eval
get_feeder_data = get_attention_feeder_data
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
evaluator, cost = eval(images, label, num_classes)
evaluator, cost = eval(data_shape, num_classes)
# data reader
test_reader = data_reader.test(
test_images_dir=args.input_images_dir,
test_list_file=args.input_images_list)
test_list_file=args.input_images_list,
model=args.model)
# prepare environment
place = fluid.CPUPlace()
......@@ -55,6 +62,7 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
for data in test_reader():
count += 1
exe.run(fluid.default_main_program(), feed=get_feeder_data(data, place))
print "Read %d samples;\r" % count,
avg_distance, avg_seq_error = evaluator.eval(exe)
print "Read %d samples; avg_distance: %s; avg_seq_error: %s" % (
count, avg_distance, avg_seq_error)
......@@ -63,7 +71,7 @@ def evaluate(args, eval=ctc_eval, data_reader=ctc_reader):
def main():
args = parser.parse_args()
print_arguments(args)
evaluate(args, data_reader=ctc_reader)
evaluate(args)
if __name__ == "__main__":
......
import paddle.v2 as paddle
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_for_infer
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_infer
from attention_model import attention_infer
import numpy as np
import ctc_reader
import data_reader
import argparse
import functools
import os
......@@ -13,6 +14,7 @@ import time
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('model_path', str, None, "The model path to be used for inference.")
add_arg('input_images_dir', str, None, "The directory of images.")
add_arg('input_images_list', str, None, "The list file of images.")
......@@ -25,20 +27,28 @@ add_arg('batch_size', int, 1, "The minibatch size.")
# yapf: enable
def inference(args, infer=ctc_infer, data_reader=ctc_reader):
def inference(args):
"""OCR inference"""
if args.model == "crnn_ctc":
infer = ctc_infer
get_feeder_data = get_ctc_feeder_data
else:
infer = attention_infer
get_feeder_data = get_attention_feeder_for_infer
eos = 1
sos = 0
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
sequence = infer(
images, num_classes, use_cudnn=True if args.use_gpu else False)
ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
# data reader
infer_reader = data_reader.inference(
batch_size=args.batch_size,
infer_images_dir=args.input_images_dir,
infer_list_file=args.input_images_list,
cycle=True if args.iterations > 0 else False)
cycle=True if args.iterations > 0 else False,
model=args.model)
# prepare environment
place = fluid.CPUPlace()
if args.use_gpu:
......@@ -68,6 +78,7 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
batch_times = []
iters = 0
for data in infer_reader():
feed_dict = get_feeder_data(data, place)
if args.iterations > 0 and iters == args.iterations + args.skip_batch_num:
break
if iters < args.skip_batch_num:
......@@ -77,14 +88,13 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
start = time.time()
result = exe.run(fluid.default_main_program(),
feed=get_feeder_data(
data, place, need_label=False),
fetch_list=[sequence],
feed=feed_dict,
fetch_list=[ids],
return_numpy=False)
indexes = prune(np.array(result[0]).flatten(), 0, 1)
batch_time = time.time() - start
fps = args.batch_size / batch_time
batch_times.append(batch_time)
indexes = np.array(result[0]).flatten()
if dict_map is not None:
print "Iteration %d, latency: %.5f s, fps: %f, result: %s" % (
iters,
......@@ -114,18 +124,29 @@ def inference(args, infer=ctc_infer, data_reader=ctc_reader):
print('average fps: %.5f, fps for 99pc latency: %.5f' % (fps_avg, fps_pc99))
def prune(words, sos, eos):
"""Remove unused tokens in prediction result."""
start_index = 0
end_index = len(words)
if sos in words:
start_index = np.where(words == sos)[0][0] + 1
if eos in words:
end_index = np.where(words == eos)[0][0]
return words[start_index:end_index]
def main():
args = parser.parse_args()
print_arguments(args)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
inference(args, data_reader=ctc_reader)
inference(args)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
inference(args, data_reader=ctc_reader)
inference(args)
else:
inference(args, data_reader=ctc_reader)
inference(args)
if __name__ == "__main__":
......
"""Trainer for OCR CTC model."""
"""Trainer for OCR CTC or attention model."""
import paddle.fluid as fluid
from utility import add_arguments, print_arguments, to_lodtensor, get_ctc_feeder_data, get_attention_feeder_data
import paddle.fluid.profiler as profiler
from utility import add_arguments, print_arguments, to_lodtensor, get_feeder_data
from crnn_ctc_model import ctc_train_net
import ctc_reader
from attention_model import attention_train_net
import data_reader
import argparse
import functools
import sys
......@@ -20,6 +21,7 @@ add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('model', str, "crnn_ctc", "Which type of network to be used. 'crnn_ctc' or 'attention'")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
......@@ -32,8 +34,16 @@ add_arg('skip_test', bool, False, "Whether to skip test phase.")
# yapf: enable
def train(args, data_reader=ctc_reader):
"""OCR CTC training"""
def train(args):
"""OCR training"""
if args.model == "crnn_ctc":
train_net = ctc_train_net
get_feeder_data = get_ctc_feeder_data
else:
train_net = attention_train_net
get_feeder_data = get_attention_feeder_data
num_classes = None
train_images = None
train_list = None
......@@ -43,20 +53,18 @@ def train(args, data_reader=ctc_reader):
) if num_classes is None else num_classes
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
sum_cost, error_evaluator, inference_program, model_average = ctc_train_net(
images, label, args, num_classes)
sum_cost, error_evaluator, inference_program, model_average = train_net(
args, data_shape, num_classes)
# data reader
train_reader = data_reader.train(
args.batch_size,
train_images_dir=train_images,
train_list_file=train_list,
cycle=args.total_step > 0)
cycle=args.total_step > 0,
model=args.model)
test_reader = data_reader.test(
test_images_dir=test_images, test_list_file=test_list)
test_images_dir=test_images, test_list_file=test_list, model=args.model)
# prepare environment
place = fluid.CPUPlace()
......@@ -144,7 +152,7 @@ def train(args, data_reader=ctc_reader):
iter_num += 1
# training log
if iter_num % args.log_period == 0:
print "\nTime: %s; Iter[%d]; Avg Warp-CTC loss: %.3f; Avg seq err: %.3f" % (
print "\nTime: %s; Iter[%d]; Avg loss: %.3f; Avg seq err: %.3f" % (
time.time(), iter_num,
total_loss / (args.log_period * args.batch_size),
total_seq_error / (args.log_period * args.batch_size))
......@@ -155,7 +163,7 @@ def train(args, data_reader=ctc_reader):
total_loss = 0.0
total_seq_error = 0.0
# evaluate
# evaluate
if not args.skip_test and iter_num % args.eval_period == 0:
if model_average:
with model_average.apply(exe):
......@@ -195,12 +203,12 @@ def main():
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args, data_reader=ctc_reader)
train(args)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(args, data_reader=ctc_reader)
train(args)
else:
train(args, data_reader=ctc_reader)
train(args)
if __name__ == "__main__":
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import distutils.util
import numpy as np
from paddle.fluid import core
import paddle.fluid as fluid
def print_arguments(args):
......@@ -77,7 +78,7 @@ def to_lodtensor(data, place):
return res
def get_feeder_data(data, place, need_label=True):
def get_ctc_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
......@@ -88,3 +89,47 @@ def get_feeder_data(data, place, need_label=True):
return {"pixel": pixel_tensor, "label": label_tensor}
else:
return {"pixel": pixel_tensor}
def get_attention_feeder_data(data, place, need_label=True):
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
label_in_tensor = to_lodtensor(map(lambda x: x[1], data), place)
label_out_tensor = to_lodtensor(map(lambda x: x[2], data), place)
if need_label:
return {
"pixel": pixel_tensor,
"label_in": label_in_tensor,
"label_out": label_out_tensor
}
else:
return {"pixel": pixel_tensor}
def get_attention_feeder_for_infer(data, place):
batch_size = len(data)
init_ids_data = np.array([0 for _ in range(batch_size)], dtype='int64')
init_scores_data = np.array(
[1. for _ in range(batch_size)], dtype='float32')
init_ids_data = init_ids_data.reshape((batch_size, 1))
init_scores_data = init_scores_data.reshape((batch_size, 1))
init_recursive_seq_lens = [1] * batch_size
init_recursive_seq_lens = [init_recursive_seq_lens, init_recursive_seq_lens]
init_ids = fluid.create_lod_tensor(init_ids_data, init_recursive_seq_lens,
place)
init_scores = fluid.create_lod_tensor(init_scores_data,
init_recursive_seq_lens, place)
pixel_tensor = core.LoDTensor()
pixel_data = None
pixel_data = np.concatenate(
map(lambda x: x[0][np.newaxis, :], data), axis=0).astype("float32")
pixel_tensor.set(pixel_data, place)
return {
"pixel": pixel_tensor,
"init_ids": init_ids,
"init_scores": init_scores
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册