提交 d740acce 编写于 作者: W wjj19950828

fixed ToPILImage

上级 270c5dc0
...@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F ...@@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F
class ToPILImage(BaseTransform): class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None): def __init__(self, mode=None, keys=None):
super(ToTensor, self).__init__(keys) super(ToPILImage, self).__init__(keys)
self.mode = mode
def _apply_image(self, pic): def _apply_image(self, pic):
""" """
...@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform): ...@@ -53,7 +54,7 @@ class ToPILImage(BaseTransform):
npimg = pic npimg = pic
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy( if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and mode != 'F': ).dtype) and self.mode != 'F':
pic = pic.mul(255).byte() pic = pic.mul(255).byte()
if isinstance(pic, paddle.Tensor): if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0)) npimg = np.transpose(pic.numpy(), (1, 2, 0))
...@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform): ...@@ -74,40 +75,40 @@ class ToPILImage(BaseTransform):
expected_mode = 'I' expected_mode = 'I'
elif npimg.dtype == np.float32: elif npimg.dtype == np.float32:
expected_mode = 'F' expected_mode = 'F'
if mode is not None and mode != expected_mode: if self.mode is not None and self.mode != expected_mode:
raise ValueError( raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}" "Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode)) .format(self.mode, np.dtype, expected_mode))
mode = expected_mode self.mode = expected_mode
elif npimg.shape[2] == 2: elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA'] permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes: if self.mode is not None and self.mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs". raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes)) format(permitted_2_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'LA' self.mode = 'LA'
elif npimg.shape[2] == 4: elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes: if self.mode is not None and self.mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs". raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes)) format(permitted_4_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'RGBA' self.mode = 'RGBA'
else: else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes: if self.mode is not None and self.mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs". raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes)) format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8: if self.mode is None and npimg.dtype == np.uint8:
mode = 'RGB' self.mode = 'RGB'
if mode is None: if self.mode is None:
raise TypeError('Input type {} is not supported'.format( raise TypeError('Input type {} is not supported'.format(
npimg.dtype)) npimg.dtype))
return Image.fromarray(npimg, mode=mode) return Image.fromarray(npimg, mode=self.mode)
``` ```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册