base.py 2.0 KB
Newer Older
G
gx_wind 已提交
1 2 3
"""
The base model of the model.
"""
wgzqz's avatar
wgzqz 已提交
4 5 6 7 8
import logging
from abc import ABCMeta
from abc import abstractmethod

import numpy as np
G
gx_wind 已提交
9 10 11 12


class Attack(object):
    """
wgzqz's avatar
wgzqz 已提交
13 14 15
    Abstract base class for adversarial attacks. `Attack` represent an
    adversarial attack which search an adversarial example. subclass should
    implement the _apply() method.
G
gx_wind 已提交
16 17 18 19 20 21 22 23 24 25

    Args:
        model(Model): an instance of the class advbox.base.Model.

    """
    __metaclass__ = ABCMeta

    def __init__(self, model):
        self.model = model

wgzqz's avatar
wgzqz 已提交
26
    def __call__(self, adversary, **kwargs):
G
gx_wind 已提交
27 28 29 30
        """
        Generate the adversarial sample.

        Args:
wgzqz's avatar
wgzqz 已提交
31 32
        adversary(object): The adversary object.
        **kwargs: Other params.
G
gx_wind 已提交
33
        """
wgzqz's avatar
wgzqz 已提交
34 35
        self._preprocess(adversary)
        return self._apply(adversary, **kwargs)
G
gx_wind 已提交
36 37

    @abstractmethod
wgzqz's avatar
wgzqz 已提交
38
    def _apply(self, adversary):
G
gx_wind 已提交
39 40 41 42
        """
        Search an adversarial example.

        Args:
wgzqz's avatar
wgzqz 已提交
43
        adversary(object): The adversary object.
G
gx_wind 已提交
44 45
        """
        raise NotImplementedError
wgzqz's avatar
wgzqz 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

    def _preprocess(self, adversary):
        """
        Preprocess the adversary object.

        :param adversary: adversary
        :return: None
        """
        if adversary.original_label is None:
            adversary.original_label = np.argmax(
                self.model.predict([(adversary.original, 0)]))
        if adversary.is_targeted_attack and adversary.target_label is None:
            if adversary.target is None:
                raise ValueError(
                    'When adversary.is_targeted_attack is True, '
                    'adversary.target_label or adversary.target must be set.')
            else:
                adversary.target_label_label = np.argmax(
                    self.model.predict([(adversary.target_label, 0)]))

        logging.info('adversary:\noriginal_label: {}'
                     '\n          target_lable: {}'
                     '\n          is_targeted_attack: {}'.format(
                         adversary.original_label, adversary.target_label,
                         adversary.is_targeted_attack))