未验证 提交 30d908b6 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #4316 from Topdu/release/2.3

pick fix nrtr export inference model from drgraph to release/2.3
...@@ -46,7 +46,7 @@ Architecture: ...@@ -46,7 +46,7 @@ Architecture:
name: Transformer name: Transformer
d_model: 512 d_model: 512
num_encoder_layers: 6 num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. beam_size: -1 # When Beam size is greater than 0, it means to use beam search when evaluation.
Loss: Loss:
...@@ -65,7 +65,7 @@ Train: ...@@ -65,7 +65,7 @@ Train:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
...@@ -85,7 +85,7 @@ Eval: ...@@ -85,7 +85,7 @@ Eval:
name: LMDBDataSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation/ data_dir: ./train_data/data_lmdb_release/evaluation/
transforms: transforms:
- NRTRDecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
......
...@@ -174,21 +174,26 @@ class NRTRLabelEncode(BaseRecLabelEncode): ...@@ -174,21 +174,26 @@ class NRTRLabelEncode(BaseRecLabelEncode):
super(NRTRLabelEncode, super(NRTRLabelEncode,
self).__init__(max_text_length, character_dict_path, self).__init__(max_text_length, character_dict_path,
character_type, use_space_char) character_type, use_space_char)
def __call__(self, data): def __call__(self, data):
text = data['label'] text = data['label']
text = self.encode(text) text = self.encode(text)
if text is None: if text is None:
return None return None
if len(text) >= self.max_text_len - 1:
return None
data['length'] = np.array(len(text)) data['length'] = np.array(len(text))
text.insert(0, 2) text.insert(0, 2)
text.append(3) text.append(3)
text = text + [0] * (self.max_text_len - len(text)) text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text) data['label'] = np.array(text)
return data return data
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
return dict_character return dict_character
class CTCLabelEncode(BaseRecLabelEncode): class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -44,12 +44,33 @@ class ClsResizeImg(object): ...@@ -44,12 +44,33 @@ class ClsResizeImg(object):
class NRTRRecResizeImg(object): class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs): def __init__(self, image_shape, resize_type, padding=False, **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.resize_type = resize_type self.resize_type = resize_type
self.padding = padding
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
image_shape = self.image_shape
if self.padding:
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
norm_img = np.expand_dims(resized_image, -1)
norm_img = norm_img.transpose((2, 0, 1))
resized_image = norm_img.astype(np.float32) / 128. - 1.
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
data['image'] = padding_im
return data
if self.resize_type == 'PIL': if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img)) image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS) img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
......
...@@ -15,7 +15,6 @@ import numpy as np ...@@ -15,7 +15,6 @@ import numpy as np
import os import os
import random import random
from paddle.io import Dataset from paddle.io import Dataset
from .imaug import transform, create_operators from .imaug import transform, create_operators
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from paddle import nn from paddle import nn
import paddle
class MTB(nn.Layer): class MTB(nn.Layer):
...@@ -40,7 +41,8 @@ class MTB(nn.Layer): ...@@ -40,7 +41,8 @@ class MTB(nn.Layer):
x = self.block(images) x = self.block(images)
if self.cnn_num == 2: if self.cnn_num == 2:
# (b, w, h, c) # (b, w, h, c)
x = x.transpose([0, 3, 2, 1]) x = paddle.transpose(x, [0, 3, 2, 1])
x_shape = x.shape x_shape = paddle.shape(x)
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]]) x = paddle.reshape(
x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
return x return x
...@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer): ...@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer):
value, value,
key_padding_mask=None, key_padding_mask=None,
incremental_state=None, incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None): attn_mask=None):
""" """
Inputs of forward function Inputs of forward function
...@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer): ...@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer):
attn_output: [target length, batch size, embed dim] attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length] attn_output_weights: [batch size, target length, sequence length]
""" """
tgt_len, bsz, embed_dim = query.shape q_shape = paddle.shape(query)
assert embed_dim == self.embed_dim src_shape = paddle.shape(key)
assert list(query.shape) == [tgt_len, bsz, embed_dim]
assert key.shape == value.shape
q = self._in_proj_q(query) q = self._in_proj_q(query)
k = self._in_proj_k(key) k = self._in_proj_k(key)
v = self._in_proj_v(value) v = self._in_proj_v(value)
q *= self.scaling q *= self.scaling
q = paddle.transpose(
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose( paddle.reshape(
[1, 0, 2]) q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( [1, 2, 0, 3])
[1, 0, 2]) k = paddle.transpose(
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose( paddle.reshape(
[1, 0, 2]) k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
src_len = k.shape[1] v = paddle.transpose(
paddle.reshape(
v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
[1, 2, 0, 3])
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.shape[0] == bsz assert key_padding_mask.shape[0] == q_shape[1]
assert key_padding_mask.shape[1] == src_len assert key_padding_mask.shape[1] == src_shape[0]
attn_output_weights = paddle.matmul(q,
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1])) paddle.transpose(k, [0, 1, 3, 2]))
assert list(attn_output_weights.
shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None: if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0) attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
attn_output_weights += attn_mask attn_output_weights += attn_mask
if key_padding_mask is not None: if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape( attn_output_weights = paddle.reshape(
[bsz, self.num_heads, tgt_len, src_len]) attn_output_weights,
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32') [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf') key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
key = paddle.cast(key, 'float32')
y = paddle.full(
shape=paddle.shape(key), dtype='float32', fill_value='-inf')
y = paddle.where(key == 0., key, y) y = paddle.where(key == 0., key, y)
attn_output_weights += y attn_output_weights += y
attn_output_weights = attn_output_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax( attn_output_weights = F.softmax(
attn_output_weights.astype('float32'), attn_output_weights.astype('float32'),
axis=-1, axis=-1,
...@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer): ...@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer):
attn_output_weights = F.dropout( attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training) attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v) attn_output = paddle.matmul(attn_output_weights, v)
assert list(attn_output. attn_output = paddle.reshape(
shape) == [bsz * self.num_heads, tgt_len, self.head_dim] paddle.transpose(attn_output, [2, 0, 1, 3]),
attn_output = attn_output.transpose([1, 0, 2]).reshape( [q_shape[0], q_shape[1], self.embed_dim])
[tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
if need_weights: return attn_output
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(
axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_q(self, query): def _in_proj_q(self, query):
query = query.transpose([1, 2, 0]) query = paddle.transpose(query, [1, 2, 0])
query = paddle.unsqueeze(query, axis=2) query = paddle.unsqueeze(query, axis=2)
res = self.conv1(query) res = self.conv1(query)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
def _in_proj_k(self, key): def _in_proj_k(self, key):
key = key.transpose([1, 2, 0]) key = paddle.transpose(key, [1, 2, 0])
key = paddle.unsqueeze(key, axis=2) key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key) res = self.conv2(key)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
def _in_proj_v(self, value): def _in_proj_v(self, value):
value = value.transpose([1, 2, 0]) #(1, 2, 0) value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2) value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value) res = self.conv3(value)
res = paddle.squeeze(res, axis=2) res = paddle.squeeze(res, axis=2)
res = res.transpose([2, 0, 1]) res = paddle.transpose(res, [2, 0, 1])
return res return res
...@@ -61,12 +61,12 @@ class Transformer(nn.Layer): ...@@ -61,12 +61,12 @@ class Transformer(nn.Layer):
custom_decoder=None, custom_decoder=None,
in_channels=0, in_channels=0,
out_channels=0, out_channels=0,
dst_vocab_size=99,
scale_embedding=True): scale_embedding=True):
super(Transformer, self).__init__() super(Transformer, self).__init__()
self.out_channels = out_channels + 1
self.embedding = Embeddings( self.embedding = Embeddings(
d_model=d_model, d_model=d_model,
vocab=dst_vocab_size, vocab=self.out_channels,
padding_idx=0, padding_idx=0,
scale_embedding=scale_embedding) scale_embedding=scale_embedding)
self.positional_encoding = PositionalEncoding( self.positional_encoding = PositionalEncoding(
...@@ -96,9 +96,10 @@ class Transformer(nn.Layer): ...@@ -96,9 +96,10 @@ class Transformer(nn.Layer):
self.beam_size = beam_size self.beam_size = beam_size
self.d_model = d_model self.d_model = d_model
self.nhead = nhead self.nhead = nhead
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False) self.tgt_word_prj = nn.Linear(
d_model, self.out_channels, bias_attr=False)
w0 = np.random.normal(0.0, d_model**-0.5, w0 = np.random.normal(0.0, d_model**-0.5,
(d_model, dst_vocab_size)).astype(np.float32) (d_model, self.out_channels)).astype(np.float32)
self.tgt_word_prj.weight.set_value(w0) self.tgt_word_prj.weight.set_value(w0)
self.apply(self._init_weights) self.apply(self._init_weights)
...@@ -156,46 +157,41 @@ class Transformer(nn.Layer): ...@@ -156,46 +157,41 @@ class Transformer(nn.Layer):
return self.forward_test(src) return self.forward_test(src)
def forward_test(self, src): def forward_test(self, src):
bs = src.shape[0] bs = paddle.shape(src)[0]
if self.encoder is not None: if self.encoder is not None:
src = self.positional_encoding(src.transpose([1, 0, 2])) src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
memory = self.encoder(src) memory = self.encoder(src)
else: else:
memory = src.squeeze(2).transpose([2, 0, 1]) memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64) dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
for len_dec_seq in range(1, 25): for len_dec_seq in range(1, 25):
src_enc = memory.clone() dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
dec_seq_embed = self.positional_encoding(dec_seq_embed) dec_seq_embed = self.positional_encoding(dec_seq_embed)
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[ tgt_mask = self.generate_square_subsequent_mask(
0]) paddle.shape(dec_seq_embed)[0])
output = self.decoder( output = self.decoder(
dec_seq_embed, dec_seq_embed,
src_enc, memory,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
memory_mask=None, memory_mask=None,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=None,
memory_key_padding_mask=None) memory_key_padding_mask=None)
dec_output = output.transpose([1, 0, 2]) dec_output = paddle.transpose(output, [1, 0, 2])
dec_output = dec_output[:, -1, :]
dec_output = dec_output[:, word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
-1, :] # Pick the last step: (bh * bm) * d_h preds_idx = paddle.argmax(word_prob, axis=1)
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
word_prob = word_prob.reshape([1, bs, -1])
preds_idx = word_prob.argmax(axis=2)
if paddle.equal_all( if paddle.equal_all(
preds_idx[-1], preds_idx,
paddle.full( paddle.full(
preds_idx[-1].shape, 3, dtype='int64')): paddle.shape(preds_idx), 3, dtype='int64')):
break break
preds_prob = paddle.max(word_prob, axis=1)
preds_prob = word_prob.max(axis=2)
dec_seq = paddle.concat( dec_seq = paddle.concat(
[dec_seq, preds_idx.reshape([-1, 1])], axis=1) [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
dec_prob = paddle.concat(
return dec_seq [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
return [dec_seq, dec_prob]
def forward_beam(self, images): def forward_beam(self, images):
''' Translation work in one batch ''' ''' Translation work in one batch '''
...@@ -211,14 +207,15 @@ class Transformer(nn.Layer): ...@@ -211,14 +207,15 @@ class Transformer(nn.Layer):
n_prev_active_inst, n_bm): n_prev_active_inst, n_bm):
''' Collect tensor parts associated to active instances. ''' ''' Collect tensor parts associated to active instances. '''
_, *d_hs = beamed_tensor.shape beamed_tensor_shape = paddle.shape(beamed_tensor)
n_curr_active_inst = len(curr_active_inst_idx) n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs) new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
beamed_tensor_shape[2])
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1]) beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select( beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0) curr_active_inst_idx, axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape]) beamed_tensor = beamed_tensor.reshape(new_shape)
return beamed_tensor return beamed_tensor
...@@ -249,44 +246,26 @@ class Transformer(nn.Layer): ...@@ -249,44 +246,26 @@ class Transformer(nn.Layer):
b.get_current_state() for b in inst_dec_beams if not b.done b.get_current_state() for b in inst_dec_beams if not b.done
] ]
dec_partial_seq = paddle.stack(dec_partial_seq) dec_partial_seq = paddle.stack(dec_partial_seq)
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq]) dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
return dec_partial_seq return dec_partial_seq
def prepare_beam_memory_key_padding_mask(
inst_dec_beams, memory_key_padding_mask, n_bm):
keep = []
for idx in (memory_key_padding_mask):
if not inst_dec_beams[idx].done:
keep.append(idx)
memory_key_padding_mask = memory_key_padding_mask[
paddle.to_tensor(keep)]
len_s = memory_key_padding_mask.shape[-1]
n_inst = memory_key_padding_mask.shape[0]
memory_key_padding_mask = paddle.concat(
[memory_key_padding_mask for i in range(n_bm)], axis=1)
memory_key_padding_mask = memory_key_padding_mask.reshape(
[n_inst * n_bm, len_s]) #repeat(1, n_bm)
return memory_key_padding_mask
def predict_word(dec_seq, enc_output, n_active_inst, n_bm, def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
memory_key_padding_mask): memory_key_padding_mask):
tgt_key_padding_mask = self.generate_padding_mask(dec_seq) dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
dec_seq = self.positional_encoding(dec_seq) dec_seq = self.positional_encoding(dec_seq)
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[ tgt_mask = self.generate_square_subsequent_mask(
0]) paddle.shape(dec_seq)[0])
dec_output = self.decoder( dec_output = self.decoder(
dec_seq, dec_seq,
enc_output, enc_output,
tgt_mask=tgt_mask, tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask, tgt_key_padding_mask=None,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, )
).transpose([1, 0, 2]) dec_output = paddle.transpose(dec_output, [1, 0, 2])
dec_output = dec_output[:, dec_output = dec_output[:,
-1, :] # Pick the last step: (bh * bm) * d_h -1, :] # Pick the last step: (bh * bm) * d_h
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1) word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
word_prob = word_prob.reshape([n_active_inst, n_bm, -1]) word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
return word_prob return word_prob
def collect_active_inst_idx_list(inst_beams, word_prob, def collect_active_inst_idx_list(inst_beams, word_prob,
...@@ -302,9 +281,8 @@ class Transformer(nn.Layer): ...@@ -302,9 +281,8 @@ class Transformer(nn.Layer):
n_active_inst = len(inst_idx_to_position_map) n_active_inst = len(inst_idx_to_position_map)
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
memory_key_padding_mask = None
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
memory_key_padding_mask) None)
# Update the beam with predicted word prob information and collect incomplete instances # Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list = collect_active_inst_idx_list( active_inst_idx_list = collect_active_inst_idx_list(
inst_dec_beams, word_prob, inst_idx_to_position_map) inst_dec_beams, word_prob, inst_idx_to_position_map)
...@@ -324,27 +302,21 @@ class Transformer(nn.Layer): ...@@ -324,27 +302,21 @@ class Transformer(nn.Layer):
with paddle.no_grad(): with paddle.no_grad():
#-- Encode #-- Encode
if self.encoder is not None: if self.encoder is not None:
src = self.positional_encoding(images.transpose([1, 0, 2])) src = self.positional_encoding(images.transpose([1, 0, 2]))
src_enc = self.encoder(src).transpose([1, 0, 2]) src_enc = self.encoder(src)
else: else:
src_enc = images.squeeze(2).transpose([0, 2, 1]) src_enc = images.squeeze(2).transpose([0, 2, 1])
#-- Repeat data for beam search
n_bm = self.beam_size n_bm = self.beam_size
n_inst, len_s, d_h = src_enc.shape src_shape = paddle.shape(src_enc)
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1) inst_dec_beams = [Beam(n_bm) for _ in range(1)]
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose( active_inst_idx_list = list(range(1))
[1, 0, 2]) # Repeat data for beam search
#-- Prepare beams src_enc = paddle.tile(src_enc, [1, n_bm, 1])
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
#-- Bookkeeping for active or not
active_inst_idx_list = list(range(n_inst))
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
active_inst_idx_list) active_inst_idx_list)
#-- Decode # Decode
for len_dec_seq in range(1, 25): for len_dec_seq in range(1, 25):
src_enc_copy = src_enc.clone() src_enc_copy = src_enc.clone()
active_inst_idx_list = beam_decode_step( active_inst_idx_list = beam_decode_step(
...@@ -358,10 +330,19 @@ class Transformer(nn.Layer): ...@@ -358,10 +330,19 @@ class Transformer(nn.Layer):
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
1) 1)
result_hyp = [] result_hyp = []
for bs_hyp in batch_hyp: hyp_scores = []
bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0])) for bs_hyp, score in zip(batch_hyp, batch_scores):
l = len(bs_hyp[0])
bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
result_hyp.append(bs_hyp_pad) result_hyp.append(bs_hyp_pad)
return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64) score = float(score) / l
hyp_score = [score for _ in range(25)]
hyp_scores.append(hyp_score)
return [
paddle.to_tensor(
np.array(result_hyp), dtype=paddle.int64),
paddle.to_tensor(hyp_scores)
]
def generate_square_subsequent_mask(self, sz): def generate_square_subsequent_mask(self, sz):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
...@@ -376,7 +357,7 @@ class Transformer(nn.Layer): ...@@ -376,7 +357,7 @@ class Transformer(nn.Layer):
return mask return mask
def generate_padding_mask(self, x): def generate_padding_mask(self, x):
padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype)) padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
return padding_mask return padding_mask
def _reset_parameters(self): def _reset_parameters(self):
...@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer): ...@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer):
src, src,
src, src,
attn_mask=src_mask, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0] key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2) src = src + self.dropout1(src2)
src = self.norm1(src) src = self.norm1(src)
src = src.transpose([1, 2, 0]) src = paddle.transpose(src, [1, 2, 0])
src = paddle.unsqueeze(src, 2) src = paddle.unsqueeze(src, 2)
src2 = self.conv2(F.relu(self.conv1(src))) src2 = self.conv2(F.relu(self.conv1(src)))
src2 = paddle.squeeze(src2, 2) src2 = paddle.squeeze(src2, 2)
src2 = src2.transpose([2, 0, 1]) src2 = paddle.transpose(src2, [2, 0, 1])
src = paddle.squeeze(src, 2) src = paddle.squeeze(src, 2)
src = src.transpose([2, 0, 1]) src = paddle.transpose(src, [2, 0, 1])
src = src + self.dropout2(src2) src = src + self.dropout2(src2)
src = self.norm2(src) src = self.norm2(src)
...@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer):
tgt, tgt,
tgt, tgt,
attn_mask=tgt_mask, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0] key_padding_mask=tgt_key_padding_mask)
tgt = tgt + self.dropout1(tgt2) tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt) tgt = self.norm1(tgt)
tgt2 = self.multihead_attn( tgt2 = self.multihead_attn(
...@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer):
memory, memory,
memory, memory,
attn_mask=memory_mask, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0] key_padding_mask=memory_key_padding_mask)
tgt = tgt + self.dropout2(tgt2) tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt) tgt = self.norm2(tgt)
# default # default
tgt = tgt.transpose([1, 2, 0]) tgt = paddle.transpose(tgt, [1, 2, 0])
tgt = paddle.unsqueeze(tgt, 2) tgt = paddle.unsqueeze(tgt, 2)
tgt2 = self.conv2(F.relu(self.conv1(tgt))) tgt2 = self.conv2(F.relu(self.conv1(tgt)))
tgt2 = paddle.squeeze(tgt2, 2) tgt2 = paddle.squeeze(tgt2, 2)
tgt2 = tgt2.transpose([2, 0, 1]) tgt2 = paddle.transpose(tgt2, [2, 0, 1])
tgt = paddle.squeeze(tgt, 2) tgt = paddle.squeeze(tgt, 2)
tgt = tgt.transpose([2, 0, 1]) tgt = paddle.transpose(tgt, [2, 0, 1])
tgt = tgt + self.dropout3(tgt2) tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt) tgt = self.norm3(tgt)
...@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer): ...@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer):
(-math.log(10000.0) / dim)) (-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0) pe = paddle.unsqueeze(pe, 0)
pe = pe.transpose([1, 0, 2]) pe = paddle.transpose(pe, [1, 0, 2])
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
def forward(self, x): def forward(self, x):
...@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer): ...@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer):
Examples: Examples:
>>> output = pos_encoder(x) >>> output = pos_encoder(x)
""" """
x = x + self.pe[:x.shape[0], :] x = x + self.pe[:paddle.shape(x)[0], :]
return self.dropout(x) return self.dropout(x)
...@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer): ...@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer):
(-math.log(10000.0) / dim)) (-math.log(10000.0) / dim))
pe[:, 0::2] = paddle.sin(position * div_term) pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term) pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0).transpose([1, 0, 2]) pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
self.register_buffer('pe', pe) self.register_buffer('pe', pe)
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1)) self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
...@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer): ...@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer):
Examples: Examples:
>>> output = pos_encoder(x) >>> output = pos_encoder(x)
""" """
w_pe = self.pe[:x.shape[-1], :] w_pe = self.pe[:paddle.shape(x)[-1], :]
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0) w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
w_pe = w_pe * w1 w_pe = w_pe * w1
w_pe = w_pe.transpose([1, 2, 0]) w_pe = paddle.transpose(w_pe, [1, 2, 0])
w_pe = w_pe.unsqueeze(2) w_pe = paddle.unsqueeze(w_pe, 2)
h_pe = self.pe[:x.shape[-2], :] h_pe = self.pe[:paddle.shape(x).shape[-2], :]
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0) w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
h_pe = h_pe * w2 h_pe = h_pe * w2
h_pe = h_pe.transpose([1, 2, 0]) h_pe = paddle.transpose(h_pe, [1, 2, 0])
h_pe = h_pe.unsqueeze(3) h_pe = paddle.unsqueeze(h_pe, 3)
x = x + w_pe + h_pe x = x + w_pe + h_pe
x = x.reshape( x = paddle.transpose(
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose( paddle.reshape(x,
[2, 0, 1]) [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
[2, 0, 1])
return self.dropout(x) return self.dropout(x)
...@@ -817,7 +799,7 @@ class Beam(): ...@@ -817,7 +799,7 @@ class Beam():
def sort_scores(self): def sort_scores(self):
"Sort the scores." "Sort the scores."
return self.scores, paddle.to_tensor( return self.scores, paddle.to_tensor(
[i for i in range(self.scores.shape[0])], dtype='int32') [i for i in range(int(self.scores.shape[0]))], dtype='int32')
def get_the_best_score_and_idx(self): def get_the_best_score_and_idx(self):
"Get the score of the best in the beam." "Get the score of the best in the beam."
......
...@@ -176,7 +176,19 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -176,7 +176,19 @@ class NRTRLabelDecode(BaseRecLabelDecode):
else: else:
preds_idx = preds preds_idx = preds
text = self.decode(preds_idx) if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]
if isinstance(preds_id, paddle.Tensor):
preds_id = preds_id.numpy()
if isinstance(preds_prob, paddle.Tensor):
preds_prob = preds_prob.numpy()
if preds_id[0][0] == 2:
preds_idx = preds_id[:, 1:]
preds_prob = preds_prob[:, 1:]
else:
preds_idx = preds_id
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None: if label is None:
return text return text
label = self.decode(label[:,1:]) label = self.decode(label[:,1:])
......
...@@ -60,6 +60,8 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -60,6 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
if arch_config["algorithm"] == "NRTR":
infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table": elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488] infer_shape = [3, 488, 488]
model = to_static( model = to_static(
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import sys import sys
from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
...@@ -61,6 +61,13 @@ class TextRecognizer(object): ...@@ -61,6 +61,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
...@@ -87,6 +94,16 @@ class TextRecognizer(object): ...@@ -87,6 +94,16 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == 'NRTR':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# return padding_im
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize([100, 32], Image.ANTIALIAS)
img = np.array(img)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2] assert imgC == img.shape[2]
max_wh_ratio = max(max_wh_ratio, imgW / imgH) max_wh_ratio = max(max_wh_ratio, imgW / imgH)
imgW = int((32 * max_wh_ratio)) imgW = int((32 * max_wh_ratio))
...@@ -252,14 +269,16 @@ class TextRecognizer(object): ...@@ -252,14 +269,16 @@ class TextRecognizer(object):
else: else:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.benchmark: if self.benchmark:
self.autolog.times.stamp() self.autolog.times.stamp()
preds = outputs[0] if len(outputs) != 1:
preds = outputs
else:
preds = outputs[0]
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册