提交 56e7a7de 编写于 作者: A Amir Lashkari

Added UniformAugment + Python Augmentation Ops

上级 862d23fe
......@@ -1312,3 +1312,177 @@ class HsvToRgb:
rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs.
"""
return util.hsv_to_rgbs(hsv_imgs, self.is_hwc)
class RandomColor:
"""
Adjust the color of the input PIL image by a random degree.
Args:
degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)).
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.RandomColor(0.5,1.5),
>>> py_transforms.ToTensor()])
"""
def __init__(self, degrees=(0.1, 1.9)):
self.degrees = degrees
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be color adjusted.
Returns:
img (PIL Image), Color adjusted image.
"""
return util.random_color(img, self.degrees)
class RandomSharpness:
"""
Adjust the sharpness of the input PIL image by a random degree.
Args:
degrees (sequence): Range of random sharpness adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)).
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.RandomColor(0.5,1.5),
>>> py_transforms.ToTensor()])
"""
def __init__(self, degrees=(0.1, 1.9)):
self.degrees = degrees
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be sharpness adjusted.
Returns:
img (PIL Image), Color adjusted image.
"""
return util.random_sharpness(img, self.degrees)
class AutoContrast:
"""
Automatically maximize the contrast of the input PIL image.
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.AutoContrast(),
>>> py_transforms.ToTensor()])
"""
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be augmented with AutoContrast.
Returns:
img (PIL Image), Augmented image.
"""
return util.auto_contrast(img)
class Invert:
"""
Invert colors of input PIL image.
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.Invert(),
>>> py_transforms.ToTensor()])
"""
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be color Inverted.
Returns:
img (PIL Image), Color inverted image.
"""
return util.invert_color(img)
class Equalize:
"""
Equalize the histogram of input PIL image.
Examples:
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.Equalize(),
>>> py_transforms.ToTensor()])
"""
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be equalized.
Returns:
img (PIL Image), Equalized image.
"""
return util.equalize(img)
class UniformAugment:
"""
Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assigns a probability to each transform for
each image to decide whether apply it or not.
Args:
transforms (list): List of transformations to be chosen from to apply.
num_ops (int, optional): number of transforms to sequentially apply (default=2).
Examples:
>>> transforms_list = [py_transforms.CenterCrop(64),
>>> py_transforms.RandomColor(),
>>> py_transforms.RandomSharpness(),
>>> py_transforms.RandomRotation(30)]
>>> py_transforms.ComposeOp([py_transforms.Decode(),
>>> py_transforms.UniformAugment(transforms_list),
>>> py_transforms.ToTensor()])
"""
def __init__(self, transforms, num_ops=2):
self.transforms = transforms
self.num_ops = num_ops
def __call__(self, img):
"""
Call method.
Args:
img (PIL Image): Image to be applied transformation.
Returns:
img (PIL Image), Transformed image.
"""
return util.uniform_augment(img, self.transforms, self.num_ops)
......@@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
if batch_size == 0:
return hsv_to_rgb(np_hsv_imgs, is_hwc)
return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
def random_color(img, degrees):
"""
Adjust the color of the input PIL image by a random degree.
Args:
img (PIL Image): Image to be color adjusted.
degrees (sequence): Range of random color adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)).
Returns:
img (PIL Image), Color adjusted image.
"""
if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if isinstance(degrees, (list, tuple)):
if len(degrees) != 2:
raise ValueError("Degrees must be a sequence length 2.")
if degrees[0] < 0:
raise ValueError("Degree value must be non-negative.")
if degrees[0] > degrees[1]:
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
else:
raise TypeError("Degrees must be a sequence in (min,max) format.")
v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Color(img).enhance(v)
def random_sharpness(img, degrees):
"""
Adjust the sharpness of the input PIL image by a random degree.
Args:
img (PIL Image): Image to be sharpness adjusted.
degrees (sequence): Range of random sharpness adjustment degrees.
It should be in (min, max) format (default=(0.1,1.9)).
Returns:
img (PIL Image), Sharpness adjusted image.
"""
if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if isinstance(degrees, (list, tuple)):
if len(degrees) != 2:
raise ValueError("Degrees must be a sequence length 2.")
if degrees[0] < 0:
raise ValueError("Degree value must be non-negative.")
if degrees[0] > degrees[1]:
raise ValueError("Degrees should be in (min,max) format. Got (max,min).")
else:
raise TypeError("Degrees must be a sequence in (min,max) format.")
v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
return ImageEnhance.Sharpness(img).enhance(v)
def auto_contrast(img):
"""
Automatically maximize the contrast of the input PIL image.
Args:
img (PIL Image): Image to be augmented with AutoContrast.
Returns:
img (PIL Image), Augmented image.
"""
if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
def invert_color(img):
"""
Invert colors of input PIL image.
Args:
img (PIL Image): Image to be color inverted.
Returns:
img (PIL Image), Color inverted image.
"""
if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)
def equalize(img):
"""
Equalize the histogram of input PIL image.
Args:
img (PIL Image): Image to be equalized
Returns:
img (PIL Image), Equalized image.
"""
if not is_pil(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
def uniform_augment(img, transforms, num_ops):
"""
Uniformly select and apply a number of transforms sequentially from
a list of transforms. Randomly assigns a probability to each transform for
each image to decide whether apply it or not.
Args:
img: Image to be applied transformation.
transforms (list): List of transformations to be chosen from to apply.
num_ops (int): number of transforms to sequentially aaply.
Returns:
img, Transformed image.
"""
if transforms is None:
raise ValueError("transforms is not provided.")
if not isinstance(transforms, list):
raise ValueError("The transforms needs to be a list.")
if not isinstance(num_ops, int):
raise ValueError("Number of operations should be a positive integer.")
if num_ops < 1:
raise ValueError("Number of operators should equal or greater than one.")
for _ in range(num_ops):
AugmentOp = random.choice(transforms)
pr = random.random()
if random.random() < pr:
img = AugmentOp(img.copy())
transforms.remove(AugmentOp)
return img
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_auto_contrast):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_auto_contrast)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_auto_contrast[i])
plt.title("DE AutoContrast image")
plt.show()
def test_auto_contrast(plot=False):
"""
Test AutoContrast
"""
logger.info("Test AutoContrast")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# AutoContrast Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_auto_contrast = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.AutoContrast(),
F.ToTensor()])
ds_auto_contrast = ds.map(input_columns="image",
operations=transforms_auto_contrast())
ds_auto_contrast = ds_auto_contrast.batch(512)
for idx, (image,label) in enumerate(ds_auto_contrast):
if idx == 0:
images_auto_contrast = np.transpose(image, (0, 2,3,1))
else:
images_auto_contrast = np.append(images_auto_contrast,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_auto_contrast)
if __name__ == "__main__":
test_auto_contrast(plot=True)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_equalize):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_equalize)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_equalize[i])
plt.title("DE Color Equalized image")
plt.show()
def test_equalize(plot=False):
"""
Test Equalize
"""
logger.info("Test Equalize")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# Color Equalized Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_equalize = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.Equalize(),
F.ToTensor()])
ds_equalize = ds.map(input_columns="image",
operations=transforms_equalize())
ds_equalize = ds_equalize.batch(512)
for idx, (image,label) in enumerate(ds_equalize):
if idx == 0:
images_equalize = np.transpose(image, (0, 2,3,1))
else:
images_equalize = np.append(images_equalize,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_equalize[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_equalize)
if __name__ == "__main__":
test_equalize(plot=True)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_invert):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_invert)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_invert[i])
plt.title("DE Color Inverted image")
plt.show()
def test_invert(plot=False):
"""
Test Invert
"""
logger.info("Test Invert")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# Color Inverted Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_invert = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.Invert(),
F.ToTensor()])
ds_invert = ds.map(input_columns="image",
operations=transforms_invert())
ds_invert = ds_invert.batch(512)
for idx, (image,label) in enumerate(ds_invert):
if idx == 0:
images_invert = np.transpose(image, (0, 2,3,1))
else:
images_invert = np.append(images_invert,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_invert[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_invert)
if __name__ == "__main__":
test_invert(plot=True)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_random_color):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_random_color)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_random_color[i])
plt.title("DE Random Color image")
plt.show()
def test_random_color(degrees=(0.1,1.9), plot=False):
"""
Test RandomColor
"""
logger.info("Test RandomColor")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# Random Color Adjusted Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_random_color = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.RandomColor(degrees=degrees),
F.ToTensor()])
ds_random_color = ds.map(input_columns="image",
operations=transforms_random_color())
ds_random_color = ds_random_color.batch(512)
for idx, (image,label) in enumerate(ds_random_color):
if idx == 0:
images_random_color = np.transpose(image, (0, 2,3,1))
else:
images_random_color = np.append(images_random_color,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_random_color[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_random_color)
if __name__ == "__main__":
test_random_color()
test_random_color(plot=True)
test_random_color(degrees=(0.5,1.5), plot=True)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_random_sharpness):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_random_sharpness)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_random_sharpness[i])
plt.title("DE Random Sharpness image")
plt.show()
def test_random_sharpness(degrees=(0.1,1.9), plot=False):
"""
Test RandomSharpness
"""
logger.info("Test RandomSharpness")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# Random Sharpness Adjusted Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_random_sharpness = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.RandomSharpness(degrees=degrees),
F.ToTensor()])
ds_random_sharpness = ds.map(input_columns="image",
operations=transforms_random_sharpness())
ds_random_sharpness = ds_random_sharpness.batch(512)
for idx, (image,label) in enumerate(ds_random_sharpness):
if idx == 0:
images_random_sharpness = np.transpose(image, (0, 2,3,1))
else:
images_random_sharpness = np.append(images_random_sharpness,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_random_sharpness)
if __name__ == "__main__":
test_random_sharpness()
test_random_sharpness(plot=True)
test_random_sharpness(degrees=(0.5,1.5), plot=True)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib.pyplot as plt
from mindspore import log as logger
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.py_transforms as F
DATA_DIR = "../data/dataset/testImageNetData/train/"
def visualize(image_original, image_ua):
"""
visualizes the image using DE op and Numpy op
"""
num = len(image_ua)
for i in range(num):
plt.subplot(2, num, i + 1)
plt.imshow(image_original[i])
plt.title("Original image")
plt.subplot(2, num, i + num + 1)
plt.imshow(image_ua[i])
plt.title("DE UniformAugment image")
plt.show()
def test_uniform_augment(plot=False, num_ops=2):
"""
Test UniformAugment
"""
logger.info("Test UniformAugment")
# Original Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transforms_original = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.ToTensor()])
ds_original = ds.map(input_columns="image",
operations=transforms_original())
ds_original = ds_original.batch(512)
for idx, (image,label) in enumerate(ds_original):
if idx == 0:
images_original = np.transpose(image, (0, 2,3,1))
else:
images_original = np.append(images_original,
np.transpose(image, (0, 2,3,1)),
axis=0)
# UniformAugment Images
ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False)
transform_list = [F.RandomRotation(45),
F.RandomColor(),
F.RandomSharpness(),
F.Invert(),
F.AutoContrast(),
F.Equalize()]
transforms_ua = F.ComposeOp([F.Decode(),
F.Resize((224,224)),
F.UniformAugment(transforms=transform_list, num_ops=num_ops),
F.ToTensor()])
ds_ua = ds.map(input_columns="image",
operations=transforms_ua())
ds_ua = ds_ua.batch(512)
for idx, (image,label) in enumerate(ds_ua):
if idx == 0:
images_ua = np.transpose(image, (0, 2,3,1))
else:
images_ua = np.append(images_ua,
np.transpose(image, (0, 2,3,1)),
axis=0)
num_samples = images_original.shape[0]
mse = np.zeros(num_samples)
for i in range(num_samples):
mse[i] = np.mean((images_ua[i]-images_original[i])**2)
logger.info("MSE= {}".format(str(np.mean(mse))))
if plot:
visualize(images_original, images_ua)
if __name__ == "__main__":
test_uniform_augment(num_ops=1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册