提交 6c39bce4 编写于 作者: wgzqz's avatar wgzqz

Using model.num_class to get the class count of the original data.

上级 94fea877
...@@ -38,10 +38,11 @@ class DeepFoolAttack(Attack): ...@@ -38,10 +38,11 @@ class DeepFoolAttack(Attack):
labels = [adversary.target_label] labels = [adversary.target_label]
else: else:
max_class_count = 10 max_class_count = 10
if len(f) > max_class_count: class_count = self.model.num_classes()
if class_count > max_class_count:
labels = np.argsort(f)[-(max_class_count + 1):-1] labels = np.argsort(f)[-(max_class_count + 1):-1]
else: else:
labels = np.arange(len(f)) labels = np.arange(class_count)
gradient = self.model.gradient([(adversary.original, pre_label)]) gradient = self.model.gradient([(adversary.original, pre_label)])
x = adversary.original.reshape(gradient.shape) x = adversary.original.reshape(gradient.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册