diff --git a/tools/infer_kie.py b/tools/infer_kie.py index 62ef697240ffe89fcb858c5308bd010105dde2ab..41e0856d467582b3560ecc8ae9f20d36aeda891b 100755 --- a/tools/infer_kie.py +++ b/tools/infer_kie.py @@ -33,8 +33,9 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import load_model import tools.program as program +import time def read_class_list(filepath): @@ -80,7 +81,8 @@ def draw_kie_result(batch, node, idx_to_cls, count): vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 vis_img[:, :w] = img vis_img[:, w:] = pred_img - save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" + save_kie_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/kie_results/" if not os.path.exists(save_kie_path): os.makedirs(save_kie_path) save_path = os.path.join(save_kie_path, str(count) + ".png") @@ -93,7 +95,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + load_model(config, model) # create data ops transforms = [] @@ -111,10 +113,15 @@ def main(): os.makedirs(os.path.dirname(save_res_path)) model.eval() + + warmup_times = 0 + count_t = [] with open(save_res_path, "wb") as fout: with open(config['Global']['infer_img'], "rb") as f: lines = f.readlines() for index, data_line in enumerate(lines): + if index == 10: + warmup_t = time.time() data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split("\t") img_path, label = data_dir + "/" + substr[0], substr[1] @@ -122,16 +129,23 @@ def main(): with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img + st = time.time() batch = transform(data, ops) batch_pred = [0] * len(batch) for i in range(len(batch)): batch_pred[i] = paddle.to_tensor( np.expand_dims( batch[i], axis=0)) + st = time.time() node, edge = model(batch_pred) node = F.softmax(node, -1) + count_t.append(time.time() - st) draw_kie_result(batch, node, idx_to_cls, index) logger.info("success!") + logger.info("It took {} s for predict {} images.".format( + np.sum(count_t), len(count_t))) + ips = np.sum(count_t[warmup_times:]) / len(count_t[warmup_times:]) + logger.info("The ips is {} images/s".format(ips)) if __name__ == '__main__':