From 52ca7b75358c216094651bbba865fb4e3bfa004e Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 3 Mar 2020 20:42:03 +0800 Subject: [PATCH] Refine ocr dygraph code, add infer module (#4182) * refine code, add infer module * update readme --- dygraph/ocr_recognition/README.md | 22 +- dygraph/ocr_recognition/data_reader.py | 70 +++- dygraph/ocr_recognition/debug.sh | 4 - dygraph/ocr_recognition/eval.py | 101 +++++ dygraph/ocr_recognition/infer.py | 91 +++++ dygraph/ocr_recognition/nets.py | 346 +++++++++++++++++ dygraph/ocr_recognition/train.py | 506 +++---------------------- 7 files changed, 657 insertions(+), 483 deletions(-) delete mode 100644 dygraph/ocr_recognition/debug.sh create mode 100644 dygraph/ocr_recognition/eval.py create mode 100644 dygraph/ocr_recognition/infer.py create mode 100644 dygraph/ocr_recognition/nets.py diff --git a/dygraph/ocr_recognition/README.md b/dygraph/ocr_recognition/README.md index 85ce6541..164807b2 100644 --- a/dygraph/ocr_recognition/README.md +++ b/dygraph/ocr_recognition/README.md @@ -25,11 +25,27 @@ ocr任务是识别图片单行的字母信息,在动态图下使用了带atten 在GPU单卡上训练ocr recognition: ``` -env CUDA_VISIBLE_DEVICES=0 python train.py +CUDA_VISIBLE_DEVICES=0 python train.py ``` 这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。 -## 效果 +## 测试ocr recognition -在test测试集合上,最好的效果为82.0% + +``` +CUDA_VISIBLE_DEVICES=0 python eval.py --pretrained_model your_trained_model_path +``` + +## 预测 + + +``` +CUDA_VISIBLE_DEVICES=0 python -u infer.py --pretrained_model your_trained_model_path --image_path your_img_path +``` + +## 预训练模型 + +|模型| 准确率| +|- |:-: | +|[ocr_attention_params](https://paddle-ocr-models.bj.bcebos.com/ocr_attention_dygraph.tar) | 82.46%| diff --git a/dygraph/ocr_recognition/data_reader.py b/dygraph/ocr_recognition/data_reader.py index 7b1b9b5f..619d7306 100644 --- a/dygraph/ocr_recognition/data_reader.py +++ b/dygraph/ocr_recognition/data_reader.py @@ -2,13 +2,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -import cv2 import tarfile import numpy as np from PIL import Image from os import path -from paddle.dataset.image import load_image import paddle +import random SOS = 0 EOS = 1 @@ -53,24 +52,53 @@ class DataGenerator(object): img_label_lines = [] to_file = "tmp.txt" - if not shuffle: - cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' > " + to_file - elif batchsize == 1: - cmd = "cat " + img_label_list + " | awk '{print $1,$2,$3,$4;}' | shuf > " + to_file - else: - #cmd1: partial shuffle - cmd = "cat " + img_label_list + " | awk '{printf(\"%04d%.4f %s\\n\", $1, rand(), $0)}' | sort | sed 1,$((1 + RANDOM % 100))d | " - #cmd2: batch merge and shuffle - cmd += "awk '{printf $2\" \"$3\" \"$4\" \"$5\" \"; if(NR % " + str( - batchsize) + " == 0) print \"\";}' | shuf | " - #cmd3: batch split - cmd += "awk '{if(NF == " + str( - batchsize - ) + " * 4) {for(i = 0; i < " + str( - batchsize - ) + "; i++) print $(4*i+1)\" \"$(4*i+2)\" \"$(4*i+3)\" \"$(4*i+4);}}' > " + to_file - os.system(cmd) - print("finish batch shuffle") + + def _shuffle_data(input_file_path, output_file_path, shuffle, + batchsize): + def _write_file(file_path, lines_to_write): + open(file_path, 'w').writelines( + ["{}\n".format(item) for item in lines_to_write]) + + input_file = open(input_file_path, 'r') + lines_to_shuf = [line.strip() for line in input_file.readlines()] + + if not shuffle: + _write_file(output_file_path, lines_to_shuf) + elif batchsize == 1: + random.shuffle(lines_to_shuf) + _write_file(output_file_path, lines_to_shuf) + else: + #partial shuffle + for i in range(len(lines_to_shuf)): + str_i = lines_to_shuf[i] + list_i = str_i.strip().split(' ') + str_i_ = "%04d%.4f " % (int(list_i[0]), random.random() + ) + str_i + lines_to_shuf[i] = str_i_ + lines_to_shuf.sort() + delete_num = random.randint(1, 100) + del lines_to_shuf[0:delete_num] + + #batch merge and shuffle + lines_concat = [] + for i in range(0, len(lines_to_shuf), batchsize): + lines_concat.append(' '.join(lines_to_shuf[i:i + + batchsize])) + random.shuffle(lines_concat) + + #batch split + out_file = open(output_file_path, 'w') + for i in range(len(lines_concat)): + tmp_list = lines_concat[i].split(' ') + for j in range(int(len(tmp_list) / 5)): + out_file.write("{} {} {} {}\n".format(tmp_list[ + 5 * j + 1], tmp_list[5 * j + 2], tmp_list[ + 5 * j + 3], tmp_list[5 * j + 4])) + out_file.close() + input_file.close() + + _shuffle_data(img_label_list, to_file, shuffle, batchsize) + img_label_lines = open(to_file, 'r').readlines() def reader(): @@ -95,7 +123,7 @@ class DataGenerator(object): mask = np.zeros((max_len)).astype('float32') mask[:len(label) + 1] = 1.0 - #mask[ j, :len(label) + 1] = 1.0 + if max_len > len(label) + 1: extend_label = [EOS] * (max_len - len(label) - 1) label.extend(extend_label) diff --git a/dygraph/ocr_recognition/debug.sh b/dygraph/ocr_recognition/debug.sh deleted file mode 100644 index 076a52aa..00000000 --- a/dygraph/ocr_recognition/debug.sh +++ /dev/null @@ -1,4 +0,0 @@ - -export CUDA_VISIBLE_DEVICES=0 - -python train.py diff --git a/dygraph/ocr_recognition/eval.py b/dygraph/ocr_recognition/eval.py new file mode 100644 index 00000000..18ef150f --- /dev/null +++ b/dygraph/ocr_recognition/eval.py @@ -0,0 +1,101 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import argparse +import functools +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import data_reader +from nets import OCRAttention +from paddle.fluid.dygraph.base import to_variable +from utility import add_arguments, print_arguments, get_attention_feeder_data + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 32, "Minibatch size.") +add_arg('pretrained_model', str, "", "pretrained_model.") +add_arg('test_images', str, None, "The directory of images to be used for test.") +add_arg('test_list', str, None, "The list file of images to be used for training.") +# model hyper paramters +add_arg('encoder_size', int, 200, "Encoder size.") +add_arg('decoder_size', int, 128, "Decoder size.") +add_arg('word_vector_dim', int, 128, "Word vector dim.") +add_arg('num_classes', int, 95, "Number classes.") +add_arg('gradient_clip', float, 5.0, "Gradient clip value.") + + +def evaluate(model, test_reader, batch_size): + model.eval() + + total_step = 0.0 + equal_size = 0 + for data in test_reader(): + data_dict = get_attention_feeder_data(data) + + label_in = to_variable(data_dict["label_in"]) + label_out = to_variable(data_dict["label_out"]) + + label_out.stop_gradient = True + + img = to_variable(data_dict["pixel"]) + + prediction = model(img, label_in) + prediction = fluid.layers.reshape(prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False) + + score, topk = layers.topk(prediction, 1) + + seq = topk.numpy() + + seq = seq.reshape((batch_size, -1)) + + mask = data_dict['mask'].reshape((batch_size, -1)) + seq_len = np.sum(mask, -1) + + trans_ref = data_dict["label_out"].reshape((batch_size, -1)) + for i in range(batch_size): + length = int(seq_len[i] - 1) + trans = seq[i][:length - 1] + ref = trans_ref[i][: length - 1] + if np.array_equal(trans, ref): + equal_size += 1 + + total_step += batch_size + accuracy = equal_size / total_step + print("eval accuracy:", accuracy) + return accuracy + + +def eval(args): + with fluid.dygraph.guard(): + ocr_attention = OCRAttention(batch_size=args.batch_size, + encoder_size=args.encoder_size, decoder_size=args.decoder_size, + num_classes=args.num_classes, word_vector_dim=args.word_vector_dim) + restore, _ = fluid.load_dygraph(args.pretrained_model) + ocr_attention.set_dict(restore) + + test_reader = data_reader.data_reader( + args.batch_size, + images_dir=args.test_images, + list_file=args.test_list, + data_type="test") + evaluate(ocr_attention, test_reader, args.batch_size) + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + + eval(args) \ No newline at end of file diff --git a/dygraph/ocr_recognition/infer.py b/dygraph/ocr_recognition/infer.py new file mode 100644 index 00000000..7e6aa146 --- /dev/null +++ b/dygraph/ocr_recognition/infer.py @@ -0,0 +1,91 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import os +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable + +import argparse +import functools +from utility import add_arguments, print_arguments +from PIL import Image +from nets import OCRAttention + + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('image_path', str, "", "image path") +add_arg('pretrained_model', str, "", "pretrained_model.") +add_arg('max_length', int, 100, "Max predict length.") +add_arg('encoder_size', int, 200, "Encoder size.") +add_arg('decoder_size', int, 128, "Decoder size.") +add_arg('word_vector_dim', int, 128, "Word vector dim.") +add_arg('num_classes', int, 95, "Number classes.") +add_arg('gradient_clip', float, 5.0, "Gradient clip value.") + + +def inference(args): + img = Image.open(os.path.join(args.image_path)).convert('L') + with fluid.dygraph.guard(): + ocr_attention = OCRAttention(batch_size=1, + encoder_size=args.encoder_size, decoder_size=args.decoder_size, + num_classes=args.num_classes, word_vector_dim=args.word_vector_dim) + restore, _ = fluid.load_dygraph(args.pretrained_model) + ocr_attention.set_dict(restore) + ocr_attention.eval() + print(img.size) + img = img.resize((img.size[0], 48), Image.BILINEAR) + img = np.array(img).astype('float32') - 127.5 + img = img[np.newaxis, np.newaxis, ...] + img = to_variable(img) + + gru_backward, encoded_vector, encoded_proj = ocr_attention.encoder_net(img) + backward_first = fluid.layers.slice( + gru_backward, axes=[1], starts=[0], ends=[1]) + backward_first = fluid.layers.reshape( + backward_first, [-1, backward_first.shape[2]], inplace=False) + + decoder_boot = ocr_attention.fc(backward_first) + label_in = fluid.layers.zeros([1], dtype='int64') + result = '' + for i in range(args.max_length): + trg_embedding = ocr_attention.embedding(label_in) + trg_embedding = fluid.layers.reshape( + trg_embedding, [1, -1, trg_embedding.shape[1]], + inplace=False) + + prediction, decoder_boot = ocr_attention.gru_decoder_with_attention( + trg_embedding, encoded_vector, encoded_proj, decoder_boot, inference=True) + prediction = fluid.layers.reshape(prediction, [args.num_classes + 2]) + score, idx = fluid.layers.topk(prediction, 1) + + idx_np = idx.numpy()[0] + if idx_np == 1: + print('met end character, predict finish!') + break + + label_in = fluid.layers.reshape(idx, [1]) + result += chr(int(idx_np + 33)) + print('predict result:', result) + + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + + inference(args) \ No newline at end of file diff --git a/dygraph/ocr_recognition/nets.py b/dygraph/ocr_recognition/nets.py new file mode 100644 index 00000000..c84df4ed --- /dev/null +++ b/dygraph/ocr_recognition/nets.py @@ -0,0 +1,346 @@ +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit +from paddle.fluid.dygraph.base import to_variable + + +class ConvBNPool(fluid.dygraph.Layer): + def __init__(self, + out_ch, + channels, + act="relu", + is_test=False, + pool=True, + use_cudnn=True): + super(ConvBNPool, self).__init__() + self.pool = pool + + filter_size = 3 + conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5 + conv_param_0 = fluid.ParamAttr( + initializer=fluid.initializer.Normal(0.0, conv_std_0)) + + conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5 + conv_param_1 = fluid.ParamAttr( + initializer=fluid.initializer.Normal(0.0, conv_std_1)) + + self.conv_0_layer = Conv2D( + channels[0], + out_ch[0], + 3, + padding=1, + param_attr=conv_param_0, + bias_attr=False, + act=None, + use_cudnn=use_cudnn) + self.bn_0_layer = BatchNorm( + out_ch[0], act=act, is_test=is_test) + self.conv_1_layer = Conv2D( + out_ch[0], + num_filters=out_ch[1], + filter_size=3, + padding=1, + param_attr=conv_param_1, + bias_attr=False, + act=None, + use_cudnn=use_cudnn) + self.bn_1_layer = BatchNorm( + out_ch[1], act=act, is_test=is_test) + + if self.pool: + self.pool_layer = Pool2D( + pool_size=2, + pool_type='max', + pool_stride=2, + use_cudnn=use_cudnn, + ceil_mode=True) + + def forward(self, inputs): + conv_0 = self.conv_0_layer(inputs) + bn_0 = self.bn_0_layer(conv_0) + conv_1 = self.conv_1_layer(bn_0) + bn_1 = self.bn_1_layer(conv_1) + if self.pool: + bn_pool = self.pool_layer(bn_1) + + return bn_pool + return bn_1 + + +class OCRConv(fluid.dygraph.Layer): + def __init__(self, is_test=False, use_cudnn=True): + super(OCRConv, self).__init__() + self.conv_bn_pool_1 = ConvBNPool( + [16, 16], [1, 16], + is_test=is_test, + use_cudnn=use_cudnn) + self.conv_bn_pool_2 = ConvBNPool( + [32, 32], [16, 32], + is_test=is_test, + use_cudnn=use_cudnn) + self.conv_bn_pool_3 = ConvBNPool( + [64, 64], [32, 64], + is_test=is_test, + use_cudnn=use_cudnn) + self.conv_bn_pool_4 = ConvBNPool( + [128, 128], [64, 128], + is_test=is_test, + pool=False, + use_cudnn=use_cudnn) + + def forward(self, inputs): + inputs_1 = self.conv_bn_pool_1(inputs) + inputs_2 = self.conv_bn_pool_2(inputs_1) + inputs_3 = self.conv_bn_pool_3(inputs_2) + inputs_4 = self.conv_bn_pool_4(inputs_3) + + return inputs_4 + + +class DynamicGRU(fluid.dygraph.Layer): + def __init__(self, + size, + param_attr=None, + bias_attr=None, + is_reverse=False, + gate_activation='sigmoid', + candidate_activation='tanh', + h_0=None, + origin_mode=False, + init_size = None): + super(DynamicGRU, self).__init__() + + self.gru_unit = GRUUnit( + size * 3, + param_attr=param_attr, + bias_attr=bias_attr, + activation=candidate_activation, + gate_activation=gate_activation, + origin_mode=origin_mode) + + self.size = size + self.h_0 = h_0 + self.is_reverse = is_reverse + + + def forward(self, inputs): + hidden = self.h_0 + res = [] + + + for i in range(inputs.shape[1]): + if self.is_reverse: + i = inputs.shape[1] - 1 - i + + input_ = inputs[:, i: i + 1, :] + + input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False) + hidden, reset, gate = self.gru_unit(input_, hidden) + + hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False) + + res.append(hidden_) + + if self.is_reverse: + res = res[::-1] + res = fluid.layers.concat(res, axis=1) + return res + + +class EncoderNet(fluid.dygraph.Layer): + def __init__(self, + batch_size, + decoder_size, + rnn_hidden_size=200, + is_test=False, + use_cudnn=True): + super(EncoderNet, self).__init__() + self.rnn_hidden_size = rnn_hidden_size + para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0, + 0.02)) + bias_attr = fluid.ParamAttr( + initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0) + if fluid.framework.in_dygraph_mode(): + h_0 = np.zeros( + (batch_size, rnn_hidden_size), dtype="float32") + h_0 = to_variable(h_0) + else: + h_0 = fluid.layers.fill_constant( + shape=[batch_size, rnn_hidden_size], + dtype='float32', + value=0) + self.ocr_convs = OCRConv( + is_test=is_test, use_cudnn=use_cudnn) + + self.fc_1_layer = Linear(768, + rnn_hidden_size * 3, + param_attr=para_attr, + bias_attr=False) + self.fc_2_layer = Linear(768, + rnn_hidden_size * 3, + param_attr=para_attr, + bias_attr=False) + self.gru_forward_layer = DynamicGRU( + size=rnn_hidden_size, + h_0=h_0, + param_attr=para_attr, + bias_attr=bias_attr, + candidate_activation='relu') + self.gru_backward_layer = DynamicGRU( + size=rnn_hidden_size, + h_0=h_0, + param_attr=para_attr, + bias_attr=bias_attr, + candidate_activation='relu', + is_reverse=True) + + self.encoded_proj_fc = Linear(rnn_hidden_size * 2, + decoder_size, + bias_attr=False) + + def forward(self, inputs): + conv_features = self.ocr_convs(inputs) + transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2]) + + sliced_feature = fluid.layers.reshape( + transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False) + + fc_1 = self.fc_1_layer(sliced_feature) + + fc_2 = self.fc_2_layer(sliced_feature) + + gru_forward = self.gru_forward_layer(fc_1) + + gru_backward = self.gru_backward_layer(fc_2) + + encoded_vector = fluid.layers.concat( + input=[gru_forward, gru_backward], axis=2) + + encoded_proj = self.encoded_proj_fc(encoded_vector) + + return gru_backward, encoded_vector, encoded_proj + + +class SimpleAttention(fluid.dygraph.Layer): + def __init__(self, decoder_size): + super(SimpleAttention, self).__init__() + + self.fc_1 = Linear( decoder_size, + decoder_size, + act=None, + bias_attr=False) + self.fc_2 = Linear( decoder_size, + 1, + act=None, + bias_attr=False) + + def forward(self, encoder_vec, encoder_proj, decoder_state): + + decoder_state_fc = self.fc_1(decoder_state) + + decoder_state_proj_reshape = fluid.layers.reshape( + decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False) + decoder_state_expand = fluid.layers.expand( + decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1]) + concated = fluid.layers.elementwise_add(encoder_proj, + decoder_state_expand) + concated = fluid.layers.tanh(x=concated) + attention_weight = self.fc_2(concated) + weights_reshape = fluid.layers.reshape( + x=attention_weight, shape=[ concated.shape[0], -1], inplace=False) + + weights_reshape = fluid.layers.softmax( weights_reshape ) + scaled = fluid.layers.elementwise_mul( + x=encoder_vec, y=weights_reshape, axis=0) + + context = fluid.layers.reduce_sum(scaled, dim=1) + + return context + + +class GRUDecoderWithAttention(fluid.dygraph.Layer): + def __init__(self, encoder_size, decoder_size, num_classes): + super(GRUDecoderWithAttention, self).__init__() + self.simple_attention = SimpleAttention(decoder_size) + + self.fc_1_layer = Linear(input_dim=encoder_size * 2, + output_dim=decoder_size * 3, + bias_attr=False) + self.fc_2_layer = Linear(input_dim=decoder_size, + output_dim=decoder_size * 3, + bias_attr=False) + self.gru_unit = GRUUnit( + size=decoder_size * 3, + param_attr=None, + bias_attr=None) + self.out_layer = Linear(input_dim=decoder_size, + output_dim =num_classes + 2, + bias_attr=None, + act='softmax') + + self.decoder_size = decoder_size + + def forward(self, current_word, encoder_vec, encoder_proj, + decoder_boot, inference=False): + current_word = fluid.layers.reshape( + current_word, [-1, current_word.shape[2]], inplace=False) + + context = self.simple_attention(encoder_vec, encoder_proj, + decoder_boot) + fc_1 = self.fc_1_layer(context) + fc_2 = self.fc_2_layer(current_word) + decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2) + + h, _, _ = self.gru_unit(decoder_inputs, decoder_boot) + out = self.out_layer(h) + + return out, h + + +class OCRAttention(fluid.dygraph.Layer): + def __init__(self, batch_size, num_classes, encoder_size, decoder_size, word_vector_dim): + super(OCRAttention, self).__init__() + self.encoder_net = EncoderNet(batch_size, decoder_size) + self.fc = Linear(input_dim=encoder_size, + output_dim=decoder_size, + bias_attr=False, + act='relu') + self.embedding = Embedding( + [num_classes + 2, word_vector_dim], + dtype='float32') + self.gru_decoder_with_attention = GRUDecoderWithAttention(encoder_size, decoder_size, + num_classes) + self.batch_size = batch_size + + + def forward(self, inputs, label_in): + gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs) + backward_first = fluid.layers.slice( + gru_backward, axes=[1], starts=[0], ends=[1]) + backward_first = fluid.layers.reshape( + backward_first, [-1, backward_first.shape[2]], inplace=False) + + decoder_boot = self.fc(backward_first) + + label_in = fluid.layers.reshape(label_in, [-1], inplace=False) + trg_embedding = self.embedding(label_in) + + trg_embedding = fluid.layers.reshape( + trg_embedding, [self.batch_size, -1, trg_embedding.shape[1]], + inplace=False) + + pred_temp = [] + for i in range(trg_embedding.shape[1]): + current_word = fluid.layers.slice( + trg_embedding, axes=[1], starts=[i], ends=[i + 1]) + out, decoder_boot = self.gru_decoder_with_attention( + current_word, encoded_vector, encoded_proj, decoder_boot + ) + pred_temp.append(out) + pred_temp = fluid.layers.concat(pred_temp, axis=1) + + batch_size = trg_embedding.shape[0] + seq_len = trg_embedding.shape[1] + prediction = fluid.layers.reshape(pred_temp, shape=[batch_size, seq_len, -1]) + + return prediction diff --git a/dygraph/ocr_recognition/train.py b/dygraph/ocr_recognition/train.py index 821c5b92..6e5792a1 100644 --- a/dygraph/ocr_recognition/train.py +++ b/dygraph/ocr_recognition/train.py @@ -13,422 +13,48 @@ # limitations under the License. from __future__ import print_function -import sys +import os -import numpy as np import paddle.fluid.profiler as profiler import paddle.fluid as fluid -import paddle.fluid.layers as layers + import data_reader -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit + from paddle.fluid.dygraph.base import to_variable import argparse import functools from utility import add_arguments, print_arguments, get_attention_feeder_data -import time -from paddle.fluid import framework +from nets import OCRAttention +from eval import evaluate parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('batch_size', int, 32, "Minibatch size.") -add_arg('total_step', int, 720000, "The number of iterations. Zero or less means whole training set. More than 0 means the training set might be looped until # of iterations is reached.") -add_arg('log_period', int, 1000, "Log period.") -add_arg('save_model_period', int, 15000, "Save model period. '-1' means never saving the model.") -add_arg('eval_period', int, 15000, "Evaluate period. '-1' means never evaluating the model.") -add_arg('save_model_dir', str, "./models", "The directory the model to be saved to.") +add_arg('epoch_num', int, 30, "Epoch number.") +add_arg('lr', float, 0.001, "Learning rate.") +add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.") +add_arg('log_period', int, 200, "Log period.") +add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.") +add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.") +add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.") add_arg('train_images', str, None, "The directory of images to be used for training.") add_arg('train_list', str, None, "The list file of images to be used for training.") add_arg('test_images', str, None, "The directory of images to be used for test.") add_arg('test_list', str, None, "The list file of images to be used for training.") add_arg('init_model', str, None, "The init model file of directory.") add_arg('use_gpu', bool, True, "Whether use GPU to train.") -add_arg('min_average_window',int, 10000, "Min average window.") -add_arg('max_average_window',int, 12500, "Max average window. It is proposed to be set as the number of minibatch in a pass.") -add_arg('average_window', float, 0.15, "Average window.") add_arg('parallel', bool, False, "Whether use parallel training.") add_arg('profile', bool, False, "Whether to use profiling.") add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.") add_arg('skip_test', bool, False, "Whether to skip test phase.") - - -class Config(object): - ''' - config for training - ''' - # encoder rnn hidden_size - encoder_size = 200 - # decoder size for decoder stage - decoder_size = 128 - # size for word embedding - word_vector_dim = 128 - # max length for label padding - max_length = 100 - gradient_clip = 10 - LR = 1.0 - beam_size = 2 - learning_rate_decay = None - - # batch size to train - batch_size = 32 - # class number to classify - num_classes = 95 - - use_gpu = False - # special label for start and end - SOS = 0 - EOS = 1 - - # data shape for input image - DATA_SHAPE = [1, 48, 512] - - -class ConvBNPool(fluid.dygraph.Layer): - def __init__(self, - group, - out_ch, - channels, - act="relu", - is_test=False, - pool=True, - use_cudnn=True): - super(ConvBNPool, self).__init__() - self.group = group - self.pool = pool - - filter_size = 3 - conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5 - conv_param_0 = fluid.ParamAttr( - initializer=fluid.initializer.Normal(0.0, conv_std_0)) - - conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5 - conv_param_1 = fluid.ParamAttr( - initializer=fluid.initializer.Normal(0.0, conv_std_1)) - - self.conv_0_layer = Conv2D( - channels[0], - out_ch[0], - 3, - padding=1, - param_attr=conv_param_0, - bias_attr=False, - act=None, - use_cudnn=use_cudnn) - self.bn_0_layer = BatchNorm( - out_ch[0], act=act, is_test=is_test) - self.conv_1_layer = Conv2D( - out_ch[0], - num_filters=out_ch[1], - filter_size=3, - padding=1, - param_attr=conv_param_1, - bias_attr=False, - act=None, - use_cudnn=use_cudnn) - self.bn_1_layer = BatchNorm( - out_ch[1], act=act, is_test=is_test) - - if self.pool: - self.pool_layer = Pool2D( - pool_size=2, - pool_type='max', - pool_stride=2, - use_cudnn=use_cudnn, - ceil_mode=True) - - def forward(self, inputs): - conv_0 = self.conv_0_layer(inputs) - bn_0 = self.bn_0_layer(conv_0) - conv_1 = self.conv_1_layer(bn_0) - bn_1 = self.bn_1_layer(conv_1) - if self.pool: - bn_pool = self.pool_layer(bn_1) - - return bn_pool - return bn_1 - - -class OCRConv(fluid.dygraph.Layer): - def __init__(self, is_test=False, use_cudnn=True): - super(OCRConv, self).__init__() - self.conv_bn_pool_1 = ConvBNPool( - 2, [16, 16], [1, 16], - is_test=is_test, - use_cudnn=use_cudnn) - self.conv_bn_pool_2 = ConvBNPool( - 2, [32, 32], [16, 32], - is_test=is_test, - use_cudnn=use_cudnn) - self.conv_bn_pool_3 = ConvBNPool( - 2, [64, 64], [32, 64], - is_test=is_test, - use_cudnn=use_cudnn) - self.conv_bn_pool_4 = ConvBNPool( - 2, [128, 128], [64, 128], - is_test=is_test, - pool=False, - use_cudnn=use_cudnn) - - def forward(self, inputs): - inputs_1 = self.conv_bn_pool_1(inputs) - inputs_2 = self.conv_bn_pool_2(inputs_1) - inputs_3 = self.conv_bn_pool_3(inputs_2) - inputs_4 = self.conv_bn_pool_4(inputs_3) - - return inputs_4 - - -class DynamicGRU(fluid.dygraph.Layer): - def __init__(self, - size, - param_attr=None, - bias_attr=None, - is_reverse=False, - gate_activation='sigmoid', - candidate_activation='tanh', - h_0=None, - origin_mode=False, - init_size = None): - super(DynamicGRU, self).__init__() - - self.gru_unit = GRUUnit( - size * 3, - param_attr=param_attr, - bias_attr=bias_attr, - activation=candidate_activation, - gate_activation=gate_activation, - origin_mode=origin_mode) - - self.size = size - self.h_0 = h_0 - self.is_reverse = is_reverse - - - def forward(self, inputs): - hidden = self.h_0 - res = [] - - - for i in range(inputs.shape[1]): - if self.is_reverse: - i = inputs.shape[1] - 1 - i - - input_ = inputs[ :, i:i+1, :] - - input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False) - hidden, reset, gate = self.gru_unit(input_, hidden) - - hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False) - - res.append(hidden_) - - if self.is_reverse: - res = res[::-1] - res = fluid.layers.concat(res, axis=1) - return res - - -class EncoderNet(fluid.dygraph.Layer): - def __init__(self, - rnn_hidden_size=Config.encoder_size, - is_test=False, - use_cudnn=True): - super(EncoderNet, self).__init__() - self.rnn_hidden_size = rnn_hidden_size - para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0, - 0.02)) - bias_attr = fluid.ParamAttr( - initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0) - if fluid.framework.in_dygraph_mode(): - h_0 = np.zeros( - (Config.batch_size, rnn_hidden_size), dtype="float32") - h_0 = to_variable(h_0) - else: - h_0 = fluid.layers.fill_constant( - shape=[Config.batch_size, rnn_hidden_size], - dtype='float32', - value=0) - self.ocr_convs = OCRConv( - is_test=is_test, use_cudnn=use_cudnn) - - self.fc_1_layer = Linear( 768, - rnn_hidden_size * 3, - param_attr=para_attr, - bias_attr=False ) - print( "weight", self.fc_1_layer.weight.shape ) - self.fc_2_layer = Linear( 768, - rnn_hidden_size * 3, - param_attr=para_attr, - bias_attr=False ) - self.gru_forward_layer = DynamicGRU( - size=rnn_hidden_size, - h_0=h_0, - param_attr=para_attr, - bias_attr=bias_attr, - candidate_activation='relu') - self.gru_backward_layer = DynamicGRU( - size=rnn_hidden_size, - h_0=h_0, - param_attr=para_attr, - bias_attr=bias_attr, - candidate_activation='relu', - is_reverse=True) - - self.encoded_proj_fc = Linear( rnn_hidden_size * 2, - Config.decoder_size, - bias_attr=False ) - - def forward(self, inputs): - conv_features = self.ocr_convs(inputs) - transpose_conv_features = fluid.layers.transpose(conv_features, perm=[0,3,1,2]) - - sliced_feature = fluid.layers.reshape( - transpose_conv_features, [-1, transpose_conv_features.shape[1] , transpose_conv_features.shape[2]*transpose_conv_features.shape[3]], inplace=False) - - fc_1 = self.fc_1_layer(sliced_feature) - - fc_2 = self.fc_2_layer(sliced_feature) - - gru_forward = self.gru_forward_layer(fc_1) - - gru_backward = self.gru_backward_layer(fc_2) - - encoded_vector = fluid.layers.concat( - input=[gru_forward, gru_backward], axis=2) - - encoded_proj = self.encoded_proj_fc(encoded_vector) - - return gru_backward, encoded_vector, encoded_proj - - -class SimpleAttention(fluid.dygraph.Layer): - def __init__(self, decoder_size): - super(SimpleAttention, self).__init__() - - self.fc_1 = Linear( decoder_size, - decoder_size, - act=None, - bias_attr=False) - self.fc_2 = Linear( decoder_size, - 1, - act=None, - bias_attr=False) - - def forward(self, encoder_vec, encoder_proj, decoder_state): - - decoder_state_fc = self.fc_1(decoder_state) - - decoder_state_proj_reshape = fluid.layers.reshape( - decoder_state_fc, [-1, 1, decoder_state_fc.shape[1]], inplace=False) - decoder_state_expand = fluid.layers.expand( - decoder_state_proj_reshape, [1, encoder_proj.shape[1], 1]) - concated = fluid.layers.elementwise_add(encoder_proj, - decoder_state_expand) - concated = fluid.layers.tanh(x=concated) - attention_weight = self.fc_2(concated) - weights_reshape = fluid.layers.reshape( - x=attention_weight, shape=[ concated.shape[0], -1], inplace=False) - - weights_reshape = fluid.layers.softmax( weights_reshape ) - scaled = fluid.layers.elementwise_mul( - x=encoder_vec, y=weights_reshape, axis=0) - - context = fluid.layers.reduce_sum(scaled, dim=1) - - return context - - -class GRUDecoderWithAttention(fluid.dygraph.Layer): - def __init__(self, decoder_size, num_classes): - super(GRUDecoderWithAttention, self).__init__() - self.simple_attention = SimpleAttention(decoder_size) - - self.fc_1_layer = Linear( input_dim = Config.encoder_size * 2, - output_dim=decoder_size * 3, - bias_attr=False) - self.fc_2_layer = Linear( input_dim = decoder_size, - output_dim=decoder_size * 3, - bias_attr=False) - self.gru_unit = GRUUnit( - size=decoder_size * 3, - param_attr=None, - bias_attr=None) - self.out_layer = Linear( input_dim = decoder_size, - output_dim =num_classes + 2, - bias_attr=None, - act='softmax') - - self.decoder_size = decoder_size - - - def forward(self, target_embedding, encoder_vec, encoder_proj, - decoder_boot): - res = [] - hidden_mem = decoder_boot - for i in range(target_embedding.shape[1]): - current_word = fluid.layers.slice( - target_embedding, axes=[1], starts=[i], ends=[i + 1]) - current_word = fluid.layers.reshape( - current_word, [-1, current_word.shape[2]], inplace=False) - - context = self.simple_attention(encoder_vec, encoder_proj, - hidden_mem) - fc_1 = self.fc_1_layer(context) - fc_2 = self.fc_2_layer(current_word) - decoder_inputs = fluid.layers.elementwise_add(x=fc_1, y=fc_2) - - h, _, _ = self.gru_unit(decoder_inputs, hidden_mem) - hidden_mem = h - out = self.out_layer(h) - - res.append(out) - - - res1 = fluid.layers.concat(res, axis=1) - - batch_size = target_embedding.shape[0] - seq_len = target_embedding.shape[1] - res1 = layers.reshape( res1, shape=[batch_size, seq_len, -1]) - - return res1 - - -class OCRAttention(fluid.dygraph.Layer): - def __init__(self): - super(OCRAttention, self).__init__() - self.encoder_net = EncoderNet() - self.fc = Linear( input_dim = Config.encoder_size, - output_dim =Config.decoder_size, - bias_attr=False, - act='relu') - self.embedding = Embedding( - [Config.num_classes + 2, Config.word_vector_dim], - dtype='float32') - self.gru_decoder_with_attention = GRUDecoderWithAttention( - Config.decoder_size, Config.num_classes) - - - def forward(self, inputs, label_in): - gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs) - backward_first = fluid.layers.slice( - gru_backward, axes=[1], starts=[0], ends=[1]) - backward_first = fluid.layers.reshape( - backward_first, [-1, backward_first.shape[2]], inplace=False) - - decoder_boot = self.fc(backward_first) - - label_in = fluid.layers.reshape(label_in, [-1], inplace=False) - trg_embedding = self.embedding(label_in) - - trg_embedding = fluid.layers.reshape( - trg_embedding, [Config.batch_size, -1, trg_embedding.shape[1]], - inplace=False) - - prediction = self.gru_decoder_with_attention( - trg_embedding, encoded_vector, encoded_proj, decoder_boot) - - return prediction +# model hyper paramters +add_arg('encoder_size', int, 200, "Encoder size.") +add_arg('decoder_size', int, 128, "Decoder size.") +add_arg('word_vector_dim', int, 128, "Word vector dim.") +add_arg('num_classes', int, 95, "Number classes.") +add_arg('gradient_clip', float, 5.0, "Gradient clip value.") def train(args): @@ -436,74 +62,41 @@ def train(args): with fluid.dygraph.guard(): backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = True - ocr_attention = OCRAttention() - if Config.learning_rate_decay == "piecewise_decay": - learning_rate = fluid.layers.piecewise_decay( - [50000], [Config.LR, Config.LR * 0.01]) + ocr_attention = OCRAttention(batch_size=args.batch_size, + encoder_size=args.encoder_size, decoder_size=args.decoder_size, + num_classes=args.num_classes, word_vector_dim=args.word_vector_dim) + + LR = args.lr + if args.lr_decay_strategy == "piecewise_decay": + learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01]) else: - learning_rate = Config.LR - optimizer = fluid.optimizer.Adam(learning_rate=0.001, parameter_list=ocr_attention.parameters()) - dy_param_init_value = {} + learning_rate = LR - grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0 ) + optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters()) + grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip) train_reader = data_reader.data_reader( - Config.batch_size, - cycle=args.total_step > 0, + args.batch_size, shuffle=True, + images_dir=args.train_images, + list_file=args.train_list, data_type='train') - infer_image= './data/data/test_images/' - infer_files = './data/data/test.list' test_reader = data_reader.data_reader( - Config.batch_size, - cycle=False, + args.batch_size, + images_dir=args.test_images, + list_file=args.test_list, data_type="test") - def eval(): - ocr_attention.eval() - total_loss = 0.0 - total_step = 0.0 - equal_size = 0 - for data in test_reader(): - data_dict = get_attention_feeder_data(data) - - label_in = to_variable(data_dict["label_in"]) - label_out = to_variable(data_dict["label_out"]) - - label_out.stop_gradient = True - - img = to_variable(data_dict["pixel"]) - - prediction = ocr_attention(img, label_in) - prediction = fluid.layers.reshape( prediction, [label_out.shape[0] * label_out.shape[1], -1], inplace=False) - - score, topk = layers.topk( prediction, 1) - - seq = topk.numpy() - - seq = seq.reshape( ( args.batch_size, -1)) - - mask = data_dict['mask'].reshape( (args.batch_size, -1)) - seq_len = np.sum( mask, -1) - - trans_ref = data_dict["label_out"].reshape( (args.batch_size, -1)) - for i in range( args.batch_size ): - length = int(seq_len[i] -1 ) - trans = seq[i][:length - 1] - ref = trans_ref[i][ : length - 1] - if np.array_equal( trans, ref ): - equal_size += 1 - - total_step += args.batch_size - print( "eval cost", equal_size / total_step ) + if not os.path.exists(args.save_model_dir): + os.makedirs(args.save_model_dir) total_step = 0 - epoch_num = 20 + epoch_num = args.epoch_num for epoch in range(epoch_num): batch_id = 0 - total_loss = 0.0 + for data in train_reader(): total_step += 1 @@ -524,7 +117,7 @@ def train(args): mask = to_variable(data_dict["mask"]) - loss = layers.elementwise_mul( loss, mask, axis=0) + loss = fluid.layers.elementwise_mul( loss, mask, axis=0) avg_loss = fluid.layers.reduce_sum(loss) total_loss += avg_loss.numpy() @@ -532,21 +125,24 @@ def train(args): optimizer.minimize(avg_loss, grad_clip=grad_clip) ocr_attention.clear_gradients() - if batch_id > 0 and batch_id % 1000 == 0: - print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id, total_loss / args.batch_size / 1000)) + if batch_id > 0 and batch_id % args.log_period == 0: + print("epoch: {}, batch_id: {}, lr: {}, loss {}".format(epoch, batch_id, + optimizer._global_learning_rate().numpy(), + total_loss / args.batch_size / args.log_period)) total_loss = 0.0 - if total_step > 0 and total_step % 2000 == 0: + if total_step > 0 and total_step % args.save_model_period == 0: + if fluid.dygraph.parallel.Env().dev_id == 0: + model_file = os.path.join(args.save_model_dir, 'step_{}'.format(total_step)) + fluid.save_dygraph(ocr_attention.state_dict(), model_file) + print('step_{}.pdparams saved!'.format(total_step)) + if total_step > 0 and total_step % args.eval_period == 0: ocr_attention.eval() - eval() + evaluate(ocr_attention, test_reader, args.batch_size) ocr_attention.train() - batch_id +=1 - - - - + batch_id += 1 if __name__ == '__main__': -- GitLab