From 139ce8f02e6ff4619311c021055e1d0b094f30ed Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Mon, 17 Aug 2020 19:54:13 +0800 Subject: [PATCH] update infer.py --- deploy/python/infer.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 7664b4f3..01f044e0 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -115,9 +115,10 @@ class ImageReader: # image processing thread worker def process_worker(self, imgs, idx, use_pr=False): image_path = imgs[idx] - im = cv2.imread(image_path, -1) - if len(im.shape) == 2: - im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) + cv2_imread_flag = cv2.IMREAD_COLOR + if self.config.channels == 4: + cv2_imread_flag = cv2.IMREAD_UNCHANGED + im = cv2.imread(image_path, cv2_imread_flag) channels = im.shape[2] if channels != 3 and channels != 4: print("Only support rgb(gray) or rgba image.") @@ -133,8 +134,10 @@ class ImageReader: # if use models with no pre-processing/post-processing op optimizations if not use_pr: - im_mean = np.array(self.config.mean).reshape((3, 1, 1)) - im_std = np.array(self.config.std).reshape((3, 1, 1)) + im_mean = np.array(self.config.mean).reshape((self.config.channels, + 1, 1)) + im_std = np.array(self.config.std).reshape((self.config.channels, 1, + 1)) # HWC -> CHW, don't use transpose((2, 0, 1)) im = im.swapaxes(1, 2) im = im.swapaxes(0, 1) -- GitLab