提交 3892a8ca 编写于 作者: L LDOUBLEV

fix Nan results and add test_reader func

上级 f1f9206b
...@@ -122,6 +122,8 @@ class TextRecognizer(object): ...@@ -122,6 +122,8 @@ class TextRecognizer(object):
blank = probs.shape[1] blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0] valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
if not valid_ind:
continue
# rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score] rec_res[indices[beg_img_no + rno]] = [preds_text, score]
else: else:
......
...@@ -99,6 +99,8 @@ def main(): ...@@ -99,6 +99,8 @@ def main():
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
blank = probs.shape[1] blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0] valid_ind = np.where(ind != (blank - 1))[0]
if not valid_ind:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention": elif loss_type == "attention":
preds = np.array(predict[0]) preds = np.array(predict[0])
......
...@@ -36,7 +36,7 @@ set_paddle_flags( ...@@ -36,7 +36,7 @@ set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
) )
import program import tools.program as program
from paddle import fluid from paddle import fluid
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
...@@ -106,6 +106,27 @@ def main(): ...@@ -106,6 +106,27 @@ def main():
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
def test_reader():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
train_reader = reader_main(config=config, mode="train")
import time
starttime = time.time()
count = 0
try:
for data in train_reader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
print("reader:", count, len(data), batch_time)
except Exception as e:
print(e)
print("finish reader:", count)
print("success")
if __name__ == '__main__': if __name__ == '__main__':
parser = program.ArgsParser() parser = program.ArgsParser()
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册