提交 3248cc58 编写于 作者: Y Yang Nie

update `DecodeImage`

上级 4266b879
...@@ -140,13 +140,12 @@ class DecodeImage(object): ...@@ -140,13 +140,12 @@ class DecodeImage(object):
""" decode image """ """ decode image """
def __init__(self, def __init__(self,
to_np=True,
to_rgb=True, to_rgb=True,
to_np=False,
channel_first=False, channel_first=False,
backend="cv2", backend="cv2"):
return_numpy=True):
self.to_rgb = to_rgb
self.to_np = to_np # to numpy self.to_np = to_np # to numpy
self.to_rgb = to_rgb # only enabled when to_np is True
self.channel_first = channel_first # only enabled when to_np is True self.channel_first = channel_first # only enabled when to_np is True
if backend.lower() not in ["cv2", "pil"]: if backend.lower() not in ["cv2", "pil"]:
...@@ -156,38 +155,33 @@ class DecodeImage(object): ...@@ -156,38 +155,33 @@ class DecodeImage(object):
backend = "cv2" backend = "cv2"
self.backend = backend.lower() self.backend = backend.lower()
if not return_numpy: if not to_np:
assert to_rgb, f"\"to_rgb\" must be True while \"return_numpy\" is False." logger.warning(
assert not channel_first, f"\"channel_first\" must be False while \"return_numpy\" is False." f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
self.return_numpy = return_numpy )
def __call__(self, img): def __call__(self, img):
if isinstance(img, Image.Image): if isinstance(img, Image.Image):
if self.return_numpy: assert self.backend == "pil", "invalid input 'img' in DecodeImage"
img = np.asarray(img)[:, :, ::-1] # to bgr
elif isinstance(img, np.ndarray): elif isinstance(img, np.ndarray):
assert self.return_numpy, "invalid input 'img' in DecodeImage" assert self.backend == "cv2", "invalid input 'img' in DecodeImage"
else: elif isinstance(img, bytes):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
if self.backend == "pil": if self.backend == "pil":
data = io.BytesIO(img) data = io.BytesIO(img)
img = Image.open(data).convert("RGB") img = Image.open(data)
if self.return_numpy:
img = np.asarray(img)[:, :, ::-1] # to bgr
else: else:
data = np.frombuffer(img, dtype='uint8') data = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(data, 1) img = cv2.imdecode(data, 1)
else:
raise ValueError("invalid input 'img' in DecodeImage")
if self.to_np:
if self.backend == "pil":
assert img.mode == "RGB", f"invalid shape of image[{img.shape}]"
img = np.asarray(img)[:, :, ::-1] # BRG
if self.return_numpy:
if self.to_rgb: if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img.shape)
img = img[:, :, ::-1] img = img[:, :, ::-1]
if self.channel_first: if self.channel_first:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册