提交 5467dbdd 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!13 Add a fuzzing test framework for DNN and image transform methods

Merge pull request !13 from ZhidanLiu/master
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
from mindspore import Model
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn import SoftmaxCrossEntropyWithLogits
from mindarmour.attacks.gradient_method import FastGradientSignMethod
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.fuzzing.fuzzing import Fuzzing
from lenet5_net import LeNet5
sys.path.append("..")
from data_processing import generate_mnist_dataset
LOGGER = LogUtil.get_instance()
TAG = 'Fuzz_test'
LOGGER.set_level('INFO')
def test_lenet_mnist_fuzzing():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# upload trained network
ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
net = LeNet5()
load_dict = load_checkpoint(ckpt_name)
load_param_into_net(net, load_dict)
model = Model(net)
# get training data
data_list = "./MNIST_datasets/train"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
train_images = []
for data in ds.create_tuple_iterator():
images = data[0].astype(np.float32)
train_images.append(images)
train_images = np.concatenate(train_images, axis=0)
# initialize fuzz test with training dataset
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images)
# fuzz test with original test data
# get test data
data_list = "./MNIST_datasets/test"
batch_size = 32
ds = generate_mnist_dataset(data_list, batch_size, sparse=True)
test_images = []
test_labels = []
for data in ds.create_tuple_iterator():
images = data[0].astype(np.float32)
labels = data[1]
test_images.append(images)
test_labels.append(labels)
test_images = np.concatenate(test_images, axis=0)
test_labels = np.concatenate(test_labels, axis=0)
initial_seeds = []
# make initial seeds
for img, label in zip(test_images, test_labels):
initial_seeds.append([img, label, 0])
initial_seeds = initial_seeds[:100]
model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20)
failed_tests = model_fuzz_test.fuzzing()
model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc())
if __name__ == '__main__':
test_lenet_mnist_fuzzing()
"""
This module includes various metrics to fuzzing the test of DNN.
"""
from .fuzzing import Fuzzing
from .model_coverage_metrics import ModelCoverageMetrics
__all__ = ['ModelCoverageMetrics']
\ No newline at end of file
__all__ = ['Fuzzing',
'ModelCoverageMetrics']
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fuzzing.
"""
import numpy as np
from random import choice
from mindspore import Tensor
from mindspore import Model
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \
Translate, Scale, Shear, Rotate
from mindarmour.utils._check_param import check_model, check_numpy_param, \
check_int_positive
class Fuzzing:
"""
Fuzzing test framework for deep neural networks.
Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep
Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_
Args:
initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0],
[image, label, 0], ...].
target_model (Model): Target fuzz model.
train_dataset (numpy.ndarray): Training dataset used for determine
the neurons' output boundaries.
const_K (int): The number of mutate tests for a seed.
mode (str): Image mode used in image transform, 'L' means grey graph.
Default: 'L'.
"""
def __init__(self, initial_seeds, target_model, train_dataset, const_K,
mode='L', max_seed_num=1000):
self.initial_seeds = initial_seeds
self.target_model = check_model('model', target_model, Model)
self.train_dataset = check_numpy_param('train_dataset', train_dataset)
self.K = check_int_positive('const_k', const_K)
self.mode = mode
self.max_seed_num = check_int_positive('max_seed_num', max_seed_num)
self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10,
train_dataset)
def _image_value_expand(self, image):
return image*255
def _image_value_compress(self, image):
return image / 255
def _metamorphic_mutate(self, seed, try_num=50):
if self.mode == 'L':
seed = seed[0]
info = [seed, seed]
mutate_tests = []
affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise']
pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate']
strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur,
'Noise': Noise,
'Translate': Translate, 'Scale': Scale, 'Shear': Shear,
'Rotate': Rotate}
for _ in range(self.K):
for _ in range(try_num):
if (info[0] == info[1]).all():
trans_strage = self._random_pick_mutate(affine_trans,
pixel_value_trans)
else:
trans_strage = self._random_pick_mutate(affine_trans, [])
transform = strages[trans_strage](
self._image_value_expand(seed), self.mode)
transform.random_param()
mutate_test = transform.transform()
mutate_test = np.expand_dims(
self._image_value_compress(mutate_test), 0)
if self._is_trans_valid(seed, mutate_test):
if trans_strage in affine_trans:
info[1] = mutate_test
mutate_tests.append(mutate_test)
if len(mutate_tests) == 0:
mutate_tests.append(seed)
return np.array(mutate_tests)
def fuzzing(self, coverage_metric='KMNC'):
"""
Fuzzing tests for deep neural networks.
Args:
coverage_metric (str): Model coverage metric of neural networks.
Default: 'KMNC'.
Returns:
list, mutated tests mis-predicted by target dnn model.
"""
seed = self._select_next()
failed_tests = []
seed_num = 0
while len(seed) > 0 and seed_num < self.max_seed_num:
mutate_tests = self._metamorphic_mutate(seed[0])
coverages, results = self._run(mutate_tests, coverage_metric)
coverage_gains = self._coverage_gains(coverages)
for mutate, cov, res in zip(mutate_tests, coverage_gains, results):
if np.argmax(seed[1]) != np.argmax(res):
failed_tests.append(mutate)
continue
if cov > 0:
self.initial_seeds.append([mutate, seed[1], 0])
seed = self._select_next()
seed_num += 1
return failed_tests
def _coverage_gains(self, coverages):
gains = [0] + coverages[:-1]
gains = np.array(coverages) - np.array(gains)
return gains
def _run(self, mutate_tests, coverage_metric="KNMC"):
coverages = []
result = self.target_model.predict(
Tensor(mutate_tests.astype(np.float32)))
result = result.asnumpy()
for index in range(len(mutate_tests)):
mutate = np.expand_dims(mutate_tests[index], 0)
self.coverage_metrics.test_adequacy_coverage_calculate(
mutate.astype(np.float32), batch_size=1)
if coverage_metric == "KMNC":
coverages.append(self.coverage_metrics.get_kmnc())
return coverages, result
def _select_next(self):
seed = choice(self.initial_seeds)
return seed
def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list):
strage = choice(affine_trans_list + pixel_value_trans_list)
return strage
def _is_trans_valid(self, seed, mutate_test):
is_valid = False
alpha = 0.02
beta = 0.2
diff = np.array(seed - mutate_test).flatten()
size = np.shape(diff)[0]
L0 = np.linalg.norm(diff, ord=0)
Linf = np.linalg.norm(diff, ord=np.inf)
if L0 > alpha*size:
if Linf < 256:
is_valid = True
else:
if Linf < beta*255:
is_valid = True
return is_valid
......@@ -76,6 +76,11 @@ class ModelCoverageMetrics:
upper_compare_array = np.concatenate(
[output, np.array([self._upper_bounds])], axis=0)
self._upper_bounds = np.max(upper_compare_array, axis=0)
if batches == 0:
output = self._model.predict(Tensor(train_dataset)).asnumpy()
self._lower_bounds = np.min(output, axis=0)
self._upper_bounds = np.max(output, axis=0)
output_mat.append(output)
self._var = np.std(np.concatenate(np.array(output_mat), axis=0),
axis=0)
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform
"""
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import random
from mindarmour.utils._check_param import check_numpy_param
class ImageTransform:
"""
The abstract base class for all image transform classes.
"""
def __init__(self):
pass
def random_param(self):
pass
def transform(self):
pass
class Contrast(ImageTransform):
"""
Contrast of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Contrast, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(-10, 10)
def transform(self):
img = Image.fromarray(self.image, self.mode)
img_contrast = ImageEnhance.Contrast(img)
trans_image = img_contrast.enhance(self.factor)
trans_image = np.array(trans_image)
return trans_image
class Brightness(ImageTransform):
"""
Brightness of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Brightness, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(-10, 10)
def transform(self):
img = Image.fromarray(self.image, self.mode)
img_contrast = ImageEnhance.Brightness(img)
trans_image = img_contrast.enhance(self.factor)
trans_image = np.array(trans_image)
return trans_image
class Blur(ImageTransform):
"""
GaussianBlur of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Blur, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.radius = random.uniform(-10, 10)
def transform(self):
""" Transform the image. """
img = Image.fromarray(self.image, self.mode)
trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius))
trans_image = np.array(trans_image)
return trans_image
class Noise(ImageTransform):
"""
Add noise of an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Noise, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" random generate parameters """
self.factor = random.uniform(-1, 1)
def transform(self):
""" Random generate parameters. """
noise = np.random.uniform(low=-1, high=1, size=self.image.shape)
trans_image = np.copy(self.image)
trans_image[noise < -self.factor] = 0
trans_image[noise > self.factor] = 255
trans_image = np.array(trans_image)
return trans_image
class Translate(ImageTransform):
"""
Translate an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Translate, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
image_shape = np.shape(self.image)
self.x_bias = random.uniform(0, image_shape[0])
self.y_bias = random.uniform(0, image_shape[1])
def transform(self):
""" Transform the image. """
img = Image.fromarray(self.image, self.mode)
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, self.x_bias, 0, 1, self.y_bias))
trans_image = np.array(trans_image)
return trans_image
class Scale(ImageTransform):
"""
Scale an image.
Args:
image(numpy.ndarray): Original image to be transformed.
mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Scale, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor_x = random.uniform(0, 1)
self.factor_y = random.uniform(0, 1)
def transform(self):
""" Transform the image. """
img = Image.fromarray(self.image, self.mode)
trans_image = img.transform(img.size, Image.AFFINE,
(self.factor_x, 0, 0, 0, self.factor_y, 0))
trans_image = np.array(trans_image)
return trans_image
class Shear(ImageTransform):
"""
Shear an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Shear, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.factor = random.uniform(0, 1)
def transform(self):
""" Transform the image. """
img = Image.fromarray(self.image, self.mode)
if np.random.random() > 0.5:
level = -self.factor
else:
level = self.factor
if np.random.random() > 0.5:
trans_image = img.transform(img.size, Image.AFFINE,
(1, level, 0, 0, 1, 0))
else:
trans_image = img.transform(img.size, Image.AFFINE,
(1, 0, 0, level, 1, 0))
trans_image = np.array(trans_image, dtype=np.float)
return trans_image
class Rotate(ImageTransform):
"""
Rotate an image.
Args:
image (numpy.ndarray): Original image to be transformed.
mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'],
'L' means grey image.
"""
def __init__(self, image, mode):
super(Rotate, self).__init__()
self.image = check_numpy_param('image', image)
self.mode = mode
def random_param(self):
""" Random generate parameters. """
self.angle = random.uniform(0, 360)
def transform(self):
""" Transform the image. """
img = Image.fromarray(self.image, self.mode)
trans_image = img.rotate(self.angle)
trans_image = np.array(trans_image)
return trans_image
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Model-fuzz coverage test.
"""
import numpy as np
import pytest
import sys
from mindspore.train import Model
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import context
from mindspore.common.initializer import TruncatedNormal
from mindarmour.utils.logger import LogUtil
from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics
from mindarmour.fuzzing.fuzzing import Fuzzing
LOGGER = LogUtil.get_instance()
TAG = 'Fuzzing test'
LOGGER.set_level('INFO')
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
return TruncatedNormal(0.02)
class Net(nn.Cell):
"""
Lenet network
"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16*5*5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (-1, 16*5*5))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
# initialize fuzz test with training dataset
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
for img, label in zip(test_data, test_labels):
initial_seeds.append([img, label, 0])
model_coverage_test.test_adequacy_coverage_calculate(
np.array(test_data).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5,
max_seed_num=10)
failed_tests = model_fuzz_test.fuzzing()
model_coverage_test.test_adequacy_coverage_calculate(
np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_fuzzing_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
# load network
net = Net()
model = Model(net)
batch_size = 8
num_classe = 10
# initialize fuzz test with training dataset
training_data = np.random.rand(32, 1, 32, 32).astype(np.float32)
model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data)
# fuzz test with original test data
# get test data
test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32)
test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32)
initial_seeds = []
for img, label in zip(test_data, test_labels):
initial_seeds.append([img, label, 0])
model_coverage_test.test_adequacy_coverage_calculate(
np.array(test_data).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5,
max_seed_num=10)
failed_tests = model_fuzz_test.fuzzing()
model_coverage_test.test_adequacy_coverage_calculate(
np.array(failed_tests).astype(np.float32))
LOGGER.info(TAG, 'KMNC of this test is : %s',
model_coverage_test.get_kmnc())
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Image transform test.
"""
import numpy as np
import pytest
from mindarmour.utils.logger import LogUtil
from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \
Translate, Scale, Shear, Rotate
LOGGER = LogUtil.get_instance()
TAG = 'Image transform test'
LOGGER.set_level('INFO')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_contrast():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Contrast(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_brightness():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Brightness(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_blur():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Blur(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_noise():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Noise(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_translate():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Translate(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_shear():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Shear(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_scale():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Scale(image, mode)
trans.random_param()
trans_image = trans.transform()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
@pytest.mark.component_mindarmour
def test_rotate():
image = (np.random.rand(32, 32)*255).astype(np.float32)
mode = 'L'
trans = Rotate(image, mode)
trans.random_param()
trans_image = trans.transform()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册