未验证 提交 52ca7b75 编写于 作者: L LielinJiang 提交者: GitHub

Refine ocr dygraph code, add infer module (#4182)

* refine code, add infer module

* update readme
上级 6967b7ab
...@@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten ...@@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten
在GPU单卡上训练ocr recognition: 在GPU单卡上训练ocr recognition:
``` ```
env CUDA_VISIBLE_DEVICES=0 python train.py CUDA_VISIBLE_DEVICES=0 python train.py
``` ```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。 这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
## 效果 ## 测试ocr recognition
在test测试集合上,最好的效果为82.0%
```
CUDA_VISIBLE_DEVICES=0 python eval.py --pretrained_model your_trained_model_path
```
## 预测
```
CUDA_VISIBLE_DEVICES=0 python -u infer.py --pretrained_model your_trained_model_path --image_path your_img_path
```
## 预训练模型
|模型| 准确率|
|- |:-: |
|[ocr_attention_params](https://paddle-ocr-models.bj.bcebos.com/ocr_attention_dygraph.tar) | 82.46%|
...@@ -2,13 +2,12 @@ from __future__ import absolute_import ...@@ -2,13 +2,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import cv2
import tarfile import tarfile
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from os import path from os import path
from paddle.dataset.image import load_image
import paddle import paddle
import random
SOS = 0 SOS = 0
EOS = 1 EOS = 1
...@@ -53,24 +52,53 @@ class DataGenerator(object): ...@@ -53,24 +52,53 @@ class DataGenerator(object):
img_label_lines = [] img_label_lines = []
to_file = "tmp.txt" to_file = "tmp.txt"
def _shuffle_data(input_file_path, output_file_path, shuffle,
batchsize):
def _write_file(file_path, lines_to_write):
open(file_path, 'w').writelines(
["{}\n".format(item) for item in lines_to_write])
input_file = open(input_file_path, 'r')
lines_to_shuf = [line.strip() for line in input_file.readlines()]
if not shuffle: if not shuffle:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file _write_file(output_file_path, lines_to_shuf)
elif batchsize == 1: elif batchsize == 1:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file random.shuffle(lines_to_shuf)
_write_file(output_file_path, lines_to_shuf)
else: else:
#cmd1: partial shuffle #partial shuffle
cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | " for i in range(len(lines_to_shuf)):
#cmd2: batch merge and shuffle str_i = lines_to_shuf[i]
cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str( list_i = str_i.strip().split(' ')
batchsize) + " == 0) print \"\";}' | shuf | " str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
#cmd3: batch split ) + str_i
cmd += "awk '{if(NF == " + str( lines_to_shuf[i] = str_i_
batchsize lines_to_shuf.sort()
) + " * 4) {for(i = 0; i < " + str( delete_num = random.randint(1, 100)
batchsize del lines_to_shuf[0:delete_num]
) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
os.system(cmd) #batch merge and shuffle
print("finish batch shuffle") lines_concat = []
for i in range(0, len(lines_to_shuf), batchsize):
lines_concat.append(' '.join(lines_to_shuf[i:i +
batchsize]))
random.shuffle(lines_concat)
#batch split
out_file = open(output_file_path, 'w')
for i in range(len(lines_concat)):
tmp_list = lines_concat[i].split(' ')
for j in range(int(len(tmp_list) / 5)):
out_file.write("{} {} {} {}\n".format(tmp_list[
5 * j + 1], tmp_list[5 * j + 2], tmp_list[
5 * j + 3], tmp_list[5 * j + 4]))
out_file.close()
input_file.close()
_shuffle_data(img_label_list, to_file, shuffle, batchsize)
img_label_lines = open(to_file, 'r').readlines() img_label_lines = open(to_file, 'r').readlines()
def reader(): def reader():
...@@ -95,7 +123,7 @@ class DataGenerator(object): ...@@ -95,7 +123,7 @@ class DataGenerator(object):
mask = np.zeros((max_len)).astype('float32') mask = np.zeros((max_len)).astype('float32')
mask[:len(label) + 1] = 1.0 mask[:len(label) + 1] = 1.0
#mask[ j, :len(label) + 1] = 1.0
if max_len > len(label) + 1: if max_len > len(label) + 1:
extend_label = [EOS] * (max_len - len(label) - 1) extend_label = [EOS] * (max_len - len(label) - 1)
label.extend(extend_label) label.extend(extend_label)
......
export CUDA_VISIBLE_DEVICES=0
python train.py
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import functools
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import data_reader
from nets import OCRAttention
from paddle.fluid.dygraph.base import to_variable
from utility import add_arguments, print_arguments, get_attention_feeder_data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def evaluate(model, test_reader, batch_size):
model.eval()
total_step = 0.0
equal_size = 0
for data in test_reader():
data_dict = get_attention_feeder_data(data)
label_in = to_variable(data_dict["label_in"])
label_out = to_variable(data_dict["label_out"])
label_out.stop_gradient = True
img = to_variable(data_dict["pixel"])
prediction = model(img, label_in)
prediction = fluid.layers.reshape(prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False)
score, topk = layers.topk(prediction, 1)
seq = topk.numpy()
seq = seq.reshape((batch_size, -1))
mask = data_dict['mask'].reshape((batch_size, -1))
seq_len = np.sum(mask, -1)
trans_ref = data_dict["label_out"].reshape((batch_size, -1))
for i in range(batch_size):
length = int(seq_len[i] - 1)
trans = seq[i][:length - 1]
ref = trans_ref[i][: length - 1]
if np.array_equal(trans, ref):
equal_size += 1
total_step += batch_size
accuracy = equal_size / total_step
print("eval accuracy:", accuracy)
return accuracy
def eval(args):
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=args.batch_size,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
test_reader = data_reader.data_reader(
args.batch_size,
images_dir=args.test_images,
list_file=args.test_list,
data_type="test")
evaluate(ocr_attention, test_reader, args.batch_size)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
eval(args)
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments
from PIL import Image
from nets import OCRAttention
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('image_path', str, "", "image path")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('max_length', int, 100, "Max predict length.")
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def inference(args):
img = Image.open(os.path.join(args.image_path)).convert('L')
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=1,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
ocr_attention.eval()
print(img.size)
img = img.resize((img.size[0], 48), Image.BILINEAR)
img = np.array(img).astype('float32') - 127.5
img = img[np.newaxis, np.newaxis, ...]
img = to_variable(img)
gru_backward, encoded_vector, encoded_proj = ocr_attention.encoder_net(img)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = ocr_attention.fc(backward_first)
label_in = fluid.layers.zeros([1], dtype='int64')
result = ''
for i in range(args.max_length):
trg_embedding = ocr_attention.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [1, -1, trg_embedding.shape[1]],
inplace=False)
prediction, decoder_boot = ocr_attention.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot, inference=True)
prediction = fluid.layers.reshape(prediction, [args.num_classes + 2])
score, idx = fluid.layers.topk(prediction, 1)
idx_np = idx.numpy()[0]
if idx_np == 1:
print('met end character, predict finish!')
break
label_in = fluid.layers.reshape(idx, [1])
result += chr(int(idx_np + 33))
print('predict result:', result)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
inference(args)
\ No newline at end of file
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
[16, 16], [1, 16],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
[32, 32], [16, 32],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
[64, 64], [32, 64],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
[128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[:, i: i + 1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
batch_size,
decoder_size,
rnn_hidden_size=200,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0,
0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
if fluid.framework.in_dygraph_mode():
h_0 = np.zeros(
(batch_size, rnn_hidden_size), dtype="float32")
h_0 = to_variable(h_0)
else:
h_0 = fluid.layers.fill_constant(
shape=[batch_size, rnn_hidden_size],
dtype='float32',
value=0)
self.ocr_convs = OCRConv(
is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.fc_2_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
self.encoded_proj_fc = Linear(rnn_hidden_size * 2,
decoder_size,
bias_attr=False)
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear( decoder_size,
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = Linear( decoder_size,
1,
act=None,
bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[ concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax( weights_reshape )
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, encoder_size, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear(input_dim=encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear(input_dim=decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3,
param_attr=None,
bias_attr=None)
self.out_layer = Linear(input_dim=decoder_size,
output_dim =num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self, current_word, encoder_vec, encoder_proj,
decoder_boot, inference=False):
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
decoder_boot)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, decoder_boot)
out = self.out_layer(h)
return out, h
class OCRAttention(fluid.dygraph.Layer):
def __init__(self, batch_size, num_classes, encoder_size, decoder_size, word_vector_dim):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet(batch_size, decoder_size)
self.fc = Linear(input_dim=encoder_size,
output_dim=decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[num_classes + 2, word_vector_dim],
dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(encoder_size, decoder_size,
num_classes)
self.batch_size = batch_size
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [self.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
pred_temp = []
for i in range(trg_embedding.shape[1]):
current_word = fluid.layers.slice(
trg_embedding, axes=[1], starts=[i], ends=[i + 1])
out, decoder_boot = self.gru_decoder_with_attention(
current_word, encoded_vector, encoded_proj, decoder_boot
)
pred_temp.append(out)
pred_temp = fluid.layers.concat(pred_temp, axis=1)
batch_size = trg_embedding.shape[0]
seq_len = trg_embedding.shape[1]
prediction = fluid.layers.reshape(pred_temp, shape=[batch_size, seq_len, -1])
return prediction
...@@ -13,422 +13,48 @@ ...@@ -13,422 +13,48 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import sys import os
import numpy as np
import paddle.fluid.profiler as profiler import paddle.fluid.profiler as profiler
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers
import data_reader import data_reader
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
import argparse import argparse
import functools import functools
from utility import add_arguments, print_arguments, get_attention_feeder_data from utility import add_arguments, print_arguments, get_attention_feeder_data
import time
from paddle.fluid import framework from nets import OCRAttention
from eval import evaluate
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.") add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.") add_arg('epoch_num', int, 30, "Epoch number.")
add_arg('log_period', int, 1000, "Log period.") add_arg('lr', float, 0.001, "Learning rate.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.") add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.") add_arg('log_period', int, 200, "Log period.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.") add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.") add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.") add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.") add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.") add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('init_model', str, None, "The init model file of directory.") add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.") add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.") add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.") add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.") add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.") add_arg('skip_test', bool, False, "Whether to skip test phase.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
class Config(object): add_arg('decoder_size', int, 128, "Decoder size.")
''' add_arg('word_vector_dim', int, 128, "Word vector dim.")
config for training add_arg('num_classes', int, 95, "Number classes.")
''' add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
# encoder rnn hidden_size
encoder_size = 200
# decoder size for decoder stage
decoder_size = 128
# size for word embedding
word_vector_dim = 128
# max length for label padding
max_length = 100
gradient_clip = 10
LR = 1.0
beam_size = 2
learning_rate_decay = None
# batch size to train
batch_size = 32
# class number to classify
num_classes = 95
use_gpu = False
# special label for start and end
SOS = 0
EOS = 1
# data shape for input image
DATA_SHAPE = [1, 48, 512]
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
group,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.group = group
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
2, [16, 16], [1, 16],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
2, [32, 32], [16, 32],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
2, [64, 64], [32, 64],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
2, [128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
rnn_hidden_size=Config.encoder_size,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0,
0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
if fluid.framework.in_dygraph_mode():
h_0 = np.zeros(
(Config.batch_size, rnn_hidden_size), dtype="float32")
h_0 = to_variable(h_0)
else:
h_0 = fluid.layers.fill_constant(
shape=[Config.batch_size, rnn_hidden_size],
dtype='float32',
value=0)
self.ocr_convs = OCRConv(
is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear( 768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False )
print( "weight", self.fc_1_layer.weight.shape )
self.fc_2_layer = Linear( 768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False )
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
self.encoded_proj_fc = Linear( rnn_hidden_size * 2,
Config.decoder_size,
bias_attr=False )
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear( decoder_size,
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = Linear( decoder_size,
1,
act=None,
bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[ concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax( weights_reshape )
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear( input_dim = Config.encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear( input_dim = decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3,
param_attr=None,
bias_attr=None)
self.out_layer = Linear( input_dim = decoder_size,
output_dim =num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self, target_embedding, encoder_vec, encoder_proj,
decoder_boot):
res = []
hidden_mem = decoder_boot
for i in range(target_embedding.shape[1]):
current_word = fluid.layers.slice(
target_embedding, axes=[1], starts=[i], ends=[i + 1])
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
hidden_mem)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, hidden_mem)
hidden_mem = h
out = self.out_layer(h)
res.append(out)
res1 = fluid.layers.concat(res, axis=1)
batch_size = target_embedding.shape[0]
seq_len = target_embedding.shape[1]
res1 = layers.reshape( res1, shape=[batch_size, seq_len, -1])
return res1
class OCRAttention(fluid.dygraph.Layer):
def __init__(self):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet()
self.fc = Linear( input_dim = Config.encoder_size,
output_dim =Config.decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[Config.num_classes + 2, Config.word_vector_dim],
dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(
Config.decoder_size, Config.num_classes)
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [Config.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
prediction = self.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot)
return prediction
def train(args): def train(args):
...@@ -436,74 +62,41 @@ def train(args): ...@@ -436,74 +62,41 @@ def train(args):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True backward_strategy.sort_sum_gradient = True
ocr_attention = OCRAttention()
if Config.learning_rate_decay == "piecewise_decay": ocr_attention = OCRAttention(batch_size=args.batch_size,
learning_rate = fluid.layers.piecewise_decay( encoder_size=args.encoder_size, decoder_size=args.decoder_size,
[50000], [Config.LR, Config.LR * 0.01]) num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
LR = args.lr
if args.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01])
else: else:
learning_rate = Config.LR learning_rate = LR
optimizer = fluid.optimizer.Adam(learning_rate=0.001, parameter_list=ocr_attention.parameters())
dy_param_init_value = {}
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0 ) optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters())
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
train_reader = data_reader.data_reader( train_reader = data_reader.data_reader(
Config.batch_size, args.batch_size,
cycle=args.total_step > 0,
shuffle=True, shuffle=True,
images_dir=args.train_images,
list_file=args.train_list,
data_type='train') data_type='train')
infer_image= './data/data/test_images/'
infer_files = './data/data/test.list'
test_reader = data_reader.data_reader( test_reader = data_reader.data_reader(
Config.batch_size, args.batch_size,
cycle=False, images_dir=args.test_images,
list_file=args.test_list,
data_type="test") data_type="test")
def eval():
ocr_attention.eval()
total_loss = 0.0
total_step = 0.0
equal_size = 0
for data in test_reader():
data_dict = get_attention_feeder_data(data)
label_in = to_variable(data_dict["label_in"])
label_out = to_variable(data_dict["label_out"])
label_out.stop_gradient = True
img = to_variable(data_dict["pixel"])
prediction = ocr_attention(img, label_in)
prediction = fluid.layers.reshape( prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False)
score, topk = layers.topk( prediction, 1)
seq = topk.numpy()
seq = seq.reshape( ( args.batch_size, -1))
mask = data_dict['mask'].reshape( (args.batch_size, -1))
seq_len = np.sum( mask, -1)
trans_ref = data_dict["label_out"].reshape( (args.batch_size, -1))
for i in range( args.batch_size ):
length = int(seq_len[i] -1 )
trans = seq[i][:length - 1]
ref = trans_ref[i][ : length - 1]
if np.array_equal( trans, ref ):
equal_size += 1
total_step += args.batch_size
print( "eval cost", equal_size / total_step )
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
total_step = 0 total_step = 0
epoch_num = 20 epoch_num = args.epoch_num
for epoch in range(epoch_num): for epoch in range(epoch_num):
batch_id = 0 batch_id = 0
total_loss = 0.0 total_loss = 0.0
for data in train_reader(): for data in train_reader():
total_step += 1 total_step += 1
...@@ -524,7 +117,7 @@ def train(args): ...@@ -524,7 +117,7 @@ def train(args):
mask = to_variable(data_dict["mask"]) mask = to_variable(data_dict["mask"])
loss = layers.elementwise_mul( loss, mask, axis=0) loss = fluid.layers.elementwise_mul( loss, mask, axis=0)
avg_loss = fluid.layers.reduce_sum(loss) avg_loss = fluid.layers.reduce_sum(loss)
total_loss += avg_loss.numpy() total_loss += avg_loss.numpy()
...@@ -532,21 +125,24 @@ def train(args): ...@@ -532,21 +125,24 @@ def train(args):
optimizer.minimize(avg_loss, grad_clip=grad_clip) optimizer.minimize(avg_loss, grad_clip=grad_clip)
ocr_attention.clear_gradients() ocr_attention.clear_gradients()
if batch_id > 0 and batch_id % 1000 == 0: if batch_id > 0 and batch_id % args.log_period == 0:
print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id, total_loss / args.batch_size / 1000)) print("epoch: {}, batch_id: {}, lr: {}, loss {}".format(epoch, batch_id,
optimizer._global_learning_rate().numpy(),
total_loss / args.batch_size / args.log_period))
total_loss = 0.0 total_loss = 0.0
if total_step > 0 and total_step % 2000 == 0: if total_step > 0 and total_step % args.save_model_period == 0:
if fluid.dygraph.parallel.Env().dev_id == 0:
model_file = os.path.join(args.save_model_dir, 'step_{}'.format(total_step))
fluid.save_dygraph(ocr_attention.state_dict(), model_file)
print('step_{}.pdparams saved!'.format(total_step))
if total_step > 0 and total_step % args.eval_period == 0:
ocr_attention.eval() ocr_attention.eval()
eval() evaluate(ocr_attention, test_reader, args.batch_size)
ocr_attention.train() ocr_attention.train()
batch_id +=1 batch_id += 1
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册