提交 6dd494b6 编写于 作者: T tink2123

add anno

上级 fc2f9c2e
......@@ -25,6 +25,14 @@ from copy import deepcopy
class RecModel(object):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params):
super(RecModel, self).__init__()
global_params = params['Global']
......@@ -58,6 +66,13 @@ class RecModel(object):
self.max_text_length = global_params['max_text_length']
def create_feed(self, mode):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
......@@ -96,9 +111,13 @@ class RecModel(object):
inputs = image
else:
inputs = self.tps(image)
# backbone
conv_feas = self.backbone(inputs)
# predict
predicts = self.head(conv_feas, labels, mode)
decoded_out = predicts['decoded_out']
#loss
if mode == "train":
loss = self.loss(predicts, labels)
if self.loss_type == "attention":
......@@ -108,9 +127,11 @@ class RecModel(object):
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
return loader, outputs
# export_model
elif mode == "export":
predict = predicts['predict']
predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
# eval or test
else:
return loader, {'decoded_out': decoded_out}
......@@ -33,6 +33,7 @@ class AttentionLoss(object):
predict = predicts['predict']
label_out = labels['label_out']
label_out = fluid.layers.cast(x=label_out, dtype='int64')
# calculate attention loss
cost = fluid.layers.cross_entropy(input=predict, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost
......@@ -30,6 +30,7 @@ class CTCLoss(object):
def __call__(self, predicts, labels):
predict = predicts['predict']
label = labels['label']
# calculate ctc loss
cost = fluid.layers.warpctc(
input=predict, label=label, blank=self.char_num, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
......
......@@ -20,14 +20,22 @@ import sys
class CharacterOps(object):
""" Convert between text-label and text-index """
"""
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
# use the default dictionary(36 char)
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type == "ch":
character_dict_path = config['character_dict_path']
self.character_str = ""
......@@ -47,26 +55,29 @@ class CharacterOps(object):
"Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
# add start and end str for attention
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
# create char dict
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
"""
convert text-label into text-index.
Args:
text: text labels of each image. [batch_size]
output:
Reture:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
# Ignore capital
if self.character_type == "en":
text = text.lower()
text_list = []
for char in text:
if char not in self.dict:
......@@ -76,7 +87,15 @@ class CharacterOps(object):
return text
def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """
"""
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list = []
char_num = self.get_char_num()
......@@ -98,6 +117,9 @@ class CharacterOps(object):
return text
def get_char_num(self):
"""
Get character num
"""
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
......@@ -122,6 +144,19 @@ def cal_predicts_accuracy(char_ops,
labels,
labels_lod,
is_remove_duplicate=False):
"""
Calculate predicts accrarcy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod:
labels:
labels_lod:
is_remove_duplicate:
Return:
"""
acc_num = 0
img_num = 0
for ino in range(len(labels_lod) - 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册