提交 b5c4e506 编写于 作者: wgzqz's avatar wgzqz

Fix the bug which will appear when model.channel_axis==1.

上级 36b8b247
......@@ -52,6 +52,8 @@ class Attack(object):
:param adversary: adversary
:return: None
"""
assert self.model.channel_axis() == adversary.original.ndim
if adversary.original_label is None:
adversary.original_label = np.argmax(
self.model.predict(adversary.original))
......
......@@ -54,14 +54,12 @@ class GradientMethodAttack(Attack):
if not isinstance(epsilons, Iterable):
epsilons = np.linspace(epsilons, epsilons + 1e-10, num=steps)
print(epsilons)
pre_label = adversary.original_label
min_, max_ = self.model.bounds()
print self.model.channel_axis()
assert self.model.channel_axis() == adversary.original.ndim
assert (self.model.channel_axis() == adversary.original.shape[0] or
assert (self.model.channel_axis() == 1 or
self.model.channel_axis() == adversary.original.shape[0] or
self.model.channel_axis() == adversary.original.shape[-1])
adv_img = adversary.original
......@@ -89,6 +87,8 @@ class GradientMethodAttack(Attack):
@staticmethod
def _norm(a, ord):
if a.ndim == 1:
return np.linalg.norm(a, ord=ord)
if a.ndim == a.shape[0]:
norm_shape = (a.ndim, reduce(np.dot, a.shape[1:]))
norm_axis = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册