test_nad.py 2.1 KB
Newer Older
Z
zheng-huanhuan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Natural adversarial defense test.
"""
import logging

Z
zheng-huanhuan 已提交
19 20
import numpy as np
import pytest
Z
zheng-huanhuan 已提交
21
from mindspore import context
Z
zheng-huanhuan 已提交
22
from mindspore import nn
Z
zheng-huanhuan 已提交
23 24
from mindspore.nn.optim.momentum import Momentum

Z
zheng-huanhuan 已提交
25
from mock_net import Net
Z
zheng-huanhuan 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
from mindarmour.defenses.natural_adversarial_defense import \
    NaturalAdversarialDefense
from mindarmour.utils.logger import LogUtil

LOGGER = LogUtil.get_instance()
TAG = 'Nad_Test'


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_nad():
    """UT for natural adversarial defense."""
    num_classes = 10
42
    batch_size = 32
Z
zheng-huanhuan 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    sparse = False
    context.set_context(mode=context.GRAPH_MODE)
    context.set_context(device_target='Ascend')

    # create test data
    inputs = np.random.rand(batch_size, 1, 32, 32).astype(np.float32)
    labels = np.random.randint(num_classes, size=batch_size).astype(np.int32)
    if not sparse:
        labels = np.eye(num_classes)[labels].astype(np.float32)

    net = Net()
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=sparse)
    optimizer = Momentum(net.trainable_params(), 0.001, 0.9)

    # defense
    nad = NaturalAdversarialDefense(net, loss_fn=loss_fn, optimizer=optimizer)
    LOGGER.set_level(logging.DEBUG)
    LOGGER.debug(TAG, '---start natural adversarial defense--')
    loss = nad.defense(inputs, labels)
    LOGGER.debug(TAG, '---end natural adversarial defense--')
    assert np.any(loss >= 0.0)