diff --git a/mindarmour/defenses/adversarial_defense.py b/mindarmour/defenses/adversarial_defense.py index e4a7ba581da72ef75ec063221c5a3c5a70e359c6..4cbe0bea61390369b09be8f87178a8417c35fdf3 100644 --- a/mindarmour/defenses/adversarial_defense.py +++ b/mindarmour/defenses/adversarial_defense.py @@ -23,7 +23,7 @@ from mindspore.nn import SoftmaxCrossEntropyWithLogits from mindspore.nn import WithLossCell, TrainOneStepCell from mindarmour.utils._check_param import check_pair_numpy_param, check_model, \ - check_param_in_range, check_param_type, check_param_multi_types + check_param_in_range, check_param_type, check_param_multi_types from mindarmour.defenses.defense import Defense @@ -166,4 +166,36 @@ class AdversarialDefenseWithAttacks(AdversarialDefense): return loss.asnumpy() -EnsembleAdversarialDefense = AdversarialDefenseWithAttacks +class EnsembleAdversarialDefense(AdversarialDefenseWithAttacks): + """ + Ensemble adversarial defense. + + Args: + network (Cell): A MindSpore network to be defensed. + attacks (list[Attack]): List of attack method. + loss_fn (Functions): Loss function. Default: None. + optimizer (Cell): Optimizer used to train the network. Default: None. + bounds (tuple): Upper and lower bounds of data. In form of (clip_min, + clip_max). Default: (0.0, 1.0). + replace_ratio (float): Ratio of replacing original samples with + adversarial, which must be between 0 and 1. Default: 0.5. + + Raises: + ValueError: If replace_ratio is not between 0 and 1. + + Examples: + >>> net = Net() + >>> fgsm = FastGradientSignMethod(net) + >>> pgd = ProjectedGradientDescent(net) + >>> ead = EnsembleAdversarialDefense(net, [fgsm, pgd]) + >>> ead.defense(inputs, labels) + """ + + def __init__(self, network, attacks, loss_fn=None, optimizer=None, + bounds=(0.0, 1.0), replace_ratio=0.5): + super(EnsembleAdversarialDefense, self).__init__(network, + attacks, + loss_fn, + optimizer, + bounds, + replace_ratio)