base.py 2.2 KB
Newer Older
G
gx_wind 已提交
1 2 3
"""
The base model of the model.
"""
wgzqz's avatar
wgzqz 已提交
4 5 6 7 8
import logging
from abc import ABCMeta
from abc import abstractmethod

import numpy as np
G
gx_wind 已提交
9 10 11 12


class Attack(object):
    """
wgzqz's avatar
wgzqz 已提交
13 14 15
    Abstract base class for adversarial attacks. `Attack` represent an
    adversarial attack which search an adversarial example. subclass should
    implement the _apply() method.
G
gx_wind 已提交
16 17 18 19 20 21 22 23 24 25

    Args:
        model(Model): an instance of the class advbox.base.Model.

    """
    __metaclass__ = ABCMeta

    def __init__(self, model):
        self.model = model

wgzqz's avatar
wgzqz 已提交
26
    def __call__(self, adversary, **kwargs):
G
gx_wind 已提交
27 28 29 30
        """
        Generate the adversarial sample.

        Args:
wgzqz's avatar
wgzqz 已提交
31
        adversary(object): The adversary object.
wgzqz's avatar
wgzqz 已提交
32
        **kwargs: Other named arguments.
G
gx_wind 已提交
33
        """
wgzqz's avatar
wgzqz 已提交
34 35
        self._preprocess(adversary)
        return self._apply(adversary, **kwargs)
G
gx_wind 已提交
36 37

    @abstractmethod
wgzqz's avatar
wgzqz 已提交
38
    def _apply(self, adversary, **kwargs):
G
gx_wind 已提交
39 40 41 42
        """
        Search an adversarial example.

        Args:
wgzqz's avatar
wgzqz 已提交
43
        adversary(object): The adversary object.
wgzqz's avatar
wgzqz 已提交
44
        **kwargs: Other named arguments.
G
gx_wind 已提交
45 46
        """
        raise NotImplementedError
wgzqz's avatar
wgzqz 已提交
47 48 49 50 51 52 53 54

    def _preprocess(self, adversary):
        """
        Preprocess the adversary object.

        :param adversary: adversary
        :return: None
        """
55 56
        assert self.model.channel_axis() == adversary.original.ndim

wgzqz's avatar
wgzqz 已提交
57 58
        if adversary.original_label is None:
            adversary.original_label = np.argmax(
59
                self.model.predict(adversary.original))
wgzqz's avatar
wgzqz 已提交
60 61 62
        if adversary.is_targeted_attack and adversary.target_label is None:
            if adversary.target is None:
                raise ValueError(
wgzqz's avatar
wgzqz 已提交
63
                    'When adversary.is_targeted_attack is true, '
wgzqz's avatar
wgzqz 已提交
64 65
                    'adversary.target_label or adversary.target must be set.')
            else:
wgzqz's avatar
wgzqz 已提交
66 67
                adversary.target_label = np.argmax(
                    self.model.predict(adversary.target))
wgzqz's avatar
wgzqz 已提交
68

69 70 71 72
        logging.info('adversary:'
                     '\n         original_label: {}'
                     '\n         target_label: {}'
                     '\n         is_targeted_attack: {}'
wgzqz's avatar
wgzqz 已提交
73 74
                     ''.format(adversary.original_label, adversary.target_label,
                               adversary.is_targeted_attack))