diff --git a/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml b/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml new file mode 100755 index 0000000000000000000000000000000000000000..933a75133563727cd7eddc3914d4dcfb41a09e32 --- /dev/null +++ b/configs/rec/rec_r50fpn_vd_none_srn_pvam_test_all.yml @@ -0,0 +1,48 @@ +Global: + algorithm: SRN + use_gpu: true + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: output/rec_pvam_withrotate + save_epoch_step: 1 + eval_batch_step: 8000 + train_batch_size_per_card: 64 + test_batch_size_per_card: 1 + image_shape: [1, 64, 256] + max_text_length: 25 + character_type: en + loss_type: srn + num_heads: 8 + average_window: 0.15 + max_average_window: 15625 + min_average_window: 10000 + reader_yml: ./configs/rec/rec_srn_reader.yml + pretrain_weights: + checkpoints: + save_inference_dir: + +Architecture: + function: ppocr.modeling.architectures.rec_model,RecModel + +Backbone: + function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet + layers: 50 + +Head: + function: ppocr.modeling.heads.rec_srn_all_head,SRNPredict + encoder_type: rnn + num_encoder_TUs: 2 + num_decoder_TUs: 4 + hidden_dims: 512 + SeqRNN: + hidden_size: 256 + +Loss: + function: ppocr.modeling.losses.rec_srn_loss,SRNLoss + +Optimizer: + function: ppocr.optimizer,AdamDecay + base_lr: 0.0001 + beta1: 0.9 + beta2: 0.999 diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index ec3e9d867ce659b729c021c3a02acead73cacf52..7135fca55336283b175d2ccfbdf04c3b76e0f62a 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger from ppocr.utils.utility import get_image_file_list logger = initial_logger() -from .img_tools import process_image, get_img_data +from .img_tools import process_image, process_image_srn, get_img_data class LMDBReader(object): @@ -40,6 +40,7 @@ class LMDBReader(object): self.image_shape = params['image_shape'] self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] + self.num_heads = params['num_heads'] self.mode = params['mode'] self.drop_last = False self.use_tps = False @@ -117,14 +118,36 @@ class LMDBReader(object): image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1] == 1 or len(list(img.shape)) == 2: + if img.shape[-1]==1 or len(list(img.shape))==2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + if self.loss_type == 'srn': + norm_img = process_image_srn( + img=img, + image_shape=self.image_shape, + num_heads=self.num_heads, + max_text_length=self.max_text_length + ) + else: + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops, + tps=self.use_tps, + infer_mode=True) + yield norm_img + elif self.mode == 'test': + image_file_list = get_image_file_list(self.infer_img) + for single_img in image_file_list: + img = cv2.imread(single_img) + if img.shape[-1]==1 or len(list(img.shape))==2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image( img=img, image_shape=self.image_shape, char_ops=self.char_ops, tps=self.use_tps, - infer_mode=True) + infer_mode=True + ) yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() @@ -144,14 +167,16 @@ class LMDBReader(object): if sample_info is None: continue img, label = sample_info - outs = process_image( - img=img, - image_shape=self.image_shape, - label=label, - char_ops=self.char_ops, - loss_type=self.loss_type, - max_text_length=self.max_text_length, - distort=self.use_distort) + outs = [] + if self.loss_type == "srn": + outs = process_image_srn(img, self.image_shape, self.num_heads, + self.max_text_length, label, + self.char_ops, self.loss_type) + + else: + outs = process_image(img, self.image_shape, label, + self.char_ops, self.loss_type, + self.max_text_length) if outs is None: continue yield outs @@ -159,7 +184,6 @@ class LMDBReader(object): if finish_read_num == len(lmdb_sets): break self.close_lmdb_dataset(lmdb_sets) - def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -167,9 +191,8 @@ class LMDBReader(object): if len(batch_outs) == self.batch_size: yield batch_outs batch_outs = [] - if not self.drop_last: - if len(batch_outs) != 0: - yield batch_outs + if len(batch_outs) != 0: + yield batch_outs if self.infer_img is None: return batch_iter_reader @@ -288,4 +311,4 @@ class SimpleReader(object): if self.infer_img is None: return batch_iter_reader - return sample_iter_reader + return sample_iter_reader \ No newline at end of file diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index 0835603b5896a4f7ab946c4a694dcbec3f853a54..527e0266ee33ac81e29b5610ed05f401860078a4 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -381,3 +381,84 @@ def process_image(img, assert False, "Unsupport loss_type %s in process_image"\ % loss_type return (norm_img) + +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + +def srn_other_inputs(image_shape, + num_heads, + max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64') + + lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] + +def process_image_srn(img, + image_shape, + num_heads, + max_text_length, + label=None, + char_ops=None, + loss_type=None): + norm_img = resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + srn_other_inputs(image_shape, num_heads, max_text_length) + + if label is not None: + char_num = char_ops.get_char_num() + text = char_ops.encode(label) + if len(text) == 0 or len(text) > max_text_length: + return None + else: + if loss_type == "srn": + text_padded = [37] * max_text_length + for i in range(len(text)): + text_padded[i] = text[i] + lbl_weight[i] = [1.0] + text_padded = np.array(text_padded) + text = text_padded.reshape(-1, 1) + return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight) + else: + assert False, "Unsupport loss_type %s in process_image"\ + % loss_type + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index e80a50ab6504c80aa7f10759576208486caf7c3f..a030f362ee8a767658606a151112796fc490a4aa 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -58,6 +58,7 @@ class RecModel(object): self.loss_type = global_params['loss_type'] self.image_shape = global_params['image_shape'] self.max_text_length = global_params['max_text_length'] + self.num_heads = global_params["num_heads"] def create_feed(self, mode): image_shape = deepcopy(self.image_shape) @@ -77,6 +78,18 @@ class RecModel(object): lod_level=1) feed_list = [image, label_in, label_out] labels = {'label_in': label_in, 'label_out': label_out} + elif self.loss_type == "srn": + encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") + gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") + gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) + gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) + lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64') + label = fluid.data( + name='label', shape=[-1, 1], dtype='int32', lod_level=1) + feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight] + labels = {'label': label, 'encoder_word_pos': encoder_word_pos, + 'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, + 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight} else: label = fluid.data( name='label', shape=[None, 1], dtype='int32', lod_level=1) @@ -88,6 +101,8 @@ class RecModel(object): use_double_buffer=True, iterable=False) else: + labels = None + loader = None if self.char_type == "ch" and self.infer_img: image_shape[-1] = -1 if self.tps != None: @@ -97,9 +112,15 @@ class RecModel(object): "We set img_shape to be the same , it may affect the inference effect" ) image_shape = deepcopy(self.image_shape) - image = fluid.data(name='image', shape=image_shape, dtype='float32') - labels = None - loader = None + image = fluid.data(name='image', shape=image_shape, dtype='float32') + if self.loss_type == "srn": + encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64") + gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64") + gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) + gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length]) + feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] + labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos, + 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2} return image, labels, loader def __call__(self, mode): @@ -117,9 +138,15 @@ class RecModel(object): label = labels['label_out'] else: label = labels['label'] - outputs = {'total_loss':loss, 'decoded_out':\ - decoded_out, 'label':label} + if self.loss_type == 'srn': + total_loss, img_loss, word_loss = self.loss(predicts, labels) + outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss, + 'decoded_out':decoded_out, 'label':label} + else: + outputs = {'total_loss':loss, 'decoded_out':\ + decoded_out, 'label':label} return loader, outputs + elif mode == "export": predict = predicts['predict'] if self.loss_type == "ctc": @@ -129,4 +156,4 @@ class RecModel(object): predict = predicts['predict'] if self.loss_type == "ctc": predict = fluid.layers.softmax(predict) - return loader, {'decoded_out': decoded_out, 'predicts': predict} + return loader, {'decoded_out': decoded_out, 'predicts': predict} \ No newline at end of file diff --git a/ppocr/modeling/backbones/rec_resnet50_fpn.py b/ppocr/modeling/backbones/rec_resnet50_fpn.py new file mode 100755 index 0000000000000000000000000000000000000000..f6d72377fe4e2d3355a4510f070178ad48dd2a27 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet50_fpn.py @@ -0,0 +1,172 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + + +__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] + +Trainable = True +w_nolr = fluid.ParamAttr( + trainable = Trainable) +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + +class ResNet(): + def __init__(self, params): + self.layers = params['layers'] + self.params = train_parameters + + + def __call__(self, input): + layers = self.layers + supported_layers = [18, 34, 50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + stride_list = [(2,2),(2,2),(1,1),(1,1)] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1") + F = [] + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=stride_list[block] if i == 0 else 1, name=conv_name) + F.append(conv) + + base = F[-1] + for i in [-2, -3]: + b, c, w, h = F[i].shape + if (w,h) == base.shape[2:]: + base = base + else: + base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2, + padding=1,act=None, + param_attr=w_nolr, + bias_attr=w_nolr) + base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) + base = fluid.layers.concat([base, F[i]], axis=1) + base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr) + base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr) + base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr) + + base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr) + + return base + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size= 2 if stride==(1,1) else filter_size, + dilation = 2 if stride==(1,1) else 1, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights",trainable = Trainable), + bias_attr=False, + name=name + '.conv2d.output.1') + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm(input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable), + bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, is_first, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1 or is_first == True: + if stride == (1,1): + return self.conv_bn_layer(input, ch_out, 1, 1, name=name) + else: #stride == (2,2) + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name): + conv0 = self.conv_bn_layer( + input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c") + + short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1") + + return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5") + + def basic_block(self, input, num_filters, stride, is_first, name): + conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride, + name=name + "_branch2a") + conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None, + name=name + "_branch2b") + short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') diff --git a/ppocr/modeling/backbones/rec_resnet_vd.py b/ppocr/modeling/backbones/rec_resnet_vd.py index bc58c8ac13a108bc61e398aae8447b6fab966504..2c7cd4c7157e59e13858f4f1d14707150643d720 100755 --- a/ppocr/modeling/backbones/rec_resnet_vd.py +++ b/ppocr/modeling/backbones/rec_resnet_vd.py @@ -32,7 +32,7 @@ class ResNet(): def __init__(self, params): self.layers = params['layers'] self.is_3x3 = True - supported_layers = [18, 34, 50, 101, 152, 200] + supported_layers = [18, 34, 50, 101, 152] assert self.layers in supported_layers, \ "supported layers are {} but input layer is {}".format(supported_layers, self.layers) diff --git a/ppocr/modeling/heads/rec_srn_all_head.py b/ppocr/modeling/heads/rec_srn_all_head.py new file mode 100755 index 0000000000000000000000000000000000000000..bf1f4a44f4fec8a9ec4452c8562c49251d2e1b7d --- /dev/null +++ b/ppocr/modeling/heads/rec_srn_all_head.py @@ -0,0 +1,218 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +#from .rec_seq_encoder import SequenceEncoder +#from ..common_functions import get_para_bias_attr +import numpy as np +from .self_attention.model import wrap_encoder +from .self_attention.model import wrap_encoder_forFeature +gradient_clip = 10 + + + +class SRNPredict(object): + def __init__(self, params): + super(SRNPredict, self).__init__() + self.char_num = params['char_num'] + self.max_length = params['max_text_length'] + + self.num_heads = params['num_heads'] + self.num_encoder_TUs = params['num_encoder_TUs'] + self.num_decoder_TUs = params['num_decoder_TUs'] + self.hidden_dims = params['hidden_dims'] + + + def pvam(self, inputs, others): + + b, c, h, w = inputs.shape + conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w]) + conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1]) + + #===== Transformer encoder ===== + b, t, c = conv_features.shape + encoder_word_pos = others["encoder_word_pos"] + gsrm_word_pos = others["gsrm_word_pos"] + + + enc_inputs = [conv_features, encoder_word_pos, None] + word_features = wrap_encoder_forFeature(src_vocab_size=-1, + max_length=t, + n_layer=self.num_encoder_TUs, + n_head=self.num_heads, + d_key= int(self.hidden_dims / self.num_heads), + d_value= int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs, + ) + fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip)) + + #===== Parallel Visual Attention Module ===== + b, t, c = word_features.shape + + word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2) + word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c]) + word_features_ = fluid.layers.expand(word_features_, [1, self.max_length, 1, 1]) + word_pos_feature = fluid.layers.embedding(gsrm_word_pos, [self.max_length, c]) + word_pos_ = fluid.layers.reshape(word_pos_feature, [-1, self.max_length, 1, c]) + word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1]) + temp = fluid.layers.elementwise_add(word_features_, word_pos_, act='tanh') + + attention_weight = fluid.layers.fc(input=temp, size=1, num_flatten_dims=3, bias_attr=False) + attention_weight = fluid.layers.reshape(x=attention_weight, shape=[-1, self.max_length, t]) + attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1) + + pvam_features = fluid.layers.matmul(attention_weight, word_features)#[b, max_length, c] + + return pvam_features + + def gsrm(self, pvam_features, others): + + #===== GSRM Visual-to-semantic embedding block ===== + b, t, c = pvam_features.shape + word_out = fluid.layers.fc(input=fluid.layers.reshape(pvam_features, [-1, c]), + size=self.char_num, + act="softmax") + #word_out.stop_gradient = True + word_ids = fluid.layers.argmax(word_out, axis=1) + word_ids.stop_gradient = True + word_ids = fluid.layers.reshape(x=word_ids, shape=[-1, t, 1]) + + #===== GSRM Semantic reasoning block ===== + """ + This module is achieved through bi-transformers, + ngram_feature1 is the froward one, ngram_fetaure2 is the backward one + """ + pad_idx = self.char_num + gsrm_word_pos = others["gsrm_word_pos"] + gsrm_slf_attn_bias1 = others["gsrm_slf_attn_bias1"] + gsrm_slf_attn_bias2 = others["gsrm_slf_attn_bias2"] + + def prepare_bi(word_ids): + """ + prepare bi for gsrm + word1 for forward; word2 for backward + """ + word1 = fluid.layers.cast(word_ids, "float32") + word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], pad_value=1.0 * pad_idx) + word1 = fluid.layers.cast(word1, "int64") + word1 = word1[:, :-1, :] + word2 = word_ids + return word1, word2 + + word1, word2 = prepare_bi(word_ids) + word1.stop_gradient = True + word2.stop_gradient = True + enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1] + enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2] + + gsrm_feature1 = wrap_encoder(src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs_1, + ) + gsrm_feature2 = wrap_encoder(src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True, + enc_inputs=enc_inputs_2, + ) + gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], pad_value=0.) + gsrm_feature2 = gsrm_feature2[:, 1:, ] + gsrm_features = gsrm_feature1 + gsrm_feature2 + + b, t, c = gsrm_features.shape + + gsrm_out = fluid.layers.matmul( + x=gsrm_features, + y=fluid.default_main_program().global_block().var("src_word_emb_table"), + transpose_y=True) + b,t,c = gsrm_out.shape + gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, [-1, c])) + + return gsrm_features, word_out, gsrm_out + + def vsfd(self, pvam_features, gsrm_features): + + #===== Visual-Semantic Fusion Decoder Module ===== + b, t, c1 = pvam_features.shape + b, t, c2 = gsrm_features.shape + combine_features_ = fluid.layers.concat([pvam_features, gsrm_features], axis=2) + img_comb_features_ = fluid.layers.reshape(x=combine_features_, shape=[-1, c1 + c2]) + img_comb_features_map = fluid.layers.fc(input=img_comb_features_, size=c1, act="sigmoid") + img_comb_features_map = fluid.layers.reshape(x=img_comb_features_map, shape=[-1, t, c1]) + combine_features = img_comb_features_map * pvam_features + (1.0 - img_comb_features_map) * gsrm_features + img_comb_features = fluid.layers.reshape(x=combine_features, shape=[-1, c1]) + + fc_out = fluid.layers.fc(input=img_comb_features, + size=self.char_num, + act="softmax") + return fc_out + + + def __call__(self, inputs, others, mode=None): + + pvam_features = self.pvam(inputs, others) + gsrm_features, word_out, gsrm_out = self.gsrm(pvam_features, others) + final_out = self.vsfd(pvam_features, gsrm_features) + + _, decoded_out = fluid.layers.topk(input=final_out, k=1) + predicts = {'predict': final_out, 'decoded_out': decoded_out, + 'word_out': word_out, 'gsrm_out': gsrm_out} + + return predicts + + + + + + + + diff --git a/ppocr/modeling/heads/self_attention/__init__.py b/ppocr/modeling/heads/self_attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ppocr/modeling/heads/self_attention/model.py b/ppocr/modeling/heads/self_attention/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aecd5f7dc34ac7a1167571b7c4c3f9befa38fe --- /dev/null +++ b/ppocr/modeling/heads/self_attention/model.py @@ -0,0 +1,1065 @@ +from functools import partial +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers + +from .desc import * +from .config import ModelHyperParams,TrainTaskConfig + +def wrap_layer_with_block(layer, block_idx): + """ + Make layer define support indicating block, by which we can add layers + to other blocks within current block. This will make it easy to define + cache among while loop. + """ + + class BlockGuard(object): + """ + BlockGuard class. + + BlockGuard class is used to switch to the given block in a program by + using the Python `with` keyword. + """ + + def __init__(self, block_idx=None, main_program=None): + self.main_program = fluid.default_main_program( + ) if main_program is None else main_program + self.old_block_idx = self.main_program.current_block().idx + self.new_block_idx = block_idx + + def __enter__(self): + self.main_program.current_block_idx = self.new_block_idx + + def __exit__(self, exc_type, exc_val, exc_tb): + self.main_program.current_block_idx = self.old_block_idx + if exc_type is not None: + return False # re-raise exception + return True + + def layer_wrapper(*args, **kwargs): + with BlockGuard(block_idx): + return layer(*args, **kwargs) + + return layer_wrapper + + +def position_encoding_init(n_position, d_pos_vec): + """ + Generate the initial values for the sinusoid position encoding table. + """ + channels = d_pos_vec + position = np.arange(n_position) + num_timescales = channels // 2 + log_timescale_increment = (np.log(float(1e4) / float(1)) / + (num_timescales - 1)) + inv_timescales = np.exp(np.arange( + num_timescales)) * -log_timescale_increment + scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales, + 0) + signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) + signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') + position_enc = signal + return position_enc.astype("float32") + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + n_head=1, + dropout_rate=0., + cache=None, + gather_idx=None, + static_kv=False): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + keys = queries if keys is None else keys + values = keys if values is None else values + + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc(input=queries, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + # For encoder-decoder attention in inference, insert the ops and vars + # into global block to use as cache among beam search. + fc_layer = wrap_layer_with_block( + layers.fc, fluid.default_main_program().current_block() + .parent_idx) if cache is not None and static_kv else layers.fc + k = fc_layer( + input=keys, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + v = fc_layer( + input=values, + size=d_value * n_head, + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Reshape input tensors at the last dimension to split multi-heads + and then transpose. Specifically, transform the input tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] to the output tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped_q = layers.reshape( + x=queries, shape=[0, 0, n_head, d_key], inplace=True) + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3]) + # For encoder-decoder attention in inference, insert the ops and vars + # into global block to use as cache among beam search. + reshape_layer = wrap_layer_with_block( + layers.reshape, + fluid.default_main_program().current_block() + .parent_idx) if cache is not None and static_kv else layers.reshape + transpose_layer = wrap_layer_with_block( + layers.transpose, + fluid.default_main_program().current_block(). + parent_idx) if cache is not None and static_kv else layers.transpose + reshaped_k = reshape_layer( + x=keys, shape=[0, 0, n_head, d_key], inplace=True) + k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3]) + reshaped_v = reshape_layer( + x=values, shape=[0, 0, n_head, d_value], inplace=True) + v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3]) + + if cache is not None: # only for faster inference + if static_kv: # For encoder-decoder attention in inference + cache_k, cache_v = cache["static_k"], cache["static_v"] + # To init the static_k and static_v in cache. + # Maybe we can use condition_op(if_else) to do these at the first + # step in while loop to replace these, however it might be less + # efficient. + static_cache_init = wrap_layer_with_block( + layers.assign, + fluid.default_main_program().current_block().parent_idx) + static_cache_init(k, cache_k) + static_cache_init(v, cache_v) + else: # For decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + # gather cell states corresponding to selected parent + select_k = layers.gather(cache_k, index=gather_idx) + select_v = layers.gather(cache_v, index=gather_idx) + if not static_kv: + # For self attention in inference, use cache and concat time steps. + select_k = layers.concat([select_k, k], axis=2) + select_v = layers.concat([select_v, v], axis=2) + # update cell states(caches) cached in global block + layers.assign(select_k, cache_k) + layers.assign(select_v, cache_v) + return q, select_k, select_v + return q, k, v + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + return layers.reshape( + x=trans_x, + shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], + inplace=True) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + # print(q) + # print(k) + + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5) + if attn_bias: + product += attn_bias + weights = layers.softmax(product) + if dropout_rate: + weights = layers.dropout( + weights, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc(input=out, + size=d_model, + bias_attr=False, + num_flatten_dims=2) + return proj_out + + +def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc(input=x, + size=d_inner_hid, + num_flatten_dims=2, + act="relu") + if dropout_rate: + hidden = layers.dropout( + hidden, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False) + out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.initializer.Constant(1.), + bias_attr=fluid.initializer.Constant(0.)) + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def prepare_encoder(src_word,#[b,t,c] + src_pos, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0., + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + This module is used at the bottom of the encoder stacks. + """ + + src_word_emb =src_word#layers.concat(res,axis=1) + src_word_emb=layers.cast(src_word_emb,'float32') + # print("src_word_emb",src_word_emb) + + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + param_attr=fluid.ParamAttr( + name=pos_enc_param_name, trainable=False)) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + return layers.dropout( + enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + is_test=False) if dropout_rate else enc_input + + +def prepare_decoder(src_word, + src_pos, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0., + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + This module is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, + size=[src_vocab_size, src_emb_dim], + padding_idx=bos_idx, # set embedding of bos to 0 + param_attr=fluid.ParamAttr( + name=word_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + # print("target_word_emb",src_word_emb) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + param_attr=fluid.ParamAttr( + name=pos_enc_param_name, trainable=False)) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + return layers.dropout( + enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + is_test=False) if dropout_rate else enc_input + +# prepare_encoder = partial( +# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0]) +# prepare_decoder = partial( +# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1]) + + +def encoder_layer(enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + """The encoder layers that can be stacked to form a deep encoder. + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output = multi_head_attention( + pre_process_layer(enc_input, preprocess_cmd, + prepostprocess_dropout), None, None, attn_bias, d_key, + d_value, d_model, n_head, attention_dropout) + attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd, + prepostprocess_dropout) + ffd_output = positionwise_feed_forward( + pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout), + d_inner_hid, d_model, relu_dropout) + return post_process_layer(attn_output, ffd_output, postprocess_cmd, + prepostprocess_dropout) + + +def encoder(enc_input, + attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + for i in range(n_layer): + enc_output = encoder_layer( + enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, ) + enc_input = enc_output + enc_output = pre_process_layer(enc_output, preprocess_cmd, + prepostprocess_dropout) + return enc_output + + +def decoder_layer(dec_input, + enc_output, + slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + cache=None, + gather_idx=None): + """ The layer to be stacked in decoder part. + The structure of this module is similar to that in the encoder part except + a multi-head attention is added to implement encoder-decoder attention. + """ + slf_attn_output = multi_head_attention( + pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout), + None, + None, + slf_attn_bias, + d_key, + d_value, + d_model, + n_head, + attention_dropout, + cache=cache, + gather_idx=gather_idx) + slf_attn_output = post_process_layer( + dec_input, + slf_attn_output, + postprocess_cmd, + prepostprocess_dropout, ) + enc_attn_output = multi_head_attention( + pre_process_layer(slf_attn_output, preprocess_cmd, + prepostprocess_dropout), + enc_output, + enc_output, + dec_enc_attn_bias, + d_key, + d_value, + d_model, + n_head, + attention_dropout, + cache=cache, + gather_idx=gather_idx, + static_kv=True) + enc_attn_output = post_process_layer( + slf_attn_output, + enc_attn_output, + postprocess_cmd, + prepostprocess_dropout, ) + ffd_output = positionwise_feed_forward( + pre_process_layer(enc_attn_output, preprocess_cmd, + prepostprocess_dropout), + d_inner_hid, + d_model, + relu_dropout, ) + dec_output = post_process_layer( + enc_attn_output, + ffd_output, + postprocess_cmd, + prepostprocess_dropout, ) + return dec_output + + +def decoder(dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + caches=None, + gather_idx=None): + """ + The decoder is composed of a stack of identical decoder_layer layers. + """ + for i in range(n_layer): + dec_output = decoder_layer( + dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + cache=None if caches is None else caches[i], + gather_idx=gather_idx) + dec_input = dec_output + dec_output = pre_process_layer(dec_output, preprocess_cmd, + prepostprocess_dropout) + return dec_output + + +def make_all_inputs(input_fields): + """ + Define the input data layers for the transformer model. + """ + inputs = [] + for input_field in input_fields: + input_var = layers.data( + name=input_field, + shape=input_descs[input_field][0], + dtype=input_descs[input_field][1], + lod_level=input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0, + append_batch_size=False) + inputs.append(input_var) + return inputs + + +def make_all_py_reader_inputs(input_fields, is_test=False): + reader = layers.py_reader( + capacity=20, + name="test_reader" if is_test else "train_reader", + shapes=[input_descs[input_field][0] for input_field in input_fields], + dtypes=[input_descs[input_field][1] for input_field in input_fields], + lod_levels=[ + input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0 + for input_field in input_fields + ]) + return layers.read_file(reader), reader + + +def transformer(src_vocab_size, + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + label_smooth_eps, + bos_idx=0, + use_py_reader=False, + is_test=False): + if weight_sharing: + assert src_vocab_size == trg_vocab_size, ( + "Vocabularies in source and target should be same for weight sharing." + ) + + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + + if use_py_reader: + all_inputs, reader = make_all_py_reader_inputs(data_input_names, + is_test) + else: + all_inputs = make_all_inputs(data_input_names) + # print("all inputs",all_inputs) + enc_inputs_len = len(encoder_data_input_fields) + dec_inputs_len = len(decoder_data_input_fields[:-1]) + enc_inputs = all_inputs[0:enc_inputs_len] + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len] + label = all_inputs[-2] + weights = all_inputs[-1] + + enc_output = wrap_encoder( + src_vocab_size, + ModelHyperParams.src_seq_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + enc_inputs) + + predict = wrap_decoder( + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs, + enc_output, ) + + # Padding index do not contribute to the total loss. The weights is used to + # cancel padding index in calculating the loss. + if label_smooth_eps: + label = layers.label_smooth( + label=layers.one_hot( + input=label, depth=trg_vocab_size), + epsilon=label_smooth_eps) + + cost = layers.softmax_with_cross_entropy( + logits=predict, + label=label, + soft_label=True if label_smooth_eps else False) + weighted_cost = cost * weights + sum_cost = layers.reduce_sum(weighted_cost) + token_num = layers.reduce_sum(weights) + token_num.stop_gradient = True + avg_cost = sum_cost / token_num + return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None + + +def wrap_encoder_forFeature(src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + enc_inputs=None, + bos_idx=0): + """ + The wrapper assembles together all needed layers for the encoder. + img, src_pos, src_slf_attn_bias = enc_inputs + img + """ + + if enc_inputs is None: + # This is used to implement independent encoder program in inference. + conv_features, src_pos, src_slf_attn_bias = make_all_inputs( + encoder_data_input_fields) + else: + conv_features, src_pos, src_slf_attn_bias = enc_inputs# + b,t,c = conv_features.shape + #""" + # insert cnn + #""" + #import basemodel + # feat = basemodel.resnet_50(img) + + # mycrnn = basemodel.CRNN() + # feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu) + # b, c, w, h = feat.shape + # src_word = layers.reshape(feat, shape=[-1, c, w * h]) + + #myconv8 = basemodel.conv8() + #feat = myconv8.net(img ) + #b , c, h, w = feat.shape#h=6 + #print(feat) + #layers.Print(feat,message="conv_feat",summarize=10) + + #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") + #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) + #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] + + #feat = layers.transpose(feat, [0,3,1,2]) + #src_word = layers.reshape(feat,[-1,w, c*h]) + #src_word = layers.im2sequence( + # input=feat, + # stride=[1, 1], + # filter_size=[feat.shape[2], 1]) + #layers.Print(src_word,message="src_word",summarize=10) + + # print('feat',feat) + #print("src_word",src_word) + + enc_input = prepare_encoder( + conv_features, + src_pos, + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name=word_emb_param_names[0]) + + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, ) + return enc_output + +def wrap_encoder(src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + enc_inputs=None, + bos_idx=0): + """ + The wrapper assembles together all needed layers for the encoder. + img, src_pos, src_slf_attn_bias = enc_inputs + img + """ + if enc_inputs is None: + # This is used to implement independent encoder program in inference. + src_word, src_pos, src_slf_attn_bias = make_all_inputs( + encoder_data_input_fields) + else: + src_word, src_pos, src_slf_attn_bias = enc_inputs# + #""" + # insert cnn + #""" + #import basemodel + # feat = basemodel.resnet_50(img) + + # mycrnn = basemodel.CRNN() + # feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu) + # b, c, w, h = feat.shape + # src_word = layers.reshape(feat, shape=[-1, c, w * h]) + + #myconv8 = basemodel.conv8() + #feat = myconv8.net(img ) + #b , c, h, w = feat.shape#h=6 + #print(feat) + #layers.Print(feat,message="conv_feat",summarize=10) + + #feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu") + #feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1)) + #src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww] + + #feat = layers.transpose(feat, [0,3,1,2]) + #src_word = layers.reshape(feat,[-1,w, c*h]) + #src_word = layers.im2sequence( + # input=feat, + # stride=[1, 1], + # filter_size=[feat.shape[2], 1]) + #layers.Print(src_word,message="src_word",summarize=10) + + # print('feat',feat) + #print("src_word",src_word) + enc_input = prepare_decoder( + src_word, + src_pos, + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name=word_emb_param_names[0]) + + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, ) + return enc_output + + +def wrap_decoder(trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs=None, + enc_output=None, + caches=None, + gather_idx=None, + bos_idx=0): + """ + The wrapper assembles together all needed layers for the decoder. + """ + if dec_inputs is None: + # This is used to implement independent decoder program in inference. + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \ + make_all_inputs(decoder_data_input_fields) + else: + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs + + dec_input = prepare_decoder( + trg_word, + trg_pos, + trg_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name=word_emb_param_names[0] + if weight_sharing else word_emb_param_names[1]) + dec_output = decoder( + dec_input, + enc_output, + trg_slf_attn_bias, + trg_src_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + caches=caches, + gather_idx=gather_idx) + return dec_output + # Reshape to 2D tensor to use GEMM instead of BatchedGEMM + dec_output = layers.reshape( + dec_output, shape=[-1, dec_output.shape[-1]], inplace=True) + if weight_sharing: + predict = layers.matmul( + x=dec_output, + y=fluid.default_main_program().global_block().var( + word_emb_param_names[0]), + transpose_y=True) + else: + predict = layers.fc(input=dec_output, + size=trg_vocab_size, + bias_attr=False) + if dec_inputs is None: + # Return probs for independent decoder program. + predict = layers.softmax(predict) + return predict + + +def fast_decode(src_vocab_size, + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + beam_size, + max_out_len, + bos_idx, + eos_idx, + use_py_reader=False): + """ + Use beam search to decode. Caches will be used to store states of history + steps which can make the decoding faster. + """ + data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields + + if use_py_reader: + all_inputs, reader = make_all_py_reader_inputs(data_input_names) + else: + all_inputs = make_all_inputs(data_input_names) + + enc_inputs_len = len(encoder_data_input_fields) + dec_inputs_len = len(fast_decoder_data_input_fields) + enc_inputs = all_inputs[0:enc_inputs_len]#enc_inputs tensor + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]#dec_inputs tensor + + enc_output = wrap_encoder( + src_vocab_size, + ModelHyperParams.src_seq_len,##to do !!!!!???? + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + enc_inputs, + bos_idx=bos_idx) + start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs + + def beam_search(): + max_len = layers.fill_constant( + shape=[1], + dtype=start_tokens.dtype, + value=max_out_len, + force_cpu=True) + step_idx = layers.fill_constant( + shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True) + cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True + while_op = layers.While(cond) + # array states will be stored for each step. + ids = layers.array_write( + layers.reshape(start_tokens, (-1, 1)), step_idx) + scores = layers.array_write(init_scores, step_idx) + # cell states will be overwrited at each step. + # caches contains states of history steps in decoder self-attention + # and static encoder output projections in encoder-decoder attention + # to reduce redundant computation. + caches = [ + { + "k": # for self attention + layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, n_head, 0, d_key], + dtype=enc_output.dtype, + value=0), + "v": # for self attention + layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, n_head, 0, d_value], + dtype=enc_output.dtype, + value=0), + "static_k": # for encoder-decoder attention + layers.create_tensor(dtype=enc_output.dtype), + "static_v": # for encoder-decoder attention + layers.create_tensor(dtype=enc_output.dtype) + } for i in range(n_layer) + ] + + with while_op.block(): + pre_ids = layers.array_read(array=ids, i=step_idx) + # Since beam_search_op dosen't enforce pre_ids' shape, we can do + # inplace reshape here which actually change the shape of pre_ids. + pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True) + pre_scores = layers.array_read(array=scores, i=step_idx) + # gather cell states corresponding to selected parent + pre_src_attn_bias = layers.gather( + trg_src_attn_bias, index=parent_idx) + pre_pos = layers.elementwise_mul( + x=layers.fill_constant_batch_size_like( + input=pre_src_attn_bias, # cann't use lod tensor here + value=1, + shape=[-1, 1, 1], + dtype=pre_ids.dtype), + y=step_idx, + axis=0) + logits = wrap_decoder( + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), + enc_output=enc_output, + caches=caches, + gather_idx=parent_idx, + bos_idx=bos_idx) + # intra-beam topK + topk_scores, topk_indices = layers.topk( + input=layers.softmax(logits), k=beam_size) + accu_scores = layers.elementwise_add( + x=layers.log(topk_scores), y=pre_scores, axis=0) + # beam_search op uses lod to differentiate branches. + accu_scores = layers.lod_reset(accu_scores, pre_ids) + # topK reduction across beams, also contain special handle of + # end beams and end sentences(batch reduction) + selected_ids, selected_scores, gather_idx = layers.beam_search( + pre_ids=pre_ids, + pre_scores=pre_scores, + ids=topk_indices, + scores=accu_scores, + beam_size=beam_size, + end_id=eos_idx, + return_parent_idx=True) + layers.increment(x=step_idx, value=1.0, in_place=True) + # cell states(caches) have been updated in wrap_decoder, + # only need to update beam search states here. + layers.array_write(selected_ids, i=step_idx, array=ids) + layers.array_write(selected_scores, i=step_idx, array=scores) + layers.assign(gather_idx, parent_idx) + layers.assign(pre_src_attn_bias, trg_src_attn_bias) + length_cond = layers.less_than(x=step_idx, y=max_len) + finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) + layers.logical_and(x=length_cond, y=finish_cond, out=cond) + + finished_ids, finished_scores = layers.beam_search_decode( + ids, scores, beam_size=beam_size, end_id=eos_idx) + return finished_ids, finished_scores + + finished_ids, finished_scores = beam_search() + return finished_ids, finished_scores, reader if use_py_reader else None diff --git a/ppocr/modeling/losses/rec_srn_loss.py b/ppocr/modeling/losses/rec_srn_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..68a480ac6a78125e6748ec7586485753d3f217ab --- /dev/null +++ b/ppocr/modeling/losses/rec_srn_loss.py @@ -0,0 +1,58 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + + +class SRNLoss(object): + def __init__(self, params): + super(SRNLoss, self).__init__() + self.char_num = params['char_num'] + + def __call__(self, predicts, others): + predict = predicts['predict'] + word_predict = predicts['word_out'] + gsrm_predict = predicts['gsrm_out'] + label = others['label'] + lbl_weight = others['lbl_weight'] + + casted_label = fluid.layers.cast(x=label, dtype='int64') + cost_word = fluid.layers.cross_entropy(input=word_predict, label=casted_label) + cost_gsrm = fluid.layers.cross_entropy(input=gsrm_predict, label=casted_label) + cost_vsfd = fluid.layers.cross_entropy(input=predict, label=casted_label) + + #cost_word = cost_word * lbl_weight + #cost_gsrm = cost_gsrm * lbl_weight + #cost_vsfd = cost_vsfd * lbl_weight + + cost_word = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_word), shape=[1]) + cost_gsrm = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_gsrm), shape=[1]) + cost_vsfd = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_vsfd), shape=[1]) + + sum_cost = fluid.layers.sum([cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15]) + + #sum_cost = fluid.layers.sum([cost_word * 3.0, cost_vsfd, cost_gsrm * 0.15]) + #sum_cost = cost_word + + #fluid.layers.Print(cost_word,message="word_cost") + #fluid.layers.Print(cost_vsfd,message="img_cost") + return [sum_cost,cost_vsfd,cost_word] + #return [sum_cost, cost_vsfd, cost_word] diff --git a/ppocr/utils/character.py b/ppocr/utils/character.py index 9a3db8dd92454c65256d1cadf7f155b6882ee171..79d6f5ca16cc017c4e194251cb7d6fcb884b02bb 100755 --- a/ppocr/utils/character.py +++ b/ppocr/utils/character.py @@ -25,6 +25,7 @@ class CharacterOps(object): def __init__(self, config): self.character_type = config['character_type'] self.loss_type = config['loss_type'] + self.max_text_len = config['max_text_length'] if self.character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -54,6 +55,8 @@ class CharacterOps(object): self.end_str = "eos" if self.loss_type == "attention": dict_character = [self.beg_str, self.end_str] + dict_character + elif self.loss_type == "srn": + dict_character = dict_character + [self.beg_str, self.end_str] self.dict = {} for i, char in enumerate(dict_character): self.dict[char] = i @@ -146,6 +149,48 @@ def cal_predicts_accuracy(char_ops, acc = acc_num * 1.0 / img_num return acc, acc_num, img_num +def cal_predicts_accuracy_srn(char_ops, + preds, + labels, + max_text_len, + is_debug=False): + acc_num = 0 + img_num = 0 + + total_len = preds.shape[0] + img_num = int(total_len / max_text_len) + #print (img_num) + for i in range(img_num): + cur_label = [] + cur_pred = [] + for j in range(max_text_len): + if labels[j + i * max_text_len] != 37: #0 + cur_label.append(labels[j + i * max_text_len][0]) + else: + break + + if is_debug: + for j in range(max_text_len): + if preds[j + i * max_text_len] != 37: #0 + cur_pred.append(preds[j + i * max_text_len][0]) + else: + break + print ("cur_label: ", cur_label) + print ("cur_pred: ", cur_pred) + + + for j in range(max_text_len + 1): + if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]: + break + elif j == len(cur_label) and j == max_text_len: + acc_num += 1 + break + elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37: + acc_num += 1 + break + acc = acc_num * 1.0 / img_num + return acc, acc_num, img_num + def convert_rec_attention_infer_res(preds): img_num = preds.shape[0] diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index aebb9f903a4d84288da0d7b042ea39e1759456b5..3d496bd3d70e90284951b6f8aaaa2271359c8bc6 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) -from ppocr.utils.character import cal_predicts_accuracy +from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn from ppocr.utils.character import convert_rec_label_to_lod from ppocr.utils.character import convert_rec_attention_infer_res from ppocr.utils.utility import create_module @@ -60,19 +60,52 @@ def eval_rec_run(exe, config, eval_info_dict, mode): for ino in range(img_num): img_list.append(data[ino][0]) label_list.append(data[ino][1]) - img_list = np.concatenate(img_list, axis=0) - outs = exe.run(eval_info_dict['program'], \ + + if config['Global']['loss_type'] != "srn": + img_list = np.concatenate(img_list, axis=0) + outs = exe.run(eval_info_dict['program'], \ feed={'image': img_list}, \ fetch_list=eval_info_dict['fetch_varname_list'], \ return_numpy=False) - preds = np.array(outs[0]) - if preds.shape[1] != 1: - preds, preds_lod = convert_rec_attention_infer_res(preds) + preds = np.array(outs[0]) + + if preds.shape[1] != 1: + preds, preds_lod = convert_rec_attention_infer_res(preds) + else: + preds_lod = outs[0].lod()[0] + labels, labels_lod = convert_rec_label_to_lod(label_list) + acc, acc_num, sample_num = cal_predicts_accuracy( + char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate) else: - preds_lod = outs[0].lod()[0] - labels, labels_lod = convert_rec_label_to_lod(label_list) - acc, acc_num, sample_num = cal_predicts_accuracy( - char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + for ino in range(img_num): + encoder_word_pos_list.append(data[ino][2]) + gsrm_word_pos_list.append(data[ino][3]) + gsrm_slf_attn_bias1_list.append(data[ino][4]) + gsrm_slf_attn_bias2_list.append(data[ino][5]) + + img_list = np.concatenate(img_list, axis=0) + label_list = np.concatenate(label_list, axis=0) + encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64) + gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32) + gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32) + + labels = label_list + + outs = exe.run(eval_info_dict['program'], \ + feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list, + 'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list, + 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \ + fetch_list=eval_info_dict['fetch_varname_list'], \ + return_numpy=False) + preds = np.array(outs[0]) + acc, acc_num, sample_num = cal_predicts_accuracy_srn( + char_ops, preds, labels, config['Global']['max_text_length']) + total_acc_num += acc_num total_sample_num += sample_num logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) @@ -85,8 +118,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode): def test_rec_benchmark(exe, config, eval_info_dict): " Evaluate lmdb dataset " - eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ - 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] + eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', \ + 'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80'] eval_data_dir = config['TestReader']['lmdb_sets_dir'] total_evaluation_data_number = 0 total_correct_number = 0 diff --git a/tools/program.py b/tools/program.py index 4ebc11670702ba627c89b060692c9827e6e163fd..64c827e7ce16072ffbc73b52fb0d40677575b2af 100755 --- a/tools/program.py +++ b/tools/program.py @@ -32,7 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run from ppocr.utils.save_load import save_model import numpy as np -from ppocr.utils.character import cal_predicts_accuracy, CharacterOps +from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps class ArgsParser(ArgumentParser): def __init__(self): @@ -176,8 +176,16 @@ def build(config, main_prog, startup_prog, mode): fetch_name_list = list(outputs.keys()) fetch_varname_list = [outputs[v].name for v in fetch_name_list] opt_loss_name = None + model_average = None + img_loss_name = None + word_loss_name = None if mode == "train": opt_loss = outputs['total_loss'] + # srn loss + #img_loss = outputs['img_loss'] + #word_loss = outputs['word_loss'] + #img_loss_name = img_loss.name + #word_loss_name = word_loss.name opt_params = config['Optimizer'] optimizer = create_module(opt_params['function'])(opt_params) optimizer.minimize(opt_loss) @@ -185,7 +193,13 @@ def build(config, main_prog, startup_prog, mode): global_lr = optimizer._global_learning_rate() fetch_name_list.insert(0, "lr") fetch_varname_list.insert(0, global_lr.name) - return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name) + if config['Global']["loss_type"] == 'srn': + model_average = fluid.optimizer.ModelAverage( + config['Global']['average_window'], + min_average_window=config['Global']['min_average_window'], + max_average_window=config['Global']['max_average_window']) + + return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average) def build_export(config, main_prog, startup_prog): @@ -329,14 +343,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): lr = np.mean(np.array(train_outs[fetch_map['lr']])) preds_idx = fetch_map['decoded_out'] preds = np.array(train_outs[preds_idx]) - preds_lod = train_outs[preds_idx].lod()[0] labels_idx = fetch_map['label'] labels = np.array(train_outs[labels_idx]) - labels_lod = train_outs[labels_idx].lod()[0] - acc, acc_num, img_num = cal_predicts_accuracy( - config['Global']['char_ops'], preds, preds_lod, labels, - labels_lod) + if config['Global']['loss_type'] != 'srn': + preds_lod = train_outs[preds_idx].lod()[0] + labels_lod = train_outs[labels_idx].lod()[0] + + acc, acc_num, img_num = cal_predicts_accuracy( + config['Global']['char_ops'], preds, preds_lod, labels, + labels_lod) + else: + acc, acc_num, img_num = cal_predicts_accuracy_srn( + config['Global']['char_ops'], preds, labels, + config['Global']['max_text_length']) t2 = time.time() train_batch_elapse = t2 - t1 stats = {'loss': loss, 'acc': acc} @@ -350,6 +370,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): if train_batch_id > 0 and\ train_batch_id % eval_batch_step == 0: + model_average = train_info_dict['model_average'] + if model_average != None: + model_average.apply(exe) metrics = eval_rec_run(exe, config, eval_info_dict, "eval") eval_acc = metrics['avg_acc'] eval_sample_num = metrics['total_sample_num'] diff --git a/tools/train.py b/tools/train.py index 68e792b7331a9d47ca6744ea1a9f362979d75542..2ea9d0e011d38f228106409f843c6ed41f10b844 100755 --- a/tools/train.py +++ b/tools/train.py @@ -52,6 +52,7 @@ def main(): train_fetch_name_list = train_build_outputs[1] train_fetch_varname_list = train_build_outputs[2] train_opt_loss_name = train_build_outputs[3] + model_average = train_build_outputs[-1] eval_program = fluid.Program() eval_build_outputs = program.build( @@ -85,7 +86,8 @@ def main(): 'train_program':train_program,\ 'reader':train_loader,\ 'fetch_name_list':train_fetch_name_list,\ - 'fetch_varname_list':train_fetch_varname_list} + 'fetch_varname_list':train_fetch_varname_list,\ + 'model_average': model_average} eval_info_dict = {'program':eval_program,\ 'reader':eval_reader,\