test_ad.py 2.1 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adversarial defense test.
"""
import logging

Z
zheng-huanhuan 已提交
19 20
import numpy as np
import pytest
Z
zheng-huanhuan 已提交
21 22
from mindspore import Tensor
from mindspore import context
Z
zheng-huanhuan 已提交
23
from mindspore import nn
Z
zheng-huanhuan 已提交
24 25
from mindspore.nn.optim.momentum import Momentum

26
from mindarmour.adv_robustness.defenses import AdversarialDefense
Z
zheng-huanhuan 已提交
27 28
from mindarmour.utils.logger import LogUtil

29 30
from ut.python.utils.mock_net import Net

Z
zheng-huanhuan 已提交
31 32 33 34 35 36 37 38 39 40 41 42
LOGGER = LogUtil.get_instance()
TAG = 'Ad_Test'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_ad():
    """UT for adversarial defense."""
    num_classes = 10
43
    batch_size = 32
Z
zheng-huanhuan 已提交
44 45 46 47 48 49 50 51 52 53 54 55

    sparse = False
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target='Ascend')

    # create test data
    inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
    labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
    if not sparse:
        labels = np.eye(num_classes)[labels].astype(np.float32)

    net = Net()
56
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=sparse)
Z
zheng-huanhuan 已提交
57
    optimizer = Momentum(learning_rate=Tensor(np.array([0.001], np.float32)),
J
jin-xiulang 已提交
58
                         momentum=0.9,
Z
zheng-huanhuan 已提交
59 60 61 62 63 64 65 66
                         params=net.trainable_params())

    ad_defense = AdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer)
    LOGGER.set_level(logging.DEBUG)
    LOGGER.debug(TAG, '--start adversarial defense--')
    loss = ad_defense.defense(inputs, labels)
    LOGGER.debug(TAG, '--end adversarial defense--')
    assert np.any(loss >= 0.0)