diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index cc580a69f741ae436fcc65f546083d311947b12c..9761ddbad9123372d706db158aba8008956f30e9 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -122,6 +122,8 @@ class TextRecognizer(object): blank = probs.shape[1] valid_ind = np.where(ind != (blank - 1))[0] score = np.mean(probs[valid_ind, ind[valid_ind]]) + if len(valid_ind) == 0: + continue # rec_res.append([preds_text, score]) rec_res[indices[beg_img_no + rno]] = [preds_text, score] else: diff --git a/tools/infer_rec.py b/tools/infer_rec.py index b1ddc8dda198b92414bc7d7a853de78527456e5a..9abbf076e9a4902cdd0876c320021fc45e227a2c 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -99,6 +99,8 @@ def main(): ind = np.argmax(probs, axis=1) blank = probs.shape[1] valid_ind = np.where(ind != (blank - 1))[0] + if len(valid_ind) == 0: + continue score = np.mean(probs[valid_ind, ind[valid_ind]]) elif loss_type == "attention": preds = np.array(predict[0]) diff --git a/tools/train.py b/tools/train.py index 287ed2059e9393a3fc1758e76b008891939a8424..15d6ebb2138ce19a2f65c7d1fabd56d86b7645be 100755 --- a/tools/train.py +++ b/tools/train.py @@ -36,7 +36,7 @@ set_paddle_flags( FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory ) -import program +import tools.program as program from paddle import fluid from ppocr.utils.utility import initial_logger logger = initial_logger() @@ -106,6 +106,26 @@ def main(): program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) +def test_reader(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + print(config) + train_reader = reader_main(config=config, mode="train") + import time + starttime = time.time() + count = 0 + try: + for data in train_reader(): + count += 1 + if count % 1 == 0: + batch_time = time.time() - starttime + starttime = time.time() + print("reader:", count, len(data), batch_time) + except Exception as e: + logger.info(e) + logger.info("finish reader: {}, Success!".format(count)) + + if __name__ == '__main__': parser = program.ArgsParser() FLAGS = parser.parse_args()