提交 fb30b3fc 编写于 作者: wgzqz's avatar wgzqz

Fix shape bug.

上级 3163b5e8
...@@ -90,7 +90,8 @@ class Adversary(object): ...@@ -90,7 +90,8 @@ class Adversary(object):
assert adversarial_example.shape == self.__original.shape assert adversarial_example.shape == self.__original.shape
ok = self._is_successful(adversarial_label) ok = self._is_successful(adversarial_label)
if ok: if ok:
self.__adversarial_example = adversarial_example self.__adversarial_example = adversarial_example.reshape(
self.__original.shape)
self.__adversarial_label = adversarial_label self.__adversarial_label = adversarial_label
return ok return ok
......
...@@ -44,8 +44,10 @@ class GradientSignAttack(Attack): ...@@ -44,8 +44,10 @@ class GradientSignAttack(Attack):
gradient = self.model.gradient([(adversary.original, gradient = self.model.gradient([(adversary.original,
adversary.original_label)]) adversary.original_label)])
gradient_sign = np.sign(gradient) * (max_ - min_) gradient_sign = np.sign(gradient) * (max_ - min_)
adv_img = adversary.original.reshape(gradient_sign.shape)
for epsilon in epsilons: for epsilon in epsilons:
adv_img = adversary.original + epsilon * gradient_sign adv_img = adv_img + epsilon * gradient_sign
adv_img = np.clip(adv_img, min_, max_) adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)])) adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'. logging.info('epsilon = {:.3f}, pre_label = {}, adv_label={}'.
......
...@@ -35,7 +35,7 @@ class IteratorGradientSignAttack(Attack): ...@@ -35,7 +35,7 @@ class IteratorGradientSignAttack(Attack):
min_, max_ = self.model.bounds() min_, max_ = self.model.bounds()
for epsilon in epsilons: for epsilon in epsilons:
adv_img = adversary.original adv_img = None
for _ in range(steps): for _ in range(steps):
if adversary.is_targeted_attack: if adversary.is_targeted_attack:
gradient = self.model.gradient([(adversary.original, gradient = self.model.gradient([(adversary.original,
...@@ -45,6 +45,8 @@ class IteratorGradientSignAttack(Attack): ...@@ -45,6 +45,8 @@ class IteratorGradientSignAttack(Attack):
gradient = self.model.gradient([(adversary.original, gradient = self.model.gradient([(adversary.original,
adversary.original_label)]) adversary.original_label)])
gradient_sign = np.sign(gradient) * (max_ - min_) gradient_sign = np.sign(gradient) * (max_ - min_)
if adv_img is None:
adv_img = adversary.original.reshape(gradient_sign.shape)
adv_img = adv_img + gradient_sign * epsilon adv_img = adv_img + gradient_sign * epsilon
adv_img = np.clip(adv_img, min_, max_) adv_img = np.clip(adv_img, min_, max_)
adv_label = np.argmax(self.model.predict([(adv_img, 0)])) adv_label = np.argmax(self.model.predict([(adv_img, 0)]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册