提交 b38240b8 编写于 作者: F flytocc

add `pil` backend for DecodeImage

上级 a618534e
......@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
from functools import partial
import io
import six
import math
import random
......@@ -93,11 +94,22 @@ class OperatorParamError(ValueError):
class DecodeImage(object):
""" decode image """
def __init__(self, to_rgb=True, to_np=False, channel_first=False):
def __init__(self,
to_rgb=True,
to_np=False,
channel_first=False,
backend="cv2"):
self.to_rgb = to_rgb
self.to_np = to_np # to numpy
self.channel_first = channel_first # only enabled when to_np is True
if backend.lower() not in ["cv2", "pil"]:
logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
)
backend = "cv2"
self.backend = backend.lower()
def __call__(self, img):
if six.PY2:
assert type(img) is str and len(
......@@ -105,8 +117,15 @@ class DecodeImage(object):
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
if self.backend == "pil":
data = io.BytesIO(img)
img = Image.open(data).convert("RGB")
img = np.asarray(img)[:, :, ::-1] # to bgr
else:
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册