diff --git a/mindarmour/fuzzing/__init__.py b/mindarmour/fuzzing/__init__.py index cc88f0c3cffbde4358d40316d4158771fdfa6889..ba54b51c154d9b6e75cb68369b3ed455c325089e 100644 --- a/mindarmour/fuzzing/__init__.py +++ b/mindarmour/fuzzing/__init__.py @@ -1,8 +1,8 @@ """ This module includes various metrics to fuzzing the test of DNN. """ -from .fuzzing import Fuzzing +from .fuzzing import Fuzzer from .model_coverage_metrics import ModelCoverageMetrics -__all__ = ['Fuzzing', +__all__ = ['Fuzzer', 'ModelCoverageMetrics'] diff --git a/mindarmour/fuzzing/fuzzing.py b/mindarmour/fuzzing/fuzzing.py index e613c6380ece8c0a21290e2adc09b3725c605e8f..f4a418d170637a453e987e87fb58597a0fdb6e37 100644 --- a/mindarmour/fuzzing/fuzzing.py +++ b/mindarmour/fuzzing/fuzzing.py @@ -23,11 +23,11 @@ from mindspore import Tensor from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics from mindarmour.utils._check_param import check_model, check_numpy_param, \ check_int_positive -from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ +from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ Translate, Scale, Shear, Rotate -class Fuzzing: +class Fuzzer: """ Fuzzing test framework for deep neural networks. @@ -84,7 +84,7 @@ class Fuzzing: []) transform = strages[trans_strage]( self._image_value_expand(seed), self.mode) - transform.random_param() + transform.set_params(auto_param=True) mutate_test = transform.transform() mutate_test = np.expand_dims( self._image_value_compress(mutate_test), 0) @@ -138,7 +138,7 @@ class Fuzzing: result = result.asnumpy() for index in range(len(mutate_tests)): mutate = np.expand_dims(mutate_tests[index], 0) - self.coverage_metrics.test_adequacy_coverage_calculate( + self.coverage_metrics.model_coverage_test( mutate.astype(np.float32), batch_size=1) if coverage_metric == "KMNC": coverages.append(self.coverage_metrics.get_kmnc()) diff --git a/mindarmour/fuzzing/image_transform.py b/mindarmour/fuzzing/image_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..d0786ac0b85309103284c6593a35a4ac2c7a8d3d --- /dev/null +++ b/mindarmour/fuzzing/image_transform.py @@ -0,0 +1,569 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image transform +""" +import numpy as np +from PIL import Image, ImageEnhance, ImageFilter + +from mindspore.dataset.transforms.vision.py_transforms_util import is_numpy, \ + to_pil, hwc_to_chw +from mindarmour.utils._check_param import check_param_multi_types +from mindarmour.utils.logger import LogUtil + +LOGGER = LogUtil.get_instance() +TAG = 'ModelCoverageMetrics' + + +def chw_to_hwc(img): + """ + Transpose the input image; shape (C, H, W) to shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be converted. + + Returns: + img (numpy.ndarray), Converted image. + """ + if is_numpy(img): + return img.transpose(1, 2, 0).copy() + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +def is_hwc(img): + """ + Check if the input image is shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is shape (H, W, C). + """ + if is_numpy(img): + img_shape = np.shape(img) + if img_shape[2] == 3 and img_shape[1] > 3 and img_shape[0] > 3: + return True + return False + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +def is_chw(img): + """ + Check if the input image is shape (H, W, C). + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is shape (H, W, C). + """ + if is_numpy(img): + img_shape = np.shape(img) + if img_shape[0] == 3 and img_shape[1] > 3 and img_shape[2] > 3: + return True + return False + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +def is_rgb(img): + """ + Check if the input image is RGB. + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is RGB. + """ + if is_numpy(img): + if len(np.shape(img)) == 3: + return True + return False + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +def is_normalized(img): + """ + Check if the input image is normalized between 0 to 1. + + Args: + img (numpy.ndarray): Image to be checked. + + Returns: + Bool, True if input is normalized between 0 to 1. + """ + if is_numpy(img): + minimal = np.min(img) + maximun = np.max(img) + if minimal >= 0 and maximun <= 1: + return True + return False + raise TypeError('img should be Numpy array. Got {}'.format(type(img))) + + +class ImageTransform: + """ + The abstract base class for all image transform classes. + """ + + def __init__(self): + pass + + def _check(self, image): + """ Check image format. If input image is RGB and its shape + is (C, H, W), it will be transposed to (H, W, C). If the value + of the image is not normalized , it will be normalized between 0 to 1.""" + rgb = is_rgb(image) + chw = False + normalized = is_normalized(image) + if rgb: + chw = is_chw(image) + if chw: + image = chw_to_hwc(image) + else: + image = image + else: + image = image + if normalized: + image = np.uint8(image*255) + return rgb, chw, normalized, image + + def _original_format(self, image, chw, normalized): + """ Return transformed image with original format. """ + if not is_numpy(image): + image = np.array(image) + if chw: + image = hwc_to_chw(image) + if normalized: + image = image / 255 + return image + + def transform(self, image): + pass + + +class Contrast(ImageTransform): + """ + Contrast of an image. + + Args: + factor ([float, int]): Control the contrast of an image. If 1.0 gives the + original image. If 0 gives a gray image. Default: 1. + """ + + def __init__(self, factor=1): + super(Contrast, self).__init__() + self.set_params(factor) + + def set_params(self, factor=1, auto_param=False): + """ + Set contrast parameters. + + Args: + factor ([float, int]): Control the contrast of an image. If 1.0 gives + the original image. If 0 gives a gray image. Default: 1. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.factor = np.random.uniform(-5, 5) + else: + self.factor = check_param_multi_types('factor', factor, [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + image = to_pil(image) + img_contrast = ImageEnhance.Contrast(image) + trans_image = img_contrast.enhance(self.factor) + trans_image = self._original_format(trans_image, chw, normalized) + + return trans_image + + +class Brightness(ImageTransform): + """ + Brightness of an image. + + Args: + factor ([float, int]): Control the brightness of an image. If 1.0 gives + the original image. If 0 gives a black image. Default: 1. + """ + + def __init__(self, factor=1): + super(Brightness, self).__init__() + self.set_params(factor) + + def set_params(self, factor=1, auto_param=False): + """ + Set brightness parameters. + + Args: + factor ([float, int]): Control the brightness of an image. If 1 + gives the original image. If 0 gives a black image. Default: 1. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.factor = np.random.uniform(0, 5) + else: + self.factor = check_param_multi_types('factor', factor, [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + image = to_pil(image) + img_contrast = ImageEnhance.Brightness(image) + trans_image = img_contrast.enhance(self.factor) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Blur(ImageTransform): + """ + Blurs the image using Gaussian blur filter. + + Args: + radius([float, int]): Blur radius, 0 means no blur. Default: 0. + """ + + def __init__(self, radius=0): + super(Blur, self).__init__() + self.set_params(radius) + + def set_params(self, radius=0, auto_param=False): + """ + Set blur parameters. + + Args: + radius ([float, int]): Blur radius, 0 means no blur. Default: 0. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.radius = np.random.uniform(-1.5, 1.5) + else: + self.radius = check_param_multi_types('radius', radius, [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + image = to_pil(image) + trans_image = image.filter(ImageFilter.GaussianBlur(radius=self.radius)) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Noise(ImageTransform): + """ + Add noise of an image. + + Args: + factor (float): 1 - factor is the ratio of pixels to add noise. + If 0 gives the original image. Default 0. + """ + + def __init__(self, factor=0): + super(Noise, self).__init__() + self.set_params(factor) + + def set_params(self, factor=0, auto_param=False): + """ + Set noise parameters. + + Args: + factor ([float, int]): 1 - factor is the ratio of pixels to add noise. + If 0 gives the original image. Default 0. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.factor = np.random.uniform(0.7, 1) + else: + self.factor = check_param_multi_types('factor', factor, [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image (numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + noise = np.random.uniform(low=-1, high=1, size=np.shape(image)) + trans_image = np.copy(image) + trans_image[noise < -self.factor] = 0 + trans_image[noise > self.factor] = 1 + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Translate(ImageTransform): + """ + Translate an image. + + Args: + x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0. + y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0. + """ + + def __init__(self, x_bias=0, y_bias=0): + super(Translate, self).__init__() + self.set_params(x_bias, y_bias) + + def set_params(self, x_bias=0, y_bias=0, auto_param=False): + """ + Set translate parameters. + + Args: + x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0. + y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0. + auto_param (bool): True if auto generate parameters. Default: False. + """ + self.auto_param = auto_param + if auto_param: + self.x_bias = np.random.uniform(-0.3, 0.3) + self.y_bias = np.random.uniform(-0.3, 0.3) + else: + self.x_bias = check_param_multi_types('x_bias', x_bias, + [int, float]) + self.y_bias = check_param_multi_types('y_bias', y_bias, + [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image(numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + img = to_pil(image) + if self.auto_param: + image_shape = np.shape(image) + self.x_bias = image_shape[0]*self.x_bias + self.y_bias = image_shape[1]*self.y_bias + trans_image = img.transform(img.size, Image.AFFINE, + (1, 0, self.x_bias, 0, 1, self.y_bias)) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Scale(ImageTransform): + """ + Scale an image in the middle. + + Args: + factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. Default: 1. + factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. Default: 1. + """ + + def __init__(self, factor_x=1, factor_y=1): + super(Scale, self).__init__() + self.set_params(factor_x, factor_y) + + def set_params(self, factor_x=1, factor_y=1, auto_param=False): + + """ + Set scale parameters. + + Args: + factor_x ([float, int]): Rescale in X-direction, x=factor_x*x. + Default: 1. + factor_y ([float, int]): Rescale in Y-direction, y=factor_y*y. + Default: 1. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.factor_x = np.random.uniform(0.7, 3) + self.factor_y = np.random.uniform(0.7, 3) + else: + self.factor_x = check_param_multi_types('factor_x', factor_x, + [int, float]) + self.factor_y = check_param_multi_types('factor_y', factor_y, + [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image(numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + rgb, chw, normalized, image = self._check(image) + if rgb: + h, w, _ = np.shape(image) + else: + h, w = np.shape(image) + move_x_centor = w / 2*(1 - self.factor_x) + move_y_centor = h / 2*(1 - self.factor_y) + img = to_pil(image) + trans_image = img.transform(img.size, Image.AFFINE, + (self.factor_x, 0, move_x_centor, + 0, self.factor_y, move_y_centor)) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Shear(ImageTransform): + """ + Shear an image, for each pixel (x, y) in the sheared image, the new value is + taken from a position (x+factor_x*y, factor_y*x+y) in the origin image. Then + the sheared image will be rescaled to fit original size. + + Args: + factor_x ([float, int]): Shear factor of horizontal direction. Default: 0. + factor_y ([float, int]): Shear factor of vertical direction. Default: 0. + + """ + + def __init__(self, factor_x=0, factor_y=0): + super(Shear, self).__init__() + self.set_params(factor_x, factor_y) + + def set_params(self, factor_x=0, factor_y=0, auto_param=False): + """ + Set shear parameters. + + Args: + factor_x ([float, int]): Shear factor of horizontal direction. + Default: 0. + factor_y ([float, int]): Shear factor of vertical direction. + Default: 0. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if factor_x != 0 and factor_y != 0: + msg = 'factor_x and factor_y can not be both more than 0.' + LOGGER.error(TAG, msg) + raise ValueError(msg) + if auto_param: + if np.random.uniform(-1, 1) > 0: + self.factor_x = np.random.uniform(-2, 2) + self.factor_y = 0 + else: + self.factor_x = 0 + self.factor_y = np.random.uniform(-2, 2) + else: + self.factor_x = check_param_multi_types('factor', factor_x, + [int, float]) + self.factor_y = check_param_multi_types('factor', factor_y, + [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image(numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + rgb, chw, normalized, image = self._check(image) + img = to_pil(image) + if rgb: + h, w, _ = np.shape(image) + else: + h, w = np.shape(image) + if self.factor_x != 0: + boarder_x = [0, -w, -self.factor_x*h, -w - self.factor_x*h] + min_x = min(boarder_x) + max_x = max(boarder_x) + scale = (max_x - min_x) / w + move_x_cen = (w - scale*w - scale*h*self.factor_x) / 2 + move_y_cen = h*(1 - scale) / 2 + else: + boarder_y = [0, -h, -self.factor_y*w, -h - self.factor_y*w] + min_y = min(boarder_y) + max_y = max(boarder_y) + scale = (max_y - min_y) / h + move_y_cen = (h - scale*h - scale*w*self.factor_y) / 2 + move_x_cen = w*(1 - scale) / 2 + trans_image = img.transform(img.size, Image.AFFINE, + (scale, scale*self.factor_x, move_x_cen, + scale*self.factor_y, scale, move_y_cen)) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image + + +class Rotate(ImageTransform): + """ + Rotate an image of degrees counter clockwise around its center. + + Args: + angle([float, int]): Degrees counter clockwise. Default: 0. + """ + + def __init__(self, angle=0): + super(Rotate, self).__init__() + self.set_params(angle) + + def set_params(self, angle=0, auto_param=False): + """ + Set rotate parameters. + + Args: + angle([float, int]): Degrees counter clockwise. Default: 0. + auto_param (bool): True if auto generate parameters. Default: False. + """ + if auto_param: + self.angle = np.random.uniform(0, 360) + else: + self.angle = check_param_multi_types('angle', angle, [int, float]) + + def transform(self, image): + """ + Transform the image. + + Args: + image(numpy.ndarray): Original image to be transformed. + + Returns: + numpy.ndarray, transformed image. + """ + _, chw, normalized, image = self._check(image) + img = to_pil(image) + trans_image = img.rotate(self.angle, expand=True) + trans_image = self._original_format(trans_image, chw, normalized) + return trans_image diff --git a/mindarmour/fuzzing/model_coverage_metrics.py b/mindarmour/fuzzing/model_coverage_metrics.py index 7d9a3bf61b90f6497bdc85963f5571f3081e2de4..d932c5ad6cf608f7459eef17e9519dad38c1fc62 100644 --- a/mindarmour/fuzzing/model_coverage_metrics.py +++ b/mindarmour/fuzzing/model_coverage_metrics.py @@ -133,8 +133,7 @@ class ModelCoverageMetrics: else: self._main_section_hits[i][int(section_indexes[i])] = 1 - def test_adequacy_coverage_calculate(self, dataset, bias_coefficient=0, - batch_size=32): + def calculate_coverage(self, dataset, bias_coefficient=0, batch_size=32): """ Calculate the testing adequacy of the given dataset. @@ -147,7 +146,7 @@ class ModelCoverageMetrics: Examples: >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) - >>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) + >>> model_fuzz_test.calculate_coverage(test_images) """ dataset = check_numpy_param('dataset', dataset) batch_size = check_int_positive('batch_size', batch_size) diff --git a/mindarmour/utils/image_transform.py b/mindarmour/utils/image_transform.py deleted file mode 100644 index fd64c14eb48290794ebb1e8d0c054b450856d8f1..0000000000000000000000000000000000000000 --- a/mindarmour/utils/image_transform.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Image transform -""" -import numpy as np -from PIL import Image, ImageEnhance, ImageFilter -import random - -from mindarmour.utils._check_param import check_numpy_param - -class ImageTransform: - """ - The abstract base class for all image transform classes. - """ - - def __init__(self): - pass - - def random_param(self): - pass - - def transform(self): - pass - - -class Contrast(ImageTransform): - """ - Contrast of an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Contrast, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.factor = random.uniform(-5, 5) - - def transform(self): - img = Image.fromarray(np.uint8(self.image*255), self.mode) - img_contrast = ImageEnhance.Contrast(img) - trans_image = img_contrast.enhance(self.factor) - trans_image = np.array(trans_image)/255 - return trans_image - - -class Brightness(ImageTransform): - """ - Brightness of an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Brightness, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.factor = random.uniform(0, 5) - - def transform(self): - img = Image.fromarray(np.uint8(self.image*255), self.mode) - img_contrast = ImageEnhance.Brightness(img) - trans_image = img_contrast.enhance(self.factor) - trans_image = np.array(trans_image)/255 - return trans_image - - -class Blur(ImageTransform): - """ - GaussianBlur of an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Blur, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.radius = random.uniform(-1.5, 1.5) - - def transform(self): - """ Transform the image. """ - img = Image.fromarray(np.uint8(self.image*255), self.mode) - trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) - trans_image = np.array(trans_image)/255 - return trans_image - - -class Noise(ImageTransform): - """ - Add noise of an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Noise, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ random generate parameters """ - self.factor = random.uniform(0.7, 1) - - def transform(self): - """ Random generate parameters. """ - noise = np.random.uniform(low=-1, high=1, size=self.image.shape) - trans_image = np.copy(self.image) - trans_image[noise < -self.factor] = 0 - trans_image[noise > self.factor] = 1 - trans_image = np.array(trans_image) - return trans_image - - -class Translate(ImageTransform): - """ - Translate an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Translate, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - image_shape = np.shape(self.image) - self.x_bias = random.uniform(-image_shape[0]/3, image_shape[0]/3) - self.y_bias = random.uniform(-image_shape[1]/3, image_shape[1]/3) - - def transform(self): - """ Transform the image. """ - img = Image.fromarray(np.uint8(self.image*255), self.mode) - trans_image = img.transform(img.size, Image.AFFINE, - (1, 0, self.x_bias, 0, 1, self.y_bias)) - trans_image = np.array(trans_image)/255 - return trans_image - - -class Scale(ImageTransform): - """ - Scale an image. - - Args: - image(numpy.ndarray): Original image to be transformed. - mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Scale, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.factor_x = random.uniform(0.7, 2) - self.factor_y = random.uniform(0.7, 2) - - def transform(self): - """ Transform the image. """ - img = Image.fromarray(np.uint8(self.image*255), self.mode) - trans_image = img.transform(img.size, Image.AFFINE, - (self.factor_x, 0, 0, 0, self.factor_y, 0)) - trans_image = np.array(trans_image)/255 - return trans_image - - -class Shear(ImageTransform): - """ - Shear an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Shear, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.factor = random.uniform(0, 1) - - def transform(self): - """ Transform the image. """ - img = Image.fromarray(np.uint8(self.image*255), self.mode) - if np.random.random() > 0.5: - level = -self.factor - else: - level = self.factor - if np.random.random() > 0.5: - trans_image = img.transform(img.size, Image.AFFINE, - (1, level, 0, 0, 1, 0)) - else: - trans_image = img.transform(img.size, Image.AFFINE, - (1, 0, 0, level, 1, 0)) - trans_image = np.array(trans_image, dtype=np.float)/255 - return trans_image - - -class Rotate(ImageTransform): - """ - Rotate an image. - - Args: - image (numpy.ndarray): Original image to be transformed. - mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], - 'L' means grey image. - """ - - def __init__(self, image, mode): - super(Rotate, self).__init__() - self.image = check_numpy_param('image', image) - self.mode = mode - - def random_param(self): - """ Random generate parameters. """ - self.angle = random.uniform(0, 360) - - def transform(self): - """ Transform the image. """ - img = Image.fromarray(np.uint8(self.image*255), self.mode) - trans_image = img.rotate(self.angle) - trans_image = np.array(trans_image)/255 - return trans_image diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index 158565ae45d156fbc2c24334d90f1d3c94a285a5..7a69863c393b74bded83f97b9901da06308cd230 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -77,7 +77,7 @@ def test_lenet_mnist_coverage_cpu(): # get test data test_data = (np.random.random((2000, 10))*20).astype(np.float32) test_labels = np.random.randint(0, 10, 2000).astype(np.int32) - model_fuzz_test.test_adequacy_coverage_calculate(test_data) + model_fuzz_test.calculate_coverage(test_data) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) @@ -86,8 +86,7 @@ def test_lenet_mnist_coverage_cpu(): loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) attack = FastGradientSignMethod(net, eps=0.3, loss_fn=loss) adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) - model_fuzz_test.test_adequacy_coverage_calculate(adv_data, - bias_coefficient=0.5) + model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) @@ -113,7 +112,7 @@ def test_lenet_mnist_coverage_ascend(): test_data = (np.random.random((2000, 10))*20).astype(np.float32) test_labels = np.random.randint(0, 10, 2000) test_labels = (np.eye(10)[test_labels]).astype(np.float32) - model_fuzz_test.test_adequacy_coverage_calculate(test_data) + model_fuzz_test.calculate_coverage(test_data) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) @@ -121,8 +120,7 @@ def test_lenet_mnist_coverage_ascend(): # generate adv_data attack = FastGradientSignMethod(net, eps=0.3) adv_data = attack.batch_generate(test_data, test_labels, batch_size=32) - model_fuzz_test.test_adequacy_coverage_calculate(adv_data, - bias_coefficient=0.5) + model_fuzz_test.calculate_coverage(adv_data, bias_coefficient=0.5) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) diff --git a/tests/ut/python/fuzzing/test_fuzzing.py b/tests/ut/python/fuzzing/test_fuzzing.py deleted file mode 100644 index 7396f4575fecfd18d3944be2b528a35c1bcce0dc..0000000000000000000000000000000000000000 --- a/tests/ut/python/fuzzing/test_fuzzing.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Model-fuzz coverage test. -""" -import numpy as np -import pytest -from mindspore import context -from mindspore import nn -from mindspore.common.initializer import TruncatedNormal -from mindspore.ops import operations as P -from mindspore.train import Model - -from mindarmour.fuzzing.fuzzing import Fuzzing -from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics -from mindarmour.utils.logger import LogUtil - -LOGGER = LogUtil.get_instance() -TAG = 'Fuzzing test' -LOGGER.set_level('INFO') - - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - - -def fc_with_initialize(input_channels, out_channels): - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - -def weight_variable(): - return TruncatedNormal(0.02) - - -class Net(nn.Cell): - """ - Lenet network - """ - def __init__(self): - super(Net, self).__init__() - self.conv1 = conv(1, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16*5*5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, 10) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.reshape(x, (-1, 16*5*5)) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x - - -@pytest.mark.level0 -@pytest.mark.platform_x86_ascend_training -@pytest.mark.platform_arm_ascend_training -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_fuzzing_ascend(): - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - # load network - net = Net() - model = Model(net) - batch_size = 8 - num_classe = 10 - - # initialize fuzz test with training dataset - training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) - - # fuzz test with original test data - # get test data - test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) - test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) - test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) - - initial_seeds = [] - for img, label in zip(test_data, test_labels): - initial_seeds.append([img, label, 0]) - model_coverage_test.test_adequacy_coverage_calculate( - np.array(test_data).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) - - model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, - max_seed_num=10) - failed_tests = model_fuzz_test.fuzzing() - if failed_tests: - model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - else: - LOGGER.info(TAG, 'Fuzzing test identifies none failed test') - - -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -@pytest.mark.component_mindarmour -def test_fuzzing_CPU(): - context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - # load network - net = Net() - model = Model(net) - batch_size = 8 - num_classe = 10 - - # initialize fuzz test with training dataset - training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) - model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) - - # fuzz test with original test data - # get test data - test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) - test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) - test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) - - initial_seeds = [] - for img, label in zip(test_data, test_labels): - initial_seeds.append([img, label, 0]) - model_coverage_test.test_adequacy_coverage_calculate( - np.array(test_data).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) - - model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, - max_seed_num=10) - failed_tests = model_fuzz_test.fuzzing() - if failed_tests: - model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) - else: - LOGGER.info(TAG, 'Fuzzing test identifies none failed test') diff --git a/tests/ut/python/utils/test_image_transform.py b/tests/ut/python/fuzzing/test_image_transform.py similarity index 63% rename from tests/ut/python/utils/test_image_transform.py rename to tests/ut/python/fuzzing/test_image_transform.py index d1fcaee40ac3e349df24044c23e7c014fc182772..77d3d189dc2db8c67e23b0ba0ef4351706928adc 100644 --- a/tests/ut/python/utils/test_image_transform.py +++ b/tests/ut/python/fuzzing/test_image_transform.py @@ -18,7 +18,7 @@ import numpy as np import pytest from mindarmour.utils.logger import LogUtil -from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ +from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, Noise, \ Translate, Scale, Shear, Rotate LOGGER = LogUtil.get_instance() @@ -31,11 +31,10 @@ LOGGER.set_level('INFO') @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_contrast(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Contrast(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Contrast() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -43,11 +42,10 @@ def test_contrast(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_brightness(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Brightness(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Brightness() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -57,11 +55,10 @@ def test_brightness(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_blur(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Blur(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Blur() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -71,11 +68,10 @@ def test_blur(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_noise(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Noise(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Noise() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -85,11 +81,10 @@ def test_noise(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_translate(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Translate(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Translate() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -99,11 +94,10 @@ def test_translate(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_shear(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Shear(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Shear() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -113,11 +107,10 @@ def test_shear(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_scale(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Scale(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Scale() + trans.set_params(auto_param=True) + _ = trans.transform(image) @pytest.mark.level0 @@ -127,8 +120,7 @@ def test_scale(): @pytest.mark.env_onecard @pytest.mark.component_mindarmour def test_rotate(): - image = (np.random.rand(32, 32)*255).astype(np.float32) - mode = 'L' - trans = Rotate(image, mode) - trans.random_param() - _ = trans.transform() + image = (np.random.rand(32, 32)).astype(np.float32) + trans = Rotate() + trans.set_params(auto_param=True) + _ = trans.transform(image)