“60e23e952dd035a3e6f294d9570516cc8cca37f9”上不存在“paddle/fluid/lite/api/paddle_use_kernels.h”
未验证 提交 52ca7b75 编写于 作者: L LielinJiang 提交者: GitHub

Refine ocr dygraph code, add infer module (#4182)

* refine code, add infer module

* update readme
上级 6967b7ab
......@@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten
在GPU单卡上训练ocr recognition:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
CUDA_VISIBLE_DEVICES=0 python train.py
```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
## 效果
## 测试ocr recognition
在test测试集合上,最好的效果为82.0%
```
CUDA_VISIBLE_DEVICES=0 python eval.py --pretrained_model your_trained_model_path
```
## 预测
```
CUDA_VISIBLE_DEVICES=0 python -u infer.py --pretrained_model your_trained_model_path --image_path your_img_path
```
## 预训练模型
|模型| 准确率|
|- |:-: |
|[ocr_attention_params](https://paddle-ocr-models.bj.bcebos.com/ocr_attention_dygraph.tar) | 82.46%|
......@@ -2,13 +2,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
import tarfile
import numpy as np
from PIL import Image
from os import path
from paddle.dataset.image import load_image
import paddle
import random
SOS = 0
EOS = 1
......@@ -53,24 +52,53 @@ class DataGenerator(object):
img_label_lines = []
to_file = "tmp.txt"
def _shuffle_data(input_file_path, output_file_path, shuffle,
batchsize):
def _write_file(file_path, lines_to_write):
open(file_path, 'w').writelines(
["{}\n".format(item) for item in lines_to_write])
input_file = open(input_file_path, 'r')
lines_to_shuf = [line.strip() for line in input_file.readlines()]
if not shuffle:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file
_write_file(output_file_path, lines_to_shuf)
elif batchsize == 1:
cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file
random.shuffle(lines_to_shuf)
_write_file(output_file_path, lines_to_shuf)
else:
#cmd1: partial shuffle
cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | "
#cmd2: batch merge and shuffle
cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str(
batchsize) + " == 0) print \"\";}' | shuf | "
#cmd3: batch split
cmd += "awk '{if(NF == " + str(
batchsize
) + " * 4) {for(i = 0; i < " + str(
batchsize
) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file
os.system(cmd)
print("finish batch shuffle")
#partial shuffle
for i in range(len(lines_to_shuf)):
str_i = lines_to_shuf[i]
list_i = str_i.strip().split(' ')
str_i_ = "%04d%.4f " % (int(list_i[0]), random.random()
) + str_i
lines_to_shuf[i] = str_i_
lines_to_shuf.sort()
delete_num = random.randint(1, 100)
del lines_to_shuf[0:delete_num]
#batch merge and shuffle
lines_concat = []
for i in range(0, len(lines_to_shuf), batchsize):
lines_concat.append(' '.join(lines_to_shuf[i:i +
batchsize]))
random.shuffle(lines_concat)
#batch split
out_file = open(output_file_path, 'w')
for i in range(len(lines_concat)):
tmp_list = lines_concat[i].split(' ')
for j in range(int(len(tmp_list) / 5)):
out_file.write("{} {} {} {}\n".format(tmp_list[
5 * j + 1], tmp_list[5 * j + 2], tmp_list[
5 * j + 3], tmp_list[5 * j + 4]))
out_file.close()
input_file.close()
_shuffle_data(img_label_list, to_file, shuffle, batchsize)
img_label_lines = open(to_file, 'r').readlines()
def reader():
......@@ -95,7 +123,7 @@ class DataGenerator(object):
mask = np.zeros((max_len)).astype('float32')
mask[:len(label) + 1] = 1.0
#mask[ j, :len(label) + 1] = 1.0
if max_len > len(label) + 1:
extend_label = [EOS] * (max_len - len(label) - 1)
label.extend(extend_label)
......
export CUDA_VISIBLE_DEVICES=0
python train.py
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import functools
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import data_reader
from nets import OCRAttention
from paddle.fluid.dygraph.base import to_variable
from utility import add_arguments, print_arguments, get_attention_feeder_data
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def evaluate(model, test_reader, batch_size):
model.eval()
total_step = 0.0
equal_size = 0
for data in test_reader():
data_dict = get_attention_feeder_data(data)
label_in = to_variable(data_dict["label_in"])
label_out = to_variable(data_dict["label_out"])
label_out.stop_gradient = True
img = to_variable(data_dict["pixel"])
prediction = model(img, label_in)
prediction = fluid.layers.reshape(prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False)
score, topk = layers.topk(prediction, 1)
seq = topk.numpy()
seq = seq.reshape((batch_size, -1))
mask = data_dict['mask'].reshape((batch_size, -1))
seq_len = np.sum(mask, -1)
trans_ref = data_dict["label_out"].reshape((batch_size, -1))
for i in range(batch_size):
length = int(seq_len[i] - 1)
trans = seq[i][:length - 1]
ref = trans_ref[i][: length - 1]
if np.array_equal(trans, ref):
equal_size += 1
total_step += batch_size
accuracy = equal_size / total_step
print("eval accuracy:", accuracy)
return accuracy
def eval(args):
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=args.batch_size,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
test_reader = data_reader.data_reader(
args.batch_size,
images_dir=args.test_images,
list_file=args.test_list,
data_type="test")
evaluate(ocr_attention, test_reader, args.batch_size)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
eval(args)
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments
from PIL import Image
from nets import OCRAttention
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('image_path', str, "", "image path")
add_arg('pretrained_model', str, "", "pretrained_model.")
add_arg('max_length', int, 100, "Max predict length.")
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def inference(args):
img = Image.open(os.path.join(args.image_path)).convert('L')
with fluid.dygraph.guard():
ocr_attention = OCRAttention(batch_size=1,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
restore, _ = fluid.load_dygraph(args.pretrained_model)
ocr_attention.set_dict(restore)
ocr_attention.eval()
print(img.size)
img = img.resize((img.size[0], 48), Image.BILINEAR)
img = np.array(img).astype('float32') - 127.5
img = img[np.newaxis, np.newaxis, ...]
img = to_variable(img)
gru_backward, encoded_vector, encoded_proj = ocr_attention.encoder_net(img)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = ocr_attention.fc(backward_first)
label_in = fluid.layers.zeros([1], dtype='int64')
result = ''
for i in range(args.max_length):
trg_embedding = ocr_attention.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [1, -1, trg_embedding.shape[1]],
inplace=False)
prediction, decoder_boot = ocr_attention.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot, inference=True)
prediction = fluid.layers.reshape(prediction, [args.num_classes + 2])
score, idx = fluid.layers.topk(prediction, 1)
idx_np = idx.numpy()[0]
if idx_np == 1:
print('met end character, predict finish!')
break
label_in = fluid.layers.reshape(idx, [1])
result += chr(int(idx_np + 33))
print('predict result:', result)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
inference(args)
\ No newline at end of file
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
[16, 16], [1, 16],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
[32, 32], [16, 32],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
[64, 64], [32, 64],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
[128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[:, i: i + 1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
batch_size,
decoder_size,
rnn_hidden_size=200,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0,
0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
if fluid.framework.in_dygraph_mode():
h_0 = np.zeros(
(batch_size, rnn_hidden_size), dtype="float32")
h_0 = to_variable(h_0)
else:
h_0 = fluid.layers.fill_constant(
shape=[batch_size, rnn_hidden_size],
dtype='float32',
value=0)
self.ocr_convs = OCRConv(
is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.fc_2_layer = Linear(768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
self.encoded_proj_fc = Linear(rnn_hidden_size * 2,
decoder_size,
bias_attr=False)
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear( decoder_size,
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = Linear( decoder_size,
1,
act=None,
bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[ concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax( weights_reshape )
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, encoder_size, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear(input_dim=encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear(input_dim=decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3,
param_attr=None,
bias_attr=None)
self.out_layer = Linear(input_dim=decoder_size,
output_dim =num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self, current_word, encoder_vec, encoder_proj,
decoder_boot, inference=False):
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
decoder_boot)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, decoder_boot)
out = self.out_layer(h)
return out, h
class OCRAttention(fluid.dygraph.Layer):
def __init__(self, batch_size, num_classes, encoder_size, decoder_size, word_vector_dim):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet(batch_size, decoder_size)
self.fc = Linear(input_dim=encoder_size,
output_dim=decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[num_classes + 2, word_vector_dim],
dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(encoder_size, decoder_size,
num_classes)
self.batch_size = batch_size
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [self.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
pred_temp = []
for i in range(trg_embedding.shape[1]):
current_word = fluid.layers.slice(
trg_embedding, axes=[1], starts=[i], ends=[i + 1])
out, decoder_boot = self.gru_decoder_with_attention(
current_word, encoded_vector, encoded_proj, decoder_boot
)
pred_temp.append(out)
pred_temp = fluid.layers.concat(pred_temp, axis=1)
batch_size = trg_embedding.shape[0]
seq_len = trg_embedding.shape[1]
prediction = fluid.layers.reshape(pred_temp, shape=[batch_size, seq_len, -1])
return prediction
......@@ -13,422 +13,48 @@
# limitations under the License.
from __future__ import print_function
import sys
import os
import numpy as np
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import data_reader
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments, get_attention_feeder_data
import time
from paddle.fluid import framework
from nets import OCRAttention
from eval import evaluate
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.")
add_arg('log_period', int, 1000, "Log period.")
add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.")
add_arg('epoch_num', int, 30, "Epoch number.")
add_arg('lr', float, 0.001, "Learning rate.")
add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
add_arg('log_period', int, 200, "Log period.")
add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('min_average_window',int, 10000, "Min average window.")
add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.")
add_arg('average_window', float, 0.15, "Average window.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
class Config(object):
'''
config for training
'''
# encoder rnn hidden_size
encoder_size = 200
# decoder size for decoder stage
decoder_size = 128
# size for word embedding
word_vector_dim = 128
# max length for label padding
max_length = 100
gradient_clip = 10
LR = 1.0
beam_size = 2
learning_rate_decay = None
# batch size to train
batch_size = 32
# class number to classify
num_classes = 95
use_gpu = False
# special label for start and end
SOS = 0
EOS = 1
# data shape for input image
DATA_SHAPE = [1, 48, 512]
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
group,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.group = group
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(
out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(
out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
2, [16, 16], [1, 16],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
2, [32, 32], [16, 32],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
2, [64, 64], [32, 64],
is_test=is_test,
use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
2, [128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class DynamicGRU(fluid.dygraph.Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
rnn_hidden_size=Config.encoder_size,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0,
0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
if fluid.framework.in_dygraph_mode():
h_0 = np.zeros(
(Config.batch_size, rnn_hidden_size), dtype="float32")
h_0 = to_variable(h_0)
else:
h_0 = fluid.layers.fill_constant(
shape=[Config.batch_size, rnn_hidden_size],
dtype='float32',
value=0)
self.ocr_convs = OCRConv(
is_test=is_test, use_cudnn=use_cudnn)
self.fc_1_layer = Linear( 768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False )
print( "weight", self.fc_1_layer.weight.shape )
self.fc_2_layer = Linear( 768,
rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False )
self.gru_forward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
self.gru_backward_layer = DynamicGRU(
size=rnn_hidden_size,
h_0=h_0,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu',
is_reverse=True)
self.encoded_proj_fc = Linear( rnn_hidden_size * 2,
Config.decoder_size,
bias_attr=False )
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False)
fc_1 = self.fc_1_layer(sliced_feature)
fc_2 = self.fc_2_layer(sliced_feature)
gru_forward = self.gru_forward_layer(fc_1)
gru_backward = self.gru_backward_layer(fc_2)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc_1 = Linear( decoder_size,
decoder_size,
act=None,
bias_attr=False)
self.fc_2 = Linear( decoder_size,
1,
act=None,
bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state_fc = self.fc_1(decoder_state)
decoder_state_proj_reshape = fluid.layers.reshape(
decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False)
decoder_state_expand = fluid.layers.expand(
decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1])
concated = fluid.layers.elementwise_add(encoder_proj,
decoder_state_expand)
concated = fluid.layers.tanh(x=concated)
attention_weight = self.fc_2(concated)
weights_reshape = fluid.layers.reshape(
x=attention_weight, shape=[ concated.shape[0], -1], inplace=False)
weights_reshape = fluid.layers.softmax( weights_reshape )
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weights_reshape, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.simple_attention = SimpleAttention(decoder_size)
self.fc_1_layer = Linear( input_dim = Config.encoder_size * 2,
output_dim=decoder_size * 3,
bias_attr=False)
self.fc_2_layer = Linear( input_dim = decoder_size,
output_dim=decoder_size * 3,
bias_attr=False)
self.gru_unit = GRUUnit(
size=decoder_size * 3,
param_attr=None,
bias_attr=None)
self.out_layer = Linear( input_dim = decoder_size,
output_dim =num_classes + 2,
bias_attr=None,
act='softmax')
self.decoder_size = decoder_size
def forward(self, target_embedding, encoder_vec, encoder_proj,
decoder_boot):
res = []
hidden_mem = decoder_boot
for i in range(target_embedding.shape[1]):
current_word = fluid.layers.slice(
target_embedding, axes=[1], starts=[i], ends=[i + 1])
current_word = fluid.layers.reshape(
current_word, [-1, current_word.shape[2]], inplace=False)
context = self.simple_attention(encoder_vec, encoder_proj,
hidden_mem)
fc_1 = self.fc_1_layer(context)
fc_2 = self.fc_2_layer(current_word)
decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2)
h, _, _ = self.gru_unit(decoder_inputs, hidden_mem)
hidden_mem = h
out = self.out_layer(h)
res.append(out)
res1 = fluid.layers.concat(res, axis=1)
batch_size = target_embedding.shape[0]
seq_len = target_embedding.shape[1]
res1 = layers.reshape( res1, shape=[batch_size, seq_len, -1])
return res1
class OCRAttention(fluid.dygraph.Layer):
def __init__(self):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet()
self.fc = Linear( input_dim = Config.encoder_size,
output_dim =Config.decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[Config.num_classes + 2, Config.word_vector_dim],
dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(
Config.decoder_size, Config.num_classes)
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
backward_first = fluid.layers.slice(
gru_backward, axes=[1], starts=[0], ends=[1])
backward_first = fluid.layers.reshape(
backward_first, [-1, backward_first.shape[2]], inplace=False)
decoder_boot = self.fc(backward_first)
label_in = fluid.layers.reshape(label_in, [-1], inplace=False)
trg_embedding = self.embedding(label_in)
trg_embedding = fluid.layers.reshape(
trg_embedding, [Config.batch_size, -1, trg_embedding.shape[1]],
inplace=False)
prediction = self.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot)
return prediction
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
def train(args):
......@@ -436,74 +62,41 @@ def train(args):
with fluid.dygraph.guard():
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
ocr_attention = OCRAttention()
if Config.learning_rate_decay == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay(
[50000], [Config.LR, Config.LR * 0.01])
ocr_attention = OCRAttention(batch_size=args.batch_size,
encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
LR = args.lr
if args.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01])
else:
learning_rate = Config.LR
optimizer = fluid.optimizer.Adam(learning_rate=0.001, parameter_list=ocr_attention.parameters())
dy_param_init_value = {}
learning_rate = LR
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0 )
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters())
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
train_reader = data_reader.data_reader(
Config.batch_size,
cycle=args.total_step > 0,
args.batch_size,
shuffle=True,
images_dir=args.train_images,
list_file=args.train_list,
data_type='train')
infer_image= './data/data/test_images/'
infer_files = './data/data/test.list'
test_reader = data_reader.data_reader(
Config.batch_size,
cycle=False,
args.batch_size,
images_dir=args.test_images,
list_file=args.test_list,
data_type="test")
def eval():
ocr_attention.eval()
total_loss = 0.0
total_step = 0.0
equal_size = 0
for data in test_reader():
data_dict = get_attention_feeder_data(data)
label_in = to_variable(data_dict["label_in"])
label_out = to_variable(data_dict["label_out"])
label_out.stop_gradient = True
img = to_variable(data_dict["pixel"])
prediction = ocr_attention(img, label_in)
prediction = fluid.layers.reshape( prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False)
score, topk = layers.topk( prediction, 1)
seq = topk.numpy()
seq = seq.reshape( ( args.batch_size, -1))
mask = data_dict['mask'].reshape( (args.batch_size, -1))
seq_len = np.sum( mask, -1)
trans_ref = data_dict["label_out"].reshape( (args.batch_size, -1))
for i in range( args.batch_size ):
length = int(seq_len[i] -1 )
trans = seq[i][:length - 1]
ref = trans_ref[i][ : length - 1]
if np.array_equal( trans, ref ):
equal_size += 1
total_step += args.batch_size
print( "eval cost", equal_size / total_step )
if not os.path.exists(args.save_model_dir):
os.makedirs(args.save_model_dir)
total_step = 0
epoch_num = 20
epoch_num = args.epoch_num
for epoch in range(epoch_num):
batch_id = 0
total_loss = 0.0
for data in train_reader():
total_step += 1
......@@ -524,7 +117,7 @@ def train(args):
mask = to_variable(data_dict["mask"])
loss = layers.elementwise_mul( loss, mask, axis=0)
loss = fluid.layers.elementwise_mul( loss, mask, axis=0)
avg_loss = fluid.layers.reduce_sum(loss)
total_loss += avg_loss.numpy()
......@@ -532,21 +125,24 @@ def train(args):
optimizer.minimize(avg_loss, grad_clip=grad_clip)
ocr_attention.clear_gradients()
if batch_id > 0 and batch_id % 1000 == 0:
print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id, total_loss / args.batch_size / 1000))
if batch_id > 0 and batch_id % args.log_period == 0:
print("epoch: {}, batch_id: {}, lr: {}, loss {}".format(epoch, batch_id,
optimizer._global_learning_rate().numpy(),
total_loss / args.batch_size / args.log_period))
total_loss = 0.0
if total_step > 0 and total_step % 2000 == 0:
if total_step > 0 and total_step % args.save_model_period == 0:
if fluid.dygraph.parallel.Env().dev_id == 0:
model_file = os.path.join(args.save_model_dir, 'step_{}'.format(total_step))
fluid.save_dygraph(ocr_attention.state_dict(), model_file)
print('step_{}.pdparams saved!'.format(total_step))
if total_step > 0 and total_step % args.eval_period == 0:
ocr_attention.eval()
eval()
evaluate(ocr_attention, test_reader, args.batch_size)
ocr_attention.train()
batch_id +=1
batch_id += 1
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册