From 15c00b9f827d8712a3d45049a5af05c45cc187f4 Mon Sep 17 00:00:00 2001 From: jhjiangcs Date: Tue, 1 Sep 2020 15:47:08 +0000 Subject: [PATCH] add youtubednn demo. --- .../youtubednn_with_movielens/README.md | 81 ++++++ .../youtubednn_with_movielens/README_CN.md | 83 ++++++ .../youtubednn_with_movielens/args.py | 46 ++++ .../decrypt_and_evaluate.py | 49 ++++ .../youtubednn_with_movielens/get_topk.py | 67 +++++ .../youtubednn_with_movielens/mpc_network.py | 82 ++++++ .../youtubednn_with_movielens/process_data.py | 239 ++++++++++++++++++ .../train_youtubednn.py | 212 ++++++++++++++++ 8 files changed, 859 insertions(+) create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/README.md create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/README_CN.md create mode 100755 python/paddle_fl/mpc/examples/youtubednn_with_movielens/args.py create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/decrypt_and_evaluate.py create mode 100755 python/paddle_fl/mpc/examples/youtubednn_with_movielens/get_topk.py create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/mpc_network.py create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/process_data.py create mode 100644 python/paddle_fl/mpc/examples/youtubednn_with_movielens/train_youtubednn.py diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README.md b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README.md new file mode 100644 index 0000000..5083d12 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README.md @@ -0,0 +1,81 @@ +## Instructions for PaddleFL-MPC YoutubeDNN Demo + +([简体中文](./README_CN.md)|English) + +This document introduces how to run YoutubeDNN demo based on Paddle-MPC, which has two ways of running, i.e., single machine and multi machines. + +### 1. Running on Single Machine + +#### (1). Prepare Data + +Generate encrypted training and testing data utilizing `gen_cypher_sample()` in `process_data.py` script. Users can run the script with command `python process_data.py` to generate encrypted feature and label in given directory, e.g., `./mpc_data/`. Different suffix names are used for these files to indicate the ownership of different computation parties. + +#### (2). Launch Demo with A Shell Script + +You should set the env params as follow: + +``` +export PYTHON=/yor/python +export PATH_TO_REDIS_BIN=/path/to/redis_bin +export LOCALHOST=/your/localhost +export REDIS_PORT=/your/redis/port +``` + +Launch demo with the `run_standalone.sh` script. The concrete command is: + +```bash +bash run_standalone.sh train_youtubednn.py +``` + +The information of current epoch and step will be displayed on screen while training, as well as the total cost time when traning finished. + +Besides, predictions would be made in this demo once training is finished. The predictions (l3: third fc's output) with cypher text format would be save in `./mpc_data/` and the format of file name is similar to what is described in Step 1. + +#### (3). Decrypt Data and Evaluate Hit Ratio + +Decrypt the saved prediction data (video and user feature) and save the decrypted prediction results into a specified file using `decrypt_data_to_file()` in `process_data.py` script. + +The similarity of all videos and users will be evaluate with the api `get_topK()` in script `get_topk.py`, then top-K videos will be chosen for each user with api `evaluate_hit_ratio()` in script `process_data.py`. + +User can run the shell script `decrypt_and_evaluate.py` to decrypt data and evaluate hit ratio.. + +### 2. Running on Multi Machines + +#### (1). Prepare Data + +Data owner encrypts data. Concrete operations are consistent with “Prepare Data” in “Running on Single Machine”. + +#### (2). Distribute Encrypted Data + +According to the suffix of file name, distribute encrypted data files to `./mpc_data/ ` directories of all 3 computation parties. For example, send `*.part0` to `./mpc_data/` of party 0 with `scp` command. + +#### (3). Modify train_youtubednn.py + +Each computation party modifies `localhost` in the following code as the IP address of it's machine. + +```python +pfl_mpc.init("aby3", int(role), "localhost", server, int(port)) +``` + +#### (4). Launch Demo on Each Party + +**Note** that Redis service is necessary for demo running. Remember to clear the cache of Redis server before launching demo on each computation party, in order to avoid any negative influences caused by the cached records in Redis. The following command can be used for clear Redis, where REDIS_BIN is the executable binary of redis-cli, SERVER and PORT represent the IP and port of Redis server respectively. + +``` +$REDIS_BIN -h $SERVER -p $PORT flushall +``` + +Launch demo on each computation party with the following command, + +``` +$PYTHON_EXECUTABLE train_youtubednn.py --role $PARTY_ID --server $SERVER --port $PORT +``` + +where PYTHON_EXECUTABLE is the python which installs PaddleFL, PARTY_ID is the ID of computation party, which is 0, 1, or 2, SERVER and PORT represent the IP and port of Redis server respectively. + +Similarly, video and user feature in cypher text format would be saved in `./mpc_data/` directory, for example, `video_vec.part0` and `user_vec.part0` will be saved for party 0. + +#### (5). Decrypt Feature Data + +Each computation party sends `video_vec.part*` and `user_vec.part*` file in `./mpc_data/` directory to the `./mpc_infer_data/` directory of data owner. Then, `decrypt and evaluate hit ratios` are are consistent with `Decrypt Data and Evaluate Hit Ratio` in `Running on Single Machine`. + diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README_CN.md b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README_CN.md new file mode 100644 index 0000000..a755ae6 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/README_CN.md @@ -0,0 +1,83 @@ +## PaddleFL-MPC YoutubeDNN Demo运行说明 + +(简体中文|[English](./README.md)) + +本示例介绍基于PaddleFL-MPC进行YoutubeDNN模型训练和预测的使用说明,分为单机运行和多机运行两种方式。 + +### 一. 单机运行 + +#### 1. 准备数据 + +使用`process_data.py`脚本中的`gen_cypher_sample()`产生加密训练数据和测试数据,用户可以直接运行脚本`python process_data.py`在指定的目录下(比如`./mpc_data/`)产生加密特征和标签。在指定目录下生成对应于3个计算party的feature和label的加密数据文件,以后缀名区分属于不同party的数据。 + +#### 2. 使用shell脚本启动demo + +运行demo之前,需设置以下环境变量: + +``` +export PYTHON=/yor/python +export PATH_TO_REDIS_BIN=/path/to/redis_bin +export LOCALHOST=/your/localhost +export REDIS_PORT=/your/redis/port +``` + +然后使用`run_standalone.sh`脚本,启动并运行demo,命令如下: + +```bash  +bash run_standalone.sh train_youtubednn.py +``` + +运行之后将在屏幕上打印训练进度:当前epoch和step,以及当前训练耗时,并在完成训练后保存参数`l4_weight`作为电影特征。 + +此外,在完成训练之后,demo会继续进行预测,并将预测密文结果(第三个fc的输出`l3`)保存到./mpc_data/目录下的文件中,作为对应用户的特征。 + +#### 3. 解密特征数据并计算命中率 + +首先使用`process_data.py`脚本中的`decrypt_data_to_file()`,将保存的密文电影和用户特征进行解密,并且将解密得到的明文预测结果保存到指定文件中。 + +然后使用`get_topk.py`脚本中的`get_topK()`计算解密的用户特征和视频特征的相似度,排序选出推荐给用户的k个电影。然后在使用`process_data.py`脚本中的`evaluate_hit_ratio()`计算命中率。 + +解密数据和计算命中率,可参考脚本`decrypt_and_evaluate.py`。 + + +### 二. 多机运行 + +#### 1. 准备数据 + +数据方对数据进行加密处理。具体操作和单机运行中的准备数据步骤一致。 + +#### 2. 分发数据 + +按照后缀名,将步骤1中准备好的数据分别发送到对应的计算party的./mpc_data/目录下。比如,使用scp命令,将 + +`*.part0`发送到party0的./mpc_data/目录下。 + +#### 3. 修改各计算party的train_youtubednn.py脚本 + +各计算party根据自己的机器环境,将脚本如下内容中的`localhost`修改为自己的IP地址: + +```python +pfl_mpc.init("aby3", int(role), "localhost", server, int(port)) +``` + +#### 4. 各计算party启动demo + +**注意**:运行需要用到redis服务。为了确保redis中已保存的数据不会影响demo的运行,请在各计算party启动demo之前,使用如下命令清空redis。其中,REDIS_BIN表示redis-cli可执行程序,SERVER和PORT分别表示redis server的IP地址和端口号。 + +``` +$REDIS_BIN -h $SERVER -p $PORT flushall +``` + +在各计算party分别执行以下命令,启动demo: + +``` +$PYTHON_EXECUTABLE train_youtubednn.py --role $PARTY_ID --server $SERVER --port $PORT +``` + +其中,PYTHON_EXECUTABLE表示自己安装了PaddleFL的python,PARTY_ID表示计算party的编号,值为0、1或2,SERVER和PORT分别表示redis server的IP地址和端口号。 + +同样地,密文电影特征和用户特征数据将会保存到`./mpc_data/`目录下的文件中。比如,在party0中将保存为文件`video_vec.part0`和`user_vec.part0`. + +#### 5. 解密特征数据 + +各计算party将`./mpc_data/`目录下的`video.part*`和`user_vec.part*`文件发送到数据方的`./mpc_data/`目录下。然后按照`单机运行`中步骤3`解密特征数据并计算命中率`。 diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/args.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/args.py new file mode 100755 index 0000000..964ce85 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/args.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import distutils.util +import sys +import numpy as np + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument('--role', type=int, default=0, help='role') + parser.add_argument('--server', type=str, default='localhost', help='server ip') + parser.add_argument('--port', type=int, default=12345, help='server port') + + parser.add_argument('--epochs', type=int, default=1, help='epochs') + parser.add_argument('--test_epoch', type=int, default=1, help='test_epoch') + parser.add_argument('--dataset_size', type=int, default=6040, help='dataset_size') + parser.add_argument('--batch_size', type=int, default=10, help='batch_size') + parser.add_argument('--batch_num', type=int, default=600, help='batch_num') + parser.add_argument('--use_gpu', type=int, default=0, help='whether using gpu') + parser.add_argument('--mpc_data_dir', type=str, default='./mpc_data/', help='mpc_data_dir') + parser.add_argument('--model_dir', type=str, default='./model_dir/', help='model_dir') + parser.add_argument('--watch_vec_size', type=int, default=64, help='watch_vec_size') + parser.add_argument('--search_vec_size', type=int, default=64, help='search_vec_size') + parser.add_argument('--other_feat_size', type=int, default=32, help='other_feat_size') + parser.add_argument('--output_size', type=int, default=3952, help='output_size') + parser.add_argument('--base_lr', type=float, default=0.01, help='base_lr') + parser.add_argument('--topk', type=int, default=10, help='topk') + + args = parser.parse_args() + return args diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/decrypt_and_evaluate.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/decrypt_and_evaluate.py new file mode 100644 index 0000000..a2d64c6 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/decrypt_and_evaluate.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +decrypt video and user feature and evaluate hit ratio +""" +import os +import args +import get_topk +import process_data + +if __name__ == '__main__': + args = args.parse_args() + + # decrypt video and user feature + mpc_data_dir = args.mpc_data_dir + user_vec_filepath = mpc_data_dir + 'user_vec' + plain_user_vec_filepath = user_vec_filepath + '.csv' + if os.path.exists(plain_user_vec_filepath): + os.system('rm -rf ' + plain_user_vec_filepath) + process_data.decrypt_data_to_file(user_vec_filepath, plain_user_vec_filepath, (args.batch_size, 32)) + + video_vec_filepath = mpc_data_dir + 'video_vec' + plain_video_vec_filepath = video_vec_filepath + '.csv' + if os.path.exists(plain_video_vec_filepath): + os.system('rm -rf ' + plain_video_vec_filepath) + process_data.decrypt_data_to_file(video_vec_filepath, plain_video_vec_filepath, (32, args.output_size)) + + # compute similarity between users and videos + # compute top k videos for each user + mpc_data_dir = args.mpc_data_dir + label_mpc_filepath = mpc_data_dir +'label_mpc' + label_actual_filepath = mpc_data_dir +'label_actual' + if os.path.exists(label_mpc_filepath): + os.system('rm -rf ' + label_mpc_filepath) + get_topk.get_topK(args, args.topk, plain_video_vec_filepath, plain_user_vec_filepath, label_actual_filepath, label_mpc_filepath) + + # evaluate hit ratio + process_data.evaluate_hit_ratio(label_actual_filepath, label_mpc_filepath) diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/get_topk.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/get_topk.py new file mode 100755 index 0000000..5d5cbea --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/get_topk.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Compute the similarity of videos and users, get topK video for each user +""" + +import numpy as np +import pandas as pd +import copy +import os +import logging +import args +import process_data + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + + +def cos_sim(vector_a, vector_b): + vector_a = np.mat(vector_a) + vector_b = np.mat(vector_b) + num = float(vector_a * vector_b.T) + denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b) + cos = num / denom + sim = 0.5 + 0.5 * cos + return sim + + +def get_topK(args, K, video_vec_path, user_vec_path, label_actual_filepath, label_paddle_filepath): + video_vec = pd.read_csv(video_vec_path, header=None) + user_vec = pd.read_csv(user_vec_path, header=None) + + if (os.path.exists(label_paddle_filepath)): + os.system("rm -rf " + label_paddle_filepath) + + for i in range(user_vec.shape[0]): + user_video_sim_list = [] + for j in range(video_vec.shape[1]): + user_video_sim = cos_sim(np.array(user_vec.loc[i]), np.array(video_vec[j])) + user_video_sim_list.append(user_video_sim) + tmp_list=copy.deepcopy(user_video_sim_list) + tmp_list.sort() + max_sim_index=[[user_video_sim_list.index(one) for one in tmp_list[::-1][:K]]] + + max_sim_index_vec = pd.DataFrame(max_sim_index) + max_sim_index_vec.to_csv(label_paddle_filepath, mode="a", index=False, header=0) + + # for debug + process_data.evaluate_hit_ratio(label_actual_filepath, label_paddle_filepath) + + +if __name__ == '__main__': + args = args.parse_args() + data_dir = './paddle_data/' + get_topK(args, args.topk, data_dir + 'video_vec.csv', data_dir + 'user_vec.csv', data_dir + 'label_paddle') diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/mpc_network.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/mpc_network.py new file mode 100644 index 0000000..793b442 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/mpc_network.py @@ -0,0 +1,82 @@ + + +import paddle +import io +import math +import numpy as np +import paddle.fluid as fluid +import paddle_fl.mpc as pfl_mpc + +class YoutubeDNN(object): + def input_data(self, batch_size, watch_vec_size, search_vec_size, other_feat_size): + watch_vec = pfl_mpc.data(name='watch_vec', shape=[batch_size, watch_vec_size], dtype='int64') + search_vec = pfl_mpc.data(name='search_vec', shape=[batch_size, search_vec_size], dtype='int64') + other_feat = pfl_mpc.data(name='other_feat', shape=[batch_size, other_feat_size], dtype='int64') + label = pfl_mpc.data(name='label', shape=[batch_size, 3952], dtype='int64') + + inputs = [watch_vec] + [search_vec] + [other_feat] + [label] + + return inputs + + def fc(self, tag, data, out_dim, active='relu'): + init_stddev = 1.0 + scales = 1.0 / np.sqrt(data.shape[2]) + + mpc_one = 65536 / 3 + rng = np.random.RandomState(23355) + param_shape = (1, data.shape[2], out_dim) # 256, 2304 + + if tag == 'l4': + param_value_float = rng.normal(loc=0.0, scale=init_stddev * scales, size=param_shape) + param_value_float_expand = np.concatenate((param_value_float, param_value_float), axis=0) + param_value = (param_value_float_expand * mpc_one).astype('int64') + initializer_l4 = pfl_mpc.initializer.NumpyArrayInitializer(param_value) + p_attr = fluid.param_attr.ParamAttr(name='%s_weight' % tag, + initializer=initializer_l4) + active = None + else: + """ + param_init=pfl_mpc.initializer.XavierInitializer(seed=23355) + p_attr = fluid.param_attr.ParamAttr(initializer=param_init) + """ + """ + param_init=pfl_mpc.initializer.XavierInitializer(uniform=False, seed=23355) + p_attr = fluid.param_attr.ParamAttr(initializer=param_init) + """ + fan_in = param_shape[1] + fan_out = param_shape[2] + scale = math.sqrt(6.0 / (fan_in + fan_out)) + param_value_float = rng.normal(-1.0 * scale, 1.0 * scale, size=param_shape) + param_value_float_expand = np.concatenate((param_value_float, param_value_float), axis=0) + param_value = (param_value_float_expand * mpc_one * scale).astype('int64') + initializer_l4 = pfl_mpc.initializer.NumpyArrayInitializer(param_value) + p_attr = fluid.param_attr.ParamAttr(name='%s_weight' % tag, + initializer=initializer_l4) + + + b_attr = fluid.ParamAttr(name='%s_bias' % tag, + initializer=fluid.initializer.ConstantInitializer(int(0.1 * mpc_one))) + + out = pfl_mpc.layers.fc(input=data, + size=out_dim, + act=active, + param_attr=p_attr, + bias_attr=b_attr, + name=tag) + return out + + def net(self, inputs, output_size, layers=[32, 32, 32]): + concat_feats = fluid.layers.concat(input=inputs[:-1], axis=-1) + + l1 = self.fc('l1', concat_feats, layers[0], 'relu') + l2 = self.fc('l2', l1, layers[1], 'relu') + l3 = self.fc('l3', l2, layers[2], 'relu') + l4 = self.fc('l4', l3, output_size, None) + cost, softmax = pfl_mpc.layers.softmax_with_cross_entropy(logits=l4, + label=inputs[-1], + soft_label=True, + use_relu=True, + use_long_div=False, + return_softmax=True) + avg_cost = pfl_mpc.layers.mean(cost) + return avg_cost, l3 diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/process_data.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/process_data.py new file mode 100644 index 0000000..9082d25 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/process_data.py @@ -0,0 +1,239 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Prepare movielens dataset for YoutubeDNN. +""" +import numpy as np +import paddle +import os +import time +import six +import pandas as pd +import logging +from paddle_fl.mpc.data_utils import aby3 +import args +import get_topk + + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger('fluid') +logger.setLevel(logging.INFO) + + +args = args.parse_args() + +watch_vec_size = args.watch_vec_size +search_vec_size = args.search_vec_size +other_feat_size = args.other_feat_size +dataset_size = args.dataset_size + +batch_size = args.batch_size +sample_size = args.batch_num +output_size = args.output_size # max movie id + + +def prepare_movielens_data(sample_size, batch_size, watch_vec_size, search_vec_size, + other_feat_size, dataset_size, label_actual_filepath): + """ + prepare movielens data + """ + watch_vecs = [] + search_vecs = [] + other_feats = [] + labels = [] + + # prepare movielens data + movie_info = paddle.dataset.movielens.movie_info() + user_info = paddle.dataset.movielens.user_info() + + max_user_id = paddle.dataset.movielens.max_user_id() + user_watch = np.zeros((max_user_id, watch_vec_size)) + user_search = np.zeros((max_user_id, search_vec_size)) + user_feat = np.zeros((max_user_id, other_feat_size)) + user_labels = np.zeros((max_user_id, 1)) + + MOVIE_EMBED_TAB_HEIGHT = paddle.dataset.movielens.max_movie_id() + MOVIE_EMBED_TAB_WIDTH = watch_vec_size + + JOB_EMBED_TAB_HEIGHT = paddle.dataset.movielens.max_job_id() + 1 + JOB_EMBED_TAB_WIDTH = paddle.dataset.movielens.max_job_id() + 1 + + AGE_EMBED_TAB_HEIGHT = len(paddle.dataset.movielens.age_table) + AGE_EMBED_TAB_WIDTH = len(paddle.dataset.movielens.age_table) + + GENDER_EMBED_TAB_HEIGHT = 2 + GENDER_EMBED_TAB_WIDTH = 4 + + np.random.seed(1) + + MOVIE_EMBED_TAB = np.zeros((MOVIE_EMBED_TAB_HEIGHT, MOVIE_EMBED_TAB_WIDTH)) + AGE_EMBED_TAB = np.zeros((AGE_EMBED_TAB_HEIGHT, AGE_EMBED_TAB_WIDTH)) + GENDER_EMBED_TAB = np.zeros((GENDER_EMBED_TAB_HEIGHT, GENDER_EMBED_TAB_WIDTH)) + JOB_EMBED_TAB = np.zeros((JOB_EMBED_TAB_HEIGHT, JOB_EMBED_TAB_WIDTH)) + + + for i in range(MOVIE_EMBED_TAB_HEIGHT): + MOVIE_EMBED_TAB[i][hash(i) % MOVIE_EMBED_TAB_WIDTH] = 1 + MOVIE_EMBED_TAB[i][hash(hash(i)) % MOVIE_EMBED_TAB_WIDTH] = 1 + + for i in range(AGE_EMBED_TAB_HEIGHT): + AGE_EMBED_TAB[i][i] = 1 + + for i in range(GENDER_EMBED_TAB_HEIGHT): + GENDER_EMBED_TAB[i][i] = 1 + + for i in range(JOB_EMBED_TAB_HEIGHT): + JOB_EMBED_TAB[i][i] = 1 + + train_set_creator = paddle.dataset.movielens.train() + + pre_uid = 0 + movie_count = 0 + user_watched_movies = [[] for i in range(dataset_size)] + for instance in train_set_creator(): + uid = int(instance[0]) - 1 + gender_id = int(instance[1]) + age_id = int(instance[2]) + job_id = int(instance[3]) + mov_id = int(instance[4]) - 1 + user_watched_movies[uid].append(mov_id) + user_watch[uid, :] += MOVIE_EMBED_TAB[mov_id, :] + user_labels[uid, :] = mov_id + + user_feat[uid, :] = np.concatenate((JOB_EMBED_TAB[job_id, :], + GENDER_EMBED_TAB[gender_id, :], + AGE_EMBED_TAB[age_id, :])) + + if uid == pre_uid: + movie_count += 1 + else: + user_watch[pre_uid, :] = user_watch[pre_uid, :] / movie_count + movie_count = 1 + pre_uid = uid + user_watch[pre_uid, :] = user_watch[pre_uid, :] / movie_count + + user_search = user_watch + + if (os.path.exists(label_actual_filepath)): + os.system('rm -rf' + label_actual_filepath) + user_watched_movies_vec = pd.DataFrame(user_watched_movies) + user_watched_movies_vec.to_csv(label_actual_filepath, mode='a', index=False, header=0) + + return user_watch, user_search, user_feat, user_labels + + +def gen_cypher_sample(mpc_data_dir, sample_size, batch_size, output_size): + """ + prepare movielens data and encrypt + """ + logger.info('Prepare data...') + if not os.path.exists(mpc_data_dir): + os.makedirs(mpc_data_dir) + else: + os.system('rm -rf ' + mpc_data_dir + '*') + label_actual_filepath = mpc_data_dir + 'label_actual' + user_watch, user_search, user_feat, user_labels = prepare_movielens_data( + sample_size, batch_size, watch_vec_size, search_vec_size, other_feat_size, dataset_size, label_actual_filepath) + + #watch_vecs = [] + #search_vecs = [] + #other_feat_vecs = [] + #label_vecss = [] + + for i in range(sample_size): + watch_vec = user_watch[i * batch_size : (i + 1) * batch_size, :] + search_vec = user_search[i * batch_size : (i + 1) * batch_size, :] + other_feat_vec = user_feat[i * batch_size : (i + 1) * batch_size, :] + save_cypher(cypher_file=mpc_data_dir + 'watch_vec', vec=watch_vec) + save_cypher(cypher_file=mpc_data_dir + 'search_vec', vec=search_vec) + save_cypher(cypher_file=mpc_data_dir + 'other_feat', vec=other_feat_vec) + #watch_vecs.append(watch_vec) + #search_vecs.append(search_vec) + #other_feat_vecs.append(other_feat_vec) + label = np.zeros((batch_size, output_size)) + for j in range(batch_size): + label[j, int(user_labels[j][0])] = 1 + save_cypher(cypher_file=mpc_data_dir + 'label', vec=label) + #label_vecs.append(label) + #return [watch_vecs, search_vecs, other_feat_vecs, label_vecs] + + +def save_cypher(cypher_file, vec): + """ + save cypertext to file + """ + shares = aby3.make_shares(vec) + exts = ['.part0', '.part1', '.part2'] + with open(cypher_file + exts[0], 'ab') as file0, \ + open(cypher_file + exts[1], 'ab') as file1, \ + open(cypher_file + exts[2], 'ab') as file2: + files = [file0, file1, file2] + for idx in six.moves.range(0, 3): # 3 parts + share = aby3.get_aby3_shares(shares, idx) + files[idx].write(share.tostring()) + + +def load_decrypt_data(filepath, shape): + """ + load the encrypted data and reconstruct + """ + part_readers = [] + for id in six.moves.range(3): + part_readers.append(aby3.load_aby3_shares(filepath, id=id, shape=shape)) + aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], part_readers[2]) + + for instance in aby3_share_reader(): + p = aby3.reconstruct(np.array(instance)) + print(p) + + +def decrypt_data_to_file(cypher_filepath, plaintext_filepath, shape): + """ + Load the encrypted data and reconstruct. + + """ + part_readers = [] + for id in six.moves.range(3): + part_readers.append( + aby3.load_aby3_shares( + cypher_filepath, id=id, shape=shape)) + aby3_share_reader = paddle.reader.compose(part_readers[0], part_readers[1], + part_readers[2]) + + for instance in aby3_share_reader(): + p = aby3.reconstruct(np.array(instance)) + tmp = pd.DataFrame(p) + tmp.to_csv(plaintext_filepath, mode='a', index=False, header=0) + + +def evaluate_hit_ratio(file1, file2): + count = 0 + same_count = 0 + f1 = open(file1, 'r') + f2 = open(file2, 'r') + while 1: + line1 = f1.readline() + line2 = f2.readline() + if (not line1) or (not line2): + break + count += 1 + set1 = set([int(float(x if x != '' and x != '\n' else 10000)) for x in line1.split(',')]) + set2 = set([int(x) for x in line2.split(',')]) + if len(set1.intersection(set2)) != 0: + same_count += 1 + logger.info(float(same_count)/count) + + +if __name__ == '__main__': + gen_cypher_sample('./mpc_data/', sample_size, batch_size, output_size) diff --git a/python/paddle_fl/mpc/examples/youtubednn_with_movielens/train_youtubednn.py b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/train_youtubednn.py new file mode 100644 index 0000000..45aa628 --- /dev/null +++ b/python/paddle_fl/mpc/examples/youtubednn_with_movielens/train_youtubednn.py @@ -0,0 +1,212 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +MPC YoutubeDNN Demo +""" + +import numpy as np +import pandas as pd +import os +import random +import time +import logging +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle_fl.mpc as pfl_mpc +from paddle_fl.mpc.data_utils import aby3 + +import args +import mpc_network + + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger('fluid') +logger.setLevel(logging.INFO) + + +def read_share(file, shape): + """ + prepare share reader + """ + ext = '.part{}'.format(args.role) + shape = (2, ) + shape + share_size = np.prod(shape) * 8 # size of int64 in bytes + def reader(): + with open(file + ext, 'rb') as part_file: + share = part_file.read(share_size) + while share: + yield np.frombuffer(share, dtype=np.int64).reshape(shape) + share = part_file.read(share_size) + return reader + + +def train(args): + """ + train + """ + # ******************** + # prepare network + pfl_mpc.init('aby3', int(args.role), 'localhost', args.server, int(args.port)) + youtube_model = mpc_network.YoutubeDNN() + inputs = youtube_model.input_data(args.batch_size, + args.watch_vec_size, + args.search_vec_size, + args.other_feat_size) + loss, l3 = youtube_model.net(inputs, args.output_size, layers=[128, 64, 32]) + + #boundaries = [200, 500, 800, 1000] + #values = [0.05, 0.02, 0.01, 0.005, 0.001] + #lr = fluid.layers.piecewise_decay(boundaries, values) + lr = args.base_lr + sgd = pfl_mpc.optimizer.SGD(learning_rate=lr) + sgd.minimize(loss) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + # ******************** + # prepare data + logger.info('Prepare data...') + mpc_data_dir = args.mpc_data_dir + if not os.path.exists(mpc_data_dir): + raise ValueError('mpc_data_dir is not found. Please prepare encrypted data.') + + video_vec_filepath = mpc_data_dir + 'video_vec' + video_vec_part_filepath = video_vec_filepath + '.part{}'.format(args.role) + user_vec_filepath = mpc_data_dir + 'user_vec.csv' + user_vec_part_filepath = user_vec_filepath + '.part{}'.format(args.role) + + watch_vecs = [] + search_vecs = [] + other_feats = [] + labels = [] + + watch_vec_reader = read_share(file=mpc_data_dir + 'watch_vec', shape=(args.batch_size, args.watch_vec_size)) + for vec in watch_vec_reader(): + watch_vecs.append(vec) + + search_vec_reader = read_share(file=mpc_data_dir + 'search_vec', shape=(args.batch_size, args.search_vec_size)) + for vec in search_vec_reader(): + search_vecs.append(vec) + + other_feat_reader = read_share(file=mpc_data_dir + 'other_feat', shape=(args.batch_size, args.other_feat_size)) + for vec in other_feat_reader(): + other_feats.append(vec) + + label_reader = read_share(file=mpc_data_dir + 'label', shape=(args.batch_size, args.output_size)) + for vec in label_reader(): + labels.append(vec) + + # ******************** + # train + logger.info('Start training...') + begin = time.time() + for epoch in range(args.epochs): + for i in range(args.batch_num): + loss_data = exe.run(fluid.default_main_program(), + feed={'watch_vec': watch_vecs[i], + 'search_vec': search_vecs[i], + 'other_feat': other_feats[i], + 'label': np.array(labels[i]) + }, + return_numpy=True, + fetch_list=[loss.name]) + + if i % 100 == 0: + end = time.time() + logger.info('Paddle training of epoch_id: {}, batch_id: {}, batch_time: {:.5f}s' + .format(epoch, i, end-begin)) + # save model + logger.info('save mpc model...') + cur_model_dir = os.path.join(args.model_dir, 'mpc_model', 'epoch_' + str(epoch + 1), + 'checkpoint', 'party_{}'.format(args.role)) + feed_var_names = ['watch_vec', 'search_vec', 'other_feat'] + fetch_vars = [l3] + fluid.io.save_inference_model(cur_model_dir, feed_var_names, fetch_vars, exe) + + # save all video vector + video_array = np.array(fluid.global_scope().find_var('l4_weight').get_tensor()) + if os.path.exists(video_vec_part_filepath): + os.system('rm -rf ' + video_vec_part_filepath) + with open(video_vec_part_filepath, 'w') as f: + f.write(np.array(video_array).tostring()) + + end = time.time() + logger.info('MPC training of epoch_num: {}, batch_time: {:.5f}s' + .format(args.epochs, end-begin)) + +def infer(args): + """ + infer + """ + logger.info('Start inferring...') + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + cur_model_path = os.path.join(args.model_dir, 'mpc_model', 'epoch_' + str(args.test_epoch), + 'checkpoint', 'party_{}'.format(args.role)) + + with fluid.scope_guard(fluid.Scope()): + pfl_mpc.init('aby3', args.role, 'localhost', args.server, args.port) + infer_program, feed_target_names, fetch_vars = aby3.load_mpc_model(exe=exe, + mpc_model_dir=cur_model_path, + mpc_model_filename='__model__', + inference=True) + mpc_data_dir = args.mpc_data_dir + user_vec_filepath = mpc_data_dir + 'user_vec' + user_vec_part_filepath = user_vec_filepath + '.part{}'.format(args.role) + + sample_batch = args.batch_size + watch_vecs = [] + search_vecs = [] + other_feats = [] + + watch_vec_reader = read_share(file=mpc_data_dir + 'watch_vec', shape=(sample_batch, args.watch_vec_size)) + for vec in watch_vec_reader(): + watch_vecs.append(vec) + search_vec_reader = read_share(file=mpc_data_dir + 'search_vec', shape=(sample_batch, args.search_vec_size)) + for vec in search_vec_reader(): + search_vecs.append(vec) + other_feat_reader = read_share(file=mpc_data_dir + 'other_feat', shape=(sample_batch, args.other_feat_size)) + for vec in other_feat_reader(): + other_feats.append(vec) + + if os.path.exists(user_vec_part_filepath): + os.system('rm -rf ' + user_vec_part_filepath) + + for i in range(args.batch_num): + l3 = exe.run(infer_program, + feed={ + 'watch_vec': watch_vecs[i], + 'search_vec': search_vecs[i], + 'other_feat': other_feats[i], + }, + return_numpy=True, + fetch_list=fetch_vars) + + with open(user_vec_part_filepath, 'a+') as f: + f.write(np.array(l3[0]).tostring()) + + +if __name__ == '__main__': + args = args.parse_args() + logger.info( + 'use_gpu: {}, batch_size: {}, epochs: {}, watch_vec_size: {}, search_vec_size: {}, ' + + 'other_feat_size: {}, output_size: {}, model_dir: {}, test_epoch: {}, base_lr: {}'.format( + args.use_gpu, args.batch_size, args.epochs, args.watch_vec_size, args.search_vec_size, args.other_feat_size, + args.output_size, args.model_dir, args.test_epoch, args.base_lr)) + + train(args) + infer(args) -- GitLab