提交 09d8cb6d 编写于 作者: T tink2123

update for srn

上级 1e8f4146
Global:
algorithm: SRN
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output/rec_pvam_withrotate
save_epoch_step: 1
eval_batch_step: 8000
train_batch_size_per_card: 64
test_batch_size_per_card: 1
image_shape: [1, 64, 256]
max_text_length: 25
character_type: en
loss_type: srn
num_heads: 8
average_window: 0.15
max_average_window: 15625
min_average_window: 10000
reader_yml: ./configs/rec/rec_srn_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.rec_srn_all_head,SRNPredict
encoder_type: rnn
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512
SeqRNN:
hidden_size: 256
Loss:
function: ppocr.modeling.losses.rec_srn_loss,SRNLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.0001
beta1: 0.9
beta2: 0.999
......@@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger()
from .img_tools import process_image, get_img_data
from .img_tools import process_image, process_image_srn, get_img_data
class LMDBReader(object):
......@@ -40,6 +40,7 @@ class LMDBReader(object):
self.image_shape = params['image_shape']
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.num_heads = params['num_heads']
self.mode = params['mode']
self.drop_last = False
self.use_tps = False
......@@ -117,14 +118,36 @@ class LMDBReader(object):
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if self.loss_type == 'srn':
norm_img = process_image_srn(
img=img,
image_shape=self.image_shape,
num_heads=self.num_heads,
max_text_length=self.max_text_length
)
else:
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img
elif self.mode == 'test':
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
infer_mode=True
)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
......@@ -144,14 +167,16 @@ class LMDBReader(object):
if sample_info is None:
continue
img, label = sample_info
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length,
distort=self.use_distort)
outs = []
if self.loss_type == "srn":
outs = process_image_srn(img, self.image_shape, self.num_heads,
self.max_text_length, label,
self.char_ops, self.loss_type)
else:
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
......@@ -159,7 +184,6 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
......@@ -167,9 +191,8 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
if len(batch_outs) != 0:
yield batch_outs
if self.infer_img is None:
return batch_iter_reader
......@@ -288,4 +311,4 @@ class SimpleReader(object):
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader
return sample_iter_reader
\ No newline at end of file
......@@ -381,3 +381,84 @@ def process_image(img,
assert False, "Unsupport loss_type %s in process_image"\
% loss_type
return (norm_img)
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape,
num_heads,
max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
def process_image_srn(img,
image_shape,
num_heads,
max_text_length,
label=None,
char_ops=None,
loss_type=None):
norm_img = resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(image_shape, num_heads, max_text_length)
if label is not None:
char_num = char_ops.get_char_num()
text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length:
return None
else:
if loss_type == "srn":
text_padded = [37] * max_text_length
for i in range(len(text)):
text_padded[i] = text[i]
lbl_weight[i] = [1.0]
text_padded = np.array(text_padded)
text = text_padded.reshape(-1, 1)
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
else:
assert False, "Unsupport loss_type %s in process_image"\
% loss_type
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
......@@ -58,6 +58,7 @@ class RecModel(object):
self.loss_type = global_params['loss_type']
self.image_shape = global_params['image_shape']
self.max_text_length = global_params['max_text_length']
self.num_heads = global_params["num_heads"]
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
......@@ -77,6 +78,18 @@ class RecModel(object):
lod_level=1)
feed_list = [image, label_in, label_out]
labels = {'label_in': label_in, 'label_out': label_out}
elif self.loss_type == "srn":
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data(
name='label', shape=[-1, 1], dtype='int32', lod_level=1)
feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight]
labels = {'label': label, 'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight}
else:
label = fluid.data(
name='label', shape=[None, 1], dtype='int32', lod_level=1)
......@@ -88,6 +101,8 @@ class RecModel(object):
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
......@@ -97,9 +112,15 @@ class RecModel(object):
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None
loader = None
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "srn":
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2}
return image, labels, loader
def __call__(self, mode):
......@@ -117,9 +138,15 @@ class RecModel(object):
label = labels['label_out']
else:
label = labels['label']
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
if self.loss_type == 'srn':
total_loss, img_loss, word_loss = self.loss(predicts, labels)
outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss,
'decoded_out':decoded_out, 'label':label}
else:
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
return loader, outputs
elif mode == "export":
predict = predicts['predict']
if self.loss_type == "ctc":
......@@ -129,4 +156,4 @@ class RecModel(object):
predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
return loader, {'decoded_out': decoded_out, 'predicts': predict}
\ No newline at end of file
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
Trainable = True
w_nolr = fluid.ParamAttr(
trainable = Trainable)
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, params):
self.layers = params['layers']
self.params = train_parameters
def __call__(self, input):
layers = self.layers
supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
stride_list = [(2,2),(2,2),(1,1),(1,1)]
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1")
F = []
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=stride_list[block] if i == 0 else 1, name=conv_name)
F.append(conv)
base = F[-1]
for i in [-2, -3]:
b, c, w, h = F[i].shape
if (w,h) == base.shape[2:]:
base = base
else:
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2,
padding=1,act=None,
param_attr=w_nolr,
bias_attr=w_nolr)
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.concat([base, F[i]], axis=1)
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr)
return base
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size= 2 if stride==(1,1) else filter_size,
dilation = 2 if stride==(1,1) else 1,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable),
bias_attr=False,
name=name + '.conv2d.output.1')
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable),
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True:
if stride == (1,1):
return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
else: #stride == (2,2)
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
name=name + "_branch2b")
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
......@@ -32,7 +32,7 @@ class ResNet():
def __init__(self, params):
self.layers = params['layers']
self.is_3x3 = True
supported_layers = [18, 34, 50, 101, 152, 200]
supported_layers = [18, 34, 50, 101, 152]
assert self.layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
#from .rec_seq_encoder import SequenceEncoder
#from ..common_functions import get_para_bias_attr
import numpy as np
from .self_attention.model import wrap_encoder
from .self_attention.model import wrap_encoder_forFeature
gradient_clip = 10
class SRNPredict(object):
def __init__(self, params):
super(SRNPredict, self).__init__()
self.char_num = params['char_num']
self.max_length = params['max_text_length']
self.num_heads = params['num_heads']
self.num_encoder_TUs = params['num_encoder_TUs']
self.num_decoder_TUs = params['num_decoder_TUs']
self.hidden_dims = params['hidden_dims']
def pvam(self, inputs, others):
b, c, h, w = inputs.shape
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
#===== Transformer encoder =====
b, t, c = conv_features.shape
encoder_word_pos = others["encoder_word_pos"]
gsrm_word_pos = others["gsrm_word_pos"]
enc_inputs = [conv_features, encoder_word_pos, None]
word_features = wrap_encoder_forFeature(src_vocab_size=-1,
max_length=t,
n_layer=self.num_encoder_TUs,
n_head=self.num_heads,
d_key= int(self.hidden_dims / self.num_heads),
d_value= int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs,
)
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip))
#===== Parallel Visual Attention Module =====
b, t, c = word_features.shape
word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2)
word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c])
word_features_ = fluid.layers.expand(word_features_, [1, self.max_length, 1, 1])
word_pos_feature = fluid.layers.embedding(gsrm_word_pos, [self.max_length, c])
word_pos_ = fluid.layers.reshape(word_pos_feature, [-1, self.max_length, 1, c])
word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1])
temp = fluid.layers.elementwise_add(word_features_, word_pos_, act='tanh')
attention_weight = fluid.layers.fc(input=temp, size=1, num_flatten_dims=3, bias_attr=False)
attention_weight = fluid.layers.reshape(x=attention_weight, shape=[-1, self.max_length, t])
attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1)
pvam_features = fluid.layers.matmul(attention_weight, word_features)#[b, max_length, c]
return pvam_features
def gsrm(self, pvam_features, others):
#===== GSRM Visual-to-semantic embedding block =====
b, t, c = pvam_features.shape
word_out = fluid.layers.fc(input=fluid.layers.reshape(pvam_features, [-1, c]),
size=self.char_num,
act="softmax")
#word_out.stop_gradient = True
word_ids = fluid.layers.argmax(word_out, axis=1)
word_ids.stop_gradient = True
word_ids = fluid.layers.reshape(x=word_ids, shape=[-1, t, 1])
#===== GSRM Semantic reasoning block =====
"""
This module is achieved through bi-transformers,
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
"""
pad_idx = self.char_num
gsrm_word_pos = others["gsrm_word_pos"]
gsrm_slf_attn_bias1 = others["gsrm_slf_attn_bias1"]
gsrm_slf_attn_bias2 = others["gsrm_slf_attn_bias2"]
def prepare_bi(word_ids):
"""
prepare bi for gsrm
word1 for forward; word2 for backward
"""
word1 = fluid.layers.cast(word_ids, "float32")
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], pad_value=1.0 * pad_idx)
word1 = fluid.layers.cast(word1, "int64")
word1 = word1[:, :-1, :]
word2 = word_ids
return word1, word2
word1, word2 = prepare_bi(word_ids)
word1.stop_gradient = True
word2.stop_gradient = True
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
gsrm_feature1 = wrap_encoder(src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs_1,
)
gsrm_feature2 = wrap_encoder(src_vocab_size=self.char_num + 1,
max_length=self.max_length,
n_layer=self.num_decoder_TUs,
n_head=self.num_heads,
d_key=int(self.hidden_dims / self.num_heads),
d_value=int(self.hidden_dims / self.num_heads),
d_model=self.hidden_dims,
d_inner_hid=self.hidden_dims,
prepostprocess_dropout=0.1,
attention_dropout=0.1,
relu_dropout=0.1,
preprocess_cmd="n",
postprocess_cmd="da",
weight_sharing=True,
enc_inputs=enc_inputs_2,
)
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], pad_value=0.)
gsrm_feature2 = gsrm_feature2[:, 1:, ]
gsrm_features = gsrm_feature1 + gsrm_feature2
b, t, c = gsrm_features.shape
gsrm_out = fluid.layers.matmul(
x=gsrm_features,
y=fluid.default_main_program().global_block().var("src_word_emb_table"),
transpose_y=True)
b,t,c = gsrm_out.shape
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, [-1, c]))
return gsrm_features, word_out, gsrm_out
def vsfd(self, pvam_features, gsrm_features):
#===== Visual-Semantic Fusion Decoder Module =====
b, t, c1 = pvam_features.shape
b, t, c2 = gsrm_features.shape
combine_features_ = fluid.layers.concat([pvam_features, gsrm_features], axis=2)
img_comb_features_ = fluid.layers.reshape(x=combine_features_, shape=[-1, c1 + c2])
img_comb_features_map = fluid.layers.fc(input=img_comb_features_, size=c1, act="sigmoid")
img_comb_features_map = fluid.layers.reshape(x=img_comb_features_map, shape=[-1, t, c1])
combine_features = img_comb_features_map * pvam_features + (1.0 - img_comb_features_map) * gsrm_features
img_comb_features = fluid.layers.reshape(x=combine_features, shape=[-1, c1])
fc_out = fluid.layers.fc(input=img_comb_features,
size=self.char_num,
act="softmax")
return fc_out
def __call__(self, inputs, others, mode=None):
pvam_features = self.pvam(inputs, others)
gsrm_features, word_out, gsrm_out = self.gsrm(pvam_features, others)
final_out = self.vsfd(pvam_features, gsrm_features)
_, decoded_out = fluid.layers.topk(input=final_out, k=1)
predicts = {'predict': final_out, 'decoded_out': decoded_out,
'word_out': word_out, 'gsrm_out': gsrm_out}
return predicts
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from .desc import *
from .config import ModelHyperParams,TrainTaskConfig
def wrap_layer_with_block(layer, block_idx):
"""
Make layer define support indicating block, by which we can add layers
to other blocks within current block. This will make it easy to define
cache among while loop.
"""
class BlockGuard(object):
"""
BlockGuard class.
BlockGuard class is used to switch to the given block in a program by
using the Python `with` keyword.
"""
def __init__(self, block_idx=None, main_program=None):
self.main_program = fluid.default_main_program(
) if main_program is None else main_program
self.old_block_idx = self.main_program.current_block().idx
self.new_block_idx = block_idx
def __enter__(self):
self.main_program.current_block_idx = self.new_block_idx
def __exit__(self, exc_type, exc_val, exc_tb):
self.main_program.current_block_idx = self.old_block_idx
if exc_type is not None:
return False # re-raise exception
return True
def layer_wrapper(*args, **kwargs):
with BlockGuard(block_idx):
return layer(*args, **kwargs)
return layer_wrapper
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
gather_idx=None,
static_kv=False):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
fc_layer = wrap_layer_with_block(
layers.fc, fluid.default_main_program().current_block()
.parent_idx) if cache is not None and static_kv else layers.fc
k = fc_layer(
input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = fc_layer(
input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Reshape input tensors at the last dimension to split multi-heads
and then transpose. Specifically, transform the input tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] to the output tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped_q = layers.reshape(
x=queries, shape=[0, 0, n_head, d_key], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
reshape_layer = wrap_layer_with_block(
layers.reshape,
fluid.default_main_program().current_block()
.parent_idx) if cache is not None and static_kv else layers.reshape
transpose_layer = wrap_layer_with_block(
layers.transpose,
fluid.default_main_program().current_block().
parent_idx) if cache is not None and static_kv else layers.transpose
reshaped_k = reshape_layer(
x=keys, shape=[0, 0, n_head, d_key], inplace=True)
k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = reshape_layer(
x=values, shape=[0, 0, n_head, d_value], inplace=True)
v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3])
if cache is not None: # only for faster inference
if static_kv: # For encoder-decoder attention in inference
cache_k, cache_v = cache["static_k"], cache["static_v"]
# To init the static_k and static_v in cache.
# Maybe we can use condition_op(if_else) to do these at the first
# step in while loop to replace these, however it might be less
# efficient.
static_cache_init = wrap_layer_with_block(
layers.assign,
fluid.default_main_program().current_block().parent_idx)
static_cache_init(k, cache_k)
static_cache_init(v, cache_v)
else: # For decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
# gather cell states corresponding to selected parent
select_k = layers.gather(cache_k, index=gather_idx)
select_v = layers.gather(cache_v, index=gather_idx)
if not static_kv:
# For self attention in inference, use cache and concat time steps.
select_k = layers.concat([select_k, k], axis=2)
select_v = layers.concat([select_v, v], axis=2)
# update cell states(caches) cached in global block
layers.assign(select_k, cache_k)
layers.assign(select_v, cache_v)
return q, select_k, select_v
return q, k, v
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
# print(q)
# print(k)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False)
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder(src_word,#[b,t,c]
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
bos_idx=0,
word_emb_param_name=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb =src_word#layers.concat(res,axis=1)
src_word_emb=layers.cast(src_word_emb,'float32')
# print("src_word_emb",src_word_emb)
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input
def prepare_decoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
bos_idx=0,
word_emb_param_name=None,
pos_enc_param_name=None):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
# print("target_word_emb",src_word_emb)
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name, trainable=False))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input
# prepare_encoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
# prepare_decoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias, d_key,
d_value, d_model, n_head, attention_dropout)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None,
gather_idx=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx)
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx,
static_kv=True)
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model,
relu_dropout, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout, )
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None,
gather_idx=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i],
gather_idx=gather_idx)
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[input_descs[input_field][0] for input_field in input_fields],
dtypes=[input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0
for input_field in input_fields
])
return layers.read_file(reader), reader
def transformer(src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
label_smooth_eps,
bos_idx=0,
use_py_reader=False,
is_test=False):
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names,
is_test)
else:
all_inputs = make_all_inputs(data_input_names)
# print("all inputs",all_inputs)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
label = all_inputs[-2]
weights = all_inputs[-1]
enc_output = wrap_encoder(
src_vocab_size,
ModelHyperParams.src_seq_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs,
enc_output, )
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.label_smooth(
label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None
def wrap_encoder_forFeature(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the encoder.
img, src_pos, src_slf_attn_bias = enc_inputs
img
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
conv_features, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
conv_features, src_pos, src_slf_attn_bias = enc_inputs#
b,t,c = conv_features.shape
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input = prepare_encoder(
conv_features,
src_pos,
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
return enc_output
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the encoder.
img, src_pos, src_slf_attn_bias = enc_inputs
img
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias = enc_inputs#
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input = prepare_decoder(
src_word,
src_pos,
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0])
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=None,
enc_output=None,
caches=None,
gather_idx=None,
bos_idx=0):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
make_all_inputs(decoder_data_input_fields)
else:
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
bos_idx=bos_idx,
word_emb_param_name=word_emb_param_names[0]
if weight_sharing else word_emb_param_names[1])
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches,
gather_idx=gather_idx)
return dec_output
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
if weight_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
word_emb_param_names[0]),
transpose_y=True)
else:
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
return predict
def fast_decode(src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
bos_idx,
eos_idx,
use_py_reader=False):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]#enc_inputs tensor
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]#dec_inputs tensor
enc_output = wrap_encoder(
src_vocab_size,
ModelHyperParams.src_seq_len,##to do !!!!!????
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
enc_inputs,
bos_idx=bos_idx)
start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs
def beam_search():
max_len = layers.fill_constant(
shape=[1],
dtype=start_tokens.dtype,
value=max_out_len,
force_cpu=True)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps in decoder self-attention
# and static encoder output projections in encoder-decoder attention
# to reduce redundant computation.
caches = [
{
"k": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_key],
dtype=enc_output.dtype,
value=0),
"v": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_value],
dtype=enc_output.dtype,
value=0),
"static_k": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype),
"static_v": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype)
} for i in range(n_layer)
]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
pre_scores = layers.array_read(array=scores, i=step_idx)
# gather cell states corresponding to selected parent
pre_src_attn_bias = layers.gather(
trg_src_attn_bias, index=parent_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_src_attn_bias, # cann't use lod tensor here
value=1,
shape=[-1, 1, 1],
dtype=pre_ids.dtype),
y=step_idx,
axis=0)
logits = wrap_decoder(
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=enc_output,
caches=caches,
gather_idx=parent_idx,
bos_idx=bos_idx)
# intra-beam topK
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches.
accu_scores = layers.lod_reset(accu_scores, pre_ids)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx,
return_parent_idx=True)
layers.increment(x=step_idx, value=1.0, in_place=True)
# cell states(caches) have been updated in wrap_decoder,
# only need to update beam search states here.
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(gather_idx, parent_idx)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=beam_size, end_id=eos_idx)
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores, reader if use_py_reader else None
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
class SRNLoss(object):
def __init__(self, params):
super(SRNLoss, self).__init__()
self.char_num = params['char_num']
def __call__(self, predicts, others):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = others['label']
lbl_weight = others['lbl_weight']
casted_label = fluid.layers.cast(x=label, dtype='int64')
cost_word = fluid.layers.cross_entropy(input=word_predict, label=casted_label)
cost_gsrm = fluid.layers.cross_entropy(input=gsrm_predict, label=casted_label)
cost_vsfd = fluid.layers.cross_entropy(input=predict, label=casted_label)
#cost_word = cost_word * lbl_weight
#cost_gsrm = cost_gsrm * lbl_weight
#cost_vsfd = cost_vsfd * lbl_weight
cost_word = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_word), shape=[1])
cost_gsrm = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
cost_vsfd = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
sum_cost = fluid.layers.sum([cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
#sum_cost = fluid.layers.sum([cost_word * 3.0, cost_vsfd, cost_gsrm * 0.15])
#sum_cost = cost_word
#fluid.layers.Print(cost_word,message="word_cost")
#fluid.layers.Print(cost_vsfd,message="img_cost")
return [sum_cost,cost_vsfd,cost_word]
#return [sum_cost, cost_vsfd, cost_word]
......@@ -25,6 +25,7 @@ class CharacterOps(object):
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
......@@ -54,6 +55,8 @@ class CharacterOps(object):
self.end_str = "eos"
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
......@@ -146,6 +149,48 @@ def cal_predicts_accuracy(char_ops,
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def cal_predicts_accuracy_srn(char_ops,
preds,
labels,
max_text_len,
is_debug=False):
acc_num = 0
img_num = 0
total_len = preds.shape[0]
img_num = int(total_len / max_text_len)
#print (img_num)
for i in range(img_num):
cur_label = []
cur_pred = []
for j in range(max_text_len):
if labels[j + i * max_text_len] != 37: #0
cur_label.append(labels[j + i * max_text_len][0])
else:
break
if is_debug:
for j in range(max_text_len):
if preds[j + i * max_text_len] != 37: #0
cur_pred.append(preds[j + i * max_text_len][0])
else:
break
print ("cur_label: ", cur_label)
print ("cur_pred: ", cur_pred)
for j in range(max_text_len + 1):
if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]:
break
elif j == len(cur_label) and j == max_text_len:
acc_num += 1
break
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
acc_num += 1
break
acc = acc_num * 1.0 / img_num
return acc, acc_num, img_num
def convert_rec_attention_infer_res(preds):
img_num = preds.shape[0]
......
......@@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
from ppocr.utils.character import cal_predicts_accuracy
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
from ppocr.utils.character import convert_rec_label_to_lod
from ppocr.utils.character import convert_rec_attention_infer_res
from ppocr.utils.utility import create_module
......@@ -60,19 +60,52 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
for ino in range(img_num):
img_list.append(data[ino][0])
label_list.append(data[ino][1])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
if config['Global']['loss_type'] != "srn":
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
preds = np.array(outs[0])
if preds.shape[1] != 1:
preds, preds_lod = convert_rec_attention_infer_res(preds)
preds = np.array(outs[0])
if preds.shape[1] != 1:
preds, preds_lod = convert_rec_attention_infer_res(preds)
else:
preds_lod = outs[0].lod()[0]
labels, labels_lod = convert_rec_label_to_lod(label_list)
acc, acc_num, sample_num = cal_predicts_accuracy(
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
else:
preds_lod = outs[0].lod()[0]
labels, labels_lod = convert_rec_label_to_lod(label_list)
acc, acc_num, sample_num = cal_predicts_accuracy(
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
for ino in range(img_num):
encoder_word_pos_list.append(data[ino][2])
gsrm_word_pos_list.append(data[ino][3])
gsrm_slf_attn_bias1_list.append(data[ino][4])
gsrm_slf_attn_bias2_list.append(data[ino][5])
img_list = np.concatenate(img_list, axis=0)
label_list = np.concatenate(label_list, axis=0)
encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
labels = label_list
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
preds = np.array(outs[0])
acc, acc_num, sample_num = cal_predicts_accuracy_srn(
char_ops, preds, labels, config['Global']['max_text_length'])
total_acc_num += acc_num
total_sample_num += sample_num
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
......@@ -85,8 +118,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict):
" Evaluate lmdb dataset "
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', \
'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir']
total_evaluation_data_number = 0
total_correct_number = 0
......
......@@ -32,7 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run
from ppocr.utils.save_load import save_model
import numpy as np
from ppocr.utils.character import cal_predicts_accuracy, CharacterOps
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
class ArgsParser(ArgumentParser):
def __init__(self):
......@@ -176,8 +176,16 @@ def build(config, main_prog, startup_prog, mode):
fetch_name_list = list(outputs.keys())
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
opt_loss_name = None
model_average = None
img_loss_name = None
word_loss_name = None
if mode == "train":
opt_loss = outputs['total_loss']
# srn loss
#img_loss = outputs['img_loss']
#word_loss = outputs['word_loss']
#img_loss_name = img_loss.name
#word_loss_name = word_loss.name
opt_params = config['Optimizer']
optimizer = create_module(opt_params['function'])(opt_params)
optimizer.minimize(opt_loss)
......@@ -185,7 +193,13 @@ def build(config, main_prog, startup_prog, mode):
global_lr = optimizer._global_learning_rate()
fetch_name_list.insert(0, "lr")
fetch_varname_list.insert(0, global_lr.name)
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
if config['Global']["loss_type"] == 'srn':
model_average = fluid.optimizer.ModelAverage(
config['Global']['average_window'],
min_average_window=config['Global']['min_average_window'],
max_average_window=config['Global']['max_average_window'])
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average)
def build_export(config, main_prog, startup_prog):
......@@ -329,14 +343,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
preds_idx = fetch_map['decoded_out']
preds = np.array(train_outs[preds_idx])
preds_lod = train_outs[preds_idx].lod()[0]
labels_idx = fetch_map['label']
labels = np.array(train_outs[labels_idx])
labels_lod = train_outs[labels_idx].lod()[0]
acc, acc_num, img_num = cal_predicts_accuracy(
config['Global']['char_ops'], preds, preds_lod, labels,
labels_lod)
if config['Global']['loss_type'] != 'srn':
preds_lod = train_outs[preds_idx].lod()[0]
labels_lod = train_outs[labels_idx].lod()[0]
acc, acc_num, img_num = cal_predicts_accuracy(
config['Global']['char_ops'], preds, preds_lod, labels,
labels_lod)
else:
acc, acc_num, img_num = cal_predicts_accuracy_srn(
config['Global']['char_ops'], preds, labels,
config['Global']['max_text_length'])
t2 = time.time()
train_batch_elapse = t2 - t1
stats = {'loss': loss, 'acc': acc}
......@@ -350,6 +370,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
if train_batch_id > 0 and\
train_batch_id % eval_batch_step == 0:
model_average = train_info_dict['model_average']
if model_average != None:
model_average.apply(exe)
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
eval_acc = metrics['avg_acc']
eval_sample_num = metrics['total_sample_num']
......
......@@ -52,6 +52,7 @@ def main():
train_fetch_name_list = train_build_outputs[1]
train_fetch_varname_list = train_build_outputs[2]
train_opt_loss_name = train_build_outputs[3]
model_average = train_build_outputs[-1]
eval_program = fluid.Program()
eval_build_outputs = program.build(
......@@ -85,7 +86,8 @@ def main():
'train_program':train_program,\
'reader':train_loader,\
'fetch_name_list':train_fetch_name_list,\
'fetch_varname_list':train_fetch_varname_list}
'fetch_varname_list':train_fetch_varname_list,\
'model_average': model_average}
eval_info_dict = {'program':eval_program,\
'reader':eval_reader,\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册