提交 55b76dca 编写于 作者: T Topdu

delete blank lines and modify forward_train

上级 a11e2199
......@@ -46,7 +46,7 @@ Architecture:
name: TransformerOptim
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
Loss:
......
......@@ -27,8 +27,9 @@ def build_backbone(config, model_type):
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_nrtr_mtb import MTB
from .rec_swin import SwinTransformer
support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
support_dict = [
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
support_dict = ["ResNet"]
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import nn
class MTB(nn.Layer):
def __init__(self, cnn_num, in_channels):
super(MTB, self).__init__()
......@@ -8,17 +23,20 @@ class MTB(nn.Layer):
self.cnn_num = cnn_num
if self.cnn_num == 2:
for i in range(self.cnn_num):
self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D(
in_channels = in_channels if i == 0 else 32*(2**(i-1)),
out_channels = 32*(2**i),
kernel_size = 3,
stride = 2,
padding=1))
self.block.add_sublayer(
'conv_{}'.format(i),
nn.Conv2D(
in_channels=in_channels
if i == 0 else 32 * (2**(i - 1)),
out_channels=32 * (2**i),
kernel_size=3,
stride=2,
padding=1))
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i)))
self.block.add_sublayer('bn_{}'.format(i),
nn.BatchNorm2D(32 * (2**i)))
def forward(self, images):
x = self.block(images)
if self.cnn_num == 2:
# (b, w, h, c)
......
......@@ -27,14 +27,13 @@ def build_head(config):
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_nrtr_optim_head import TransformerOptim
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead']
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'
]
#table head
from .table_att_head import TableAttentionHead
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
......@@ -11,7 +25,7 @@ ones_ = constant_(value=1.)
class MultiheadAttentionOptim(nn.Layer):
r"""Allows the model to jointly attend to information
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
......@@ -23,37 +37,43 @@ class MultiheadAttentionOptim(nn.Layer):
embed_dim: total dimension of the model
num_heads: parallel attention layers, or heads
Examples::
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
def __init__(self,
embed_dim,
num_heads,
dropout=0.,
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttentionOptim, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.scaling = self.head_dim**-0.5
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
self._reset_parameters()
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv1 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv2 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
self.conv3 = paddle.nn.Conv2D(
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
def _reset_parameters(self):
xavier_uniform_(self.out_proj.weight)
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
need_weights=True, static_kv=False, attn_mask=None):
def forward(self,
query,
key,
value,
key_padding_mask=None,
incremental_state=None,
need_weights=True,
static_kv=False,
attn_mask=None):
"""
Inputs of forward function
query: [target length, batch size, embed dim]
......@@ -68,8 +88,6 @@ class MultiheadAttentionOptim(nn.Layer):
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
"""
tgt_len, bsz, embed_dim = query.shape
assert embed_dim == self.embed_dim
assert list(query.shape) == [tgt_len, bsz, embed_dim]
......@@ -80,11 +98,12 @@ class MultiheadAttentionOptim(nn.Layer):
v = self._in_proj_v(value)
q *= self.scaling
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
[1, 0, 2])
src_len = k.shape[1]
......@@ -92,44 +111,48 @@ class MultiheadAttentionOptim(nn.Layer):
assert key_padding_mask.shape[0] == bsz
assert key_padding_mask.shape[1] == src_len
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
assert list(attn_output_weights.
shape) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
y = paddle.where(key==0.,key, y)
y = paddle.where(key == 0., key, y)
attn_output_weights += y
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.reshape(
[bsz * self.num_heads, tgt_len, src_len])
attn_output_weights = F.softmax(
attn_output_weights.astype('float32'), axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output_weights.astype('float32'),
axis=-1,
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
else attn_output_weights.dtype)
attn_output_weights = F.dropout(
attn_output_weights, p=self.dropout, training=self.training)
attn_output = paddle.bmm(attn_output_weights, v)
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
assert list(attn_output.
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn_output = attn_output.transpose([1, 0, 2]).reshape(
[tgt_len, bsz, embed_dim])
attn_output = self.out_proj(attn_output)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
attn_output_weights = attn_output_weights.reshape(
[bsz, self.num_heads, tgt_len, src_len])
attn_output_weights = attn_output_weights.sum(
axis=1) / self.num_heads
else:
attn_output_weights = None
return attn_output, attn_output_weights
def _in_proj_q(self, query):
query = query.transpose([1, 2, 0])
query = paddle.unsqueeze(query, axis=2)
......@@ -139,7 +162,6 @@ class MultiheadAttentionOptim(nn.Layer):
return res
def _in_proj_k(self, key):
key = key.transpose([1, 2, 0])
key = paddle.unsqueeze(key, axis=2)
res = self.conv2(key)
......@@ -148,8 +170,7 @@ class MultiheadAttentionOptim(nn.Layer):
return res
def _in_proj_v(self, value):
value = value.transpose([1,2,0])#(1, 2, 0)
value = value.transpose([1, 2, 0]) #(1, 2, 0)
value = paddle.unsqueeze(value, axis=2)
res = self.conv3(value)
res = paddle.squeeze(res, axis=2)
......
......@@ -189,9 +189,9 @@ def train(config,
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
try:
try:
model_type = config['Architecture']['model_type']
except:
except:
model_type = None
if 'start_epoch' in best_model_dict:
......@@ -216,11 +216,8 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table':
if use_srn or model_type == 'table' or use_nrtr:
preds = model(images, data=batch[1:])
elif use_nrtr:
max_len = batch[2].max()
preds = model(images, batch[1][:,:2+max_len])
else:
preds = model(images)
loss = loss_class(preds, batch)
......@@ -405,9 +402,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册