提交 2bf8ad9b 编写于 作者: T Topdu

modify transformeroptim, resize

上级 73058cc0
...@@ -43,7 +43,7 @@ Architecture: ...@@ -43,7 +43,7 @@ Architecture:
name: MTB name: MTB
cnn_num: 2 cnn_num: 2
Head: Head:
name: TransformerOptim name: Transformer
d_model: 512 d_model: 512
num_encoder_layers: 6 num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation. beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
...@@ -69,8 +69,9 @@ Train: ...@@ -69,8 +69,9 @@ Train:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- PILResize: - NRTRRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
...@@ -88,8 +89,9 @@ Eval: ...@@ -88,8 +89,9 @@ Eval:
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- NRTRLabelEncode: # Class handling label - NRTRLabelEncode: # Class handling label
- PILResize: - NRTRRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
......
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .operators import * from .operators import *
......
...@@ -42,30 +42,21 @@ class ClsResizeImg(object): ...@@ -42,30 +42,21 @@ class ClsResizeImg(object):
data['image'] = norm_img data['image'] = norm_img
return data return data
class PILResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
image_pil = Image.fromarray(np.uint8(img))
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
norm_img = np.array(norm_img)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class NRTRRecResizeImg(object):
class CVResize(object): def __init__(self, image_shape, resize_type, **kwargs):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
self.resize_type = resize_type
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
#print(img) if self.resize_type == 'PIL':
norm_img = cv2.resize(img,self.image_shape) image_pil = Image.fromarray(np.uint8(img))
norm_img = np.expand_dims(norm_img, -1) img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1)) norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1. data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data return data
......
...@@ -3,34 +3,26 @@ from paddle import nn ...@@ -3,34 +3,26 @@ from paddle import nn
import paddle.nn.functional as F import paddle.nn.functional as F
def cal_performance(pred, tgt):
pred = pred.max(1)[1]
tgt = tgt.contiguous().view(-1)
non_pad_mask = tgt.ne(0)
n_correct = pred.eq(tgt)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return n_correct
class NRTRLoss(nn.Layer): class NRTRLoss(nn.Layer):
def __init__(self,smoothing=True, **kwargs): def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__() super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0) self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing self.smoothing = smoothing
def forward(self, pred, batch): def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]]) pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max() max_len = batch[2].max()
tgt = batch[1][:,1:2+max_len] tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1] ) tgt = tgt.reshape([-1])
if self.smoothing: if self.smoothing:
eps = 0.1 eps = 0.1
n_class = pred.shape[1] n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1]) one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1) log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64')) non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1) loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean() loss = loss.masked_select(non_pad_mask).mean()
else: else:
......
...@@ -26,13 +26,13 @@ def build_head(config): ...@@ -26,13 +26,13 @@ def build_head(config):
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead from .rec_srn_head import SRNHead
from .rec_nrtr_optim_head import TransformerOptim from .rec_nrtr_head import Transformer
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
support_dict = [ support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead', 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead' 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
] ]
#table head #table head
......
...@@ -24,7 +24,7 @@ zeros_ = constant_(value=0.) ...@@ -24,7 +24,7 @@ zeros_ = constant_(value=0.)
ones_ = constant_(value=1.) ones_ = constant_(value=1.)
class MultiheadAttentionOptim(nn.Layer): class MultiheadAttention(nn.Layer):
"""Allows the model to jointly attend to information """Allows the model to jointly attend to information
from different representation subspaces. from different representation subspaces.
See reference: Attention Is All You Need See reference: Attention Is All You Need
...@@ -46,7 +46,7 @@ class MultiheadAttentionOptim(nn.Layer): ...@@ -46,7 +46,7 @@ class MultiheadAttentionOptim(nn.Layer):
bias=True, bias=True,
add_bias_kv=False, add_bias_kv=False,
add_zero_attn=False): add_zero_attn=False):
super(MultiheadAttentionOptim, self).__init__() super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
......
...@@ -21,7 +21,7 @@ from paddle.nn import LayerList ...@@ -21,7 +21,7 @@ from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_ from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np import numpy as np
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
from paddle.nn.initializer import Constant as constant_ from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_ from paddle.nn.initializer import XavierNormal as xavier_normal_
...@@ -29,7 +29,7 @@ zeros_ = constant_(value=0.) ...@@ -29,7 +29,7 @@ zeros_ = constant_(value=0.)
ones_ = constant_(value=1.) ones_ = constant_(value=1.)
class TransformerOptim(nn.Layer): class Transformer(nn.Layer):
"""A transformer model. User is able to modify the attributes as needed. The architechture """A transformer model. User is able to modify the attributes as needed. The architechture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
...@@ -63,7 +63,7 @@ class TransformerOptim(nn.Layer): ...@@ -63,7 +63,7 @@ class TransformerOptim(nn.Layer):
out_channels=0, out_channels=0,
dst_vocab_size=99, dst_vocab_size=99,
scale_embedding=True): scale_embedding=True):
super(TransformerOptim, self).__init__() super(Transformer, self).__init__()
self.embedding = Embeddings( self.embedding = Embeddings(
d_model=d_model, d_model=d_model,
vocab=dst_vocab_size, vocab=dst_vocab_size,
...@@ -215,8 +215,7 @@ class TransformerOptim(nn.Layer): ...@@ -215,8 +215,7 @@ class TransformerOptim(nn.Layer):
n_curr_active_inst = len(curr_active_inst_idx) n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs) new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.reshape( beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
[n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select( beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0) paddle.to_tensor(curr_active_inst_idx), axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape]) beamed_tensor = beamed_tensor.reshape([*new_shape])
...@@ -486,7 +485,7 @@ class TransformerEncoderLayer(nn.Layer): ...@@ -486,7 +485,7 @@ class TransformerEncoderLayer(nn.Layer):
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
residual_dropout_rate=0.1): residual_dropout_rate=0.1):
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim( self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate) d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = Conv2D( self.conv1 = Conv2D(
...@@ -555,9 +554,9 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -555,9 +554,9 @@ class TransformerDecoderLayer(nn.Layer):
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
residual_dropout_rate=0.1): residual_dropout_rate=0.1):
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim( self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate) d_model, nhead, dropout=attention_dropout_rate)
self.multihead_attn = MultiheadAttentionOptim( self.multihead_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate) d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = Conv2D( self.conv1 = Conv2D(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册