提交 f253c47b 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!73 reconstruct image transform

Merge pull request !73 from ZhidanLiu/master
"""
This module includes various metrics to fuzzing the test of DNN.
"""
from .fuzzing import Fuzzing
from .fuzzing import Fuzzer
from .model_coverage_metrics import ModelCoverageMetrics
__all__ = ['Fuzzing',
__all__ = ['Fuzzer',
'ModelCoverageMetrics']
......@@ -23,11 +23,11 @@ from mindspore import Tensor
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \
Translate, Scale, Shear, Rotate
class Fuzzing:
class Fuzzer:
"""
Fuzzing test framework for deep neural networks.
......@@ -84,7 +84,7 @@ class Fuzzing:
[])
transform = strages[trans_strage](
self._image_value_expand(seed), self.mode)
transform.random_param()
transform.set_params(auto_param=True)
mutate_test = transform.transform()
mutate_test = np.expand_dims(
self._image_value_compress(mutate_test), 0)
......@@ -138,7 +138,7 @@ class Fuzzing:
result = result.asnumpy()
for index in range(len(mutate_tests)):
mutate = np.expand_dims(mutate_tests[index], 0)
self.coverage_metrics.test_adequacy_coverage_calculate(
self.coverage_metrics.model_coverage_test(
mutate.astype(np.float32), batch_size=1)
if coverage_metric == "KMNC":
coverages.append(self.coverage_metrics.get_kmnc())
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform
"""
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
from mindspore.dataset.transforms.vision.py_transforms_util import is_numpy, \
to_pil, hwc_to_chw
from mindarmour.utils._check_param import check_param_multi_types
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'ModelCoverageMetrics'
def chw_to_hwc(img):
"""
Transpose the input image; shape (C, H, W) to shape (H, W, C).
Args:
img (numpy.ndarray): Image to be converted.
Returns:
img (numpy.ndarray), Converted image.
"""
if is_numpy(img):
return img.transpose(1, 2, 0).copy()
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
def is_hwc(img):
"""
Check if the input image is shape (H, W, C).
Args:
img (numpy.ndarray): Image to be checked.
Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
def is_chw(img):
"""
Check if the input image is shape (H, W, C).
Args:
img (numpy.ndarray): Image to be checked.
Returns:
Bool, True if input is shape (H, W, C).
"""
if is_numpy(img):
img_shape = np.shape(img)
if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
def is_rgb(img):
"""
Check if the input image is RGB.
Args:
img (numpy.ndarray): Image to be checked.
Returns:
Bool, True if input is RGB.
"""
if is_numpy(img):
if len(np.shape(img)) == 3:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
def is_normalized(img):
"""
Check if the input image is normalized between 0 to 1.
Args:
img (numpy.ndarray): Image to be checked.
Returns:
Bool, True if input is normalized between 0 to 1.
"""
if is_numpy(img):
minimal = np.min(img)
maximun = np.max(img)
if minimal >= 0 and maximun <= 1:
return True
return False
raise TypeError('img should be Numpy array. Got {}'.format(type(img)))
class ImageTransform:
"""
The abstract base class for all image transform classes.
"""
def __init__(self):
pass
def _check(self, image):
""" Check image format. If input image is RGB and its shape
is (C, H, W), it will be transposed to (H, W, C). If the value
of the image is not normalized , it will be normalized between 0 to 1."""
rgb = is_rgb(image)
chw = False
normalized = is_normalized(image)
if rgb:
chw = is_chw(image)
if chw:
image = chw_to_hwc(image)
else:
image = image
else:
image = image
if normalized:
image = np.uint8(image*255)
return rgb, chw, normalized, image
def _original_format(self, image, chw, normalized):
""" Return transformed image with original format. """
if not is_numpy(image):
image = np.array(image)
if chw:
image = hwc_to_chw(image)
if normalized:
image = image / 255
return image
def transform(self, image):
pass
class Contrast(ImageTransform):
"""
Contrast of an image.
Args:
factor ([float, int]): Control the contrast of an image. If 1.0 gives the
original image. If 0 gives a gray image. Default: 1.
"""
def __init__(self, factor=1):
super(Contrast, self).__init__()
self.set_params(factor)
def set_params(self, factor=1, auto_param=False):
"""
Set contrast parameters.
Args:
factor ([float, int]): Control the contrast of an image. If 1.0 gives
the original image. If 0 gives a gray image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(-5, 5)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])
def transform(self, image):
"""
Transform the image.
Args:
image (numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Contrast(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Brightness(ImageTransform):
"""
Brightness of an image.
Args:
factor ([float, int]): Control the brightness of an image. If 1.0 gives
the original image. If 0 gives a black image. Default: 1.
"""
def __init__(self, factor=1):
super(Brightness, self).__init__()
self.set_params(factor)
def set_params(self, factor=1, auto_param=False):
"""
Set brightness parameters.
Args:
factor ([float, int]): Control the brightness of an image. If 1
gives the original image. If 0 gives a black image. Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(0, 5)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])
def transform(self, image):
"""
Transform the image.
Args:
image (numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
image = to_pil(image)
img_contrast = ImageEnhance.Brightness(image)
trans_image = img_contrast.enhance(self.factor)
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Blur(ImageTransform):
"""
Blurs the image using Gaussian blur filter.
Args:
radius([float, int]): Blur radius, 0 means no blur. Default: 0.
"""
def __init__(self, radius=0):
super(Blur, self).__init__()
self.set_params(radius)
def set_params(self, radius=0, auto_param=False):
"""
Set blur parameters.
Args:
radius ([float, int]): Blur radius, 0 means no blur. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.radius = np.random.uniform(-1.5, 1.5)
else:
self.radius = check_param_multi_types('radius', radius, [int, float])
def transform(self, image):
"""
Transform the image.
Args:
image (numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
image = to_pil(image)
trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius))
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Noise(ImageTransform):
"""
Add noise of an image.
Args:
factor (float): 1 - factor is the ratio of pixels to add noise.
If 0 gives the original image. Default 0.
"""
def __init__(self, factor=0):
super(Noise, self).__init__()
self.set_params(factor)
def set_params(self, factor=0, auto_param=False):
"""
Set noise parameters.
Args:
factor ([float, int]): 1 - factor is the ratio of pixels to add noise.
If 0 gives the original image. Default 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor = np.random.uniform(0.7, 1)
else:
self.factor = check_param_multi_types('factor', factor, [int, float])
def transform(self, image):
"""
Transform the image.
Args:
image (numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
noise = np.random.uniform(low=-1, high=1, size=np.shape(image))
trans_image = np.copy(image)
trans_image[noise < -self.factor] = 0
trans_image[noise > self.factor] = 1
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Translate(ImageTransform):
"""
Translate an image.
Args:
x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0.
y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0.
"""
def __init__(self, x_bias=0, y_bias=0):
super(Translate, self).__init__()
self.set_params(x_bias, y_bias)
def set_params(self, x_bias=0, y_bias=0, auto_param=False):
"""
Set translate parameters.
Args:
x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0.
y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
self.auto_param = auto_param
if auto_param:
self.x_bias = np.random.uniform(-0.3, 0.3)
self.y_bias = np.random.uniform(-0.3, 0.3)
else:
self.x_bias = check_param_multi_types('x_bias', x_bias,
[int, float])
self.y_bias = check_param_multi_types('y_bias', y_bias,
[int, float])
def transform(self, image):
"""
Transform the image.
Args:
image(numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
img = to_pil(image)
if self.auto_param:
image_shape = np.shape(image)
self.x_bias = image_shape[0]*self.x_bias
self.y_bias = image_shape[1]*self.y_bias
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Scale(ImageTransform):
"""
Scale an image in the middle.
Args:
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. Default: 1.
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. Default: 1.
"""
def __init__(self, factor_x=1, factor_y=1):
super(Scale, self).__init__()
self.set_params(factor_x, factor_y)
def set_params(self, factor_x=1, factor_y=1, auto_param=False):
"""
Set scale parameters.
Args:
factor_x ([float, int]): Rescale in X-direction, x=factor_x*x.
Default: 1.
factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y.
Default: 1.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.factor_x = np.random.uniform(0.7, 3)
self.factor_y = np.random.uniform(0.7, 3)
else:
self.factor_x = check_param_multi_types('factor_x', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor_y', factor_y,
[int, float])
def transform(self, image):
"""
Transform the image.
Args:
image(numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
rgb, chw, normalized, image = self._check(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
move_x_centor = w / 2*(1 - self.factor_x)
move_y_centor = h / 2*(1 - self.factor_y)
img = to_pil(image)
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, move_x_centor,
0, self.factor_y, move_y_centor))
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Shear(ImageTransform):
"""
Shear an image, for each pixel (x, y) in the sheared image, the new value is
taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then
the sheared image will be rescaled to fit original size.
Args:
factor_x ([float, int]): Shear factor of horizontal direction. Default: 0.
factor_y ([float, int]): Shear factor of vertical direction. Default: 0.
"""
def __init__(self, factor_x=0, factor_y=0):
super(Shear, self).__init__()
self.set_params(factor_x, factor_y)
def set_params(self, factor_x=0, factor_y=0, auto_param=False):
"""
Set shear parameters.
Args:
factor_x ([float, int]): Shear factor of horizontal direction.
Default: 0.
factor_y ([float, int]): Shear factor of vertical direction.
Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if factor_x != 0 and factor_y != 0:
msg = 'factor_x and factor_y can not be both more than 0.'
LOGGER.error(TAG, msg)
raise ValueError(msg)
if auto_param:
if np.random.uniform(-1, 1) > 0:
self.factor_x = np.random.uniform(-2, 2)
self.factor_y = 0
else:
self.factor_x = 0
self.factor_y = np.random.uniform(-2, 2)
else:
self.factor_x = check_param_multi_types('factor', factor_x,
[int, float])
self.factor_y = check_param_multi_types('factor', factor_y,
[int, float])
def transform(self, image):
"""
Transform the image.
Args:
image(numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
rgb, chw, normalized, image = self._check(image)
img = to_pil(image)
if rgb:
h, w, _ = np.shape(image)
else:
h, w = np.shape(image)
if self.factor_x != 0:
boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h]
min_x = min(boarder_x)
max_x = max(boarder_x)
scale = (max_x - min_x) / w
move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2
move_y_cen = h*(1 - scale) / 2
else:
boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w]
min_y = min(boarder_y)
max_y = max(boarder_y)
scale = (max_y - min_y) / h
move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2
move_x_cen = w*(1 - scale) / 2
trans_image = img.transform(img.size, Image.AFFINE,
(scale, scale*self.factor_x, move_x_cen,
scale*self.factor_y, scale, move_y_cen))
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
class Rotate(ImageTransform):
"""
Rotate an image of degrees counter clockwise around its center.
Args:
angle([float, int]): Degrees counter clockwise. Default: 0.
"""
def __init__(self, angle=0):
super(Rotate, self).__init__()
self.set_params(angle)
def set_params(self, angle=0, auto_param=False):
"""
Set rotate parameters.
Args:
angle([float, int]): Degrees counter clockwise. Default: 0.
auto_param (bool): True if auto generate parameters. Default: False.
"""
if auto_param:
self.angle = np.random.uniform(0, 360)
else:
self.angle = check_param_multi_types('angle', angle, [int, float])
def transform(self, image):
"""
Transform the image.
Args:
image(numpy.ndarray): Original image to be transformed.
Returns:
numpy.ndarray, transformed image.
"""
_, chw, normalized, image = self._check(image)
img = to_pil(image)
trans_image = img.rotate(self.angle, expand=True)
trans_image = self._original_format(trans_image, chw, normalized)
return trans_image
......@@ -133,8 +133,7 @@ class ModelCoverageMetrics:
else:
self._main_section_hits[i][int(section_indexes[i])] = 1
def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0,
batch_size=32):
def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32):
"""
Calculate the testing adequacy of the given dataset.
......@@ -147,7 +146,7 @@ class ModelCoverageMetrics:
Examples:
>>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images)
>>> model_fuzz_test.test_adequacy_coverage_calculate(test_images)
>>> model_fuzz_test.calculate_coverage(test_images)
"""
dataset = check_numpy_param('dataset', dataset)
batch_size = check_int_positive('batch_size', batch_size)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform
"""
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import random
from mindarmour.utils._check_param import check_numpy_param
class ImageTransform:
"""
The abstract base class for all image transform classes.
"""
def __init__(self):
pass
def random_param(self):
pass
def transform(self):
pass
class Contrast(ImageTransform):
"""
Contrast of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Contrast, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(-5, 5)
def transform(self):
img = Image.fromarray(np.uint8(self.image*255), self.mode)
img_contrast = ImageEnhance.Contrast(img)
trans_image = img_contrast.enhance(self.factor)
trans_image = np.array(trans_image)/255
return trans_image
class Brightness(ImageTransform):
"""
Brightness of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Brightness, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(0, 5)
def transform(self):
img = Image.fromarray(np.uint8(self.image*255), self.mode)
img_contrast = ImageEnhance.Brightness(img)
trans_image = img_contrast.enhance(self.factor)
trans_image = np.array(trans_image)/255
return trans_image
class Blur(ImageTransform):
"""
GaussianBlur of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Blur, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.radius = random.uniform(-1.5, 1.5)
def transform(self):
""" Transform the image. """
img = Image.fromarray(np.uint8(self.image*255), self.mode)
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius))
trans_image = np.array(trans_image)/255
return trans_image
class Noise(ImageTransform):
"""
Add noise of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Noise, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" random generate parameters """
self.factor = random.uniform(0.7, 1)
def transform(self):
""" Random generate parameters. """
noise = np.random.uniform(low=-1, high=1, size=self.image.shape)
trans_image = np.copy(self.image)
trans_image[noise < -self.factor] = 0
trans_image[noise > self.factor] = 1
trans_image = np.array(trans_image)
return trans_image
class Translate(ImageTransform):
"""
Translate an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Translate, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
image_shape = np.shape(self.image)
self.x_bias = random.uniform(-image_shape[0]/3, image_shape[0]/3)
self.y_bias = random.uniform(-image_shape[1]/3, image_shape[1]/3)
def transform(self):
""" Transform the image. """
img = Image.fromarray(np.uint8(self.image*255), self.mode)
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = np.array(trans_image)/255
return trans_image
class Scale(ImageTransform):
"""
Scale an image.
Args:
image(numpy.ndarray): Original image to be transformed.
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Scale, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor_x = random.uniform(0.7, 2)
self.factor_y = random.uniform(0.7, 2)
def transform(self):
""" Transform the image. """
img = Image.fromarray(np.uint8(self.image*255), self.mode)
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, 0, 0, self.factor_y, 0))
trans_image = np.array(trans_image)/255
return trans_image
class Shear(ImageTransform):
"""
Shear an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Shear, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(0, 1)
def transform(self):
""" Transform the image. """
img = Image.fromarray(np.uint8(self.image*255), self.mode)
if np.random.random() > 0.5:
level = -self.factor
else:
level = self.factor
if np.random.random() > 0.5:
trans_image = img.transform(img.size, Image.AFFINE,
(1, level, 0, 0, 1, 0))
else:
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, 0, level, 1, 0))
trans_image = np.array(trans_image, dtype=np.float)/255
return trans_image
class Rotate(ImageTransform):
"""
Rotate an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Rotate, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.angle = random.uniform(0, 360)
def transform(self):
""" Transform the image. """
img = Image.fromarray(np.uint8(self.image*255), self.mode)
trans_image = img.rotate(self.angle)
trans_image = np.array(trans_image)/255
return trans_image
......@@ -77,7 +77,7 @@ def test_lenet_mnist_coverage_cpu():
# get test data
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000).astype(np.int32)
model_fuzz_test.test_adequacy_coverage_calculate(test_data)
model_fuzz_test.calculate_coverage(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
......@@ -86,8 +86,7 @@ def test_lenet_mnist_coverage_cpu():
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss)
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
......@@ -113,7 +112,7 @@ def test_lenet_mnist_coverage_ascend():
test_data = (np.random.random((2000, 10))*20).astype(np.float32)
test_labels = np.random.randint(0, 10, 2000)
test_labels = (np.eye(10)[test_labels]).astype(np.float32)
model_fuzz_test.test_adequacy_coverage_calculate(test_data)
model_fuzz_test.calculate_coverage(test_data)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
......@@ -121,8 +120,7 @@ def test_lenet_mnist_coverage_ascend():
# generate adv_data
attack = FastGradientSignMethod(net, eps=0.3)
adv_data = attack.batch_generate(test_data, test_labels, batch_size=32)
model_fuzz_test.test_adequacy_coverage_calculate(adv_data,
bias_coefficient=0.5)
model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5)
LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc())
LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc())
LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-fuzz coverage test.
"""
import numpy as np
import pytest
from mindspore import context
from mindspore import nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
from mindspore.train import Model
from mindarmour.fuzzing.fuzzing import Fuzzing
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils.logger import LogUtil
LOGGER = LogUtil.get_instance()
TAG = 'Fuzzing test'
LOGGER.set_level('INFO')
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
return TruncatedNormal(0.02)
class Net(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
# initialize fuzz test with training dataset
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
for img, label in zip(test_data, test_labels):
initial_seeds.append([img, label, 0])
model_coverage_test.test_adequacy_coverage_calculate(
np.array(test_data).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5,
max_seed_num=10)
failed_tests = model_fuzz_test.fuzzing()
if failed_tests:
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
else:
LOGGER.info(TAG, 'Fuzzing test identifies none failed test')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_CPU():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
# initialize fuzz test with training dataset
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
for img, label in zip(test_data, test_labels):
initial_seeds.append([img, label, 0])
model_coverage_test.test_adequacy_coverage_calculate(
np.array(test_data).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5,
max_seed_num=10)
failed_tests = model_fuzz_test.fuzzing()
if failed_tests:
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
else:
LOGGER.info(TAG, 'Fuzzing test identifies none failed test')
......@@ -18,7 +18,7 @@ import numpy as np
import pytest
from mindarmour.utils.logger import LogUtil
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \
from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \
Translate, Scale, Shear, Rotate
LOGGER = LogUtil.get_instance()
......@@ -31,11 +31,10 @@ LOGGER.set_level('INFO')
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_contrast():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Contrast(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Contrast()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -43,11 +42,10 @@ def test_contrast():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_brightness():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Brightness(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Brightness()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -57,11 +55,10 @@ def test_brightness():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_blur():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Blur(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Blur()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -71,11 +68,10 @@ def test_blur():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_noise():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Noise(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Noise()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -85,11 +81,10 @@ def test_noise():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_translate():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Translate(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Translate()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -99,11 +94,10 @@ def test_translate():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_shear():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Shear(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Shear()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -113,11 +107,10 @@ def test_shear():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_scale():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Scale(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Scale()
trans.set_params(auto_param=True)
_ = trans.transform(image)
@pytest.mark.level0
......@@ -127,8 +120,7 @@ def test_scale():
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_rotate():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Rotate(image, mode)
trans.random_param()
_ = trans.transform()
image = (np.random.rand(32, 32)).astype(np.float32)
trans = Rotate()
trans.set_params(auto_param=True)
_ = trans.transform(image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册