未验证 提交 ab32cf3d 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #78 from XDUXK/dlg-attack

add DLG tool of model inversion attack
### 1. What is DLG?
#### 1.1 Background
In federated learning framework, users’ data are stored locally while only gradients of target model being transferred among different devices. This framework can prevent privacy leakage of these data and the model utility can benefit from abundant data of multiply users. However, recent research work showed that federated learning framework incurred severe privacy leakage, e.g., **Deep Leakage from Gradients (DLG) [1]**.
#### 1.2 Method
In DLG, in which attackers can steal accurate participant’s train data from the shared gradients. The workflow of DLG is shown as follow:
![workflow](./images/dlg.png)
(1) Target data (e.g., one picture of the victim’s training set) and label are sent into model **F**, and output gradients of model parameters ∇W. The attacker can obtain ∇W, for example, by downloading it from a parameter server in federated learning framework.
(2) Dummy data and label generated randomly by the attacker.
(3) Dummy data and label are sent into the same model **F** and also output gradients of model parameters ∇W'.
(4) Define the square error of ∇W and ∇W' as a loss. Optimize the loss by updating dummy data and label with their gradients.
(5) The attacker sends the updated dummy data and label into model **F** again. Repeat step 3-5 until results with small loss are obtained, i.e., the generated data is close to the original training data.
### 2. How to use this tool?
#### 2.1 Usage
Execute python script "`python mnist_example.py`". Optional arguments include:
```python
parser.add_argument('--use_gpu',
type=bool, default=False,
help='Whether to use GPU or not.')
parser.add_argument('--batch_size',
type=int, default=2,
help='The batch size of normal training.')
parser.add_argument('--iterations',
type=int, default=3000,
help='The iterations of attacking training.')
parser.add_argument('--learning_rate',
type=float, default=-8.5,
help='The learning rate of attacking training.')
parser.add_argument('--result_dir',
type=str, default='./att_results',
help='The directory for saving attack result.')
```
#### 2.2 Procedure of example of DLG attack on MNIST
As is show in "`mnist_example.py`", you can define the normal training process with PaddlePaddle as usual. What the difference is that at the first step of normal training, you should just obtain original gradients of model parameters generated by the original training data, and then pass the original gradients to the DLG module (in `dlg/dlg.py`) together with several arguments such as iterations, learning rate, feature variable, label variable, model network used for MNIST, the same executor with normal training, and the original gradients. Then the program would entrance the DLG attack module and generate effective attacking results in a defined directory.
#### 2.3 Result of DLG attack on MNIST
We use MNIST dataset to evaluate the effect and performance of DLG attack. A figure of example is shown as following. The "`target.png"` in the left is the target picture, and the pictures in the right named with the format of "`result_NUM.png`" are the attacking results, where "`NUM`" means the number of attacking iterations.
* **Effect**: As we can see, effective attacking result occurs after about 1700 attacking iterations. And the attacker is able to get almost the same result with target picture after 2100 attacking iterations.
* **Performance**: We use the default argument of attacking iterations, i.e., "`--iterations=3000`" to run "`mnist_example.py`" with CPU. The total time costed is about 12 seconds, which means that DLG attack is very efficient.
![result_example](./images/result_example.png)
### 3. Requirements.
* **PaddlePaddle version** should be **1.7 or higher**.
* Normally, you can use Python 2.7 or 3.7, and we **recommend Python 3.7**.
#### Referances
[1] Zhu, L., Liu, Z., & Han, S. (2019). Deep leakage from gradients. In *Advances in Neural Information Processing Systems* (pp. 14747-14756).
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Import DLG module.
"""
from .dlg import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides DLG attack. Please refer to
README for more details.
"""
import os
import time
import numpy
import paddle.fluid as fluid
from PIL import Image
__all__ = ["dlg_attack"]
def dlg_attack(args, feature, label, network, exe, origin_grad):
"""
The implementation of DLG attack.
:param args: the parameters for dlg attack
:param feature: the variable of feature
:param label: the variable of label
:param network: the network of model which is trained
:param exe: the same executor with normal training procedure
:param origin_grad: the original gradients of model params
generated by target data, i.e., the data which is being attacked.
:return:
"""
main_program = fluid.Program()
# use a new program
with fluid.program_guard(main_program):
# the dummy feature, which aims to imitate the target data
dummy_x = fluid.data(name="dummy_x",
shape=list(feature.shape),
dtype=feature.dtype)
# let dummy_x can be updated
dummy_x.stop_gradient = False
# the dummy_label
dummy_y = fluid.data(name="dummy_y",
shape=list(label.shape),
dtype=label.dtype)
# let dummy_y can be updated
dummy_y.stop_gradient = False
# use the model network of training
_, dummy_loss = network(dummy_x, dummy_y)
# get gradients of params that can be trainable
all_params = main_program.global_block().all_parameters()
grad_params = [param for param in all_params if param.trainable]
dummy_grads = fluid.gradients(dummy_loss, grad_params)
# original gradients
origin_grad_vars = []
for g_id, origin_g in enumerate(origin_grad):
grad_name = "origin_g_" + str(g_id)
grad_shape = origin_g.shape
grad = fluid.data(name=grad_name,
shape=grad_shape,
dtype=origin_g.dtype)
origin_grad_vars.append(grad)
# the target loss of optimization, i.e., the difference
# between gradients of model parameters generated respectively
# by target data and dummy data
diff_loss = 0.0
for orig_g, dum_g in zip(origin_grad_vars, dummy_grads):
cur_loss = fluid.layers.square_error_cost(orig_g, dum_g)
cur_loss = fluid.layers.reduce_mean(cur_loss)
diff_loss += cur_loss
mean_diff_loss = fluid.layers.mean(diff_loss)
# the gradient of dummy_x
grad_of_x = fluid.gradients(mean_diff_loss, dummy_x)
dummy_feature_shape = [1 if d == -1 else d for d in list(feature.shape)]
dummy_label_shape = [1 if d == -1 else d for d in list(label.shape)]
# Generate dummy target data. The main two types, i.e., float32 and int64,
# are used here for feature and label variables respectively, which can be
# changed according to different types in different scenarios.
dummy_feature = numpy.random.normal(0, 1,
size=dummy_feature_shape
).astype("float32")
dummy_label = numpy.zeros(shape=dummy_label_shape).astype("int64")
feed_dict = {}
# add original gradients into feed_dict
for idx, orig_g in enumerate(origin_grad):
key = "origin_g_" + str(idx)
feed_dict[key] = orig_g
# the time of starting attack
start = time.time()
for iteration in range(args.iterations):
feed_dict["dummy_x"] = dummy_feature
feed_dict["dummy_y"] = dummy_label
result = exe.run(main_program,
feed=feed_dict,
fetch_list=[mean_diff_loss] + grad_of_x)
grad_diff_loss, feature_grad = result[0][0], result[1:]
# update dummy_x with it's gradient
feature_grad = numpy.array(feature_grad).reshape(dummy_feature_shape)
dummy_feature = numpy.add(dummy_feature, args.learning_rate * feature_grad)
dummy_feature = numpy.array(dummy_feature)
# the shape of target image
img_shape = dummy_feature_shape[-2:]
# save attack results per 100 iterations
if iteration % 100 == 0:
print("Attack Iteration {}: grad_diff_loss = {}"
.format(iteration, grad_diff_loss))
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
img = Image.fromarray((dummy_feature * 255)
.reshape(img_shape)
.astype(numpy.uint8))
img.save(args.result_dir + "/result_{}.png".format(iteration))
end = time.time()
print("Attack cost time in seconds: {}".format(end - start))
# exit after attack finished
exit("Attack finished.")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provides an example of DLG attack on MNIST. Please refer to
README for more details.
"""
from __future__ import print_function
import argparse
import numpy
import paddle
import paddle.fluid as fluid
from PIL import Image
from paddle.fluid.param_attr import ParamAttr
from dlg import dlg
def parse_args():
"""
Parse command line arguments.
:return:
"""
parser = argparse.ArgumentParser("DLG")
parser.add_argument("--use_gpu",
type=bool, default=False,
help="Whether to use GPU or not.")
parser.add_argument("--batch_size",
type=int, default=2,
help="The batch size of normal training.")
parser.add_argument("--iterations",
type=int, default=3000,
help="The iterations of attacking training.")
parser.add_argument("--learning_rate",
type=float, default=-8.5,
help="The learning rate of attacking training.")
parser.add_argument("--result_dir",
type=str, default="./att_results",
help="the directory for saving attack result.")
args = parser.parse_args()
return args
def network(img, label):
"""
The network of model.
:param img: the feature of training data
:param label: the label of training data
:return: the prediction and average loos
"""
# ensure that dummy data use the same initialized
# model params with real data
param_attr = ParamAttr(name="fc.w_0")
bias_attr = ParamAttr(name="fc.b_0")
prediction = fluid.layers.fc(input=img,
size=10,
act="softmax",
param_attr=param_attr,
bias_attr=bias_attr)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return prediction, avg_loss
def train_and_attack(args):
"""
The training procedure that starts from several normal training steps as usual,
but entrance the dlg method as soon as the gradients of target data are obtained.
:param args: the execution parameters.
:return:
"""
if args.use_gpu and not fluid.core.is_compiled_with_cuda():
return
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=500),
batch_size=args.batch_size)
img = fluid.data(name="img", shape=[None, 28, 28], dtype="float32")
label = fluid.data(name="label", shape=[None, 1], dtype="int64")
prediction, avg_loss = network(img, label)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
# ensure that the model parameters are not be updated before attack finished.
_ = optimizer.backward(avg_loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
for step_id, data in enumerate(train_reader()):
params = main_program.global_block().all_parameters()
grad_param = [param.name + "@GRAD" for param in params if param.trainable]
# save the target data for checking out the effectiveness of attack
image = Image.fromarray((data[0][0] * 255).reshape(28, 28).astype(numpy.uint8))
image.save("./target.png")
target_x = numpy.array(data[0][0]).reshape((1, 28, 28))
target_y = numpy.array(data[0][1]).reshape(1, 1)
metrics = exe.run(
main_program,
feed={"img": target_x, "label": target_y},
fetch_list=[avg_loss] + grad_param)
# entrance DLG attack procedure at the first step
if step_id == 0:
# the gradients of model parameters generated by target data
origin_grad = metrics[1:]
dlg.dlg_attack(args, img, label, network, exe, origin_grad)
if __name__ == "__main__":
arguments = parse_args()
train_and_attack(arguments)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册