test_ead.py 2.3 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ensemble adversarial defense test.
"""
import logging

Z
zheng-huanhuan 已提交
19 20
import numpy as np
import pytest
Z
zheng-huanhuan 已提交
21
from mindspore import context
Z
zheng-huanhuan 已提交
22
from mindspore import nn
Z
zheng-huanhuan 已提交
23 24
from mindspore.nn.optim.momentum import Momentum

25 26
from mindarmour.adv_robustness.attacks import FastGradientSignMethod
from mindarmour.adv_robustness.attacks import \
Z
zheng-huanhuan 已提交
27
    ProjectedGradientDescent
28
from mindarmour.adv_robustness.defenses import EnsembleAdversarialDefense
Z
zheng-huanhuan 已提交
29 30
from mindarmour.utils.logger import LogUtil

31 32
from ut.python.utils.mock_net import Net

Z
zheng-huanhuan 已提交
33 34 35 36 37 38 39 40 41 42 43 44
LOGGER = LogUtil.get_instance()
TAG = 'Ead_Test'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_ead():
    """UT for ensemble adversarial defense."""
    num_classes = 10
45
    batch_size = 64
Z
zheng-huanhuan 已提交
46 47 48 49 50 51 52 53 54 55 56

    sparse = False
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target='Ascend')

    # create test data
    inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
    labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
    if not sparse:
        labels = np.eye(num_classes)[labels].astype(np.float32)

57
    net = Net()
58
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=sparse)
Z
zheng-huanhuan 已提交
59 60 61 62 63 64 65 66 67 68 69 70
    optimizer = Momentum(net.trainable_params(), 0.001, 0.9)

    net = Net()
    fgsm = FastGradientSignMethod(net)
    pgd = ProjectedGradientDescent(net)
    ead = EnsembleAdversarialDefense(net, [fgsm, pgd], loss_fn=loss_fn,
                                     optimizer=optimizer)
    LOGGER.set_level(logging.DEBUG)
    LOGGER.debug(TAG, '---start ensemble adversarial defense--')
    loss = ead.defense(inputs, labels)
    LOGGER.debug(TAG, '---end ensemble adversarial defense--')
    assert np.any(loss >= 0.0)