提交 59cc4efd 编写于 作者: T tink2123

add for SEED

上级 38801c7f
Global:
use_gpu: False
epoch_num: 400
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec/b3_rare_r34_none_gru/
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
character_type: EN_symbol
max_text_length: 25
infer_mode: False
use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
learning_rate: 0.0005
regularizer:
name: 'L2'
factor: 0.00000
Architecture:
model_type: rec
algorithm: ASTER
Transform:
name: STN_ON
tps_inputsize: [32, 64]
tps_outputsize: [32, 100]
num_control_points: 20
tps_margins: [0.05,0.05]
stn_activation: none
Backbone:
name: ResNet_ASTER
Head:
name: AsterHead # AttentionHead
sDim: 512
attDim: 512
max_len_labels: 100
Loss:
name: AsterLoss
PostProcess:
name: AttnLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/1.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- AttnLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 2
drop_last: True
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list: ["./train_data/ic15_data/1.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- AttnLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 100]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 2
num_workers: 8
......@@ -104,6 +104,7 @@ class BaseRecLabelEncode(object):
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
self.unknown = "UNKNOWN"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
......@@ -275,7 +276,9 @@ class AttnLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = [self.beg_str] + dict_character + [self.end_str]
self.unknown = "UNKNOWN"
dict_character = [self.beg_str] + dict_character + [self.end_str
] + [self.unknown]
return dict_character
def __call__(self, data):
......@@ -288,6 +291,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- len(text) - 2)
data['label'] = np.array(text)
return data
......@@ -352,19 +356,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
% beg_or_end
return idx
class TableLabelEncode(object):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight = 1.0,
span_weight=1.0,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
......@@ -383,10 +390,10 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1+character_num):
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n")
list_character.append(character)
for eno in range(1+character_num, 1+character_num+elem_num):
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n")
list_elem.append(elem)
return list_character, list_elem
......@@ -412,18 +419,22 @@ class TableLabelEncode(object):
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
)
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
......@@ -450,9 +461,11 @@ class TableLabelEncode(object):
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num])
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data
def encode(self, text, char_or_elem):
......@@ -509,4 +522,3 @@ class TableLabelEncode(object):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
\ No newline at end of file
......@@ -22,6 +22,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
print("===== simpledataset ========")
super(SimpleDataSet, self).__init__()
self.logger = logger
self.mode = mode.lower()
......
......@@ -41,10 +41,13 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
from .rec_aster_loss import AsterLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'AsterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import fasttext
class AsterLoss(nn.Layer):
def __init__(self,
weight=None,
size_average=True,
ignore_index=-100,
sequence_normalize=False,
sample_normalize=True,
**kwargs):
super(AsterLoss, self).__init__()
self.weight = weight
self.size_average = size_average
self.ignore_index = ignore_index
self.sequence_normalize = sequence_normalize
self.sample_normalize = sample_normalize
self.loss_func = paddle.nn.CosineSimilarity()
def forward(self, predicts, batch):
targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64')
# sem_target = batch[3].astype('float32')
embedding_vectors = predicts['embedding_vectors']
rec_pred = predicts['rec_pred']
# semantic loss
# print(embedding_vectors)
# print(embedding_vectors.shape)
# targets = fasttext[targets]
# sem_loss = 1 - self.loss_func(embedding_vectors, targets)
# rec loss
batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[
1], rec_pred.shape[2]
assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
mask = paddle.zeros([batch_size, num_steps])
for i in range(batch_size):
mask[i, :label_lengths[i]] = 1
mask = paddle.cast(mask, "float32")
max_length = max(label_lengths)
assert max_length == rec_pred.shape[1]
targets = targets[:, :max_length]
mask = mask[:, :max_length]
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]])
input = nn.functional.log_softmax(rec_pred, axis=1)
targets = paddle.reshape(targets, [-1, 1])
mask = paddle.reshape(mask, [-1, 1])
# print("input:", input)
output = -paddle.gather(input, index=targets, axis=1) * mask
output = paddle.sum(output)
if self.sequence_normalize:
output = output / paddle.sum(mask)
if self.sample_normalize:
output = output / batch_size
loss = output
return {'loss': loss} # , 'sem_loss':sem_loss}
......@@ -35,5 +35,7 @@ class AttentionLoss(nn.Layer):
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1])
print("input:", paddle.argmax(inputs, axis=1))
print("targets:", targets)
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
......@@ -26,8 +26,10 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_resnet_aster import ResNet_ASTER
support_dict = [
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN",
"ResNet_ASTER"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
......
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import sys
import math
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2D(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2D(
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
# [n_position]
positions = paddle.arange(0, n_position)
# [feat_dim]
dim_range = paddle.arange(0, feat_dim)
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
# [n_position, feat_dim]
angles = paddle.unsqueeze(
positions, axis=1) / paddle.unsqueeze(
dim_range, axis=0)
angles = paddle.cast(angles, "float32")
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
return angles
class AsterBlock(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(AsterBlock, self).__init__()
self.conv1 = conv1x1(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2D(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2D(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet_ASTER(nn.Layer):
"""For aster or crnn"""
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
super(ResNet_ASTER, self).__init__()
self.with_lstm = with_lstm
self.n_group = n_group
self.layer0 = nn.Sequential(
nn.Conv2D(
in_channels,
32,
kernel_size=(3, 3),
stride=1,
padding=1,
bias_attr=False),
nn.BatchNorm2D(32),
nn.ReLU())
self.inplanes = 32
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
if with_lstm:
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
self.out_channels = 2 * 256
else:
self.out_channels = 512
def _make_layer(self, planes, blocks, stride):
downsample = None
if stride != [1, 1] or self.inplanes != planes:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
layers = []
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
self.inplanes = planes
for _ in range(1, blocks):
layers.append(AsterBlock(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x0 = self.layer0(x)
x1 = self.layer1(x0)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
x5 = self.layer5(x4)
cnn_feat = x5.squeeze(2) # [N, c, w]
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
if self.with_lstm:
rnn_feat, _ = self.rnn(cnn_feat)
return rnn_feat
else:
return cnn_feat
if __name__ == "__main__":
x = paddle.randn([3, 3, 32, 100])
net = ResNet_ASTER()
encoder_feat = net(x)
print(encoder_feat.shape)
......@@ -26,12 +26,15 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_aster_head import AttentionRecognitionHead, AsterHead
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TableAttentionHead']
'SRNHead', 'PGHead', 'TableAttentionHead', 'AttentionRecognitionHead',
'AsterHead'
]
#table head
from .table_att_head import TableAttentionHead
......@@ -39,5 +42,6 @@ def build_head(config):
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
support_dict))
print(config)
module_class = eval(module_name)(**config)
return module_class
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import paddle
from paddle import nn
from paddle.nn import functional as F
class AsterHead(nn.Layer):
def __init__(self,
in_channels,
out_channels,
sDim,
attDim,
max_len_labels,
time_step=25,
beam_width=5,
**kwargs):
super(AsterHead, self).__init__()
self.num_classes = out_channels
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
attDim, max_len_labels)
self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width
def forward(self, x, targets=None, embed=None):
return_dict = {}
embedding_vectors = self.embeder(x)
rec_targets, rec_lengths = targets
if self.training:
rec_pred = self.decoder([x, rec_targets, rec_lengths],
embedding_vectors)
return_dict['rec_pred'] = rec_pred
return_dict['embedding_vectors'] = embedding_vectors
else:
rec_pred, rec_pred_scores = self.decoder.beam_search(
x, self.beam_width, self.eos, embedding_vectors)
return_dict['rec_pred'] = rec_pred
return_dict['rec_pred_scores'] = rec_pred_scores
return_dict['embedding_vectors'] = embedding_vectors
return return_dict
class Embedding(nn.Layer):
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
super(Embedding, self).__init__()
self.in_timestep = in_timestep
self.in_planes = in_planes
self.embed_dim = embed_dim
self.mid_dim = mid_dim
self.eEmbed = nn.Linear(
in_timestep * in_planes,
self.embed_dim) # Embed encoder output to a word-embedding like
def forward(self, x):
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
x = self.eEmbed(x)
return x
class AttentionRecognitionHead(nn.Layer):
"""
input: [b x 16 x 64 x in_planes]
output: probability sequence: [b x T x num_classes]
"""
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
super(AttentionRecognitionHead, self).__init__()
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
self.in_planes = in_channels
self.sDim = sDim
self.attDim = attDim
self.max_len_labels = max_len_labels
self.decoder = DecoderUnit(
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
def forward(self, x, embed):
x, targets, lengths = x
batch_size = paddle.shape(x)[0]
# Decoder
state = self.decoder.get_initial_state(embed)
outputs = []
for i in range(max(lengths)):
if i == 0:
y_prev = paddle.full(
shape=[batch_size], fill_value=self.num_classes)
else:
y_prev = targets[:, i - 1]
output, state = self.decoder(x, state, y_prev)
outputs.append(output)
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
return outputs
# inference stage.
def sample(self, x):
x, _, _ = x
batch_size = x.size(0)
# Decoder
state = paddle.zeros([1, batch_size, self.sDim])
predicted_ids, predicted_scores = [], []
for i in range(self.max_len_labels):
if i == 0:
y_prev = paddle.full(
shape=[batch_size], fill_value=self.num_classes)
else:
y_prev = predicted
output, state = self.decoder(x, state, y_prev)
output = F.softmax(output, axis=1)
score, predicted = output.max(1)
predicted_ids.append(predicted.unsqueeze(1))
predicted_scores.append(score.unsqueeze(1))
predicted_ids = paddle.concat([predicted_ids, 1])
predicted_scores = paddle.concat([predicted_scores, 1])
# return predicted_ids.squeeze(), predicted_scores.squeeze()
return predicted_ids, predicted_scores
class AttentionUnit(nn.Layer):
def __init__(self, sDim, xDim, attDim):
super(AttentionUnit, self).__init__()
self.sDim = sDim
self.xDim = xDim
self.attDim = attDim
self.sEmbed = nn.Linear(
sDim,
attDim,
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
self.xEmbed = nn.Linear(
xDim,
attDim,
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
self.wEmbed = nn.Linear(
attDim,
1,
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
def forward(self, x, sPrev):
batch_size, T, _ = x.shape # [b x T x xDim]
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
xProj = self.xEmbed(x) # [(b x T) x attDim]
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
sPrev = sPrev.squeeze(0)
sProj = self.sEmbed(sPrev) # [b x attDim]
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
sProj = paddle.expand(sProj,
[batch_size, T, self.attDim]) # [b x T x attDim]
sumTanh = paddle.tanh(sProj + xProj)
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
vProj = paddle.reshape(vProj, [batch_size, T])
alpha = F.softmax(
vProj, axis=1) # attention weights for each sample in the minibatch
return alpha
class DecoderUnit(nn.Layer):
def __init__(self, sDim, xDim, yDim, attDim):
super(DecoderUnit, self).__init__()
self.sDim = sDim
self.xDim = xDim
self.yDim = yDim
self.attDim = attDim
self.emdDim = attDim
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
self.tgt_embedding = nn.Embedding(
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
std=0.01)) # the last is used for <BOS>
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
self.fc = nn.Linear(
sDim,
yDim,
weight_attr=nn.initializer.Normal(std=0.01),
bias_attr=nn.initializer.Constant(value=0))
self.embed_fc = nn.Linear(300, self.sDim)
def get_initial_state(self, embed, tile_times=1):
assert embed.shape[1] == 300
state = self.embed_fc(embed) # N * sDim
if tile_times != 1:
state = state.unsqueeze(1)
trans_state = paddle.transpose(state, perm=[1, 0, 2])
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
trans_state = paddle.transpose(state, perm=[1, 0, 2])
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
state = state.unsqueeze(0) # 1 * N * sDim
return state
def forward(self, x, sPrev, yPrev):
# x: feature sequence from the image decoder.
batch_size, T, _ = x.shape
alpha = self.attention_unit(x, sPrev)
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
yPrev = paddle.cast(yPrev, dtype="int64")
yProj = self.tgt_embedding(yPrev)
concat_context = paddle.concat([yProj, context], 1)
concat_context = paddle.squeeze(concat_context, 1)
sPrev = paddle.squeeze(sPrev, 0)
output, state = self.gru(concat_context, sPrev)
output = paddle.squeeze(output, axis=1)
output = self.fc(output)
return output, state
if __name__ == "__main__":
model = AttentionRecognitionHead(
num_classes=20,
in_channels=30,
sDim=512,
attDim=512,
max_len_labels=25,
out_channels=38)
data = paddle.ones([16, 64, 3])
targets = paddle.ones([16, 25])
length = paddle.to_tensor(20)
x = [data, targets, length]
output = model(x)
print(output.shape)
......@@ -44,10 +44,13 @@ class AttentionHead(nn.Layer):
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
targets = targets[0]
print(targets)
if targets is not None:
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
# print("char_onehots:", char_onehots)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......@@ -104,6 +107,8 @@ class AttentionGRUCell(nn.Layer):
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
# print("concat_context:", concat_context.shape)
# print("prev_hidden:", prev_hidden.shape)
cur_hidden = self.rnn(concat_context, prev_hidden)
......
......@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
from .tps import STN_ON
support_dict = ['TPS']
support_dict = ['TPS', 'STN_ON']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
def conv3x3_block(in_channels, out_channels, stride=1):
n = 3 * 3 * out_channels
w = math.sqrt(2. / n)
conv_layer = nn.Conv2D(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
weight_attr=nn.initializer.Normal(
mean=0.0, std=w),
bias_attr=nn.initializer.Constant(0))
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
return block
class STN(nn.Layer):
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
super(STN, self).__init__()
self.in_channels = in_channels
self.num_ctrlpoints = num_ctrlpoints
self.activation = activation
self.stn_convnet = nn.Sequential(
conv3x3_block(in_channels, 32), #32x64
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(32, 64), #16x32
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(64, 128), # 8*16
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(128, 256), # 4*8
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(256, 256), # 2*4,
nn.MaxPool2D(
kernel_size=2, stride=2),
conv3x3_block(256, 256)) # 1*2
self.stn_fc1 = nn.Sequential(
nn.Linear(
2 * 256,
512,
weight_attr=nn.initializer.Normal(0, 0.001),
bias_attr=nn.initializer.Constant(0)),
nn.BatchNorm1D(512),
nn.ReLU())
fc2_bias = self.init_stn()
self.stn_fc2 = nn.Linear(
512,
num_ctrlpoints * 2,
weight_attr=nn.initializer.Constant(0.0),
bias_attr=nn.initializer.Assign(fc2_bias))
def init_stn(self):
margin = 0.01
sampling_num_per_side = int(self.num_ctrlpoints / 2)
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
ctrl_points = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
if self.activation == 'none':
pass
elif self.activation == 'sigmoid':
ctrl_points = -np.log(1. / ctrl_points - 1.)
ctrl_points = paddle.to_tensor(ctrl_points)
fc2_bias = paddle.reshape(
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
return fc2_bias
def forward(self, x):
x = self.stn_convnet(x)
batch_size, _, h, w = x.shape
x = paddle.reshape(x, shape=(batch_size, -1))
img_feat = self.stn_fc1(x)
x = self.stn_fc2(0.1 * img_feat)
if self.activation == 'sigmoid':
x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x
if __name__ == "__main__":
in_planes = 3
num_ctrlpoints = 20
np.random.seed(100)
activation = 'none' # 'sigmoid'
stn_head = STN(in_planes, num_ctrlpoints, activation)
data = np.random.randn(10, 3, 32, 64).astype("float32")
print("data:", np.sum(data))
input = paddle.to_tensor(data)
#input = paddle.randn([10, 3, 32, 64])
control_points = stn_head(input)
......@@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .tps_spatial_transformer import TPSSpatialTransformer
from .stn import STN
class ConvBNLayer(nn.Layer):
def __init__(self,
......@@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F = self.F
hat_eye = paddle.eye(F, dtype='float64') # F x F
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = paddle.norm(
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
......@@ -301,3 +305,26 @@ class TPS(nn.Layer):
[-1, image.shape[2], image.shape[3], 2])
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
return batch_I_r
class STN_ON(nn.Layer):
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
num_control_points, tps_margins, stn_activation):
super(STN_ON, self).__init__()
self.tps = TPSSpatialTransformer(
output_image_size=tuple(tps_outputsize),
num_control_points=num_control_points,
margins=tuple(tps_margins))
self.stn_head = STN(in_channels=in_channels,
num_ctrlpoints=num_control_points,
activation=stn_activation)
self.tps_inputsize = tps_inputsize
self.out_channels = in_channels
def forward(self, image):
stn_input = paddle.nn.functional.interpolate(
image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points)
# print(x.shape)
return x
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
import itertools
def grid_sample(input, grid, canvas=None):
input.stop_gradient = False
output = F.grid_sample(input, grid)
if canvas is None:
return output
else:
input_mask = paddle.ones(shape=input.shape)
output_mask = F.grid_sample(input_mask, grid)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.shape[0]
M = control_points.shape[0]
pairwise_diff = paddle.reshape(
input_points, shape=[N, 1, 2]) - paddle.reshape(
control_points, shape=[1, M, 2])
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
1]
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix[mask] = 0
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def build_output_control_points(num_control_points, margins):
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
return output_ctrl_pts
class TPSSpatialTransformer(nn.Layer):
def __init__(self,
output_image_size=None,
num_control_points=None,
margins=None):
super(TPSSpatialTransformer, self).__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
target_control_points = build_output_control_points(num_control_points,
margins)
N = num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
target_control_partial_repr = compute_partial_repr(
target_control_points, target_control_points)
target_control_partial_repr = paddle.cast(target_control_partial_repr,
forward_kernel.dtype)
forward_kernel[:N, :N] = target_control_partial_repr
forward_kernel[:N, -3] = 1
forward_kernel[-3, :N] = 1
target_control_points = paddle.cast(target_control_points,
forward_kernel.dtype)
forward_kernel[:N, -2:] = target_control_points
forward_kernel[-2:, :N] = paddle.transpose(
target_control_points, perm=[1, 0])
# compute inverse matrix
inverse_kernel = paddle.inverse(forward_kernel)
# create target cordinate matrix
HW = self.target_height * self.target_width
target_coordinate = list(
itertools.product(
range(self.target_height), range(self.target_width)))
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
Y, X = paddle.split(
target_coordinate, target_coordinate.shape[1], axis=1)
#Y, X = target_coordinate.split(1, dim = 1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
target_coordinate = paddle.concat(
[X, Y], axis=1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(
target_coordinate, target_control_points)
target_coordinate_repr = paddle.concat(
[
target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
target_coordinate
],
axis=1)
# register precomputed matrices
self.inverse_kernel = inverse_kernel
self.padding_matrix = paddle.zeros(shape=[3, 2])
self.target_coordinate_repr = target_coordinate_repr
self.target_control_points = target_control_points
def forward(self, input, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.shape[1] == self.num_control_points
assert source_control_points.shape[2] == 2
batch_size = source_control_points.shape[0]
self.padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2])
Y = paddle.concat([source_control_points, self.padding_matrix], 1)
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
source_coordinate = paddle.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = paddle.reshape(
source_coordinate,
shape=[-1, self.target_height, self.target_width, 2])
grid = paddle.clip(grid, 0,
1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
# grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate
if __name__ == "__main__":
from stn import STN
in_planes = 3
num_ctrlpoints = 20
np.random.seed(100)
activation = 'none' # 'sigmoid'
stn_head = STN(in_planes, num_ctrlpoints, activation)
data = np.random.randn(10, 3, 32, 64).astype("float32")
input = paddle.to_tensor(data)
#input = paddle.randn([10, 3, 32, 64])
control_points = stn_head(input)
#print("control points:", control_points)
#input = paddle.randn(shape=[10,3,32,100])
tps = TPSSpatialTransformer(
output_image_size=[32, 320],
num_control_points=20,
margins=[0.05, 0.05])
out = tps(input, control_points[1])
print("out 0 :", out[0].shape)
print("out 1:", out[1].shape)
from __future__ import absolute_import
import numpy as np
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
def grid_sample(input, grid, canvas=None):
output = F.grid_sample(input, grid)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def build_output_control_points(num_control_points, margins):
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
return output_ctrl_pts
# demo: ~/test/models/test_tps_transformation.py
class TPSSpatialTransformer(nn.Module):
def __init__(self,
output_image_size=None,
num_control_points=None,
margins=None):
super(TPSSpatialTransformer, self).__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
target_control_points = build_output_control_points(num_control_points,
margins)
N = num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = compute_partial_repr(
target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel)
# create target cordinate matrix
HW = self.target_height * self.target_width
target_coordinate = list(
itertools.product(
range(self.target_height), range(self.target_width)))
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
Y, X = target_coordinate.split(1, dim=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
target_coordinate = torch.cat([X, Y],
dim=1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(
target_coordinate, target_control_points)
target_coordinate_repr = torch.cat([
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
],
dim=1)
# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
self.register_buffer('target_control_points', target_control_points)
def forward(self, input, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)
Y = torch.cat([
source_control_points, self.padding_matrix.expand(batch_size, 3, 2)
], 1)
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
source_coordinate = torch.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = source_coordinate.view(-1, self.target_height, self.target_width,
2)
grid = torch.clamp(grid, 0,
1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate
if __name__ == "__main__":
from stn_torch import STNHead
in_planes = 3
num_ctrlpoints = 20
torch.manual_seed(10)
activation = 'none' # 'sigmoid'
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
np.random.seed(100)
data = np.random.randn(10, 3, 32, 64).astype("float32")
input = torch.tensor(data)
control_points = stn_head(input)
tps = TPSSpatialTransformer(
output_image_size=[32, 320],
num_control_points=20,
margins=[0.05, 0.05])
out = tps(input, control_points[1])
print("out 0 :", out[0].shape)
print("out 1:", out[1].shape)
......@@ -170,8 +170,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
self.unkonwn = "UNKNOWN"
dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str]
dict_character = [self.beg_str] + dict_character + [self.end_str
] + [self.unkonwn]
return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
......@@ -212,6 +214,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds = preds["rec_pred"]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
......@@ -324,10 +327,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object):
""" """
def __init__(self,
character_dict_path,
**kwargs):
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
......@@ -366,14 +368,14 @@ class TableLabelDecode(object):
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs,paddle.Tensor):
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds,paddle.Tensor):
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
structure_probs, 'elem')
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
......@@ -388,8 +390,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
......
......@@ -105,12 +105,15 @@ def load_dygraph_params(config, model, logger, optimizer):
params = paddle.load(pm)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
# for k1, k2 in zip(state_dict.keys(), params.keys()):
for k1 in state_dict.keys():
if k1 not in params:
continue
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !"
)
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
......
......@@ -187,6 +187,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type']
algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
......@@ -210,10 +211,14 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if use_srn or model_type == 'table':
# if use_srn or model_type == 'table' or algorithm == "ASTER":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds = model(images, data=batch[1:])
else:
preds = model(images)
state_dict = model.state_dict()
# for key in state_dict:
# print(key)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
......@@ -395,7 +400,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'TableAttn'
'CLS', 'PGNet', 'Distillation', 'TableAttn', 'ASTER'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
......@@ -72,6 +72,8 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
character = getattr(post_process_class, 'character')
print("getattr character:", character)
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册