From 3892a8ca02d99d869cca8cc23408c7cf75ab4193 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 1 Jul 2020 12:45:59 +0800 Subject: [PATCH] fix Nan results and add test_reader func --- tools/infer/predict_rec.py | 2 ++ tools/infer_rec.py | 2 ++ tools/train.py | 23 ++++++++++++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index cc580a69..b1538938 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -122,6 +122,8 @@ class TextRecognizer(object): blank = probs.shape[1] valid_ind = np.where(ind != (blank - 1))[0] score = np.mean(probs[valid_ind, ind[valid_ind]]) + if not valid_ind: + continue # rec_res.append([preds_text, score]) rec_res[indices[beg_img_no + rno]] = [preds_text, score] else: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index b1ddc8dd..6086e499 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -99,6 +99,8 @@ def main(): ind = np.argmax(probs, axis=1) blank = probs.shape[1] valid_ind = np.where(ind != (blank - 1))[0] + if not valid_ind: + continue score = np.mean(probs[valid_ind, ind[valid_ind]]) elif loss_type == "attention": preds = np.array(predict[0]) diff --git a/tools/train.py b/tools/train.py index 287ed205..20f63c25 100755 --- a/tools/train.py +++ b/tools/train.py @@ -36,7 +36,7 @@ set_paddle_flags( FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory ) -import program +import tools.program as program from paddle import fluid from ppocr.utils.utility import initial_logger logger = initial_logger() @@ -106,6 +106,27 @@ def main(): program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) +def test_reader(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + print(config) + train_reader = reader_main(config=config, mode="train") + import time + starttime = time.time() + count = 0 + try: + for data in train_reader(): + count += 1 + if count % 1 == 0: + batch_time = time.time() - starttime + starttime = time.time() + print("reader:", count, len(data), batch_time) + except Exception as e: + print(e) + print("finish reader:", count) + print("success") + + if __name__ == '__main__': parser = program.ArgsParser() FLAGS = parser.parse_args() -- GitLab