提交 3248cc58 编写于 作者: Y Yang Nie

update `DecodeImage`

上级 4266b879
......@@ -140,13 +140,12 @@ class DecodeImage(object):
""" decode image """
def __init__(self,
to_np=True,
to_rgb=True,
to_np=False,
channel_first=False,
backend="cv2",
return_numpy=True):
self.to_rgb = to_rgb
backend="cv2"):
self.to_np = to_np # to numpy
self.to_rgb = to_rgb # only enabled when to_np is True
self.channel_first = channel_first # only enabled when to_np is True
if backend.lower() not in ["cv2", "pil"]:
......@@ -156,38 +155,33 @@ class DecodeImage(object):
backend = "cv2"
self.backend = backend.lower()
if not return_numpy:
assert to_rgb, f"\"to_rgb\" must be True while \"return_numpy\" is False."
assert not channel_first, f"\"channel_first\" must be False while \"return_numpy\" is False."
self.return_numpy = return_numpy
if not to_np:
logger.warning(
f"\"to_rgb\" and \"channel_first\" are only enabled when to_np is True. \"to_np\" is now {to_np}."
)
def __call__(self, img):
if isinstance(img, Image.Image):
if self.return_numpy:
img = np.asarray(img)[:, :, ::-1] # to bgr
assert self.backend == "pil", "invalid input 'img' in DecodeImage"
elif isinstance(img, np.ndarray):
assert self.return_numpy, "invalid input 'img' in DecodeImage"
else:
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
assert self.backend == "cv2", "invalid input 'img' in DecodeImage"
elif isinstance(img, bytes):
if self.backend == "pil":
data = io.BytesIO(img)
img = Image.open(data).convert("RGB")
if self.return_numpy:
img = np.asarray(img)[:, :, ::-1] # to bgr
img = Image.open(data)
else:
data = np.frombuffer(img, dtype='uint8')
data = np.frombuffer(img, dtype="uint8")
img = cv2.imdecode(data, 1)
else:
raise ValueError("invalid input 'img' in DecodeImage")
if self.to_np:
if self.backend == "pil":
assert img.mode == "RGB", f"invalid shape of image[{img.shape}]"
img = np.asarray(img)[:, :, ::-1] # BRG
if self.return_numpy:
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]"
img = img[:, :, ::-1]
if self.channel_first:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册