mnist_similarity_detector.py 5.9 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import numpy as np
import pytest
from scipy.special import softmax

from mindspore import Model
from mindspore import context
from mindspore import Tensor
from mindspore.nn import Cell
from mindspore.ops.operations import TensorAdd
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from mindarmour.utils.logger import LogUtil
from mindarmour.attacks.black.pso_attack import PSOAttack
from mindarmour.attacks.black.black_model import BlackModel
from mindarmour.detectors.black.similarity_detector import SimilarityDetector

from lenet5_net import LeNet5

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")

sys.path.append("..")
from data_processing import generate_mnist_dataset

LOGGER = LogUtil.get_instance()
TAG = 'Similarity Detector test'


class ModelToBeAttacked(BlackModel):
    """
    model to be attack
    """

    def __init__(self, network):
        super(ModelToBeAttacked, self).__init__()
        self._network = network
        self._queries = []

    def predict(self, inputs):
        """
        predict function
        """
        query_num = inputs.shape[0]
        for i in range(query_num):
            self._queries.append(inputs[i].astype(np.float32))
        result = self._network(Tensor(inputs.astype(np.float32)))
        return result.asnumpy()

    def get_queries(self):
        return self._queries


class EncoderNet(Cell):
    """
    Similarity encoder for input data
    """

    def __init__(self, encode_dim):
        super(EncoderNet, self).__init__()
        self._encode_dim = encode_dim
        self.add = TensorAdd()

    def construct(self, inputs):
        """
        construct the neural network
        Args:
            inputs (Tensor): input data to neural network.
        Returns:
            Tensor, output of neural network.
        """
        return self.add(inputs, inputs)

    def get_encode_dim(self):
        """
        Get the dimension of encoded inputs

        Returns:
            int, dimension of encoded inputs.
        """
        return self._encode_dim


@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_similarity_detector():
    """
    Similarity Detector test.
    """
    # load trained network
    ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt'
    net = LeNet5()
    load_dict = load_checkpoint(ckpt_name)
    load_param_into_net(net, load_dict)

    # get mnist data
    data_list = "./MNIST_unzip/test"
    batch_size = 1000
    ds = generate_mnist_dataset(data_list, batch_size=batch_size)
    model = ModelToBeAttacked(net)

    batch_num = 10  # the number of batches of input samples
    all_images = []
    true_labels = []
    predict_labels = []
    i = 0
    for data in ds.create_tuple_iterator():
        i += 1
        images = data[0].astype(np.float32)
        labels = data[1]
        all_images.append(images)
        true_labels.append(labels)
        pred_labels = np.argmax(model.predict(images), axis=1)
        predict_labels.append(pred_labels)
        if i >= batch_num:
            break
    all_images = np.concatenate(all_images)
    true_labels = np.concatenate(true_labels)
    predict_labels = np.concatenate(predict_labels)
    accuracy = np.mean(np.equal(predict_labels, true_labels))
    LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy)

    train_images = all_images[0:6000, :, :, :]
    attacked_images = all_images[0:10, :, :, :]
    attacked_labels = true_labels[0:10]

    # generate malicious query sequence of black attack
    attack = PSOAttack(model, bounds=(0.0, 1.0), pm=0.5, sparse=True,
                       t_max=1000)
    success_list, adv_data, query_list = attack.generate(attacked_images,
                                                         attacked_labels)
    LOGGER.info(TAG, 'pso attack success_list: %s', success_list)
    LOGGER.info(TAG, 'average of query counts is : %s', np.mean(query_list))
    pred_logits_adv = model.predict(adv_data)
    # rescale predict confidences into (0, 1).
    pred_logits_adv = softmax(pred_logits_adv, axis=1)
    pred_lables_adv = np.argmax(pred_logits_adv, axis=1)
    accuracy_adv = np.mean(np.equal(pred_lables_adv, attacked_labels))
    LOGGER.info(TAG, "prediction accuracy after attacking is : %g",
                accuracy_adv)

    benign_queries = all_images[6000:10000, :, :, :]
    suspicious_queries = model.get_queries()

    # explicit threshold not provided, calculate threshold for K
    encoder = Model(EncoderNet(encode_dim=256))
    detector = SimilarityDetector(max_k_neighbor=50, trans_model=encoder)
    detector.fit(inputs=train_images)

    # test benign queries
    detector.detect(benign_queries)
    fpr = len(detector.get_detected_queries()) / benign_queries.shape[0]
    LOGGER.info(TAG, 'Number of false positive of attack detector is : %s',
                len(detector.get_detected_queries()))
    LOGGER.info(TAG, 'False positive rate of attack detector is : %s', fpr)

    # test attack queries
    detector.clear_buffer()
    detector.detect(suspicious_queries)
    LOGGER.info(TAG, 'Number of detected attack queries is : %s',
                len(detector.get_detected_queries()))
    LOGGER.info(TAG, 'The detected attack query indexes are : %s',
                detector.get_detected_queries())


if __name__ == '__main__':
    test_similarity_detector()