diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 7664b4f36279569eed79a91378f557f96816b859..01f044e0124da9e4dce24f0df320551cc2c11e52 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -115,9 +115,10 @@ class ImageReader: # image processing thread worker def process_worker(self, imgs, idx, use_pr=False): image_path = imgs[idx] - im = cv2.imread(image_path, -1) - if len(im.shape) == 2: - im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) + cv2_imread_flag = cv2.IMREAD_COLOR + if self.config.channels == 4: + cv2_imread_flag = cv2.IMREAD_UNCHANGED + im = cv2.imread(image_path, cv2_imread_flag) channels = im.shape[2] if channels != 3 and channels != 4: print("Only support rgb(gray) or rgba image.") @@ -133,8 +134,10 @@ class ImageReader: # if use models with no pre-processing/post-processing op optimizations if not use_pr: - im_mean = np.array(self.config.mean).reshape((3, 1, 1)) - im_std = np.array(self.config.std).reshape((3, 1, 1)) + im_mean = np.array(self.config.mean).reshape((self.config.channels, + 1, 1)) + im_std = np.array(self.config.std).reshape((self.config.channels, 1, + 1)) # HWC -> CHW, don't use transpose((2, 0, 1)) im = im.swapaxes(1, 2) im = im.swapaxes(0, 1)