提交 6dd494b6 编写于 作者: T tink2123

add anno

上级 fc2f9c2e
...@@ -25,6 +25,14 @@ from copy import deepcopy ...@@ -25,6 +25,14 @@ from copy import deepcopy
class RecModel(object): class RecModel(object):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params): def __init__(self, params):
super(RecModel, self).__init__() super(RecModel, self).__init__()
global_params = params['Global'] global_params = params['Global']
...@@ -58,6 +66,13 @@ class RecModel(object): ...@@ -58,6 +66,13 @@ class RecModel(object):
self.max_text_length = global_params['max_text_length'] self.max_text_length = global_params['max_text_length']
def create_feed(self, mode): def create_feed(self, mode):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
...@@ -96,9 +111,13 @@ class RecModel(object): ...@@ -96,9 +111,13 @@ class RecModel(object):
inputs = image inputs = image
else: else:
inputs = self.tps(image) inputs = self.tps(image)
# backbone
conv_feas = self.backbone(inputs) conv_feas = self.backbone(inputs)
# predict
predicts = self.head(conv_feas, labels, mode) predicts = self.head(conv_feas, labels, mode)
decoded_out = predicts['decoded_out'] decoded_out = predicts['decoded_out']
#loss
if mode == "train": if mode == "train":
loss = self.loss(predicts, labels) loss = self.loss(predicts, labels)
if self.loss_type == "attention": if self.loss_type == "attention":
...@@ -108,9 +127,11 @@ class RecModel(object): ...@@ -108,9 +127,11 @@ class RecModel(object):
outputs = {'total_loss':loss, 'decoded_out':\ outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label} decoded_out, 'label':label}
return loader, outputs return loader, outputs
# export_model
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
predict = fluid.layers.softmax(predict) predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
# eval or test
else: else:
return loader, {'decoded_out': decoded_out} return loader, {'decoded_out': decoded_out}
...@@ -33,6 +33,7 @@ class AttentionLoss(object): ...@@ -33,6 +33,7 @@ class AttentionLoss(object):
predict = predicts['predict'] predict = predicts['predict']
label_out = labels['label_out'] label_out = labels['label_out']
label_out = fluid.layers.cast(x=label_out, dtype='int64') label_out = fluid.layers.cast(x=label_out, dtype='int64')
# calculate attention loss
cost = fluid.layers.cross_entropy(input=predict, label=label_out) cost = fluid.layers.cross_entropy(input=predict, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost) sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost return sum_cost
...@@ -30,6 +30,7 @@ class CTCLoss(object): ...@@ -30,6 +30,7 @@ class CTCLoss(object):
def __call__(self, predicts, labels): def __call__(self, predicts, labels):
predict = predicts['predict'] predict = predicts['predict']
label = labels['label'] label = labels['label']
# calculate ctc loss
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
input=predict, label=label, blank=self.char_num, norm_by_times=True) input=predict, label=label, blank=self.char_num, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost) sum_cost = fluid.layers.reduce_sum(cost)
......
...@@ -20,14 +20,22 @@ import sys ...@@ -20,14 +20,22 @@ import sys
class CharacterOps(object): class CharacterOps(object):
""" Convert between text-label and text-index """ """
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config): def __init__(self, config):
self.character_type = config['character_type'] self.character_type = config['character_type']
self.loss_type = config['loss_type'] self.loss_type = config['loss_type']
# use the default dictionary(36 char)
if self.character_type == "en": if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type == "ch": elif self.character_type == "ch":
character_dict_path = config['character_dict_path'] character_dict_path = config['character_dict_path']
self.character_str = "" self.character_str = ""
...@@ -47,26 +55,29 @@ class CharacterOps(object): ...@@ -47,26 +55,29 @@ class CharacterOps(object):
"Nonsupport type of the character: {}".format(self.character_str) "Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
# add start and end str for attention
if self.loss_type == "attention": if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character dict_character = [self.beg_str, self.end_str] + dict_character
# create char dict
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
self.dict[char] = i self.dict[char] = i
self.character = dict_character self.character = dict_character
def encode(self, text): def encode(self, text):
"""convert text-label into text-index. """
input: convert text-label into text-index.
Args:
text: text labels of each image. [batch_size] text: text labels of each image. [batch_size]
output: Reture:
text: concatenated text index for CTCLoss. text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
""" """
# Ignore capital
if self.character_type == "en": if self.character_type == "en":
text = text.lower() text = text.lower()
text_list = [] text_list = []
for char in text: for char in text:
if char not in self.dict: if char not in self.dict:
...@@ -76,7 +87,15 @@ class CharacterOps(object): ...@@ -76,7 +87,15 @@ class CharacterOps(object):
return text return text
def decode(self, text_index, is_remove_duplicate=False): def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """ """
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list = [] char_list = []
char_num = self.get_char_num() char_num = self.get_char_num()
...@@ -98,6 +117,9 @@ class CharacterOps(object): ...@@ -98,6 +117,9 @@ class CharacterOps(object):
return text return text
def get_char_num(self): def get_char_num(self):
"""
Get character num
"""
return len(self.character) return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end): def get_beg_end_flag_idx(self, beg_or_end):
...@@ -122,6 +144,19 @@ def cal_predicts_accuracy(char_ops, ...@@ -122,6 +144,19 @@ def cal_predicts_accuracy(char_ops,
labels, labels,
labels_lod, labels_lod,
is_remove_duplicate=False): is_remove_duplicate=False):
"""
Calculate predicts accrarcy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod:
labels:
labels_lod:
is_remove_duplicate:
Return:
"""
acc_num = 0 acc_num = 0
img_num = 0 img_num = 0
for ino in range(len(labels_lod) - 1): for ino in range(len(labels_lod) - 1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册