From 7521c0592ec14976ea44c0a432399bc2aca89b5f Mon Sep 17 00:00:00 2001 From: LutaoChu <30695251+LutaoChu@users.noreply.github.com> Date: Tue, 2 Jun 2020 00:33:14 +0800 Subject: [PATCH] fix infer.py bug (#280) --- deploy/python/infer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 05e84eb1..7664b4f3 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -107,7 +107,6 @@ class DeployConfig: self.use_pr = deploy_conf["USE_PR"] - class ImageReader: def __init__(self, configs): self.config = configs @@ -117,19 +116,18 @@ class ImageReader: def process_worker(self, imgs, idx, use_pr=False): image_path = imgs[idx] im = cv2.imread(image_path, -1) - channels = im.shape[2] - ori_h = im.shape[0] - ori_w = im.shape[1] - if channels == 1: + if len(im.shape) == 2: im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) - channels = im.shape[2] + channels = im.shape[2] if channels != 3 and channels != 4: print("Only support rgb(gray) or rgba image.") return -1 + ori_h = im.shape[0] + ori_w = im.shape[1] # resize to eval_crop_size eval_crop_size = self.config.eval_crop_size - if (ori_h != eval_crop_size[0] or ori_w != eval_crop_size[1]): + if (ori_h != eval_crop_size[1] or ori_w != eval_crop_size[0]): im = cv2.resize( im, eval_crop_size, fx=0, fy=0, interpolation=cv2.INTER_LINEAR) -- GitLab