提交 ae418490 编写于 作者: wgzqz's avatar wgzqz

Add bad_adversary_example property to record the example even it is bad.

上级 307dd366
......@@ -18,13 +18,15 @@ class Adversary(object):
"""
assert original is not None
self.original_label = original_label
self.target_label = None
self.adversarial_label = None
self.__original = original
self.__original_label = original_label
self.__target_label = None
self.__target = None
self.__is_targeted_attack = False
self.__adversarial_example = None
self.__adversarial_label = None
self.__bad_adversarial_example = None
def set_target(self, is_targeted_attack, target=None, target_label=None):
"""
......@@ -38,10 +40,10 @@ class Adversary(object):
"""
assert (target_label is None) or is_targeted_attack
self.__is_targeted_attack = is_targeted_attack
self.__target_label = target_label
self.target_label = target_label
self.__target = target
if not is_targeted_attack:
self.__target_label = None
self.target_label = None
self.__target = None
def set_original(self, original, original_label=None):
......@@ -53,10 +55,11 @@ class Adversary(object):
"""
if original != self.__original:
self.__original = original
self.__original_label = original_label
self.original_label = original_label
self.__adversarial_example = None
self.__bad_adversarial_example = None
if original is None:
self.__original_label = None
self.original_label = None
def _is_successful(self, adversarial_label):
"""
......@@ -65,11 +68,11 @@ class Adversary(object):
:param adversarial_label: adversarial label.
:return: bool
"""
if self.__target_label is not None:
return adversarial_label == self.__target_label
if self.target_label is not None:
return adversarial_label == self.target_label
else:
return (adversarial_label is not None) and \
(adversarial_label != self.__original_label)
(adversarial_label != self.original_label)
def is_successful(self):
"""
......@@ -77,7 +80,7 @@ class Adversary(object):
:return: bool
"""
return self._is_successful(self.__adversarial_label)
return self._is_successful(self.adversarial_label)
def try_accept_the_example(self, adversarial_example, adversarial_label):
"""
......@@ -93,7 +96,9 @@ class Adversary(object):
ok = self._is_successful(adversarial_label)
if ok:
self.__adversarial_example = adversarial_example
self.__adversarial_label = adversarial_label
self.adversarial_label = adversarial_label
else:
self.__bad_adversarial_example = adversarial_example
return ok
def perturbation(self, multiplying_factor=1.0):
......@@ -104,9 +109,14 @@ class Adversary(object):
:return: The perturbation that is multiplied by multiplying_factor.
"""
assert self.__original is not None
assert self.__adversarial_example is not None
return multiplying_factor * (
self.__adversarial_example - self.__original)
assert (self.__adversarial_example is not None) or \
(self.__bad_adversarial_example is not None)
if self.__adversarial_example is not None:
return multiplying_factor * (
self.__adversarial_example - self.__original)
else:
return multiplying_factor * (
self.__bad_adversarial_example - self.__original)
@property
def is_targeted_attack(self):
......@@ -115,20 +125,6 @@ class Adversary(object):
"""
return self.__is_targeted_attack
@property
def target_label(self):
"""
:property: target_label
"""
return self.__target_label
@target_label.setter
def target_label(self, label):
"""
:property: target_label
"""
self.__target_label = label
@property
def target(self):
"""
......@@ -143,20 +139,6 @@ class Adversary(object):
"""
return self.__original
@property
def original_label(self):
"""
:property: original
"""
return self.__original_label
@original_label.setter
def original_label(self, label):
"""
original_label setter
"""
self.__original_label = label
@property
def adversarial_example(self):
"""
......@@ -164,23 +146,9 @@ class Adversary(object):
"""
return self.__adversarial_example
@adversarial_example.setter
def adversarial_example(self, example):
"""
adversarial_example setter
"""
self.__adversarial_example = example
@property
def adversarial_label(self):
"""
:property: adversarial_label
"""
return self.__adversarial_label
@adversarial_label.setter
def adversarial_label(self, label):
def bad_adversarial_example(self):
"""
adversarial_label setter
:property: bad_adversarial_example
"""
self.__adversarial_label = label
return self.__bad_adversarial_example
......@@ -66,8 +66,9 @@ class Attack(object):
adversary.target_label = np.argmax(
self.model.predict(adversary.target))
logging.info('adversary:\noriginal_label: {}'
'\n target_label: {}'
'\n is_targeted_attack: {}'
logging.info('adversary:'
'\n original_label: {}'
'\n target_label: {}'
'\n is_targeted_attack: {}'
''.format(adversary.original_label, adversary.target_label,
adversary.is_targeted_attack))
......@@ -10,6 +10,8 @@ import numpy as np
from .base import Attack
__all__ = ['DeepFoolAttack']
class DeepFoolAttack(Attack):
"""
......@@ -70,9 +72,12 @@ class DeepFoolAttack(Attack):
f = self.model.predict(x)
gradient = self.model.gradient(x, pre_label)
adv_label = np.argmax(f)
logging.info('iteration = {}, f = {}, pre_label = {}'
', adv_label={}'.format(iteration, f[pre_label],
pre_label, adv_label))
logging.info('iteration={}, f[pre_label]={}, f[target_label]={}'
', f[adv_label]={}, pre_label={}, adv_label={}'
''.format(iteration, f[pre_label], (
f[adversary.target_label]
if adversary.is_targeted_attack else 'NaN'), f[
adv_label], pre_label, adv_label))
if adversary.try_accept_the_example(x, adv_label):
return adversary
......
......@@ -10,6 +10,8 @@ from scipy.optimize import fmin_l_bfgs_b
from .base import Attack
__all__ = ['LBFGSAttack', 'LBFGS']
class LBFGSAttack(Attack):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册