diff --git a/example/mnist_demo/lenet5_mnist_coverage.py b/example/mnist_demo/lenet5_mnist_coverage.py index b5181e893d18b4ceb4ec387cd21581eed7372973..0ce254018decd353d1de557d51049915f2a05782 100644 --- a/example/mnist_demo/lenet5_mnist_coverage.py +++ b/example/mnist_demo/lenet5_mnist_coverage.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import numpy as np +import numpy as np from mindspore import Model from mindspore import context -from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from lenet5_net import LeNet5 from mindarmour.attacks.gradient_method import FastGradientSignMethod -from mindarmour.utils.logger import LogUtil from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil sys.path.append("..") from data_processing import generate_mnist_dataset diff --git a/example/mnist_demo/lenet5_mnist_fuzzing.py b/example/mnist_demo/lenet5_mnist_fuzzing.py index d6604fd412064e15c5619c7ebe76bf6c8c01a5ba..4adcbbd64f627b6667d7fddf17e8f771d0418498 100644 --- a/example/mnist_demo/lenet5_mnist_fuzzing.py +++ b/example/mnist_demo/lenet5_mnist_fuzzing.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys -import numpy as np +import numpy as np from mindspore import Model from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn import SoftmaxCrossEntropyWithLogits -from mindarmour.attacks.gradient_method import FastGradientSignMethod -from mindarmour.utils.logger import LogUtil -from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics -from mindarmour.fuzzing.fuzzing import Fuzzing from lenet5_net import LeNet5 +from mindarmour.fuzzing.fuzzing import Fuzzing +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics +from mindarmour.utils.logger import LogUtil sys.path.append("..") from data_processing import generate_mnist_dataset @@ -81,8 +79,11 @@ def test_lenet_mnist_fuzzing(): model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) failed_tests = model_fuzz_test.fuzzing() - model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + if failed_tests: + model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + else: + LOGGER.info(TAG, 'Fuzzing test identifies none failed test') if __name__ == '__main__': diff --git a/example/mnist_demo/mnist_attack_cw.py b/example/mnist_demo/mnist_attack_cw.py index 3fa614e27da399794e324aebfb5d5a4b91b9f405..6fdd626c1575c2ddd1b1a760fd432ce089c74952 100644 --- a/example/mnist_demo/mnist_attack_cw.py +++ b/example/mnist_demo/mnist_attack_cw.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.carlini_wagner import CarliniWagnerL2Attack -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_deepfool.py b/example/mnist_demo/mnist_attack_deepfool.py index 925f50bcf3884d2e559d605e697dbe3f5a002923..5a0b9b036b79e069dd05b81fb35f9bbea78594dd 100644 --- a/example/mnist_demo/mnist_attack_deepfool.py +++ b/example/mnist_demo/mnist_attack_deepfool.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.deep_fool import DeepFool -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_fgsm.py b/example/mnist_demo/mnist_attack_fgsm.py index f951656b92700d51c794d24b818a5316e47c20b9..636830395ea1c4f6b608428f5175363d9364e88f 100644 --- a/example/mnist_demo/mnist_attack_fgsm.py +++ b/example/mnist_demo/mnist_attack_fgsm.py @@ -13,21 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.gradient_method import FastGradientSignMethod - -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_genetic.py b/example/mnist_demo/mnist_attack_genetic.py index 6c4a6f604c0cd011c4a857530510621a91857b7c..65f75c65b5f6dc937fec76f3e46c995998f1c556 100644 --- a/example/mnist_demo/mnist_attack_genetic.py +++ b/example/mnist_demo/mnist_attack_genetic.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax -from mindarmour.attacks.black.genetic_attack import GeneticAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel -from mindarmour.utils.logger import LogUtil +from mindarmour.attacks.black.genetic_attack import GeneticAttack from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -97,8 +96,8 @@ def test_genetic_attack_on_mnist(): per_bounds=0.1, step_size=0.25, temp=0.1, sparse=True) targeted_labels = np.random.randint(0, 10, size=len(true_labels)) - for i in range(len(true_labels)): - if targeted_labels[i] == true_labels[i]: + for i, true_l in enumerate(true_labels): + if targeted_labels[i] == true_l: targeted_labels[i] = (targeted_labels[i] + 1) % 10 start_time = time.clock() success_list, adv_data, query_list = attack.generate( diff --git a/example/mnist_demo/mnist_attack_hsja.py b/example/mnist_demo/mnist_attack_hsja.py index 11d0c1884aba063a3714d4a9f8dbfcf5e52d4097..7b38c1719d7b5338aa0bef4233b0fc062adb7736 100644 --- a/example/mnist_demo/mnist_attack_hsja.py +++ b/example/mnist_demo/mnist_attack_hsja.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys + import numpy as np import pytest - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.attacks.black.hop_skip_jump_attack import HopSkipJumpAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel - +from mindarmour.attacks.black.hop_skip_jump_attack import HopSkipJumpAttack from mindarmour.utils.logger import LogUtil -from lenet5_net import LeNet5 sys.path.append("..") from data_processing import generate_mnist_dataset @@ -64,9 +63,9 @@ def random_target_labels(true_labels): def create_target_images(dataset, data_labels, target_labels): res = [] for label in target_labels: - for i in range(len(data_labels)): - if data_labels[i] == label: - res.append(dataset[i]) + for data_label, data in zip(data_labels, dataset): + if data_label == label: + res.append(data) break return np.array(res) @@ -126,9 +125,9 @@ def test_hsja_mnist_attack(): target_images = create_target_images(test_images, predict_labels, target_labels) attack.set_target_images(target_images) - success_list, adv_data, query_list = attack.generate(test_images, target_labels) + success_list, adv_data, _ = attack.generate(test_images, target_labels) else: - success_list, adv_data, query_list = attack.generate(test_images, None) + success_list, adv_data, _ = attack.generate(test_images, None) adv_datas = [] gts = [] @@ -136,7 +135,7 @@ def test_hsja_mnist_attack(): if success: adv_datas.append(adv) gts.append(gt) - if len(gts) > 0: + if gts: adv_datas = np.concatenate(np.asarray(adv_datas), axis=0) gts = np.asarray(gts) pred_logits_adv = model.predict(adv_datas) diff --git a/example/mnist_demo/mnist_attack_jsma.py b/example/mnist_demo/mnist_attack_jsma.py index de8b24ff0d825b84960983eb8fef6db5cf19f069..658fe32b07be9bb7cdbfc22b09f0640ca4e7ed15 100644 --- a/example/mnist_demo/mnist_attack_jsma.py +++ b/example/mnist_demo/mnist_attack_jsma.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.jsma import JSMAAttack -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -79,8 +78,8 @@ def test_jsma_attack(): predict_labels = np.concatenate(predict_labels) true_labels = np.concatenate(test_labels) targeted_labels = np.random.randint(0, 10, size=len(true_labels)) - for i in range(len(true_labels)): - if targeted_labels[i] == true_labels[i]: + for i, true_l in enumerate(true_labels): + if targeted_labels[i] == true_l: targeted_labels[i] = (targeted_labels[i] + 1) % 10 accuracy = np.mean(np.equal(predict_labels, true_labels)) LOGGER.info(TAG, "prediction accuracy before attacking is : %g", accuracy) diff --git a/example/mnist_demo/mnist_attack_lbfgs.py b/example/mnist_demo/mnist_attack_lbfgs.py index 425b105b93e4506467dab62760e5eed7172c86e6..5e4bd60ca6b05c3ea84a5c18805510af696d87c8 100644 --- a/example/mnist_demo/mnist_attack_lbfgs.py +++ b/example/mnist_demo/mnist_attack_lbfgs.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.lbfgs import LBFGS -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -85,8 +84,8 @@ def test_lbfgs_attack(): is_targeted = True if is_targeted: targeted_labels = np.random.randint(0, 10, size=len(true_labels)).astype(np.int32) - for i in range(len(true_labels)): - if targeted_labels[i] == true_labels[i]: + for i, true_l in enumerate(true_labels): + if targeted_labels[i] == true_l: targeted_labels[i] = (targeted_labels[i] + 1) % 10 else: targeted_labels = true_labels.astype(np.int32) diff --git a/example/mnist_demo/mnist_attack_mdi2fgsm.py b/example/mnist_demo/mnist_attack_mdi2fgsm.py index eb983b5cd885fe9c4665fd0f61251e54169c6b53..23c197be56ede4660b2c54e061f32b181e26dbe9 100644 --- a/example/mnist_demo/mnist_attack_mdi2fgsm.py +++ b/example/mnist_demo/mnist_attack_mdi2fgsm.py @@ -13,21 +13,20 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net - -from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod - -from mindarmour.utils.logger import LogUtil -from mindarmour.evaluations.attack_evaluation import AttackEvaluate +from scipy.special import softmax from lenet5_net import LeNet5 +from mindarmour.attacks.iterative_gradient_method import \ + MomentumDiverseInputIterativeMethod +from mindarmour.evaluations.attack_evaluation import AttackEvaluate +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_nes.py b/example/mnist_demo/mnist_attack_nes.py index 35e322c2f380cd4d6dbf455648fb73da1468fa77..08187719bc91d9d36b3e85fb14e2194afac135df 100644 --- a/example/mnist_demo/mnist_attack_nes.py +++ b/example/mnist_demo/mnist_attack_nes.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys + import numpy as np import pytest - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.attacks.black.natural_evolutionary_strategy import NES +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel - +from mindarmour.attacks.black.natural_evolutionary_strategy import NES from mindarmour.utils.logger import LogUtil -from lenet5_net import LeNet5 sys.path.append("..") from data_processing import generate_mnist_dataset @@ -73,9 +72,9 @@ def _pseudorandom_target(index, total_indices, true_class): def create_target_images(dataset, data_labels, target_labels): res = [] for label in target_labels: - for i in range(len(data_labels)): - if data_labels[i] == label: - res.append(dataset[i]) + for data_label, data in zip(data_labels, dataset): + if data_label == label: + res.append(data) break return np.array(res) diff --git a/example/mnist_demo/mnist_attack_pgd.py b/example/mnist_demo/mnist_attack_pgd.py index e084aca2336586f42b606f44e7e2944b3c693fec..d433c037237f06eba42c0a952e91535f1fa9d1c5 100644 --- a/example/mnist_demo/mnist_attack_pgd.py +++ b/example/mnist_demo/mnist_attack_pgd.py @@ -13,21 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks.iterative_gradient_method import ProjectedGradientDescent - -from mindarmour.utils.logger import LogUtil from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_pointwise.py b/example/mnist_demo/mnist_attack_pointwise.py index 5ac33e02edb81eea26de9826bd6759d4743995d5..53bece667ac3dc9939925988f7e70d525cc72888 100644 --- a/example/mnist_demo/mnist_attack_pointwise.py +++ b/example/mnist_demo/mnist_attack_pointwise.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax -from mindarmour.attacks.black.pointwise_attack import PointWiseAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel -from mindarmour.utils.logger import LogUtil +from mindarmour.attacks.black.pointwise_attack import PointWiseAttack from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -99,8 +98,8 @@ def test_pointwise_attack_on_mnist(): attack = PointWiseAttack(model=model, is_targeted=is_target) if is_target: targeted_labels = np.random.randint(0, 10, size=len(true_labels)) - for i in range(len(true_labels)): - if targeted_labels[i] == true_labels[i]: + for i, true_l in enumerate(true_labels): + if targeted_labels[i] == true_l: targeted_labels[i] = (targeted_labels[i] + 1) % 10 else: targeted_labels = true_labels diff --git a/example/mnist_demo/mnist_attack_pso.py b/example/mnist_demo/mnist_attack_pso.py index 19c4213c76af9b9ff80402c20f285e7da536fd95..bd0f72ff383ac714ed2b4d802abd6498cedd2d99 100644 --- a/example/mnist_demo/mnist_attack_pso.py +++ b/example/mnist_demo/mnist_attack_pso.py @@ -13,20 +13,19 @@ # limitations under the License. import sys import time + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax -from mindarmour.attacks.black.pso_attack import PSOAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel -from mindarmour.utils.logger import LogUtil +from mindarmour.attacks.black.pso_attack import PSOAttack from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_attack_salt_and_pepper.py b/example/mnist_demo/mnist_attack_salt_and_pepper.py index 441ebe4f93a3584c13238d5a2516fe3c305647d5..635f54bfafc51b340e8cd49d433baf4ef2ec71e8 100644 --- a/example/mnist_demo/mnist_attack_salt_and_pepper.py +++ b/example/mnist_demo/mnist_attack_salt_and_pepper.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax -from mindarmour.attacks.black.salt_and_pepper_attack import SaltAndPepperNoiseAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel -from mindarmour.utils.logger import LogUtil +from mindarmour.attacks.black.salt_and_pepper_attack import SaltAndPepperNoiseAttack from mindarmour.evaluations.attack_evaluation import AttackEvaluate - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -102,8 +101,8 @@ def test_salt_and_pepper_attack_on_mnist(): sparse=True) if is_target: targeted_labels = np.random.randint(0, 10, size=len(true_labels)) - for i in range(len(true_labels)): - if targeted_labels[i] == true_labels[i]: + for i, true_l in enumerate(true_labels): + if targeted_labels[i] == true_l: targeted_labels[i] = (targeted_labels[i] + 1) % 10 else: targeted_labels = true_labels diff --git a/example/mnist_demo/mnist_defense_nad.py b/example/mnist_demo/mnist_defense_nad.py index e9e04d390743da9ca7063a0343098f86e0e38109..06a1391aae77ff6a30597a98bf8ab84bb5ec7034 100644 --- a/example/mnist_demo/mnist_defense_nad.py +++ b/example/mnist_demo/mnist_defense_nad.py @@ -12,25 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. """defense example using nad""" -import sys - import logging +import sys import numpy as np import pytest - from mindspore import Tensor from mindspore import context from mindspore import nn from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.train.serialization import load_checkpoint, load_param_into_net +from lenet5_net import LeNet5 from mindarmour.attacks import FastGradientSignMethod from mindarmour.defenses import NaturalAdversarialDefense from mindarmour.utils.logger import LogUtil -from lenet5_net import LeNet5 - sys.path.append("..") from data_processing import generate_mnist_dataset diff --git a/example/mnist_demo/mnist_evaluation.py b/example/mnist_demo/mnist_evaluation.py index 35871f604d22c9cefc91f82806713e529b8d8882..4451e4ebd0072f619befbb396e62654779e2abcd 100644 --- a/example/mnist_demo/mnist_evaluation.py +++ b/example/mnist_demo/mnist_evaluation.py @@ -12,30 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. """evaluate example""" -import sys import os +import sys import time -import numpy as np -from scipy.special import softmax -from lenet5_net import LeNet5 +import numpy as np from mindspore import Model from mindspore import Tensor from mindspore import context from mindspore import nn from mindspore.nn import Cell -from mindspore.ops.operations import TensorAdd from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore.ops.operations import TensorAdd from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax +from lenet5_net import LeNet5 from mindarmour.attacks import FastGradientSignMethod from mindarmour.attacks import GeneticAttack from mindarmour.attacks.black.black_model import BlackModel from mindarmour.defenses import NaturalAdversarialDefense +from mindarmour.detectors.black.similarity_detector import SimilarityDetector from mindarmour.evaluations import BlackDefenseEvaluate from mindarmour.evaluations import DefenseEvaluate from mindarmour.utils.logger import LogUtil -from mindarmour.detectors.black.similarity_detector import SimilarityDetector sys.path.append("..") from data_processing import generate_mnist_dataset @@ -237,7 +237,7 @@ def test_black_defense(): # gen black-box adversarial examples of test data for idx in range(attacked_size): raw_st = time.time() - raw_sl, raw_a, raw_qc = attack_rm.generate( + _, raw_a, raw_qc = attack_rm.generate( np.expand_dims(attacked_sample[idx], axis=0), np.expand_dims(attack_target_label[idx], axis=0)) raw_t = time.time() - raw_st @@ -271,7 +271,7 @@ def test_black_defense(): sparse=False) for idx in range(attacked_size): def_st = time.time() - def_sl, def_a, def_qc = attack_dm.generate( + _, def_a, def_qc = attack_dm.generate( np.expand_dims(attacked_sample[idx], axis=0), np.expand_dims(attack_target_label[idx], axis=0)) def_t = time.time() - def_st diff --git a/example/mnist_demo/mnist_similarity_detector.py b/example/mnist_demo/mnist_similarity_detector.py index da438a702f90463bcb9a77f7f503541557c4e461..2a59ed2cc521dd701695766b846d91a19803c2c1 100644 --- a/example/mnist_demo/mnist_similarity_detector.py +++ b/example/mnist_demo/mnist_similarity_detector.py @@ -12,23 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys + import numpy as np import pytest -from scipy.special import softmax - from mindspore import Model -from mindspore import context from mindspore import Tensor +from mindspore import context from mindspore.nn import Cell from mindspore.ops.operations import TensorAdd from mindspore.train.serialization import load_checkpoint, load_param_into_net +from scipy.special import softmax -from mindarmour.utils.logger import LogUtil -from mindarmour.attacks.black.pso_attack import PSOAttack +from lenet5_net import LeNet5 from mindarmour.attacks.black.black_model import BlackModel +from mindarmour.attacks.black.pso_attack import PSOAttack from mindarmour.detectors.black.similarity_detector import SimilarityDetector - -from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") diff --git a/example/mnist_demo/mnist_train.py b/example/mnist_demo/mnist_train.py index eeaba3f80a3b5f945d8fdbb8f21e088509c59844..9f1721300565c4655e2ed247cc3af3fd89f0473e 100644 --- a/example/mnist_demo/mnist_train.py +++ b/example/mnist_demo/mnist_train.py @@ -11,20 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================ import os import sys import mindspore.nn as nn -from mindspore import context, Tensor +from mindspore import context +from mindspore.nn.metrics import Accuracy +from mindspore.train import Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train import Model -from mindspore.nn.metrics import Accuracy - -from mindarmour.utils.logger import LogUtil from lenet5_net import LeNet5 +from mindarmour.utils.logger import LogUtil sys.path.append("..") from data_processing import generate_mnist_dataset diff --git a/mindarmour/attacks/gradient_method.py b/mindarmour/attacks/gradient_method.py index 66cab6f08feb219d3aedbea33f87fd77ddcd4fbf..936eaa202433fc0e3ed0ff27561624ee8982e3d9 100644 --- a/mindarmour/attacks/gradient_method.py +++ b/mindarmour/attacks/gradient_method.py @@ -183,8 +183,7 @@ class FastGradientMethod(GradientMethod): >>> grad = self._gradient([[0.2, 0.3, 0.4]], >>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) """ - sens = Tensor(np.array([1.0], self._dtype)) - out_grad = self._grad_all(Tensor(inputs), Tensor(labels), sens) + out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) if isinstance(out_grad, tuple): out_grad = out_grad[0] gradient = out_grad.asnumpy() @@ -286,8 +285,7 @@ class FastGradientSignMethod(GradientMethod): >>> grad = self._gradient([[0.2, 0.3, 0.4]], >>> [[0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) """ - sens = Tensor(np.array([1.0], self._dtype)) - out_grad = self._grad_all(Tensor(inputs), Tensor(labels), sens) + out_grad = self._grad_all(Tensor(inputs), Tensor(labels)) if isinstance(out_grad, tuple): out_grad = out_grad[0] gradient = out_grad.asnumpy() diff --git a/mindarmour/attacks/iterative_gradient_method.py b/mindarmour/attacks/iterative_gradient_method.py index 337fec805412c891d36fbce3a610c582585e6125..9a212b21b1d122da15acb9cf8b1fc69fe27b78ae 100644 --- a/mindarmour/attacks/iterative_gradient_method.py +++ b/mindarmour/attacks/iterative_gradient_method.py @@ -351,9 +351,8 @@ class MomentumIterativeMethod(IterativeGradientMethod): >>> grad = self._gradient([[0.5, 0.3, 0.4]], >>> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0]) """ - sens = Tensor(np.array([1.0], inputs.dtype)) # get grad of loss over x - out_grad = self._loss_grad(Tensor(inputs), Tensor(labels), sens) + out_grad = self._loss_grad(Tensor(inputs), Tensor(labels)) if isinstance(out_grad, tuple): out_grad = out_grad[0] gradient = out_grad.asnumpy() diff --git a/mindarmour/attacks/lbfgs.py b/mindarmour/attacks/lbfgs.py index a3c2ccc1527bc1ec157b01f072d0a8b2280b7cb0..790e5ea9bfaf9cca2846cccdd69527fe79e1a922 100644 --- a/mindarmour/attacks/lbfgs.py +++ b/mindarmour/attacks/lbfgs.py @@ -115,12 +115,11 @@ class LBFGS(Attack): def _gradient(self, cur_input, labels, shape): """ Return model gradient to minimize loss in l-bfgs-b.""" label_dtype = labels.dtype - sens = Tensor(np.array([1], self._dtype)) labels = np.expand_dims(labels, axis=0).astype(label_dtype) # input shape should like original shape reshape_input = np.expand_dims(cur_input.reshape(shape), axis=0) - out_grad = self._grad_all(Tensor(reshape_input), Tensor(labels), sens) + out_grad = self._grad_all(Tensor(reshape_input), Tensor(labels)) if isinstance(out_grad, tuple): out_grad = out_grad[0] return out_grad.asnumpy() @@ -131,9 +130,9 @@ class LBFGS(Attack): the cross-entropy loss. """ cur_input = cur_input.astype(self._dtype) - l2_distance = np.linalg.norm(cur_input.reshape( - (cur_input.shape[0], -1)) - start_input.reshape( - (start_input.shape[0], -1))) + l2_distance = np.linalg.norm( + cur_input.reshape((cur_input.shape[0], -1)) - start_input.reshape( + (start_input.shape[0], -1))) logits = self._forward_one(cur_input.reshape(shape)).flatten() logits = logits - np.max(logits) if self._sparse: diff --git a/mindarmour/fuzzing/fuzzing.py b/mindarmour/fuzzing/fuzzing.py index e0e21aa3bffd6a00657abefed4e0710c33d686e8..21f4b3c19631f8af0a364f8c0284e7e60540473e 100644 --- a/mindarmour/fuzzing/fuzzing.py +++ b/mindarmour/fuzzing/fuzzing.py @@ -14,17 +14,17 @@ """ Fuzzing. """ -import numpy as np from random import choice -from mindspore import Tensor +import numpy as np from mindspore import Model +from mindspore import Tensor from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics -from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ - Translate, Scale, Shear, Rotate from mindarmour.utils._check_param import check_model, check_numpy_param, \ check_int_positive +from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ + Translate, Scale, Shear, Rotate class Fuzzing: @@ -40,9 +40,10 @@ class Fuzzing: target_model (Model): Target fuzz model. train_dataset (numpy.ndarray): Training dataset used for determine the neurons' output boundaries. - const_K (int): The number of mutate tests for a seed. + const_k (int): The number of mutate tests for a seed. mode (str): Image mode used in image transform, 'L' means grey graph. Default: 'L'. + max_seed_num (int): The initial seeds max value. Default: 1000 """ def __init__(self, initial_seeds, target_model, train_dataset, const_K, @@ -50,7 +51,7 @@ class Fuzzing: self.initial_seeds = initial_seeds self.target_model = check_model('model', target_model, Model) self.train_dataset = check_numpy_param('train_dataset', train_dataset) - self.K = check_int_positive('const_k', const_K) + self.const_k = check_int_positive('const_k', const_K) self.mode = mode self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, @@ -73,7 +74,7 @@ class Fuzzing: 'Noise': Noise, 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, 'Rotate': Rotate} - for _ in range(self.K): + for _ in range(self.const_k): for _ in range(try_num): if (info[0] == info[1]).all(): trans_strage = self._random_pick_mutate(affine_trans, @@ -91,7 +92,7 @@ class Fuzzing: if trans_strage in affine_trans: info[1] = mutate_test mutate_tests.append(mutate_test) - if len(mutate_tests) == 0: + if not mutate_tests: mutate_tests.append(seed) return np.array(mutate_tests) @@ -109,7 +110,7 @@ class Fuzzing: seed = self._select_next() failed_tests = [] seed_num = 0 - while len(seed) > 0 and seed_num < self.max_seed_num: + while seed and seed_num < self.max_seed_num: mutate_tests = self._metamorphic_mutate(seed[0]) coverages, results = self._run(mutate_tests, coverage_metric) coverage_gains = self._coverage_gains(coverages) @@ -157,13 +158,13 @@ class Fuzzing: beta = 0.2 diff = np.array(seed - mutate_test).flatten() size = np.shape(diff)[0] - L0 = np.linalg.norm(diff, ord=0) - Linf = np.linalg.norm(diff, ord=np.inf) - if L0 > alpha*size: - if Linf < 256: + l0 = np.linalg.norm(diff, ord=0) + linf = np.linalg.norm(diff, ord=np.inf) + if l0 > alpha*size: + if linf < 256: is_valid = True else: - if Linf < beta*255: + if linf < beta*255: is_valid = True return is_valid diff --git a/mindarmour/utils/util.py b/mindarmour/utils/util.py index 094177b4f795888341c4e9cc036e5004b1dfb1dd..906b32ccbbd1d336c8b2b3089a47fe0a74c6147b 100644 --- a/mindarmour/utils/util.py +++ b/mindarmour/utils/util.py @@ -13,7 +13,6 @@ # limitations under the License. """ Util for MindArmour. """ import numpy as np - from mindspore import Tensor from mindspore.nn import Cell from mindspore.ops.composite import GradOperation @@ -99,23 +98,21 @@ class GradWrapWithLoss(Cell): super(GradWrapWithLoss, self).__init__() self._grad_all = GradOperation(name="get_all", get_all=True, - sens_param=True) + sens_param=False) self._network = network - def construct(self, inputs, labels, weight): + def construct(self, inputs, labels): """ Compute gradient of `inputs` with labels and weight. Args: inputs (Tensor): Inputs of network. labels (Tensor): Labels of inputs. - weight (Tensor): Weight of each gradient, `weight` has the same - shape with labels. Returns: Tensor, gradient matrix. """ - gout = self._grad_all(self._network)(inputs, labels, weight) + gout = self._grad_all(self._network)(inputs, labels) return gout[0] diff --git a/tests/st/resnet50/resnet_cifar10.py b/tests/st/resnet50/resnet_cifar10.py index 080cfc13df4ce4bab6e33c667f3f5b9e1c1d114f..43282901f5a5c3ccf30ec2aacbe28b5b2904ca1a 100644 --- a/tests/st/resnet50/resnet_cifar10.py +++ b/tests/st/resnet50/resnet_cifar10.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np -import math - from mindspore import nn -from mindspore.ops import operations as P from mindspore.common.tensor import Tensor -from mindspore import context +from mindspore.ops import operations as P def variance_scaling_raw(shape): @@ -110,8 +107,7 @@ class ResidualBlock(nn.Cell): def __init__(self, in_channels, out_channels, - stride=1, - down_sample=False): + stride=1): super(ResidualBlock, self).__init__() out_chls = out_channels // self.expansion @@ -168,7 +164,7 @@ class ResidualBlockWithDown(nn.Cell): self.bn3 = bn_with_initialize_last(out_channels) self.relu = P.ReLU() - self.downSample = down_sample + self.downsample = down_sample self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) self.bn_down_sample = bn_with_initialize(out_channels) diff --git a/tests/st/resnet50/test_cifar10_attack_fgsm.py b/tests/st/resnet50/test_cifar10_attack_fgsm.py index 51e741bdc9d3e063d98c0b7b2b8c6e22402383b5..6591faaa1521aa7eb8ddbd90f537a21e0168f15c 100644 --- a/tests/st/resnet50/test_cifar10_attack_fgsm.py +++ b/tests/st/resnet50/test_cifar10_attack_fgsm.py @@ -18,7 +18,6 @@ Fuction: Usage: py.test test_cifar10_attack_fgsm.py """ -import os import numpy as np import pytest diff --git a/tests/ut/python/attacks/black/test_genetic_attack.py b/tests/ut/python/attacks/black/test_genetic_attack.py index 8ae7fb7de9bbf5196a280240308ad032b9b34d16..85fe57e97dcabd1f5a58c59dfc925d7f742d9e8e 100644 --- a/tests/ut/python/attacks/black/test_genetic_attack.py +++ b/tests/ut/python/attacks/black/test_genetic_attack.py @@ -16,15 +16,13 @@ Genetic-Attack test. """ import numpy as np import pytest - import mindspore.ops.operations as M from mindspore import Tensor -from mindspore.nn import Cell from mindspore import context +from mindspore.nn import Cell -from mindarmour.attacks.black.genetic_attack import GeneticAttack from mindarmour.attacks.black.black_model import BlackModel - +from mindarmour.attacks.black.genetic_attack import GeneticAttack context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -115,7 +113,7 @@ def test_supplement(): adaptive=True, sparse=False) # raise error - _, adv_data, _ = attack.generate(inputs, labels) + _, _, _ = attack.generate(inputs, labels) @pytest.mark.level0 @@ -140,5 +138,5 @@ def test_value_error(): adaptive=True, sparse=False) # raise error - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError): assert attack.generate(inputs, labels) diff --git a/tests/ut/python/attacks/black/test_hsja.py b/tests/ut/python/attacks/black/test_hsja.py index c67354ac85d8ff921ed8c03c9e19d5ec94dc2c91..9bb42fe9df3dbb9468efcaeafc87139b98baaac2 100644 --- a/tests/ut/python/attacks/black/test_hsja.py +++ b/tests/ut/python/attacks/black/test_hsja.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys import os +import sys + import numpy as np import pytest - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.attacks.black.hop_skip_jump_attack import HopSkipJumpAttack from mindarmour.attacks.black.black_model import BlackModel - +from mindarmour.attacks.black.hop_skip_jump_attack import HopSkipJumpAttack from mindarmour.utils.logger import LogUtil + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../../")) from example.mnist_demo.lenet5_net import LeNet5 @@ -135,7 +135,7 @@ def test_hsja_mnist_attack(): attack.set_target_images(target_images) success_list, adv_data, _ = attack.generate(test_images, target_labels) else: - success_list, adv_data, query_list = attack.generate(test_images, None) + success_list, adv_data, _ = attack.generate(test_images, None) assert (adv_data != test_images).any() adv_datas = [] @@ -144,7 +144,7 @@ def test_hsja_mnist_attack(): if success: adv_datas.append(adv) gts.append(gt) - if len(gts) > 0: + if gts: adv_datas = np.concatenate(np.asarray(adv_datas), axis=0) gts = np.asarray(gts) pred_logits_adv = model.predict(adv_datas) @@ -162,5 +162,5 @@ def test_hsja_mnist_attack(): def test_value_error(): model = get_model() norm = 'l2' - with pytest.raises(ValueError) as e: + with pytest.raises(ValueError): assert HopSkipJumpAttack(model, constraint=norm, stepsize_search='bad-search') diff --git a/tests/ut/python/attacks/black/test_nes.py b/tests/ut/python/attacks/black/test_nes.py index 33f0f3d794f7a39fc2a7de69d285d943a7a306e0..eca3c646003ea82daa860a23c7458483fc87dc5b 100644 --- a/tests/ut/python/attacks/black/test_nes.py +++ b/tests/ut/python/attacks/black/test_nes.py @@ -11,19 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys + import numpy as np -import os import pytest - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindarmour.attacks.black.natural_evolutionary_strategy import NES from mindarmour.attacks.black.black_model import BlackModel - +from mindarmour.attacks.black.natural_evolutionary_strategy import NES from mindarmour.utils.logger import LogUtil + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../../")) from example.mnist_demo.lenet5_net import LeNet5 @@ -156,7 +156,7 @@ def nes_mnist_attack(scene, top_k): assert (advs != test_images[:batch_num]).any() adv_pred = np.argmax(model.predict(advs), axis=1) - adv_accuracy = np.mean(np.equal(adv_pred, true_labels[:test_length])) + _ = np.mean(np.equal(adv_pred, true_labels[:test_length])) @pytest.mark.level0 diff --git a/tests/ut/python/attacks/black/test_pointwise_attack.py b/tests/ut/python/attacks/black/test_pointwise_attack.py index 7acd0f4f12bfbda0c42e7a7a519a29c458f02a7b..29ddbd5d583440856d9279d7be610e75f2faa7eb 100644 --- a/tests/ut/python/attacks/black/test_pointwise_attack.py +++ b/tests/ut/python/attacks/black/test_pointwise_attack.py @@ -14,19 +14,18 @@ """ PointWise Attack test """ -import sys import os +import sys + import numpy as np import pytest - - from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindarmour.attacks.black.black_model import BlackModel from mindarmour.attacks.black.pointwise_attack import PointWiseAttack from mindarmour.utils.logger import LogUtil -from mindarmour.attacks.black.black_model import BlackModel sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../../")) @@ -75,13 +74,13 @@ def test_pointwise_attack_method(): input_np = np.load(os.path.join(current_dir, '../../test_data/test_images.npy'))[:3] labels = np.load(os.path.join(current_dir, - '../../test_data/test_labels.npy'))[:3] + '../../test_data/test_labels.npy'))[:3] model = ModelToBeAttacked(net) pre_label = np.argmax(model.predict(input_np), axis=1) LOGGER.info(TAG, 'original sample predict labels are :{}'.format(pre_label)) LOGGER.info(TAG, 'true labels are: {}'.format(labels)) attack = PointWiseAttack(model, sparse=True, is_targeted=False) - is_adv, adv_data, query_times = attack.generate(input_np, pre_label) + is_adv, adv_data, _ = attack.generate(input_np, pre_label) LOGGER.info(TAG, 'adv sample predict labels are: {}' .format(np.argmax(model.predict(adv_data), axis=1))) diff --git a/tests/ut/python/attacks/test_gradient_method.py b/tests/ut/python/attacks/test_gradient_method.py index bab0a0474392e9881be24dc8f744b2a17d771e73..65c748f2cb711019fb5cc9a406d2999a058af077 100644 --- a/tests/ut/python/attacks/test_gradient_method.py +++ b/tests/ut/python/attacks/test_gradient_method.py @@ -233,10 +233,6 @@ def test_assert_error(): """ Random least likely class method unit test. """ - input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) - label = np.asarray([2], np.int32) - label = np.eye(3)[label].astype(np.float32) - with pytest.raises(ValueError) as e: assert RandomLeastLikelyClassMethod(Net(), eps=0.05, alpha=0.21) assert str(e.value) == 'eps must be larger than alpha!' diff --git a/tests/ut/python/attacks/test_iterative_gradient_method.py b/tests/ut/python/attacks/test_iterative_gradient_method.py index 9a766e222e62141bbb8f4924615cc624df95aa9e..3a9fcb024a94634c08770c2be0ce0c0b524f113c 100644 --- a/tests/ut/python/attacks/test_iterative_gradient_method.py +++ b/tests/ut/python/attacks/test_iterative_gradient_method.py @@ -134,10 +134,9 @@ def test_diverse_input_iterative_method(): label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) - for i in range(5): - attack = DiverseInputIterativeMethod(Net()) - ms_adv_x = attack.generate(input_np, label) - assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ + attack = DiverseInputIterativeMethod(Net()) + ms_adv_x = attack.generate(input_np, label) + assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ ' value must not be equal to' \ ' original value.' @@ -155,10 +154,9 @@ def test_momentum_diverse_input_iterative_method(): label = np.asarray([2], np.int32) label = np.eye(3)[label].astype(np.float32) - for i in range(5): - attack = MomentumDiverseInputIterativeMethod(Net()) - ms_adv_x = attack.generate(input_np, label) - assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ + attack = MomentumDiverseInputIterativeMethod(Net()) + ms_adv_x = attack.generate(input_np, label) + assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ 'generate value must not be equal to' \ ' original value.' diff --git a/tests/ut/python/attacks/test_lbfgs.py b/tests/ut/python/attacks/test_lbfgs.py index 649ea1f1e8a23484f2463e1223232b4988d7f758..d1387e2a34903c000b634d0ef59e49dccad6b0ec 100644 --- a/tests/ut/python/attacks/test_lbfgs.py +++ b/tests/ut/python/attacks/test_lbfgs.py @@ -14,11 +14,11 @@ """ LBFGS-Attack test. """ +import os import sys + import numpy as np import pytest -import os - from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net @@ -69,4 +69,4 @@ def test_lbfgs_attack(): attack = LBFGS(net, is_targeted=True) LOGGER.debug(TAG, 'target_np is :{}'.format(target_np[0])) - adv_data = attack.generate(input_np, target_np) + _ = attack.generate(input_np, target_np) diff --git a/tests/ut/python/defenses/mock_net.py b/tests/ut/python/defenses/mock_net.py index 663b5a0a4d2c95955c8c842e919c1b12fdd8d71e..d9ad42d1615d39342ea05fee69907373c43e1ae6 100644 --- a/tests/ut/python/defenses/mock_net.py +++ b/tests/ut/python/defenses/mock_net.py @@ -18,10 +18,8 @@ import numpy as np from mindspore import nn from mindspore import Tensor -from mindspore.nn import Cell from mindspore.nn import WithLossCell, TrainOneStepCell from mindspore.nn.optim.momentum import Momentum -from mindspore.ops import operations as P from mindspore import context from mindspore.common.initializer import TruncatedNormal @@ -58,7 +56,7 @@ class Net(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() def construct(self, x): x = self.conv1(x) @@ -67,7 +65,7 @@ class Net(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (-1, 16*5*5)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) @@ -75,6 +73,7 @@ class Net(nn.Cell): x = self.fc3(x) return x + if __name__ == '__main__': num_classes = 10 batch_size = 32 @@ -104,4 +103,3 @@ if __name__ == '__main__': train_net.set_train() train_net(Tensor(inputs_np), Tensor(labels_np)) - diff --git a/tests/ut/python/defenses/test_ad.py b/tests/ut/python/defenses/test_ad.py index d90c853a74f8dc2de9c5c64cd255da264388a450..78b581ad4c0f9722c5ca2fbf1a6087952df5d78c 100644 --- a/tests/ut/python/defenses/test_ad.py +++ b/tests/ut/python/defenses/test_ad.py @@ -14,20 +14,19 @@ """ Adversarial defense test. """ -import numpy as np -import pytest import logging -from mindspore import nn +import numpy as np +import pytest from mindspore import Tensor from mindspore import context +from mindspore import nn from mindspore.nn.optim.momentum import Momentum +from mock_net import Net from mindarmour.defenses.adversarial_defense import AdversarialDefense from mindarmour.utils.logger import LogUtil -from mock_net import Net - LOGGER = LogUtil.get_instance() TAG = 'Ad_Test' diff --git a/tests/ut/python/defenses/test_ead.py b/tests/ut/python/defenses/test_ead.py index 3001e24219f151b7e506622a34ae7df0389535d2..9eeac213f1a9ff3c151486b120739a3aa4c371a2 100644 --- a/tests/ut/python/defenses/test_ead.py +++ b/tests/ut/python/defenses/test_ead.py @@ -14,22 +14,21 @@ """ ensemble adversarial defense test. """ -import numpy as np -import pytest import logging -from mindspore import nn +import numpy as np +import pytest from mindspore import context +from mindspore import nn from mindspore.nn.optim.momentum import Momentum +from mock_net import Net from mindarmour.attacks.gradient_method import FastGradientSignMethod from mindarmour.attacks.iterative_gradient_method import \ ProjectedGradientDescent from mindarmour.defenses.adversarial_defense import EnsembleAdversarialDefense from mindarmour.utils.logger import LogUtil -from mock_net import Net - LOGGER = LogUtil.get_instance() TAG = 'Ead_Test' @@ -54,7 +53,7 @@ def test_ead(): if not sparse: labels = np.eye(num_classes)[labels].astype(np.float32) - net = Net() + net = SimpleNet() loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse) optimizer = Momentum(net.trainable_params(), 0.001, 0.9) diff --git a/tests/ut/python/defenses/test_nad.py b/tests/ut/python/defenses/test_nad.py index a5e06a80375c5abcd88f0854e4fdd16bc7590b33..938a4e53ef76460da57ed18357ab0f296ca5465d 100644 --- a/tests/ut/python/defenses/test_nad.py +++ b/tests/ut/python/defenses/test_nad.py @@ -14,20 +14,19 @@ """ Natural adversarial defense test. """ -import numpy as np -import pytest import logging -from mindspore import nn +import numpy as np +import pytest from mindspore import context +from mindspore import nn from mindspore.nn.optim.momentum import Momentum +from mock_net import Net from mindarmour.defenses.natural_adversarial_defense import \ NaturalAdversarialDefense from mindarmour.utils.logger import LogUtil -from mock_net import Net - LOGGER = LogUtil.get_instance() TAG = 'Nad_Test' diff --git a/tests/ut/python/defenses/test_pad.py b/tests/ut/python/defenses/test_pad.py index f4ee0ad8331e7af0f8c6382e3c76129859aa7c89..79080954e2004582d3abbc4de05a414bd6d6ab32 100644 --- a/tests/ut/python/defenses/test_pad.py +++ b/tests/ut/python/defenses/test_pad.py @@ -14,20 +14,19 @@ """ Projected adversarial defense test. """ -import numpy as np -import pytest import logging -from mindspore import nn +import numpy as np +import pytest from mindspore import context +from mindspore import nn from mindspore.nn.optim.momentum import Momentum +from mock_net import Net from mindarmour.defenses.projected_adversarial_defense import \ ProjectedAdversarialDefense from mindarmour.utils.logger import LogUtil -from mock_net import Net - LOGGER = LogUtil.get_instance() TAG = 'Pad_Test' diff --git a/tests/ut/python/detectors/black/test_similarity_detector.py b/tests/ut/python/detectors/black/test_similarity_detector.py index 255a58bd0d987be2f377e482995bfd666cdaa181..284d4f8e348eb52bfec9ddeca3f51f85e8688294 100644 --- a/tests/ut/python/detectors/black/test_similarity_detector.py +++ b/tests/ut/python/detectors/black/test_similarity_detector.py @@ -98,4 +98,3 @@ def test_similarity_detector(): 1561, 1612, 1663, 1714, 1765, 1816, 1867, 1918, 1969] assert np.all(detector.get_detected_queries() == expected_value) - diff --git a/tests/ut/python/detectors/test_spatial_smoothing.py b/tests/ut/python/detectors/test_spatial_smoothing.py index fe7669c7a6cc21f18a0fd83329780d6f3ec2222d..4ed8b8830dcc7f9710687166cd90e2e3e640a245 100644 --- a/tests/ut/python/detectors/test_spatial_smoothing.py +++ b/tests/ut/python/detectors/test_spatial_smoothing.py @@ -111,6 +111,3 @@ def test_spatial_smoothing_diff(): 0.38254014, 0.543059, 0.06452079, 0.36902517, 1.1845329, 0.3870097]) assert np.allclose(diffs, expected_value, 0.0001, 0.0001) - - - diff --git a/tests/ut/python/evaluations/black/test_black_defense_eval.py b/tests/ut/python/evaluations/black/test_black_defense_eval.py index 4cbd58640c0e635974bdf0e89b754095bbcb5214..3cb925d2e01b25573c23f8ea9dcf489be1eb32ee 100644 --- a/tests/ut/python/evaluations/black/test_black_defense_eval.py +++ b/tests/ut/python/evaluations/black/test_black_defense_eval.py @@ -53,14 +53,14 @@ def test_def_eval(): # create obj def_eval = BlackDefenseEvaluate(raw_preds, - def_preds, - raw_query_counts, - def_query_counts, - raw_query_time, - def_query_time, - def_detection_counts, - true_labels, - max_queries=100) + def_preds, + raw_query_counts, + def_query_counts, + raw_query_time, + def_query_time, + def_detection_counts, + true_labels, + max_queries=100) # run eval qcv = def_eval.qcv() asv = def_eval.asv() diff --git a/tests/ut/python/evaluations/test_attack_eval.py b/tests/ut/python/evaluations/test_attack_eval.py index daee550abae4e4d9b9d1a4c8ca6f34754be1a4bb..645f5d8b80005ecbba53558ccb2baa271df6f3ac 100644 --- a/tests/ut/python/evaluations/test_attack_eval.py +++ b/tests/ut/python/evaluations/test_attack_eval.py @@ -30,8 +30,8 @@ def test_attack_eval(): np.random.seed(1024) inputs = np.random.normal(size=(3, 512, 512, 3)) labels = np.array([[0.1, 0.1, 0.2, 0.6], - [0.1, 0.7, 0.0, 0.2], - [0.8, 0.1, 0.0, 0.1]]) + [0.1, 0.7, 0.0, 0.2], + [0.8, 0.1, 0.0, 0.1]]) adv_x = inputs + np.ones((3, 512, 512, 3))*0.001 adv_y = np.array([[0.1, 0.1, 0.2, 0.6], [0.1, 0.0, 0.8, 0.1], @@ -63,8 +63,8 @@ def test_value_error(): np.random.seed(1024) inputs = np.random.normal(size=(3, 512, 512, 3)) labels = np.array([[0.1, 0.1, 0.2, 0.6], - [0.1, 0.7, 0.0, 0.2], - [0.8, 0.1, 0.0, 0.1]]) + [0.1, 0.7, 0.0, 0.2], + [0.8, 0.1, 0.0, 0.1]]) adv_x = inputs + np.ones((3, 512, 512, 3))*0.001 adv_y = np.array([[0.1, 0.1, 0.2, 0.6], [0.1, 0.0, 0.8, 0.1], @@ -81,7 +81,7 @@ def test_value_error(): @pytest.mark.platform_x86_ascend_training @pytest.mark.env_card @pytest.mark.component_mindarmour -def test_value_error(): +def test_empty_input_error(): # prepare test data np.random.seed(1024) inputs = np.array([]) diff --git a/tests/ut/python/evaluations/test_radar_metric.py b/tests/ut/python/evaluations/test_radar_metric.py index f93ef577ee918372348b672296f7706e5cd6205b..5324a22dad2fe621bdee6d361aeee4914a8d7945 100644 --- a/tests/ut/python/evaluations/test_radar_metric.py +++ b/tests/ut/python/evaluations/test_radar_metric.py @@ -30,7 +30,7 @@ def test_radar_metric(): metrics_labels = ['before', 'after'] # create obj - rm = RadarMetric(metrics_name, metrics_data, metrics_labels, title='', + _ = RadarMetric(metrics_name, metrics_data, metrics_labels, title='', scale='sparse') @@ -54,4 +54,3 @@ def test_value_error(): with pytest.raises(ValueError): assert RadarMetric(['MR', 'ACAC', 'ASS'], metrics_data, metrics_labels, title='', scale='bad_s') - diff --git a/tests/ut/python/fuzzing/test_coverage_metrics.py b/tests/ut/python/fuzzing/test_coverage_metrics.py index dd98507b1dd41dd4ceccbec9850488dd271f1459..158565ae45d156fbc2c24334d90f1d3c94a285a5 100644 --- a/tests/ut/python/fuzzing/test_coverage_metrics.py +++ b/tests/ut/python/fuzzing/test_coverage_metrics.py @@ -125,4 +125,4 @@ def test_lenet_mnist_coverage_ascend(): bias_coefficient=0.5) LOGGER.info(TAG, 'KMNC of this test is : %s', model_fuzz_test.get_kmnc()) LOGGER.info(TAG, 'NBC of this test is : %s', model_fuzz_test.get_nbc()) - LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) \ No newline at end of file + LOGGER.info(TAG, 'SNAC of this test is : %s', model_fuzz_test.get_snac()) diff --git a/tests/ut/python/fuzzing/test_fuzzing.py b/tests/ut/python/fuzzing/test_fuzzing.py index 6ddf0aaafe0830bf2fd6963be25f7e2deabaef7c..7396f4575fecfd18d3944be2b528a35c1bcce0dc 100644 --- a/tests/ut/python/fuzzing/test_fuzzing.py +++ b/tests/ut/python/fuzzing/test_fuzzing.py @@ -16,18 +16,15 @@ Model-fuzz coverage test. """ import numpy as np import pytest -import sys - -from mindspore.train import Model -from mindspore import nn -from mindspore.ops import operations as P from mindspore import context +from mindspore import nn from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P +from mindspore.train import Model -from mindarmour.utils.logger import LogUtil -from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics from mindarmour.fuzzing.fuzzing import Fuzzing - +from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics +from mindarmour.utils.logger import LogUtil LOGGER = LogUtil.get_instance() TAG = 'Fuzzing test' @@ -116,17 +113,18 @@ def test_fuzzing_ascend(): model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, max_seed_num=10) failed_tests = model_fuzz_test.fuzzing() - model_coverage_test.test_adequacy_coverage_calculate( - np.array(failed_tests).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) + if failed_tests: + model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + else: + LOGGER.info(TAG, 'Fuzzing test identifies none failed test') @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @pytest.mark.component_mindarmour -def test_fuzzing_ascend(): +def test_fuzzing_CPU(): context.set_context(mode=context.GRAPH_MODE, device_target="CPU") # load network net = Net() @@ -155,7 +153,8 @@ def test_fuzzing_ascend(): model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, max_seed_num=10) failed_tests = model_fuzz_test.fuzzing() - model_coverage_test.test_adequacy_coverage_calculate( - np.array(failed_tests).astype(np.float32)) - LOGGER.info(TAG, 'KMNC of this test is : %s', - model_coverage_test.get_kmnc()) + if failed_tests: + model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) + LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) + else: + LOGGER.info(TAG, 'Fuzzing test identifies none failed test') diff --git a/tests/ut/python/utils/test_image_transform.py b/tests/ut/python/utils/test_image_transform.py index 1b49c95feeb7a4f30e859b34fd78947403240dfe..d1fcaee40ac3e349df24044c23e7c014fc182772 100644 --- a/tests/ut/python/utils/test_image_transform.py +++ b/tests/ut/python/utils/test_image_transform.py @@ -35,7 +35,7 @@ def test_contrast(): mode = 'L' trans = Contrast(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -47,7 +47,7 @@ def test_brightness(): mode = 'L' trans = Brightness(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -61,7 +61,7 @@ def test_blur(): mode = 'L' trans = Blur(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -75,7 +75,7 @@ def test_noise(): mode = 'L' trans = Noise(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -89,7 +89,7 @@ def test_translate(): mode = 'L' trans = Translate(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -103,7 +103,7 @@ def test_shear(): mode = 'L' trans = Shear(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -117,7 +117,7 @@ def test_scale(): mode = 'L' trans = Scale(image, mode) trans.random_param() - trans_image = trans.transform() + _ = trans.transform() @pytest.mark.level0 @@ -131,6 +131,4 @@ def test_rotate(): mode = 'L' trans = Rotate(image, mode) trans.random_param() - trans_image = trans.transform() - - + _ = trans.transform()