diff --git a/deploy/python/predict_cls.py b/deploy/python/predict_cls.py index a9165a92efa62a9252834a988050eebfa8d89f69..b161512f7f34437e1e5070b0251925129a17820c 100644 --- a/deploy/python/predict_cls.py +++ b/deploy/python/predict_cls.py @@ -27,7 +27,6 @@ from utils.get_image_list import get_image_list from python.preprocess import create_operators from python.postprocess import build_postprocess - class ClsPredictor(Predictor): def __init__(self, config): super().__init__(config["Global"]) @@ -59,6 +58,8 @@ class ClsPredictor(Predictor): input_tensor.copy_from_cpu(image) self.paddle_predictor.run() batch_output = output_tensor.copy_to_cpu() + if self.postprocess is not None: + batch_output = self.postprocess(batch_output) return batch_output @@ -66,14 +67,38 @@ def main(config): cls_predictor = ClsPredictor(config) image_list = get_image_list(config["Global"]["infer_imgs"]) - assert config["Global"]["batch_size"] == 1 - for idx, image_file in enumerate(image_list): - img = cv2.imread(image_file)[:, :, ::-1] - output = cls_predictor.predict(img) - output = cls_predictor.postprocess(output, [image_file]) - print(output) - return + batch_imgs = [] + batch_names = [] + cnt = 0 + for idx, img_path in enumerate(image_list): + img = cv2.imread(img_path) + if img is None: + logger.warning( + "Image file failed to read and has been skipped. The path: {}". + format(img_path)) + else: + img = img[:, :, ::-1] + batch_imgs.append(img) + img_name = os.path.basename(img_path) + batch_names.append(img_name) + cnt += 1 + if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list): + if len(batch_imgs) == 0: + continue + + batch_results = cls_predictor.predict(batch_imgs) + for number, result_dict in enumerate(batch_results): + filename = batch_names[number] + clas_ids = result_dict["class_ids"] + scores_str = "[{}]".format(", ".join("{:.2f}".format( + r) for r in result_dict["scores"])) + label_names = result_dict["label_names"] + print("{}:\tclass id(s): {}, score(s): {}, label_name(s): {}". + format(filename, clas_ids, scores_str, label_names)) + batch_imgs = [] + batch_names = [] + return if __name__ == "__main__": args = config.parse_args() diff --git a/deploy/python/predict_rec.py b/deploy/python/predict_rec.py index de293bf0097f9ea48e1ecd296ec2695e15c54ba4..d41c513f89fd83972e86bc5941a8dba1fd488856 100644 --- a/deploy/python/predict_rec.py +++ b/deploy/python/predict_rec.py @@ -54,12 +54,14 @@ class RecPredictor(Predictor): input_tensor.copy_from_cpu(image) self.paddle_predictor.run() batch_output = output_tensor.copy_to_cpu() - + if feature_normalize: feas_norm = np.sqrt( np.sum(np.square(batch_output), axis=1, keepdims=True)) batch_output = np.divide(batch_output, feas_norm) - + + if self.postprocess is not None: + batch_output = self.postprocess(batch_output) return batch_output @@ -67,14 +69,33 @@ def main(config): rec_predictor = RecPredictor(config) image_list = get_image_list(config["Global"]["infer_imgs"]) - assert config["Global"]["batch_size"] == 1 - for idx, image_file in enumerate(image_list): - batch_input = [] - img = cv2.imread(image_file)[:, :, ::-1] - output = rec_predictor.predict(img) - if rec_predictor.postprocess is not None: - output = rec_predictor.postprocess(output) - print(output) + batch_imgs = [] + batch_names = [] + cnt = 0 + for idx, img_path in enumerate(image_list): + img = cv2.imread(img_path) + if img is None: + logger.warning( + "Image file failed to read and has been skipped. The path: {}". + format(img_path)) + else: + img = img[:, :, ::-1] + batch_imgs.append(img) + img_name = os.path.basename(img_path) + batch_names.append(img_name) + cnt += 1 + + if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list): + if len(batch_imgs) == 0: + continue + + batch_results = rec_predictor.predict(batch_imgs) + for number, result_dict in enumerate(batch_results): + filename = batch_names[number] + print("{}:\t{}".format(filename, result_dict)) + batch_imgs = [] + batch_names = [] + return