提交 c9e1077d 编写于 作者: T tink2123

polish code

上级 59cc4efd
Global: Global:
use_gpu: False use_gpu: True
epoch_num: 400 epoch_num: 400
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec/b3_rare_r34_none_gru/ save_model_dir: ./output/rec/seed
save_epoch_step: 3 save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration # evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
...@@ -12,28 +12,32 @@ Global: ...@@ -12,28 +12,32 @@ Global:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg infer_img: doc/imgs_words_en/word_10.png
# for data or label process # for data or label process
character_dict_path: character_dict_path:
character_type: EN_symbol character_type: EN_symbol
max_text_length: 25 max_text_length: 100
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt eval_filter: True
save_res_path: ./output/rec/predicts_seed.txt
Optimizer: Optimizer:
name: Adam name: Adadelta
beta1: 0.9 weight_deacy: 0.0
beta2: 0.999 momentum: 0.9
lr: lr:
learning_rate: 0.0005 name: Piecewise
decay_epochs: [4,5,8]
values: [1.0, 0.1, 0.01]
regularizer: regularizer:
name: 'L2' name: 'L2'
factor: 0.00000 factor: 2.0e-05
Architecture: Architecture:
model_type: rec model_type: seed
algorithm: ASTER algorithm: ASTER
Transform: Transform:
name: STN_ON name: STN_ON
...@@ -54,48 +58,49 @@ Loss: ...@@ -54,48 +58,49 @@ Loss:
name: AsterLoss name: AsterLoss
PostProcess: PostProcess:
name: AttnLabelDecode name: SEEDLabelDecode
Metric: Metric:
name: RecMetric name: RecMetric
main_indicator: acc main_indicator: acc
is_filter: True
Train: Train:
dataset: dataset:
name: SimpleDataSet name: LMDBDataSet
data_dir: ./train_data/ic15_data/ data_dir: ./train_data/data_lmdb_release/training/
label_file_list: ["./train_data/ic15_data/1.txt"]
transforms: transforms:
- Fasttext:
path: "./cc.en.300.bin"
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- AttnLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- RecResizeImg: - SEEDResize:
image_shape: [3, 32, 100] image_shape: [3, 64, 256]
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
loader: loader:
shuffle: True shuffle: True
batch_size_per_card: 2 batch_size_per_card: 256
drop_last: True drop_last: True
num_workers: 8 num_workers: 6
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: LMDBDataSet
data_dir: ./train_data/ic15_data/ data_dir: ./train_data/data_lmdb_release/evaluation/
label_file_list: ["./train_data/ic15_data/1.txt"]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- AttnLabelEncode: # Class handling label - SEEDLabelEncode: # Class handling label
- RecResizeImg: - SEEDResize:
image_shape: [3, 32, 100] image_shape: [3, 64, 256]
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: True
batch_size_per_card: 2 batch_size_per_card: 256
num_workers: 8 num_workers: 4
...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap ...@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SEEDResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .operators import * from .operators import *
......
...@@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.unknown = "UNKNOWN" dict_character = [self.beg_str] + dict_character + [self.end_str]
dict_character = [self.beg_str] + dict_character + [self.end_str
] + [self.unknown]
return dict_character return dict_character
def __call__(self, data): def __call__(self, data):
...@@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text)) data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
- len(text) - 2) - len(text) - 2)
data['label'] = np.array(text) data['label'] = np.array(text)
return data return data
...@@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode): ...@@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx return idx
class SEEDLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character + [self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
if text is None:
return None
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text)) + 1 # conclue eos
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
)
data['label'] = np.array(text)
return data
class SRNLabelEncode(BaseRecLabelEncode): class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -23,6 +23,7 @@ import sys ...@@ -23,6 +23,7 @@ import sys
import six import six
import cv2 import cv2
import numpy as np import numpy as np
import fasttext
class DecodeImage(object): class DecodeImage(object):
...@@ -101,6 +102,17 @@ class ToCHWImage(object): ...@@ -101,6 +102,17 @@ class ToCHWImage(object):
return data return data
class Fasttext(object):
def __init__(self, path="None", **kwargs):
self.fast_model = fasttext.load_model(path)
def __call__(self, data):
label = data['label']
fast_label = self.fast_model[label]
data['fast_label'] = fast_label
return data
class KeepKeys(object): class KeepKeys(object):
def __init__(self, keep_keys, **kwargs): def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys self.keep_keys = keep_keys
...@@ -183,7 +195,7 @@ class DetResizeForTest(object): ...@@ -183,7 +195,7 @@ class DetResizeForTest(object):
else: else:
ratio = 1. ratio = 1.
elif self.limit_type == 'resize_long': elif self.limit_type == 'resize_long':
ratio = float(limit_side_len) / max(h,w) ratio = float(limit_side_len) / max(h, w)
else: else:
raise Exception('not support limit type, image ') raise Exception('not support limit type, image ')
resize_h = int(h * ratio) resize_h = int(h * ratio)
......
...@@ -63,6 +63,18 @@ class RecResizeImg(object): ...@@ -63,6 +63,18 @@ class RecResizeImg(object):
return data return data
class SEEDResize(object):
def __init__(self, image_shape, infer_mode=False, **kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
def __call__(self, data):
img = data['image']
norm_img = resize_no_padding_img(img, self.image_shape)
data['image'] = norm_img
return data
class SRNRecResizeImg(object): class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs): def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape self.image_shape = image_shape
...@@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape): ...@@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_no_padding_img(img, image_shape):
imgC, imgH, imgW = image_shape
resized_image = cv2.resize(
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_image = resized_image.astype('float32')
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
return resized_image
def resize_norm_img_chinese(img, image_shape): def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape # todo: change to 0 and modified image shape
......
...@@ -22,7 +22,6 @@ from .imaug import transform, create_operators ...@@ -22,7 +22,6 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset): class SimpleDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None): def __init__(self, config, mode, logger, seed=None):
print("===== simpledataset ========")
super(SimpleDataSet, self).__init__() super(SimpleDataSet, self).__init__()
self.logger = logger self.logger = logger
self.mode = mode.lower() self.mode = mode.lower()
......
...@@ -18,7 +18,26 @@ from __future__ import print_function ...@@ -18,7 +18,26 @@ from __future__ import print_function
import paddle import paddle
from paddle import nn from paddle import nn
import fasttext
class CosineEmbeddingLoss(nn.Layer):
def __init__(self, margin=0.):
super(CosineEmbeddingLoss, self).__init__()
self.margin = margin
def forward(self, x1, x2, target):
similarity = paddle.fluid.layers.reduce_sum(
x1 * x2, dim=-1) / (paddle.norm(
x1, axis=-1) * paddle.norm(
x2, axis=-1))
one_list = paddle.full_like(target, fill_value=1)
out = paddle.fluid.layers.reduce_mean(
paddle.where(
paddle.equal(target, one_list), 1. - similarity,
paddle.maximum(
paddle.zeros_like(similarity), similarity - self.margin)))
return out
class AsterLoss(nn.Layer): class AsterLoss(nn.Layer):
...@@ -35,28 +54,28 @@ class AsterLoss(nn.Layer): ...@@ -35,28 +54,28 @@ class AsterLoss(nn.Layer):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.sequence_normalize = sequence_normalize self.sequence_normalize = sequence_normalize
self.sample_normalize = sample_normalize self.sample_normalize = sample_normalize
self.loss_func = paddle.nn.CosineSimilarity() self.loss_sem = CosineEmbeddingLoss()
self.is_cosin_loss = True
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
def forward(self, predicts, batch): def forward(self, predicts, batch):
targets = batch[1].astype("int64") targets = batch[1].astype("int64")
label_lengths = batch[2].astype('int64') label_lengths = batch[2].astype('int64')
# sem_target = batch[3].astype('float32') sem_target = batch[3].astype('float32')
embedding_vectors = predicts['embedding_vectors'] embedding_vectors = predicts['embedding_vectors']
rec_pred = predicts['rec_pred'] rec_pred = predicts['rec_pred']
# semantic loss if not self.is_cosin_loss:
# print(embedding_vectors) sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
# print(embedding_vectors.shape) else:
# targets = fasttext[targets] label_target = paddle.ones([embedding_vectors.shape[0]])
# sem_loss = 1 - self.loss_func(embedding_vectors, targets) sem_loss = paddle.sum(
self.loss_sem(embedding_vectors, sem_target, label_target))
# rec loss # rec loss
batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[ batch_size, def_max_length = targets.shape[0], targets.shape[1]
1], rec_pred.shape[2]
assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
mask = paddle.zeros([batch_size, num_steps]) mask = paddle.zeros([batch_size, def_max_length])
for i in range(batch_size): for i in range(batch_size):
mask[i, :label_lengths[i]] = 1 mask[i, :label_lengths[i]] = 1
mask = paddle.cast(mask, "float32") mask = paddle.cast(mask, "float32")
...@@ -64,16 +83,16 @@ class AsterLoss(nn.Layer): ...@@ -64,16 +83,16 @@ class AsterLoss(nn.Layer):
assert max_length == rec_pred.shape[1] assert max_length == rec_pred.shape[1]
targets = targets[:, :max_length] targets = targets[:, :max_length]
mask = mask[:, :max_length] mask = mask[:, :max_length]
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]]) rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
input = nn.functional.log_softmax(rec_pred, axis=1) input = nn.functional.log_softmax(rec_pred, axis=1)
targets = paddle.reshape(targets, [-1, 1]) targets = paddle.reshape(targets, [-1, 1])
mask = paddle.reshape(mask, [-1, 1]) mask = paddle.reshape(mask, [-1, 1])
# print("input:", input) output = -paddle.index_sample(input, index=targets) * mask
output = -paddle.gather(input, index=targets, axis=1) * mask
output = paddle.sum(output) output = paddle.sum(output)
if self.sequence_normalize: if self.sequence_normalize:
output = output / paddle.sum(mask) output = output / paddle.sum(mask)
if self.sample_normalize: if self.sample_normalize:
output = output / batch_size output = output / batch_size
loss = output
return {'loss': loss} # , 'sem_loss':sem_loss} loss = output + sem_loss * 0.1
return {'loss': loss}
...@@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer): ...@@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer):
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]]) inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
targets = paddle.reshape(targets, [-1]) targets = paddle.reshape(targets, [-1])
print("input:", paddle.argmax(inputs, axis=1))
print("targets:", targets)
return {'loss': paddle.sum(self.loss_func(inputs, targets))} return {'loss': paddle.sum(self.loss_func(inputs, targets))}
...@@ -13,13 +13,20 @@ ...@@ -13,13 +13,20 @@
# limitations under the License. # limitations under the License.
import Levenshtein import Levenshtein
import string
class RecMetric(object): class RecMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.is_filter = is_filter
self.reset() self.reset()
def _normalize_text(self, text):
text = ''.join(
filter(lambda x: x in (string.digits + string.ascii_letters), text))
return text.lower()
def __call__(self, pred_label, *args, **kwargs): def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label preds, labels = pred_label
correct_num = 0 correct_num = 0
...@@ -28,6 +35,9 @@ class RecMetric(object): ...@@ -28,6 +35,9 @@ class RecMetric(object):
for (pred, pred_conf), (target, _) in zip(preds, labels): for (pred, pred_conf), (target, _) in zip(preds, labels):
pred = pred.replace(" ", "") pred = pred.replace(" ", "")
target = target.replace(" ", "") target = target.replace(" ", "")
if self.is_filter:
pred = self._normalize_text(pred)
target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max( norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1) len(pred), len(target), 1)
if pred == target: if pred == target:
......
...@@ -26,10 +26,8 @@ def build_backbone(config, model_type): ...@@ -26,10 +26,8 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance from .rec_mv1_enhance import MobileNetV1Enhance
from .rec_resnet_aster import ResNet_ASTER
support_dict = [ support_dict = [
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
"ResNet_ASTER"
] ]
elif model_type == "e2e": elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet from .e2e_resnet_vd_pg import ResNet
...@@ -38,6 +36,9 @@ def build_backbone(config, model_type): ...@@ -38,6 +36,9 @@ def build_backbone(config, model_type):
from .table_resnet_vd import ResNet from .table_resnet_vd import ResNet
from .table_mobilenet_v3 import MobileNetV3 from .table_mobilenet_v3 import MobileNetV3
support_dict = ["ResNet", "MobileNetV3"] support_dict = ["ResNet", "MobileNetV3"]
elif model_type == "seed":
from .rec_resnet_aster import ResNet_ASTER
support_dict = ["ResNet_ASTER"]
else: else:
raise NotImplementedError raise NotImplementedError
......
此差异已折叠。
...@@ -42,6 +42,5 @@ def build_head(config): ...@@ -42,6 +42,5 @@ def build_head(config):
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception('head only support {}'.format(
support_dict)) support_dict))
print(config)
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
...@@ -43,13 +43,14 @@ class AsterHead(nn.Layer): ...@@ -43,13 +43,14 @@ class AsterHead(nn.Layer):
self.time_step = time_step self.time_step = time_step
self.embeder = Embedding(self.time_step, in_channels) self.embeder = Embedding(self.time_step, in_channels)
self.beam_width = beam_width self.beam_width = beam_width
self.eos = self.num_classes - 1
def forward(self, x, targets=None, embed=None): def forward(self, x, targets=None, embed=None):
return_dict = {} return_dict = {}
embedding_vectors = self.embeder(x) embedding_vectors = self.embeder(x)
rec_targets, rec_lengths = targets
if self.training: if self.training:
rec_targets, rec_lengths, _ = targets
rec_pred = self.decoder([x, rec_targets, rec_lengths], rec_pred = self.decoder([x, rec_targets, rec_lengths],
embedding_vectors) embedding_vectors)
return_dict['rec_pred'] = rec_pred return_dict['rec_pred'] = rec_pred
...@@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer): ...@@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer):
# Decoder # Decoder
state = self.decoder.get_initial_state(embed) state = self.decoder.get_initial_state(embed)
outputs = [] outputs = []
for i in range(max(lengths)): for i in range(max(lengths)):
if i == 0: if i == 0:
y_prev = paddle.full( y_prev = paddle.full(
shape=[batch_size], fill_value=self.num_classes) shape=[batch_size], fill_value=self.num_classes)
else: else:
y_prev = targets[:, i - 1] y_prev = targets[:, i - 1]
output, state = self.decoder(x, state, y_prev) output, state = self.decoder(x, state, y_prev)
outputs.append(output) outputs.append(output)
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1) outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
...@@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer): ...@@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer):
# return predicted_ids.squeeze(), predicted_scores.squeeze() # return predicted_ids.squeeze(), predicted_scores.squeeze()
return predicted_ids, predicted_scores return predicted_ids, predicted_scores
def beam_search(self, x, beam_width, eos, embed):
def _inflate(tensor, times, dim):
repeat_dims = [1] * tensor.dim()
repeat_dims[dim] = times
output = paddle.tile(tensor, repeat_dims)
return output
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
batch_size, l, d = x.shape
# inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
x = paddle.tile(
paddle.transpose(
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
inflated_encoder_feats = paddle.reshape(
paddle.transpose(
x, perm=[1, 0, 2, 3]), [-1, l, d])
# Initialize the decoder
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
pos_index = paddle.reshape(
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
# Initialize the scores
sequence_scores = paddle.full(
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
index = [i * beam_width for i in range(0, batch_size)]
sequence_scores[index] = 0.0
# Initialize the input vector
y_prev = paddle.full(
shape=[batch_size * beam_width], fill_value=self.num_classes)
# Store decisions for backtracking
stored_scores = list()
stored_predecessors = list()
stored_emitted_symbols = list()
for i in range(self.max_len_labels):
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
state = paddle.unsqueeze(state, axis=0)
log_softmax_output = paddle.nn.functional.log_softmax(
output, axis=1)
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
sequence_scores += log_softmax_output
scores, candidates = paddle.topk(
paddle.reshape(sequence_scores, [batch_size, -1]),
beam_width,
axis=1)
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
y_prev = paddle.reshape(
candidates % self.num_classes, shape=[batch_size * beam_width])
sequence_scores = paddle.reshape(
scores, shape=[batch_size * beam_width, 1])
# Update fields for next timestep
pos_index = paddle.expand_as(pos_index, candidates)
predecessors = paddle.cast(
candidates / self.num_classes + pos_index, dtype='int64')
predecessors = paddle.reshape(
predecessors, shape=[batch_size * beam_width, 1])
state = paddle.index_select(
state, index=predecessors.squeeze(), axis=1)
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
stored_scores.append(sequence_scores.clone())
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
eos_prev = paddle.full_like(y_prev, fill_value=eos)
mask = eos_prev == y_prev
mask = paddle.nonzero(mask)
if mask.dim() > 0:
sequence_scores = sequence_scores.numpy()
mask = mask.numpy()
sequence_scores[mask] = -float('inf')
sequence_scores = paddle.to_tensor(sequence_scores)
# Cache results for backtracking
stored_predecessors.append(predecessors)
y_prev = paddle.squeeze(y_prev)
stored_emitted_symbols.append(y_prev)
# Do backtracking to return the optimal values
#====== backtrak ======#
# Initialize return variables given different types
p = list()
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
] # Placeholder for lengths of top-k sequences
# the last step output of the beams are not sorted
# thus they are sorted here
sorted_score, sorted_idx = paddle.topk(
paddle.reshape(
stored_scores[-1], shape=[batch_size, beam_width]),
beam_width)
# initialize the sequence scores with the sorted last step beam scores
s = sorted_score.clone()
batch_eos_found = [0] * batch_size # the number of EOS found
# in the backward loop below for each batch
t = self.max_len_labels - 1
# initialize the back pointer with the sorted order of the last step beams.
# add pos_index for indexing variable with b*k as the first dimension.
t_predecessors = paddle.reshape(
sorted_idx + pos_index.expand_as(sorted_idx),
shape=[batch_size * beam_width])
while t >= 0:
# Re-order the variables with the back pointer
current_symbol = paddle.index_select(
stored_emitted_symbols[t], index=t_predecessors, axis=0)
t_predecessors = paddle.index_select(
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
eos_indices = stored_emitted_symbols[t] == eos
eos_indices = paddle.nonzero(eos_indices)
if eos_indices.dim() > 0:
for i in range(eos_indices.shape[0] - 1, -1, -1):
# Indices of the EOS symbol for both variables
# with b*k as the first dimension, and b, k for
# the first two dimensions
idx = eos_indices[i]
b_idx = int(idx[0] / beam_width)
# The indices of the replacing position
# according to the replacement strategy noted above
res_k_idx = beam_width - (batch_eos_found[b_idx] %
beam_width) - 1
batch_eos_found[b_idx] += 1
res_idx = b_idx * beam_width + res_k_idx
# Replace the old information in return variables
# with the new ended sequence information
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
l[b_idx][res_k_idx] = t + 1
# record the back tracked results
p.append(current_symbol)
t -= 1
# Sort and re-order again as the added ended sequences may change
# the order (very unlikely)
s, re_sorted_idx = s.topk(beam_width)
for b_idx in range(batch_size):
l[b_idx] = [
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
]
re_sorted_idx = paddle.reshape(
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
[batch_size * beam_width])
# Reverse the sequences and re-order at the same time
# It is reversed because the backtracking happens in reverse time order
p = [
paddle.reshape(
paddle.index_select(step, re_sorted_idx, 0),
shape=[batch_size, beam_width, -1]) for step in reversed(p)
]
p = paddle.concat(p, -1)[:, 0, :]
return p, paddle.ones_like(p)
class AttentionUnit(nn.Layer): class AttentionUnit(nn.Layer):
def __init__(self, sDim, xDim, attDim): def __init__(self, sDim, xDim, attDim):
...@@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer): ...@@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer):
self.xDim = xDim self.xDim = xDim
self.attDim = attDim self.attDim = attDim
self.sEmbed = nn.Linear( self.sEmbed = nn.Linear(sDim, attDim)
sDim, self.xEmbed = nn.Linear(xDim, attDim)
attDim, self.wEmbed = nn.Linear(attDim, 1)
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
self.xEmbed = nn.Linear(
xDim,
attDim,
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
self.wEmbed = nn.Linear(
attDim,
1,
weight_attr=paddle.nn.initializer.Normal(std=0.01),
bias_attr=paddle.nn.initializer.Constant(0.0))
def forward(self, x, sPrev): def forward(self, x, sPrev):
batch_size, T, _ = x.shape # [b x T x xDim] batch_size, T, _ = x.shape # [b x T x xDim]
...@@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer): ...@@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer):
vProj = self.wEmbed(sumTanh) # [(b x T) x 1] vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
vProj = paddle.reshape(vProj, [batch_size, T]) vProj = paddle.reshape(vProj, [batch_size, T])
alpha = F.softmax( alpha = F.softmax(
vProj, axis=1) # attention weights for each sample in the minibatch vProj, axis=1) # attention weights for each sample in the minibatch
return alpha return alpha
...@@ -239,20 +388,3 @@ class DecoderUnit(nn.Layer): ...@@ -239,20 +388,3 @@ class DecoderUnit(nn.Layer):
output = paddle.squeeze(output, axis=1) output = paddle.squeeze(output, axis=1)
output = self.fc(output) output = self.fc(output)
return output, state return output, state
\ No newline at end of file
if __name__ == "__main__":
model = AttentionRecognitionHead(
num_classes=20,
in_channels=30,
sDim=512,
attDim=512,
max_len_labels=25,
out_channels=38)
data = paddle.ones([16, 64, 3])
targets = paddle.ones([16, 25])
length = paddle.to_tensor(20)
x = [data, targets, length]
output = model(x)
print(output.shape)
...@@ -44,13 +44,10 @@ class AttentionHead(nn.Layer): ...@@ -44,13 +44,10 @@ class AttentionHead(nn.Layer):
hidden = paddle.zeros((batch_size, self.hidden_size)) hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = [] output_hiddens = []
targets = targets[0]
print(targets)
if targets is not None: if targets is not None:
for i in range(num_steps): for i in range(num_steps):
char_onehots = self._char_to_onehot( char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes) targets[:, i], onehot_dim=self.num_classes)
# print("char_onehots:", char_onehots)
(outputs, hidden), alpha = self.attention_cell(hidden, inputs, (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
char_onehots) char_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer): ...@@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer):
alpha = paddle.transpose(alpha, [0, 2, 1]) alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1) context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1) concat_context = paddle.concat([context, char_onehots], 1)
# print("concat_context:", concat_context.shape)
# print("prev_hidden:", prev_hidden.shape)
cur_hidden = self.rnn(concat_context, prev_hidden) cur_hidden = self.rnn(concat_context, prev_hidden)
......
...@@ -106,16 +106,3 @@ class STN(nn.Layer): ...@@ -106,16 +106,3 @@ class STN(nn.Layer):
x = F.sigmoid(x) x = F.sigmoid(x)
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2]) x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
return img_feat, x return img_feat, x
if __name__ == "__main__":
in_planes = 3
num_ctrlpoints = 20
np.random.seed(100)
activation = 'none' # 'sigmoid'
stn_head = STN(in_planes, num_ctrlpoints, activation)
data = np.random.randn(10, 3, 32, 64).astype("float32")
print("data:", np.sum(data))
input = paddle.to_tensor(data)
#input = paddle.randn([10, 3, 32, 64])
control_points = stn_head(input)
...@@ -326,5 +326,6 @@ class STN_ON(nn.Layer): ...@@ -326,5 +326,6 @@ class STN_ON(nn.Layer):
image, self.tps_inputsize, mode="bilinear", align_corners=True) image, self.tps_inputsize, mode="bilinear", align_corners=True)
stn_img_feat, ctrl_points = self.stn_head(stn_input) stn_img_feat, ctrl_points = self.stn_head(stn_input)
x, _ = self.tps(image, ctrl_points) x, _ = self.tps(image, ctrl_points)
#print("x:", np.sum(x.numpy()))
# print(x.shape) # print(x.shape)
return x return x
...@@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer):
assert source_control_points.ndimension() == 3 assert source_control_points.ndimension() == 3
assert source_control_points.shape[1] == self.num_control_points assert source_control_points.shape[1] == self.num_control_points
assert source_control_points.shape[2] == 2 assert source_control_points.shape[2] == 2
batch_size = source_control_points.shape[0] #batch_size = source_control_points.shape[0]
batch_size = paddle.shape(source_control_points)[0]
self.padding_matrix = paddle.expand( self.padding_matrix = paddle.expand(
self.padding_matrix, shape=[batch_size, 3, 2]) self.padding_matrix, shape=[batch_size, 3, 2])
...@@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer): ...@@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer):
grid = paddle.clip(grid, 0, grid = paddle.clip(grid, 0,
1) # the source_control_points may be out of [0, 1]. 1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
# grid = 2.0 * grid - 1.0 grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None) output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate return output_maps, source_coordinate
if __name__ == "__main__":
from stn import STN
in_planes = 3
num_ctrlpoints = 20
np.random.seed(100)
activation = 'none' # 'sigmoid'
stn_head = STN(in_planes, num_ctrlpoints, activation)
data = np.random.randn(10, 3, 32, 64).astype("float32")
input = paddle.to_tensor(data)
#input = paddle.randn([10, 3, 32, 64])
control_points = stn_head(input)
#print("control points:", control_points)
#input = paddle.randn(shape=[10,3,32,100])
tps = TPSSpatialTransformer(
output_image_size=[32, 320],
num_control_points=20,
margins=[0.05, 0.05])
out = tps(input, control_points[1])
print("out 0 :", out[0].shape)
print("out 1:", out[1].shape)
from __future__ import absolute_import
import numpy as np
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
def grid_sample(input, grid, canvas=None):
output = F.grid_sample(input, grid)
if canvas is None:
return output
else:
input_mask = input.data.new(input.size()).fill_(1)
output_mask = F.grid_sample(input_mask, grid)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
def compute_partial_repr(input_points, control_points):
N = input_points.size(0)
M = control_points.size(0)
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
# original implementation, very slow
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
pairwise_diff_square = pairwise_diff * pairwise_diff
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
1]
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
# fix numerical error for 0 * log(0), substitute all nan with 0
mask = repr_matrix != repr_matrix
repr_matrix.masked_fill_(mask, 0)
return repr_matrix
# output_ctrl_pts are specified, according to our task.
def build_output_control_points(num_control_points, margins):
margin_x, margin_y = margins
num_ctrl_pts_per_side = num_control_points // 2
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
output_ctrl_pts_arr = np.concatenate(
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
return output_ctrl_pts
# demo: ~/test/models/test_tps_transformation.py
class TPSSpatialTransformer(nn.Module):
def __init__(self,
output_image_size=None,
num_control_points=None,
margins=None):
super(TPSSpatialTransformer, self).__init__()
self.output_image_size = output_image_size
self.num_control_points = num_control_points
self.margins = margins
self.target_height, self.target_width = output_image_size
target_control_points = build_output_control_points(num_control_points,
margins)
N = num_control_points
# N = N - 4
# create padded kernel matrix
forward_kernel = torch.zeros(N + 3, N + 3)
target_control_partial_repr = compute_partial_repr(
target_control_points, target_control_points)
forward_kernel[:N, :N].copy_(target_control_partial_repr)
forward_kernel[:N, -3].fill_(1)
forward_kernel[-3, :N].fill_(1)
forward_kernel[:N, -2:].copy_(target_control_points)
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
# compute inverse matrix
inverse_kernel = torch.inverse(forward_kernel)
# create target cordinate matrix
HW = self.target_height * self.target_width
target_coordinate = list(
itertools.product(
range(self.target_height), range(self.target_width)))
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
Y, X = target_coordinate.split(1, dim=1)
Y = Y / (self.target_height - 1)
X = X / (self.target_width - 1)
target_coordinate = torch.cat([X, Y],
dim=1) # convert from (y, x) to (x, y)
target_coordinate_partial_repr = compute_partial_repr(
target_coordinate, target_control_points)
target_coordinate_repr = torch.cat([
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
],
dim=1)
# register precomputed matrices
self.register_buffer('inverse_kernel', inverse_kernel)
self.register_buffer('padding_matrix', torch.zeros(3, 2))
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
self.register_buffer('target_control_points', target_control_points)
def forward(self, input, source_control_points):
assert source_control_points.ndimension() == 3
assert source_control_points.size(1) == self.num_control_points
assert source_control_points.size(2) == 2
batch_size = source_control_points.size(0)
Y = torch.cat([
source_control_points, self.padding_matrix.expand(batch_size, 3, 2)
], 1)
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
source_coordinate = torch.matmul(self.target_coordinate_repr,
mapping_matrix)
grid = source_coordinate.view(-1, self.target_height, self.target_width,
2)
grid = torch.clamp(grid, 0,
1) # the source_control_points may be out of [0, 1].
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
grid = 2.0 * grid - 1.0
output_maps = grid_sample(input, grid, canvas=None)
return output_maps, source_coordinate
if __name__ == "__main__":
from stn_torch import STNHead
in_planes = 3
num_ctrlpoints = 20
torch.manual_seed(10)
activation = 'none' # 'sigmoid'
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
np.random.seed(100)
data = np.random.randn(10, 3, 32, 64).astype("float32")
input = torch.tensor(data)
control_points = stn_head(input)
tps = TPSSpatialTransformer(
output_image_size=[32, 320],
num_control_points=20,
margins=[0.05, 0.05])
out = tps(input, control_points[1])
print("out 0 :", out[0].shape)
print("out 1:", out[1].shape)
...@@ -127,3 +127,34 @@ class RMSProp(object): ...@@ -127,3 +127,34 @@ class RMSProp(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=parameters)
return opt return opt
class Adadelta(object):
def __init__(self,
learning_rate=0.001,
epsilon=1e-08,
rho=0.95,
parameter_list=None,
weight_decay=None,
grad_clip=None,
name=None,
**kwargs):
self.learning_rate = learning_rate
self.epsilon = epsilon
self.rho = rho
self.parameter_list = parameter_list
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.name = name
def __call__(self, parameters):
opt = optim.Adadelta(
learning_rate=self.learning_rate,
epsilon=self.epsilon,
rho=self.rho,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
parameters=parameters)
return opt
...@@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess ...@@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
TableLabelDecode TableLabelDecode, SEEDLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
...@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None): ...@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode' 'DistillationCTCLabelDecode', 'TableLabelDecode', 'SEEDLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.unkonwn = "UNKNOWN"
dict_character = dict_character dict_character = dict_character
dict_character = [self.beg_str] + dict_character + [self.end_str dict_character = [self.beg_str] + dict_character + [self.end_str]
] + [self.unkonwn]
return dict_character return dict_character
def decode(self, text_index, text_prob=None, is_remove_duplicate=False): def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
...@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
label = self.decode(label, is_remove_duplicate=False) label = self.decode(label, is_remove_duplicate=False)
return text, label return text, label
""" """
preds = preds["rec_pred"]
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
...@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx return idx
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
self.end_str = "eos"
dict_character = dict_character
dict_character = dict_character + [self.end_str]
return dict_character
def get_ignored_tokens(self):
end_idx = self.get_beg_end_flag_idx("eos")
return [end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "sos":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "eos":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
return idx
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
[end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
preds_idx = preds["rec_pred"]
if isinstance(preds_idx, paddle.Tensor):
preds_idx = preds_idx.numpy()
if "rec_pred_scores" in preds:
preds_idx = preds["rec_pred"]
preds_prob = preds["rec_pred_scores"]
else:
preds_idx = preds["rec_pred"].argmax(axis=2)
preds_prob = preds["rec_pred"].max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label
class SRNLabelDecode(BaseRecLabelDecode): class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -105,15 +105,12 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -105,15 +105,12 @@ def load_dygraph_params(config, model, logger, optimizer):
params = paddle.load(pm) params = paddle.load(pm)
state_dict = model.state_dict() state_dict = model.state_dict()
new_state_dict = {} new_state_dict = {}
# for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
for k1 in state_dict.keys(): if list(state_dict[k1].shape) == list(params[k2].shape):
if k1 not in params: new_state_dict[k1] = params[k2]
continue
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else: else:
logger.info( logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
......
...@@ -211,11 +211,10 @@ def train(config, ...@@ -211,11 +211,10 @@ def train(config,
images = batch[0] images = batch[0]
if use_srn: if use_srn:
model_average = True model_average = True
# if use_srn or model_type == 'table' or algorithm == "ASTER": if use_srn or model_type == 'table' or model_type == "seed":
# preds = model(images, data=batch[1:])
# else:
# preds = model(images)
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
else:
preds = model(images)
state_dict = model.state_dict() state_dict = model.state_dict()
# for key in state_dict: # for key in state_dict:
# print(key) # print(key)
...@@ -415,6 +414,7 @@ def preprocess(is_train=False): ...@@ -415,6 +414,7 @@ def preprocess(is_train=False):
yaml.dump( yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False) dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir) log_file = '{}/train.log'.format(save_model_dir)
print("log has save in {}/train.log".format(save_model_dir))
else: else:
log_file = None log_file = None
logger = get_logger(name='root', log_file=log_file) logger = get_logger(name='root', log_file=log_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册