diff --git a/fluid/adversarial/advbox/attacks/saliency.py b/fluid/adversarial/advbox/attacks/saliency.py index 4372025251bbfa280ec83eca9ab4201b65076f3d..db9f05e9d31f7bdfd7498da422f7a264bdbf52b9 100644 --- a/fluid/adversarial/advbox/attacks/saliency.py +++ b/fluid/adversarial/advbox/attacks/saliency.py @@ -3,6 +3,8 @@ This module provide the attack method for JSMA's implement. """ from __future__ import division +import logging +import random import numpy as np from .base import Attack @@ -33,9 +35,13 @@ class SaliencyMapAttack(Attack): adversary: The Adversary object. """ assert adversary is not None - assert (adversary.target_label is None) or adversary.is_targeted_attack - target_labels = [adversary.target_label] + if not adversary.is_targeted_attack or (adversary.target_label is None): + target_labels = self._generate_random_target( + adversary.original_label) + else: + target_labels = [adversary.target_label] + for target in target_labels: original_image = adversary.original @@ -60,6 +66,9 @@ class SaliencyMapAttack(Attack): if not any(mask): return adversary + logging.info('step = {}, original_label = {}, adv_label={}'. + format(step, adversary.original_label, adv_label)) + # get pixel location with highest influence on class idx, p_sign = self._saliency_map( adv_img, target, labels, mask, fast=fast) @@ -80,7 +89,26 @@ class SaliencyMapAttack(Attack): adv_img = np.clip(adv_img, min_, max_) - return adversary + def _generate_random_target(self, original_label): + """ + Draw random target labels all of which are different and not the original label. + Args: + original_label(int): Original label. + Return: + target_labels(list): random target labels + """ + num_random_target = 1 + num_classes = self.model.num_classes() + assert num_random_target <= num_classes - 1 + + target_labels = random.sample(range(num_classes), num_random_target + 1) + target_labels = [t for t in target_labels if t != original_label] + target_labels = target_labels[:num_random_target] + + # str_target_labels = [str(t) for t in target_labels] + # logging.info('Random target labels: {}'.format(', '.join(str_target_labels))) + + return target_labels def _saliency_map(self, image, target, labels, mask, fast=False): """ @@ -108,10 +136,10 @@ class SaliencyMapAttack(Attack): ], 0) # compute saliency map (take into account both pos. & neg. perturbations) - salmap = np.abs(alphas) * np.abs(betas) * np.sign(alphas * betas) + sal_map = np.abs(alphas) * np.abs(betas) * np.sign(alphas * betas) # find optimal pixel & direction of perturbation - idx = np.argmin(salmap) + idx = np.argmin(sal_map) idx = np.unravel_index(idx, mask.shape) pix_sign = np.sign(alphas)[idx] diff --git a/fluid/adversarial/mnist_tutorial_jsma.py b/fluid/adversarial/mnist_tutorial_jsma.py index 7b6dbc32468099fc39d1ac447f1cae8832ae3012..2010b5b4655b8693c5955092a7810d327667375f 100644 --- a/fluid/adversarial/mnist_tutorial_jsma.py +++ b/fluid/adversarial/mnist_tutorial_jsma.py @@ -75,28 +75,22 @@ def main(): m = PaddleModel(fluid.default_main_program(), IMG_NAME, LABEL_NAME, logits.name, avg_cost.name, (-1, 1)) attack = SaliencyMapAttack(m) - - target_label = 1 - print('target_label = %d' % target_label) - + total_num = 0 + success_num = 0 for data in train_reader(): - # JSMA attack - if target_label == data[0][1]: - continue - print('original label =%d, target_label = %d' % - (data[0][1], target_label)) - - adversary = Adversary(data[0][0], data[0][1]) - adversary.set_target(True, target_label=target_label) - jsma_attack = attack(adversary) - if jsma_attack.is_successful(): + total_num += 1 + # adversary.set_target(True, target_label=target_label) + jsma_attack = attack(Adversary(data[0][0], data[0][1])) + if jsma_attack is not None and jsma_attack.is_successful(): # plt.imshow(jsma_attack.target, cmap='Greys_r') # plt.show() - print('adversary examples label =%d' % - jsma_attack.adversarial_label) - np.save('adv_img', jsma_attack.adversarial_example) + success_num += 1 + print('original_label=%d, adversary examples label =%d' % + (data[0][1], jsma_attack.adversarial_label)) + # np.save('adv_img', jsma_attack.adversarial_example) + print('total num = %d, success num = %d ' % (total_num, success_num)) + if total_num == 100: break - break if __name__ == '__main__':