提交 15c00b9f 编写于 作者: J jhjiangcs

add youtubednn demo.

上级 acfda7c1
## Instructions for PaddleFL-MPC YoutubeDNN Demo
([简体中文](./README_CN.md)|English)
This document introduces how to run YoutubeDNN demo based on Paddle-MPC, which has two ways of running, i.e., single machine and multi machines.
### 1. Running on Single Machine
#### (1). Prepare Data
Generate encrypted training and testing data utilizing `gen_cypher_sample()` in `process_data.py` script. Users can run the script with command `python process_data.py` to generate encrypted feature and label in given directory, e.g., `./mpc_data/`. Different suffix names are used for these files to indicate the ownership of different computation parties.
#### (2). Launch Demo with A Shell Script
You should set the env params as follow:
```
export PYTHON=/yor/python
export PATH_TO_REDIS_BIN=/path/to/redis_bin
export LOCALHOST=/your/localhost
export REDIS_PORT=/your/redis/port
```
Launch demo with the `run_standalone.sh` script. The concrete command is:
```bash
bash run_standalone.sh train_youtubednn.py
```
The information of current epoch and step will be displayed on screen while training, as well as the total cost time when traning finished.
Besides, predictions would be made in this demo once training is finished. The predictions (l3: third fc's output) with cypher text format would be save in `./mpc_data/` and the format of file name is similar to what is described in Step 1.
#### (3). Decrypt Data and Evaluate Hit Ratio
Decrypt the saved prediction data (video and user feature) and save the decrypted prediction results into a specified file using `decrypt_data_to_file()` in `process_data.py` script.
The similarity of all videos and users will be evaluate with the api `get_topK()` in script `get_topk.py`, then top-K videos will be chosen for each user with api `evaluate_hit_ratio()` in script `process_data.py`.
User can run the shell script `decrypt_and_evaluate.py` to decrypt data and evaluate hit ratio..
### 2. Running on Multi Machines
#### (1). Prepare Data
Data owner encrypts data. Concrete operations are consistent with “Prepare Data” in “Running on Single Machine”.
#### (2). Distribute Encrypted Data
According to the suffix of file name, distribute encrypted data files to `./mpc_data/ ` directories of all 3 computation parties. For example, send `*.part0` to `./mpc_data/` of party 0 with `scp` command.
#### (3). Modify train_youtubednn.py
Each computation party modifies `localhost` in the following code as the IP address of it's machine.
```python
pfl_mpc.init("aby3", int(role), "localhost", server, int(port))
```
#### (4). Launch Demo on Each Party
**Note** that Redis service is necessary for demo running. Remember to clear the cache of Redis server before launching demo on each computation party, in order to avoid any negative influences caused by the cached records in Redis. The following command can be used for clear Redis, where REDIS_BIN is the executable binary of redis-cli, SERVER and PORT represent the IP and port of Redis server respectively.
```
$REDIS_BIN -h $SERVER -p $PORT flushall
```
Launch demo on each computation party with the following command,
```
$PYTHON_EXECUTABLE train_youtubednn.py --role $PARTY_ID --server $SERVER --port $PORT
```
where PYTHON_EXECUTABLE is the python which installs PaddleFL, PARTY_ID is the ID of computation party, which is 0, 1, or 2, SERVER and PORT represent the IP and port of Redis server respectively.
Similarly, video and user feature in cypher text format would be saved in `./mpc_data/` directory, for example, `video_vec.part0` and `user_vec.part0` will be saved for party 0.
#### (5). Decrypt Feature Data
Each computation party sends `video_vec.part*` and `user_vec.part*` file in `./mpc_data/` directory to the `./mpc_infer_data/` directory of data owner. Then, `decrypt and evaluate hit ratios` are are consistent with `Decrypt Data and Evaluate Hit Ratio` in `Running on Single Machine`.
## PaddleFL-MPC YoutubeDNN Demo运行说明
(简体中文|[English](./README.md))
本示例介绍基于PaddleFL-MPC进行YoutubeDNN模型训练和预测的使用说明,分为单机运行和多机运行两种方式。
### 一. 单机运行
#### 1. 准备数据
使用`process_data.py`脚本中的`gen_cypher_sample()`产生加密训练数据和测试数据,用户可以直接运行脚本`python process_data.py`在指定的目录下(比如`./mpc_data/`)产生加密特征和标签。在指定目录下生成对应于3个计算party的feature和label的加密数据文件,以后缀名区分属于不同party的数据。
#### 2. 使用shell脚本启动demo
运行demo之前,需设置以下环境变量:
```
export PYTHON=/yor/python
export PATH_TO_REDIS_BIN=/path/to/redis_bin
export LOCALHOST=/your/localhost
export REDIS_PORT=/your/redis/port
```
然后使用`run_standalone.sh`脚本,启动并运行demo,命令如下:
```bash 
bash run_standalone.sh train_youtubednn.py
```
运行之后将在屏幕上打印训练进度:当前epoch和step,以及当前训练耗时,并在完成训练后保存参数`l4_weight`作为电影特征。
此外,在完成训练之后,demo会继续进行预测,并将预测密文结果(第三个fc的输出`l3`)保存到./mpc_data/目录下的文件中,作为对应用户的特征。
#### 3. 解密特征数据并计算命中率
首先使用`process_data.py`脚本中的`decrypt_data_to_file()`,将保存的密文电影和用户特征进行解密,并且将解密得到的明文预测结果保存到指定文件中。
然后使用`get_topk.py`脚本中的`get_topK()`计算解密的用户特征和视频特征的相似度,排序选出推荐给用户的k个电影。然后在使用`process_data.py`脚本中的`evaluate_hit_ratio()`计算命中率。
解密数据和计算命中率,可参考脚本`decrypt_and_evaluate.py`
### 二. 多机运行
#### 1. 准备数据
数据方对数据进行加密处理。具体操作和单机运行中的准备数据步骤一致。
#### 2. 分发数据
按照后缀名,将步骤1中准备好的数据分别发送到对应的计算party的./mpc_data/目录下。比如,使用scp命令,将
`*.part0`发送到party0的./mpc_data/目录下。
#### 3. 修改各计算party的train_youtubednn.py脚本
各计算party根据自己的机器环境,将脚本如下内容中的`localhost`修改为自己的IP地址:
```python
pfl_mpc.init("aby3", int(role), "localhost", server, int(port))
```
#### 4. 各计算party启动demo
**注意**:运行需要用到redis服务。为了确保redis中已保存的数据不会影响demo的运行,请在各计算party启动demo之前,使用如下命令清空redis。其中,REDIS_BIN表示redis-cli可执行程序,SERVER和PORT分别表示redis server的IP地址和端口号。
```
$REDIS_BIN -h $SERVER -p $PORT flushall
```
在各计算party分别执行以下命令,启动demo:
```
$PYTHON_EXECUTABLE train_youtubednn.py --role $PARTY_ID --server $SERVER --port $PORT
```
其中,PYTHON_EXECUTABLE表示自己安装了PaddleFL的python,PARTY_ID表示计算party的编号,值为0、1或2,SERVER和PORT分别表示redis server的IP地址和端口号。
同样地,密文电影特征和用户特征数据将会保存到`./mpc_data/`目录下的文件中。比如,在party0中将保存为文件`video_vec.part0``user_vec.part0`.
#### 5. 解密特征数据
各计算party将`./mpc_data/`目录下的`video.part*``user_vec.part*`文件发送到数据方的`./mpc_data/`目录下。然后按照`单机运行`中步骤3`解密特征数据并计算命中率`
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
import sys
import numpy as np
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--role', type=int, default=0, help='role')
parser.add_argument('--server', type=str, default='localhost', help='server ip')
parser.add_argument('--port', type=int, default=12345, help='server port')
parser.add_argument('--epochs', type=int, default=1, help='epochs')
parser.add_argument('--test_epoch', type=int, default=1, help='test_epoch')
parser.add_argument('--dataset_size', type=int, default=6040, help='dataset_size')
parser.add_argument('--batch_size', type=int, default=10, help='batch_size')
parser.add_argument('--batch_num', type=int, default=600, help='batch_num')
parser.add_argument('--use_gpu', type=int, default=0, help='whether using gpu')
parser.add_argument('--mpc_data_dir', type=str, default='./mpc_data/', help='mpc_data_dir')
parser.add_argument('--model_dir', type=str, default='./model_dir/', help='model_dir')
parser.add_argument('--watch_vec_size', type=int, default=64, help='watch_vec_size')
parser.add_argument('--search_vec_size', type=int, default=64, help='search_vec_size')
parser.add_argument('--other_feat_size', type=int, default=32, help='other_feat_size')
parser.add_argument('--output_size', type=int, default=3952, help='output_size')
parser.add_argument('--base_lr', type=float, default=0.01, help='base_lr')
parser.add_argument('--topk', type=int, default=10, help='topk')
args = parser.parse_args()
return args
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
decrypt video and user feature and evaluate hit ratio
"""
import os
import args
import get_topk
import process_data
if __name__ == '__main__':
args = args.parse_args()
# decrypt video and user feature
mpc_data_dir = args.mpc_data_dir
user_vec_filepath = mpc_data_dir + 'user_vec'
plain_user_vec_filepath = user_vec_filepath + '.csv'
if os.path.exists(plain_user_vec_filepath):
os.system('rm -rf ' + plain_user_vec_filepath)
process_data.decrypt_data_to_file(user_vec_filepath, plain_user_vec_filepath, (args.batch_size, 32))
video_vec_filepath = mpc_data_dir + 'video_vec'
plain_video_vec_filepath = video_vec_filepath + '.csv'
if os.path.exists(plain_video_vec_filepath):
os.system('rm -rf ' + plain_video_vec_filepath)
process_data.decrypt_data_to_file(video_vec_filepath, plain_video_vec_filepath, (32, args.output_size))
# compute similarity between users and videos
# compute top k videos for each user
mpc_data_dir = args.mpc_data_dir
label_mpc_filepath = mpc_data_dir +'label_mpc'
label_actual_filepath = mpc_data_dir +'label_actual'
if os.path.exists(label_mpc_filepath):
os.system('rm -rf ' + label_mpc_filepath)
get_topk.get_topK(args, args.topk, plain_video_vec_filepath, plain_user_vec_filepath, label_actual_filepath, label_mpc_filepath)
# evaluate hit ratio
process_data.evaluate_hit_ratio(label_actual_filepath, label_mpc_filepath)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Compute the similarity of videos and users, get topK video for each user
"""
import numpy as np
import pandas as pd
import copy
import os
import logging
import args
import process_data
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def cos_sim(vector_a, vector_b):
vector_a = np.mat(vector_a)
vector_b = np.mat(vector_b)
num = float(vector_a * vector_b.T)
denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b)
cos = num / denom
sim = 0.5 + 0.5 * cos
return sim
def get_topK(args, K, video_vec_path, user_vec_path, label_actual_filepath, label_paddle_filepath):
video_vec = pd.read_csv(video_vec_path, header=None)
user_vec = pd.read_csv(user_vec_path, header=None)
if (os.path.exists(label_paddle_filepath)):
os.system("rm -rf " + label_paddle_filepath)
for i in range(user_vec.shape[0]):
user_video_sim_list = []
for j in range(video_vec.shape[1]):
user_video_sim = cos_sim(np.array(user_vec.loc[i]), np.array(video_vec[j]))
user_video_sim_list.append(user_video_sim)
tmp_list=copy.deepcopy(user_video_sim_list)
tmp_list.sort()
max_sim_index=[[user_video_sim_list.index(one) for one in tmp_list[::-1][:K]]]
max_sim_index_vec = pd.DataFrame(max_sim_index)
max_sim_index_vec.to_csv(label_paddle_filepath, mode="a", index=False, header=0)
# for debug
process_data.evaluate_hit_ratio(label_actual_filepath, label_paddle_filepath)
if __name__ == '__main__':
args = args.parse_args()
data_dir = './paddle_data/'
get_topK(args, args.topk, data_dir + 'video_vec.csv', data_dir + 'user_vec.csv', data_dir + 'label_paddle')
import paddle
import io
import math
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
class YoutubeDNN(object):
def input_data(self, batch_size, watch_vec_size, search_vec_size, other_feat_size):
watch_vec = pfl_mpc.data(name='watch_vec', shape=[batch_size, watch_vec_size], dtype='int64')
search_vec = pfl_mpc.data(name='search_vec', shape=[batch_size, search_vec_size], dtype='int64')
other_feat = pfl_mpc.data(name='other_feat', shape=[batch_size, other_feat_size], dtype='int64')
label = pfl_mpc.data(name='label', shape=[batch_size, 3952], dtype='int64')
inputs = [watch_vec] + [search_vec] + [other_feat] + [label]
return inputs
def fc(self, tag, data, out_dim, active='relu'):
init_stddev = 1.0
scales = 1.0 / np.sqrt(data.shape[2])
mpc_one = 65536 / 3
rng = np.random.RandomState(23355)
param_shape = (1, data.shape[2], out_dim) # 256, 2304
if tag == 'l4':
param_value_float = rng.normal(loc=0.0, scale=init_stddev * scales, size=param_shape)
param_value_float_expand = np.concatenate((param_value_float, param_value_float), axis=0)
param_value = (param_value_float_expand * mpc_one).astype('int64')
initializer_l4 = pfl_mpc.initializer.NumpyArrayInitializer(param_value)
p_attr = fluid.param_attr.ParamAttr(name='%s_weight' % tag,
initializer=initializer_l4)
active = None
else:
"""
param_init=pfl_mpc.initializer.XavierInitializer(seed=23355)
p_attr = fluid.param_attr.ParamAttr(initializer=param_init)
"""
"""
param_init=pfl_mpc.initializer.XavierInitializer(uniform=False, seed=23355)
p_attr = fluid.param_attr.ParamAttr(initializer=param_init)
"""
fan_in = param_shape[1]
fan_out = param_shape[2]
scale = math.sqrt(6.0 / (fan_in + fan_out))
param_value_float = rng.normal(-1.0 * scale, 1.0 * scale, size=param_shape)
param_value_float_expand = np.concatenate((param_value_float, param_value_float), axis=0)
param_value = (param_value_float_expand * mpc_one * scale).astype('int64')
initializer_l4 = pfl_mpc.initializer.NumpyArrayInitializer(param_value)
p_attr = fluid.param_attr.ParamAttr(name='%s_weight' % tag,
initializer=initializer_l4)
b_attr = fluid.ParamAttr(name='%s_bias' % tag,
initializer=fluid.initializer.ConstantInitializer(int(0.1 * mpc_one)))
out = pfl_mpc.layers.fc(input=data,
size=out_dim,
act=active,
param_attr=p_attr,
bias_attr=b_attr,
name=tag)
return out
def net(self, inputs, output_size, layers=[32, 32, 32]):
concat_feats = fluid.layers.concat(input=inputs[:-1], axis=-1)
l1 = self.fc('l1', concat_feats, layers[0], 'relu')
l2 = self.fc('l2', l1, layers[1], 'relu')
l3 = self.fc('l3', l2, layers[2], 'relu')
l4 = self.fc('l4', l3, output_size, None)
cost, softmax = pfl_mpc.layers.softmax_with_cross_entropy(logits=l4,
label=inputs[-1],
soft_label=True,
use_relu=True,
use_long_div=False,
return_softmax=True)
avg_cost = pfl_mpc.layers.mean(cost)
return avg_cost, l3
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Prepare movielens dataset for YoutubeDNN.
"""
import numpy as np
import paddle
import os
import time
import six
import pandas as pd
import logging
from paddle_fl.mpc.data_utils import aby3
import args
import get_topk
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('fluid')
logger.setLevel(logging.INFO)
args = args.parse_args()
watch_vec_size = args.watch_vec_size
search_vec_size = args.search_vec_size
other_feat_size = args.other_feat_size
dataset_size = args.dataset_size
batch_size = args.batch_size
sample_size = args.batch_num
output_size = args.output_size # max movie id
def prepare_movielens_data(sample_size, batch_size, watch_vec_size, search_vec_size,
other_feat_size, dataset_size, label_actual_filepath):
"""
prepare movielens data
"""
watch_vecs = []
search_vecs = []
other_feats = []
labels = []
# prepare movielens data
movie_info = paddle.dataset.movielens.movie_info()
user_info = paddle.dataset.movielens.user_info()
max_user_id = paddle.dataset.movielens.max_user_id()
user_watch = np.zeros((max_user_id, watch_vec_size))
user_search = np.zeros((max_user_id, search_vec_size))
user_feat = np.zeros((max_user_id, other_feat_size))
user_labels = np.zeros((max_user_id, 1))
MOVIE_EMBED_TAB_HEIGHT = paddle.dataset.movielens.max_movie_id()
MOVIE_EMBED_TAB_WIDTH = watch_vec_size
JOB_EMBED_TAB_HEIGHT = paddle.dataset.movielens.max_job_id() + 1
JOB_EMBED_TAB_WIDTH = paddle.dataset.movielens.max_job_id() + 1
AGE_EMBED_TAB_HEIGHT = len(paddle.dataset.movielens.age_table)
AGE_EMBED_TAB_WIDTH = len(paddle.dataset.movielens.age_table)
GENDER_EMBED_TAB_HEIGHT = 2
GENDER_EMBED_TAB_WIDTH = 4
np.random.seed(1)
MOVIE_EMBED_TAB = np.zeros((MOVIE_EMBED_TAB_HEIGHT, MOVIE_EMBED_TAB_WIDTH))
AGE_EMBED_TAB = np.zeros((AGE_EMBED_TAB_HEIGHT, AGE_EMBED_TAB_WIDTH))
GENDER_EMBED_TAB = np.zeros((GENDER_EMBED_TAB_HEIGHT, GENDER_EMBED_TAB_WIDTH))
JOB_EMBED_TAB = np.zeros((JOB_EMBED_TAB_HEIGHT, JOB_EMBED_TAB_WIDTH))
for i in range(MOVIE_EMBED_TAB_HEIGHT):
MOVIE_EMBED_TAB[i][hash(i) % MOVIE_EMBED_TAB_WIDTH] = 1
MOVIE_EMBED_TAB[i][hash(hash(i)) % MOVIE_EMBED_TAB_WIDTH] = 1
for i in range(AGE_EMBED_TAB_HEIGHT):
AGE_EMBED_TAB[i][i] = 1
for i in range(GENDER_EMBED_TAB_HEIGHT):
GENDER_EMBED_TAB[i][i] = 1
for i in range(JOB_EMBED_TAB_HEIGHT):
JOB_EMBED_TAB[i][i] = 1
train_set_creator = paddle.dataset.movielens.train()
pre_uid = 0
movie_count = 0
user_watched_movies = [[] for i in range(dataset_size)]
for instance in train_set_creator():
uid = int(instance[0]) - 1
gender_id = int(instance[1])
age_id = int(instance[2])
job_id = int(instance[3])
mov_id = int(instance[4]) - 1
user_watched_movies[uid].append(mov_id)
user_watch[uid, :] += MOVIE_EMBED_TAB[mov_id, :]
user_labels[uid, :] = mov_id
user_feat[uid, :] = np.concatenate((JOB_EMBED_TAB[job_id, :],
GENDER_EMBED_TAB[gender_id, :],
AGE_EMBED_TAB[age_id, :]))
if uid == pre_uid:
movie_count += 1
else:
user_watch[pre_uid, :] = user_watch[pre_uid, :] / movie_count
movie_count = 1
pre_uid = uid
user_watch[pre_uid, :] = user_watch[pre_uid, :] / movie_count
user_search = user_watch
if (os.path.exists(label_actual_filepath)):
os.system('rm -rf' + label_actual_filepath)
user_watched_movies_vec = pd.DataFrame(user_watched_movies)
user_watched_movies_vec.to_csv(label_actual_filepath, mode='a', index=False, header=0)
return user_watch, user_search, user_feat, user_labels
def gen_cypher_sample(mpc_data_dir, sample_size, batch_size, output_size):
"""
prepare movielens data and encrypt
"""
logger.info('Prepare data...')
if not os.path.exists(mpc_data_dir):
os.makedirs(mpc_data_dir)
else:
os.system('rm -rf ' + mpc_data_dir + '*')
label_actual_filepath = mpc_data_dir + 'label_actual'
user_watch, user_search, user_feat, user_labels = prepare_movielens_data(
sample_size, batch_size, watch_vec_size, search_vec_size, other_feat_size, dataset_size, label_actual_filepath)
#watch_vecs = []
#search_vecs = []
#other_feat_vecs = []
#label_vecss = []
for i in range(sample_size):
watch_vec = user_watch[i * batch_size : (i + 1) * batch_size, :]
search_vec = user_search[i * batch_size : (i + 1) * batch_size, :]
other_feat_vec = user_feat[i * batch_size : (i + 1) * batch_size, :]
save_cypher(cypher_file=mpc_data_dir + 'watch_vec', vec=watch_vec)
save_cypher(cypher_file=mpc_data_dir + 'search_vec', vec=search_vec)
save_cypher(cypher_file=mpc_data_dir + 'other_feat', vec=other_feat_vec)
#watch_vecs.append(watch_vec)
#search_vecs.append(search_vec)
#other_feat_vecs.append(other_feat_vec)
label = np.zeros((batch_size, output_size))
for j in range(batch_size):
label[j, int(user_labels[j][0])] = 1
save_cypher(cypher_file=mpc_data_dir + 'label', vec=label)
#label_vecs.append(label)
#return [watch_vecs, search_vecs, other_feat_vecs, label_vecs]
def save_cypher(cypher_file, vec):
"""
save cypertext to file
"""
shares = aby3.make_shares(vec)
exts = ['.part0', '.part1', '.part2']
with open(cypher_file + exts[0], 'ab') as file0, \
open(cypher_file + exts[1], 'ab') as file1, \
open(cypher_file + exts[2], 'ab') as file2:
files = [file0, file1, file2]
for idx in six.moves.range(0, 3): # 3 parts
share = aby3.get_aby3_shares(shares, idx)
files[idx].write(share.tostring())
def load_decrypt_data(filepath, shape):
"""
load the encrypted data and reconstruct
"""
part_readers = []
for id in six.moves.range(3):
part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape))
aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2])
for instance in aby3_share_reader():
p = aby3.reconstruct(np.array(instance))
print(p)
def decrypt_data_to_file(cypher_filepath, plaintext_filepath, shape):
"""
Load the encrypted data and reconstruct.
"""
part_readers = []
for id in six.moves.range(3):
part_readers.append(
aby3.load_aby3_shares(
cypher_filepath, id=id, shape=shape))
aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1],
part_readers[2])
for instance in aby3_share_reader():
p = aby3.reconstruct(np.array(instance))
tmp = pd.DataFrame(p)
tmp.to_csv(plaintext_filepath, mode='a', index=False, header=0)
def evaluate_hit_ratio(file1, file2):
count = 0
same_count = 0
f1 = open(file1, 'r')
f2 = open(file2, 'r')
while 1:
line1 = f1.readline()
line2 = f2.readline()
if (not line1) or (not line2):
break
count += 1
set1 = set([int(float(x if x != '' and x != '\n' else 10000)) for x in line1.split(',')])
set2 = set([int(x) for x in line2.split(',')])
if len(set1.intersection(set2)) != 0:
same_count += 1
logger.info(float(same_count)/count)
if __name__ == '__main__':
gen_cypher_sample('./mpc_data/', sample_size, batch_size, output_size)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MPC YoutubeDNN Demo
"""
import numpy as np
import pandas as pd
import os
import random
import time
import logging
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle_fl.mpc as pfl_mpc
from paddle_fl.mpc.data_utils import aby3
import args
import mpc_network
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger('fluid')
logger.setLevel(logging.INFO)
def read_share(file, shape):
"""
prepare share reader
"""
ext = '.part{}'.format(args.role)
shape = (2, ) + shape
share_size = np.prod(shape) * 8 # size of int64 in bytes
def reader():
with open(file + ext, 'rb') as part_file:
share = part_file.read(share_size)
while share:
yield np.frombuffer(share, dtype=np.int64).reshape(shape)
share = part_file.read(share_size)
return reader
def train(args):
"""
train
"""
# ********************
# prepare network
pfl_mpc.init('aby3', int(args.role), 'localhost', args.server, int(args.port))
youtube_model = mpc_network.YoutubeDNN()
inputs = youtube_model.input_data(args.batch_size,
args.watch_vec_size,
args.search_vec_size,
args.other_feat_size)
loss, l3 = youtube_model.net(inputs, args.output_size, layers=[128, 64, 32])
#boundaries = [200, 500, 800, 1000]
#values = [0.05, 0.02, 0.01, 0.005, 0.001]
#lr = fluid.layers.piecewise_decay(boundaries, values)
lr = args.base_lr
sgd = pfl_mpc.optimizer.SGD(learning_rate=lr)
sgd.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# ********************
# prepare data
logger.info('Prepare data...')
mpc_data_dir = args.mpc_data_dir
if not os.path.exists(mpc_data_dir):
raise ValueError('mpc_data_dir is not found. Please prepare encrypted data.')
video_vec_filepath = mpc_data_dir + 'video_vec'
video_vec_part_filepath = video_vec_filepath + '.part{}'.format(args.role)
user_vec_filepath = mpc_data_dir + 'user_vec.csv'
user_vec_part_filepath = user_vec_filepath + '.part{}'.format(args.role)
watch_vecs = []
search_vecs = []
other_feats = []
labels = []
watch_vec_reader = read_share(file=mpc_data_dir + 'watch_vec', shape=(args.batch_size, args.watch_vec_size))
for vec in watch_vec_reader():
watch_vecs.append(vec)
search_vec_reader = read_share(file=mpc_data_dir + 'search_vec', shape=(args.batch_size, args.search_vec_size))
for vec in search_vec_reader():
search_vecs.append(vec)
other_feat_reader = read_share(file=mpc_data_dir + 'other_feat', shape=(args.batch_size, args.other_feat_size))
for vec in other_feat_reader():
other_feats.append(vec)
label_reader = read_share(file=mpc_data_dir + 'label', shape=(args.batch_size, args.output_size))
for vec in label_reader():
labels.append(vec)
# ********************
# train
logger.info('Start training...')
begin = time.time()
for epoch in range(args.epochs):
for i in range(args.batch_num):
loss_data = exe.run(fluid.default_main_program(),
feed={'watch_vec': watch_vecs[i],
'search_vec': search_vecs[i],
'other_feat': other_feats[i],
'label': np.array(labels[i])
},
return_numpy=True,
fetch_list=[loss.name])
if i % 100 == 0:
end = time.time()
logger.info('Paddle training of epoch_id: {}, batch_id: {}, batch_time: {:.5f}s'
.format(epoch, i, end-begin))
# save model
logger.info('save mpc model...')
cur_model_dir = os.path.join(args.model_dir, 'mpc_model', 'epoch_' + str(epoch + 1),
'checkpoint', 'party_{}'.format(args.role))
feed_var_names = ['watch_vec', 'search_vec', 'other_feat']
fetch_vars = [l3]
fluid.io.save_inference_model(cur_model_dir, feed_var_names, fetch_vars, exe)
# save all video vector
video_array = np.array(fluid.global_scope().find_var('l4_weight').get_tensor())
if os.path.exists(video_vec_part_filepath):
os.system('rm -rf ' + video_vec_part_filepath)
with open(video_vec_part_filepath, 'w') as f:
f.write(np.array(video_array).tostring())
end = time.time()
logger.info('MPC training of epoch_num: {}, batch_time: {:.5f}s'
.format(args.epochs, end-begin))
def infer(args):
"""
infer
"""
logger.info('Start inferring...')
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
cur_model_path = os.path.join(args.model_dir, 'mpc_model', 'epoch_' + str(args.test_epoch),
'checkpoint', 'party_{}'.format(args.role))
with fluid.scope_guard(fluid.Scope()):
pfl_mpc.init('aby3', args.role, 'localhost', args.server, args.port)
infer_program, feed_target_names, fetch_vars = aby3.load_mpc_model(exe=exe,
mpc_model_dir=cur_model_path,
mpc_model_filename='__model__',
inference=True)
mpc_data_dir = args.mpc_data_dir
user_vec_filepath = mpc_data_dir + 'user_vec'
user_vec_part_filepath = user_vec_filepath + '.part{}'.format(args.role)
sample_batch = args.batch_size
watch_vecs = []
search_vecs = []
other_feats = []
watch_vec_reader = read_share(file=mpc_data_dir + 'watch_vec', shape=(sample_batch, args.watch_vec_size))
for vec in watch_vec_reader():
watch_vecs.append(vec)
search_vec_reader = read_share(file=mpc_data_dir + 'search_vec', shape=(sample_batch, args.search_vec_size))
for vec in search_vec_reader():
search_vecs.append(vec)
other_feat_reader = read_share(file=mpc_data_dir + 'other_feat', shape=(sample_batch, args.other_feat_size))
for vec in other_feat_reader():
other_feats.append(vec)
if os.path.exists(user_vec_part_filepath):
os.system('rm -rf ' + user_vec_part_filepath)
for i in range(args.batch_num):
l3 = exe.run(infer_program,
feed={
'watch_vec': watch_vecs[i],
'search_vec': search_vecs[i],
'other_feat': other_feats[i],
},
return_numpy=True,
fetch_list=fetch_vars)
with open(user_vec_part_filepath, 'a+') as f:
f.write(np.array(l3[0]).tostring())
if __name__ == '__main__':
args = args.parse_args()
logger.info(
'use_gpu: {}, batch_size: {}, epochs: {}, watch_vec_size: {}, search_vec_size: {}, ' +
'other_feat_size: {}, output_size: {}, model_dir: {}, test_epoch: {}, base_lr: {}'.format(
args.use_gpu, args.batch_size, args.epochs, args.watch_vec_size, args.search_vec_size, args.other_feat_size,
args.output_size, args.model_dir, args.test_epoch, args.base_lr))
train(args)
infer(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册