diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index b256ed3b238175cff6dc581d2b7a420b6d7cbf70..f3c90050c92803789252304bb4e3a9a4bf04c70b 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -123,7 +123,7 @@ class BaseRecLabelEncode(object): [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] length: length of each text. [batch_size] """ - if len(text) > self.max_text_len: + if len(text) == 0 or len(text) > self.max_text_len: return None if self.character_type == "en": text = text.lower() @@ -138,9 +138,6 @@ class BaseRecLabelEncode(object): return None return text_list - def get_ignored_tokens(self): - return [0] # for ctc blank - class CTCLabelEncode(BaseRecLabelEncode): """ Convert between text-label and text-index """ @@ -160,8 +157,6 @@ class CTCLabelEncode(BaseRecLabelEncode): text = self.encode(text) if text is None: return None - if len(text) > self.max_text_len: - return None data['length'] = np.array(len(text)) text = text + [0] * (self.max_text_len - len(text)) data['label'] = np.array(text) @@ -195,11 +190,6 @@ class AttnLabelEncode(BaseRecLabelEncode): text = self.encode(text) return text - def get_ignored_tokens(self): - beg_idx = self.get_beg_end_flag_idx("beg") - end_idx = self.get_beg_end_flag_idx("end") - return [beg_idx, end_idx] - def get_beg_end_flag_idx(self, beg_or_end): if beg_or_end == "beg": idx = np.array(self.dict[self.beg_str]) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 00f0ffc1fe3387f139789490c4e6557eebb646d0..7d7e4720143c26e36343e3c8f94a0bf4b2caf892 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -82,7 +82,7 @@ class TextClassifier(object): cls_res = [['', 0.0]] * img_num batch_num = self.cls_batch_num - predict_time = 0 + elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] @@ -107,14 +107,14 @@ class TextClassifier(object): self.predictor.run([norm_img_batch]) prob_out = self.output_tensors[0].copy_to_cpu() cls_res = self.postprocess_op(prob_out) - elapse = time.time() - starttime + elapse += time.time() - starttime for rno in range(len(cls_res)): label, score = cls_res[rno] cls_res[indices[beg_img_no + rno]] = [label, score] if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]], 1) - return img_list, cls_res, predict_time + return img_list, cls_res, elapse def main(args): @@ -143,10 +143,10 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino])) - print("Total predict time for %d images, cost: %.3f" % - (len(img_list), predict_time)) + print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ + ino])) + print("Total predict time for {} images, cost: {:.3f}".format( + len(img_list), predict_time)) - -if __name__ == "__main__": - main(utility.parse_args()) + if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 69db67db3588ec551b9e586ce439ddc96154b899..4b4825a66a145faf78d96446604329730a453381 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -174,15 +174,15 @@ if __name__ == "__main__": if img is None: logger.info("error in loading image:{}".format(image_file)) continue - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) dt_boxes, elapse = text_detector(img) if count > 0: total_time += elapse count += 1 - print("Predict time of %s:" % image_file, elapse) + print("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_text_det_res(dt_boxes, image_file) - img_name_pure = image_file.split("/")[-1] - cv2.imwrite( - os.path.join(draw_img_save, "det_res_%s" % img_name_pure), src_im) + img_name_pure = os.path.split(image_file)[-1] + img_path = os.path.join(draw_img_save, + "det_res_{}".format(img_name_pure)) + cv2.imwrite(img_path, src_im) if count > 1: print("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 54dbb03b570b605e5c05150b5fabef4d45d4337a..c1f20ef3c6b42772d47665032504b1fae039cbcd 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -115,7 +115,7 @@ class TextRecognizer(object): rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] - elapse = time.time() - starttime + elapse += time.time() - starttime return rec_res, elapse @@ -145,9 +145,10 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) - print("Total predict time for %d images, cost: %.3f" % - (len(img_list), predict_time)) + print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ + ino])) + print("Total predict time for {} images, cost: {:.3f}".format( + len(img_list), predict_time)) if __name__ == "__main__": diff --git a/tools/program.py b/tools/program.py index a5c3f794b1467f782fa3a023d3a89d8825c00e4f..8e84d30e64fa19a99fea205bca2d08c490b6fd7e 100755 --- a/tools/program.py +++ b/tools/program.py @@ -236,7 +236,6 @@ def train(config, train_batch_cost = 0.0 train_reader_cost = 0.0 batch_sum = 0 - batch_start = time.time() # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: @@ -275,6 +274,7 @@ def train(config, best_model_dict[main_indicator], global_step) global_step += 1 + batch_start = time.time() if dist.get_rank() == 0: save_model( model, @@ -333,20 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc -def save_inference_mode(model, config, logger): - model.eval() - save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], - config['Architecture']['model_type']) - if config['Architecture']['model_type'] == 'rec': - input_shape = [None, 3, 32, None] - jit_model = paddle.jit.to_static( - model, input_spec=[paddle.static.InputSpec(input_shape)]) - paddle.jit.save(jit_model, save_path) - logger.info('inference model save to {}'.format(save_path)) - - model.train() - - def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) diff --git a/tools/train.py b/tools/train.py index 1cf644e6fd4b61d7925c6d9dda79855c7a72e886..6e44c5982ec5595c9202d83b14c058a7579c6a27 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) - program.save_inference_mode(model, config, logger) def test_reader(config, device, logger):