提交 9dd50e6a 编写于 作者: L LDOUBLEV

add ips

上级 3312d624
......@@ -33,8 +33,9 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.utils.save_load import init_model
from ppocr.utils.save_load import load_model
import tools.program as program
import time
def read_class_list(filepath):
......@@ -80,7 +81,8 @@ def draw_kie_result(batch, node, idx_to_cls, count):
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
vis_img[:, :w] = img
vis_img[:, w:] = pred_img
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
save_kie_path = os.path.dirname(config['Global'][
'save_res_path']) + "/kie_results/"
if not os.path.exists(save_kie_path):
os.makedirs(save_kie_path)
save_path = os.path.join(save_kie_path, str(count) + ".png")
......@@ -93,7 +95,7 @@ def main():
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
load_model(config, model)
# create data ops
transforms = []
......@@ -111,10 +113,15 @@ def main():
os.makedirs(os.path.dirname(save_res_path))
model.eval()
warmup_times = 0
count_t = []
with open(save_res_path, "wb") as fout:
with open(config['Global']['infer_img'], "rb") as f:
lines = f.readlines()
for index, data_line in enumerate(lines):
if index == 10:
warmup_t = time.time()
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path, label = data_dir + "/" + substr[0], substr[1]
......@@ -122,16 +129,23 @@ def main():
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
st = time.time()
batch = transform(data, ops)
batch_pred = [0] * len(batch)
for i in range(len(batch)):
batch_pred[i] = paddle.to_tensor(
np.expand_dims(
batch[i], axis=0))
st = time.time()
node, edge = model(batch_pred)
node = F.softmax(node, -1)
count_t.append(time.time() - st)
draw_kie_result(batch, node, idx_to_cls, index)
logger.info("success!")
logger.info("It took {} s for predict {} images.".format(
np.sum(count_t), len(count_t)))
ips = np.sum(count_t[warmup_times:]) / len(count_t[warmup_times:])
logger.info("The ips is {} images/s".format(ips))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册