未验证 提交 2b163230 编写于 作者: D dyning 提交者: GitHub

Merge pull request #848 from tink2123/add_anno

add comments for rec
...@@ -25,6 +25,12 @@ from copy import deepcopy ...@@ -25,6 +25,12 @@ from copy import deepcopy
class RecModel(object): class RecModel(object):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params): def __init__(self, params):
super(RecModel, self).__init__() super(RecModel, self).__init__()
global_params = params['Global'] global_params = params['Global']
...@@ -64,6 +70,12 @@ class RecModel(object): ...@@ -64,6 +70,12 @@ class RecModel(object):
self.num_heads = None self.num_heads = None
def create_feed(self, mode): def create_feed(self, mode):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
if mode == "train": if mode == "train":
...@@ -189,9 +201,12 @@ class RecModel(object): ...@@ -189,9 +201,12 @@ class RecModel(object):
inputs = image inputs = image
else: else:
inputs = self.tps(image) inputs = self.tps(image)
# backbone
conv_feas = self.backbone(inputs) conv_feas = self.backbone(inputs)
# predict
predicts = self.head(conv_feas, labels, mode) predicts = self.head(conv_feas, labels, mode)
decoded_out = predicts['decoded_out'] decoded_out = predicts['decoded_out']
# loss
if mode == "train": if mode == "train":
loss = self.loss(predicts, labels) loss = self.loss(predicts, labels)
if self.loss_type == "attention": if self.loss_type == "attention":
...@@ -211,7 +226,7 @@ class RecModel(object): ...@@ -211,7 +226,7 @@ class RecModel(object):
outputs = {'total_loss':loss, 'decoded_out':\ outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label} decoded_out, 'label':label}
return loader, outputs return loader, outputs
# export_model
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
if self.loss_type == "ctc": if self.loss_type == "ctc":
...@@ -225,6 +240,7 @@ class RecModel(object): ...@@ -225,6 +240,7 @@ class RecModel(object):
] ]
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
# eval or test
else: else:
predict = predicts['predict'] predict = predicts['predict']
if self.loss_type == "ctc": if self.loss_type == "ctc":
......
...@@ -27,6 +27,12 @@ import numpy as np ...@@ -27,6 +27,12 @@ import numpy as np
class CTCPredict(object): class CTCPredict(object):
"""
CTC predict
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params): def __init__(self, params):
super(CTCPredict, self).__init__() super(CTCPredict, self).__init__()
self.char_num = params['char_num'] self.char_num = params['char_num']
......
...@@ -33,6 +33,7 @@ class AttentionLoss(object): ...@@ -33,6 +33,7 @@ class AttentionLoss(object):
predict = predicts['predict'] predict = predicts['predict']
label_out = labels['label_out'] label_out = labels['label_out']
label_out = fluid.layers.cast(x=label_out, dtype='int64') label_out = fluid.layers.cast(x=label_out, dtype='int64')
# calculate attention loss
cost = fluid.layers.cross_entropy(input=predict, label=label_out) cost = fluid.layers.cross_entropy(input=predict, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost) sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost return sum_cost
...@@ -30,6 +30,7 @@ class CTCLoss(object): ...@@ -30,6 +30,7 @@ class CTCLoss(object):
def __call__(self, predicts, labels): def __call__(self, predicts, labels):
predict = predicts['predict'] predict = predicts['predict']
label = labels['label'] label = labels['label']
# calculate ctc loss
cost = fluid.layers.warpctc( cost = fluid.layers.warpctc(
input=predict, label=label, blank=self.char_num, norm_by_times=True) input=predict, label=label, blank=self.char_num, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost) sum_cost = fluid.layers.reduce_sum(cost)
......
...@@ -20,15 +20,21 @@ import sys ...@@ -20,15 +20,21 @@ import sys
class CharacterOps(object): class CharacterOps(object):
""" Convert between text-label and text-index """ """
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config): def __init__(self, config):
self.character_type = config['character_type'] self.character_type = config['character_type']
self.loss_type = config['loss_type'] self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length'] self.max_text_len = config['max_text_length']
# use the default dictionary(36 char)
if self.character_type == "en": if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type in [ elif self.character_type in [
"ch", 'japan', 'korean', 'french', 'german' "ch", 'japan', 'korean', 'french', 'german'
]: ]:
...@@ -55,25 +61,27 @@ class CharacterOps(object): ...@@ -55,25 +61,27 @@ class CharacterOps(object):
"Nonsupport type of the character: {}".format(self.character_str) "Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
# add start and end str for attention
if self.loss_type == "attention": if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn": elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str] dict_character = dict_character + [self.beg_str, self.end_str]
# create char dict
self.dict = {} self.dict = {}
for i, char in enumerate(dict_character): for i, char in enumerate(dict_character):
self.dict[char] = i self.dict[char] = i
self.character = dict_character self.character = dict_character
def encode(self, text): def encode(self, text):
"""convert text-label into text-index. """
input: convert text-label into text-index.
Args:
text: text labels of each image. [batch_size] text: text labels of each image. [batch_size]
Return:
output:
text: concatenated text index for CTCLoss. text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
""" """
# Ignore capital
if self.character_type == "en": if self.character_type == "en":
text = text.lower() text = text.lower()
...@@ -86,7 +94,15 @@ class CharacterOps(object): ...@@ -86,7 +94,15 @@ class CharacterOps(object):
return text return text
def decode(self, text_index, is_remove_duplicate=False): def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """ """
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list = [] char_list = []
char_num = self.get_char_num() char_num = self.get_char_num()
...@@ -108,6 +124,9 @@ class CharacterOps(object): ...@@ -108,6 +124,9 @@ class CharacterOps(object):
return text return text
def get_char_num(self): def get_char_num(self):
"""
Get character num
"""
return len(self.character) return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end): def get_beg_end_flag_idx(self, beg_or_end):
...@@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops, ...@@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops,
labels, labels,
labels_lod, labels_lod,
is_remove_duplicate=False): is_remove_duplicate=False):
"""
Calculate prediction accuracy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num = 0 acc_num = 0
img_num = 0 img_num = 0
for ino in range(len(labels_lod) - 1): for ino in range(len(labels_lod) - 1):
...@@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops, ...@@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops,
def convert_rec_attention_infer_res(preds): def convert_rec_attention_infer_res(preds):
"""
Convert recognition attention predict result with lod information
Args:
preds: the output of the model
Return:
convert_ids: A 1-D Tensor represents all the predicted results.
target_lod: The lod information of the predicted results
"""
img_num = preds.shape[0] img_num = preds.shape[0]
target_lod = [0] target_lod = [0]
convert_ids = [] convert_ids = []
......
...@@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict): def test_rec_benchmark(exe, config, eval_info_dict):
" Evaluate lmdb dataset " """
eval rec benchmark
"""
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \ eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir'] eval_data_dir = config['TestReader']['lmdb_sets_dir']
......
...@@ -150,19 +150,20 @@ def check_gpu(use_gpu): ...@@ -150,19 +150,20 @@ def check_gpu(use_gpu):
def build(config, main_prog, startup_prog, mode): def build(config, main_prog, startup_prog, mode):
""" """
Build a program using a model and an optimizer Build a program using a model and an optimizer
1. create feeds 1. create a dataloader
2. create a dataloader 2. create a model
3. create a model 3. create fetches
4. create fetchs 4. create an optimizer
5. create an optimizer
Args: Args:
config(dict): config config(dict): config
main_prog(): main program main_prog(): main program
startup_prog(): startup program startup_prog(): startup program
is_train(bool): train or valid mode(str): train or valid
Returns: Returns:
dataloader(): a bridge between the model and the data dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures) fetch_name_list(dict): dict of model outputs(included loss and measures)
fetch_varname_list(list): list of outputs' varname
opt_loss_name(str): name of loss
""" """
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
...@@ -207,8 +208,8 @@ def build_export(config, main_prog, startup_prog): ...@@ -207,8 +208,8 @@ def build_export(config, main_prog, startup_prog):
Build input and output for exporting a checkpoints model to an inference model Build input and output for exporting a checkpoints model to an inference model
Args: Args:
config(dict): config config(dict): config
main_prog(): main program main_prog: main program
startup_prog(): startup program startup_prog: startup program
Returns: Returns:
feeded_var_names(list[str]): var names of input for exported inference model feeded_var_names(list[str]): var names of input for exported inference model
target_vars(list[Variable]): output vars for exported inference model target_vars(list[Variable]): output vars for exported inference model
...@@ -257,9 +258,14 @@ def train_eval_det_run(config, ...@@ -257,9 +258,14 @@ def train_eval_det_run(config,
train_info_dict, train_info_dict,
eval_info_dict, eval_info_dict,
is_slim=None): is_slim=None):
''' """
main program of evaluation for detection Feed data to the model and fetch the measures and loss for detection
''' Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0 train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
...@@ -376,9 +382,14 @@ def train_eval_rec_run(config, ...@@ -376,9 +382,14 @@ def train_eval_rec_run(config,
train_info_dict, train_info_dict,
eval_info_dict, eval_info_dict,
is_slim=None): is_slim=None):
''' """
main program of evaluation for recognition Feed data to the model and fetch the measures and loss for recognition
''' Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0 train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num'] epoch_num = config['Global']['epoch_num']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册