未验证 提交 52ca7b75 编写于 作者: L LielinJiang 提交者: GitHub

Refine ocr dygraph code, add infer module (#4182)

* refine code, add infer module

* update readme
上级 6967b7ab
...@@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten ...@@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten
在GPU单卡上训练ocr recognition: 在GPU单卡上训练ocr recognition:
``` ```
env CUDA_VISIBLE_DEVICES=0 python train.py CUDA_VISIBLE_DEVICES=0 python train.py
``` ```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。 这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
## 效果 ## 测试ocr recognition
在test测试集合上,最好的效果为82.0%
```
CUDA_VISIBLE_DEVICES=0 python eval.py --pretrained_model your_trained_model_path
```
## 预测
```
CUDA_VISIBLE_DEVICES=0 python -u infer.py --pretrained_model your_trained_model_path --image_path your_img_path
```
## 预训练模型
|模型| 准确率|
|- |:-: |
|[ocr_attention_params](https://paddle-ocr-models.bj.bcebos.com/ocr_attention_dygraph.tar) | 82.46%|
...@@ -2,13 +2,12 @@ from __future__ import absolute_import ...@@ -2,13 +2,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import cv2
import tarfile import tarfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from os import path from os import path
from paddle.dataset.image import load_image
import paddle import paddle
import random
SOS = 0 SOS = 0
EOS = 1 EOS = 1
...@@ -53,24 +52,53 @@ class DataGenerator(object): ...@@ -53,24 +52,53 @@ class DataGenerator(object):
img_label_lines = [] img_label_lines = []
to_file = "tmp.txt" to_file = "tmp.txt"
if not shuffle:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file def _shuffle_data(input_file_path, output_file_path, shuffle,
elif batchsize == 1: batchsize):
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file def _write_file(file_path, lines_to_write):
else: open(file_path, 'w').writelines(
#cmd1: partial shuffle ["{}\n".format(item) for item in lines_to_write])
cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
#cmd2: batch merge and shuffle input_file = open(input_file_path, 'r')
cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str( lines_to_shuf = [line.strip() for line in input_file.readlines()]
batchsize) + " == 0) print \"\";}' | shuf | "
#cmd3: batch split if not shuffle:
cmd += "awk '{if(NF == " + str( _write_file(output_file_path, lines_to_shuf)
batchsize elif batchsize == 1:
) + " * 4) {for(i = 0; i < " + str( random.shuffle(lines_to_shuf)
batchsize _write_file(output_file_path, lines_to_shuf)
) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file else:
os.system(cmd) #partial shuffle
print("finish batch shuffle") for i in range(len(lines_to_shuf)):
str_i = lines_to_shuf[i]
list_i = str_i.strip().split(' ')
str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
) + str_i
lines_to_shuf[i] = str_i_
lines_to_shuf.sort()
delete_num = random.randint(1, 100)
del lines_to_shuf[0:delete_num]
#batch merge and shuffle
lines_concat = []
for i in range(0, len(lines_to_shuf), batchsize):
lines_concat.append(' '.join(lines_to_shuf[i:i +
batchsize]))
random.shuffle(lines_concat)
#batch split
out_file = open(output_file_path, 'w')
for i in range(len(lines_concat)):
tmp_list = lines_concat[i].split(' ')
for j in range(int(len(tmp_list) / 5)):
out_file.write("{} {} {} {}\n".format(tmp_list[
5 * j + 1], tmp_list[5 * j + 2], tmp_list[
5 * j + 3], tmp_list[5 * j + 4]))
out_file.close()
input_file.close()
_shuffle_data(img_label_list, to_file, shuffle, batchsize)
img_label_lines = open(to_file, 'r').readlines() img_label_lines = open(to_file, 'r').readlines()
def reader(): def reader():
...@@ -95,7 +123,7 @@ class DataGenerator(object): ...@@ -95,7 +123,7 @@ class DataGenerator(object):
mask = np.zeros((max_len)).astype('float32') mask = np.zeros((max_len)).astype('float32')
mask[:len(label) + 1] = 1.0 mask[:len(label) + 1] = 1.0
#mask[ j, :len(label) + 1] = 1.0
if max_len > len(label) + 1: if max_len > len(label) + 1:
extend_label = [EOS] * (max_len - len(label) - 1) extend_label = [EOS] * (max_len - len(label) - 1)
label.extend(extend_label) label.extend(extend_label)
......
export CUDA_VISIBLE_DEVICES=0
python train.py
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import functools
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import data_reader
from nets import OCRAttention
from paddle.fluid.dygraph.base import to_variable
from utility import add_arguments, print_arguments, get_attention_feeder_data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def evaluate(model, test_reader, batch_size):
model.eval()
total_step = 0.0
equal_size = 0
for data in test_reader():
data_dict = get_attention_feeder_data(data)
label_in = to_variable(data_dict["label_in"])
label_out = to_variable(data_dict["label_out"])
label_out.stop_gradient = True
img = to_variable(data_dict["pixel"])
prediction = model(img, label_in)
prediction = fluid.layers.reshape(prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False)
score, topk = layers.topk(prediction, 1)
seq = topk.numpy()
seq = seq.reshape((batch_size, -1))
mask = data_dict['mask'].reshape((batch_size, -1))
seq_len = np.sum(mask, -1)
trans_ref = data_dict["label_out"].reshape((batch_size, -1))
for i in range(batch_size):
length = int(seq_len[i] - 1)
trans = seq[i][:length - 1]
ref = trans_ref[i][: length - 1]
if np.array_equal(trans, ref):
equal_size += 1
total_step += batch_size
accuracy = equal_size / total_step
print("eval accuracy:", accuracy)
return accuracy
def eval(args):
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=args.batch_size,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
test_reader = data_reader.data_reader(
args.batch_size,
images_dir=args.test_images,
list_file=args.test_list,
data_type="test")
evaluate(ocr_attention, test_reader, args.batch_size)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
eval(args)
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments
from PIL import Image
from nets import OCRAttention
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('image_path', str, "", "image path")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('max_length', int, 100, "Max predict length.")
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def inference(args):
img = Image.open(os.path.join(args.image_path)).convert('L')
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=1,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
ocr_attention.eval()
print(img.size)
img = img.resize((img.size[0], 48), Image.BILINEAR)
img = np.array(img).astype('float32') - 127.5
img = img[np.newaxis, np.newaxis, ...]
img = to_variable(img)
gru_backward, encoded_vector, encoded_proj = ocr_attention.encoder_net(img)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = ocr_attention.fc(backward_first)
label_in = fluid.layers.zeros([1], dtype='int64')
result = ''
for i in range(args.max_length):
trg_embedding = ocr_attention.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [1, -1, trg_embedding.shape[1]],
inplace=False)
prediction, decoder_boot = ocr_attention.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot, inference=True)
prediction = fluid.layers.reshape(prediction, [args.num_classes + 2])
score, idx = fluid.layers.topk(prediction, 1)
idx_np = idx.numpy()[0]
if idx_np == 1:
print('met end character, predict finish!')
break
label_in = fluid.layers.reshape(idx, [1])
result += chr(int(idx_np + 33))
print('predict result:', result)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
inference(args)
\ No newline at end of file
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
[16, 16], [1, 16],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
[32, 32], [16, 32],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
[64, 64], [32, 64],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
[128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[:, i: i + 1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
batch_size,
decoder_size,
rnn_hidden_size=200,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0,
0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
if fluid.framework.in_dygraph_mode():
h_0 = np.zeros(
(batch_size, rnn_hidden_size), dtype="float32")
h_0 = to_variable(h_0)
else:
h_0 = fluid.layers.fill_constant(
shape=[batch_size, rnn_hidden_size],
dtype='float32',
value=0)
self.ocr_convs = OCRConv(
is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.fc_2_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
self.encoded_proj_fc = Linear(rnn_hidden_size * 2,
decoder_size,
bias_attr=False)
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear( decoder_size,
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = Linear( decoder_size,
1,
act=None,
bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[ concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax( weights_reshape )
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, encoder_size, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear(input_dim=encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear(input_dim=decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3,
param_attr=None,
bias_attr=None)
self.out_layer = Linear(input_dim=decoder_size,
output_dim =num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self, current_word, encoder_vec, encoder_proj,
decoder_boot, inference=False):
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
decoder_boot)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, decoder_boot)
out = self.out_layer(h)
return out, h
class OCRAttention(fluid.dygraph.Layer):
def __init__(self, batch_size, num_classes, encoder_size, decoder_size, word_vector_dim):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet(batch_size, decoder_size)
self.fc = Linear(input_dim=encoder_size,
output_dim=decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[num_classes + 2, word_vector_dim],
dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(encoder_size, decoder_size,
num_classes)
self.batch_size = batch_size
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [self.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
pred_temp = []
for i in range(trg_embedding.shape[1]):
current_word = fluid.layers.slice(
trg_embedding, axes=[1], starts=[i], ends=[i + 1])
out, decoder_boot = self.gru_decoder_with_attention(
current_word, encoded_vector, encoded_proj, decoder_boot
)
pred_temp.append(out)
pred_temp = fluid.layers.concat(pred_temp, axis=1)
batch_size = trg_embedding.shape[0]
seq_len = trg_embedding.shape[1]
prediction = fluid.layers.reshape(pred_temp, shape=[batch_size, seq_len, -1])
return prediction
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册