From b5c4e506d04589c88a6005ce137f0f1738714b55 Mon Sep 17 00:00:00 2001 From: guangzhuwu Date: Thu, 8 Feb 2018 15:38:04 +0800 Subject: [PATCH] Fix the bug which will appear when model.channel_axis==1. --- fluid/adversarial/advbox/attacks/base.py | 2 ++ fluid/adversarial/advbox/attacks/gradient_method.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fluid/adversarial/advbox/attacks/base.py b/fluid/adversarial/advbox/attacks/base.py index eb9b1d48..d322d8b5 100644 --- a/fluid/adversarial/advbox/attacks/base.py +++ b/fluid/adversarial/advbox/attacks/base.py @@ -52,6 +52,8 @@ class Attack(object): :param adversary: adversary :return: None """ + assert self.model.channel_axis() == adversary.original.ndim + if adversary.original_label is None: adversary.original_label = np.argmax( self.model.predict(adversary.original)) diff --git a/fluid/adversarial/advbox/attacks/gradient_method.py b/fluid/adversarial/advbox/attacks/gradient_method.py index e7d42c4f..70910136 100644 --- a/fluid/adversarial/advbox/attacks/gradient_method.py +++ b/fluid/adversarial/advbox/attacks/gradient_method.py @@ -54,14 +54,12 @@ class GradientMethodAttack(Attack): if not isinstance(epsilons, Iterable): epsilons = np.linspace(epsilons, epsilons + 1e-10, num=steps) - print(epsilons) - pre_label = adversary.original_label min_, max_ = self.model.bounds() - print self.model.channel_axis() assert self.model.channel_axis() == adversary.original.ndim - assert (self.model.channel_axis() == adversary.original.shape[0] or + assert (self.model.channel_axis() == 1 or + self.model.channel_axis() == adversary.original.shape[0] or self.model.channel_axis() == adversary.original.shape[-1]) adv_img = adversary.original @@ -89,6 +87,8 @@ class GradientMethodAttack(Attack): @staticmethod def _norm(a, ord): + if a.ndim == 1: + return np.linalg.norm(a, ord=ord) if a.ndim == a.shape[0]: norm_shape = (a.ndim, reduce(np.dot, a.shape[1:])) norm_axis = 1 -- GitLab