diff --git a/python/paddle_fl/mobile/Makefile b/python/paddle_fl/mobile/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..314cbe4978368392ee2d0c392f4087ddfdfa1083 --- /dev/null +++ b/python/paddle_fl/mobile/Makefile @@ -0,0 +1,2 @@ +clean: + rm -rf *~ *pyc */*~ */*.pyc diff --git a/python/paddle_fl/mobile/README.md b/python/paddle_fl/mobile/README.md new file mode 100644 index 0000000000000000000000000000000000000000..cc14e3a9c2133f87107e6e8cc2bc9e19633b7e0f --- /dev/null +++ b/python/paddle_fl/mobile/README.md @@ -0,0 +1,109 @@ + +## 联邦算法模拟器 (fl-mobile simulator) + +FL-mobile是一个集移动端算法模拟调研、训练和部署为一体的框架。算法模拟器 (simulator) 是FL-mobile的一部分。 + +该模拟器的设计目的,是为了模拟实际线上多个移动端设备配合训练的场景。框架的设计思想在服务器上模拟数个端上设备,快速验证算法效果。模拟器的优势为: + +- 支持单机和分布式训练 +- 支持常见开源数据集的训练 +- 支持模型中的私有参数和共享参数,私有参数不参与全局更新 + +## 准备工作 + +- 安装mpirun +- python安装grpc + ```shell + pip install grpcio==1.28.1 + ``` +- 安装Paddle + +## 快速开始 + +我们以Leaf数据集中的[reddit数据](https://github.com/TalwalkarLab/leaf/tree/master/data/reddit)为例,用LSTM建模,在simulator +中给出一个单机训练的例子,通过这个例子,您能了解simulator的基础用法。 + +### 准备数据 + +``` +wget https://paddle-serving.bj.bcebos.com/temporary_files_for_docker/reddit_subsampled.zip --no-check-certificate +unzip reddit_subsampled.zip +``` +在模拟器中,我们假设用户的数据都是天级别的,因此我们将下载的数据重新归置如下 + +``` +tree lm_data +lm_data +|-- 20200101 +| `-- train_data.json +|-- 20200102 +| `-- test_data.json +`-- vocab.json +``` +可以看到,我们将训练数据作为20200101的数据,测试数据作为20200102的数据。 + +### 生成server代码 + +``` +cd protos +python run_codegen.py +cd .. +``` + +### 开始训练 + +```shell +export PYTHONPATH=$PWD:$PYTHONPATH +mpirun -np 2 python application.py lm_data +``` + +### 训练结果 + +```shell +framework.py : INFO infer results: 0.085723 +``` + +即:在测试集上的,测试Top1为 8.6% + +## 添加自己的数据集和Trainer + +如果您想要训练自己的联邦模型,您需要做四件事: + +1. 创建reader,参考`reader/leaf_reddit_reader.py` +2. 创建trainer,参考`trainer/language_model_trainer.py` +3. 创建model,即组网,参考`model/language_model.py` +4. 创建application,参考`application.py` + +## 模拟器(simulator) 介绍 + +框架主要由scheduler和simulator构成;其中scheduler负责统筹规划数据和全局参数;simulator负责做实际的训练和私有参数更新。 + +- scheduler +在一次训练流程中,只会有一个global scheduler, 而每个机器上都会有一个scheduler client,负责和global scheduler做参数、数据的通信。 + +- simulator +每个机器上都会有1个Simulator,每个Simulator又会有多个shard,shard是用于本机并行训练的。 + +作为一个分布式框架,FL-mobile simulator 也是包含模型初始化、模型分发、模型训练、模型更新四部分,下面通过一次实际的训练流程,来了解一下各个模型的工作吧: + +- Step1 模型初始化 + + 1. 全局参数初始化:由编号为0的simulator来做模型初始化工作,初始化之后,它会通过UpdateGlobalParams()接口将参数传递给Scheduler; + + 2. 个性化参数初始化 + +- Step2 模型分发 + + 1. 全局参数分发:每个simulator开始训练之前,都需要先找SchedulerServer拿全局参数,即通过scheduler_client.get_global_params()获得全局参数; + + 2. 个性化参数分发:个性化参数由每个local trainer训练前,向data server获取,获取接口为:get_param_by_uid; + +- Step3 模型训练 + + 模型训练是整个流程的核心,多个trainer并行训练,训练中参数不做同步;每个trainer的训练步数由这个用户的数据量决定。 + +- Step4 模型更新 + + 1. 全局参数更新:在所有trainer训练结束后,会做一次同步,并且通过FedAvg算法计算参数梯度;之后上传梯度至scheduler,再拉取新的全局参数,回到第二步; + + 2. 个性化参数更新:个性化参数更新简单一些,每个trainer调用set_param_by_uid就可完成自己的个性化参数更新; diff --git a/python/paddle_fl/mobile/__init__.py b/python/paddle_fl/mobile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/python/paddle_fl/mobile/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle_fl/mobile/application.py b/python/paddle_fl/mobile/application.py new file mode 100644 index 0000000000000000000000000000000000000000..905c2e7375482777b1a18c4fe861eab2aa1fa7c9 --- /dev/null +++ b/python/paddle_fl/mobile/application.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +For large scale mobile devices simulation, we need to consider: +- model structure and input features +- trainer, i.e. how to design training strategy for each simulation +- user sampler, i.e. users that need to participate training +- optimizer, i.e. how to update global weights + +Currently, we couple trainer and model together for simplicity +''' +from utils import FLSimRoleMaker +from framework import SimulationFramework +from trainer import LanguageModelTrainer +from optimizer import FedAvgOptimizer +from sampler import UniformSampler, Test1percentSampler +from datetime import date, timedelta +import sys +import time + +role_maker = FLSimRoleMaker() +role_maker.init_env(local_shard_num=30) +simulator = SimulationFramework(role_maker) + +language_model_trainer = LanguageModelTrainer() + +language_model_trainer.set_trainer_configs({ + "epoch": 3, + "max_steps_in_epoch": -1, + "lr": 0.1, + "batch_size": 5, +}) + +sampler = UniformSampler() +sampler.set_sample_num(30) +sampler.set_min_ins_num(1) +test_sampler = Test1percentSampler() +fed_avg_optimizer = FedAvgOptimizer(learning_rate=2.0) + +simulator.set_trainer(language_model_trainer) +simulator.set_sampler(sampler) +simulator.set_test_sampler(test_sampler) +simulator.set_fl_optimizer(fed_avg_optimizer) + +if simulator.is_scheduler(): + simulator.run_scheduler_service() +elif simulator.is_simulator(): + base_path = sys.argv[1] + dates = [] + start_date = date(2020, 1, 1) + end_date = date(2020, 1, 2) + delta = timedelta(days=1) + while start_date <= end_date: + dates.append(start_date.strftime("%Y%m%d")) + start_date += delta + + print("dates: {}".format(dates)) + + time.sleep(10) + simulator.run_simulation( + base_path, dates, sim_num_everyday=100, do_test=True, test_skip_day=1) diff --git a/python/paddle_fl/mobile/clients/__init__.py b/python/paddle_fl/mobile/clients/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f558d0d0b945caa03275cf6b914a893089d745ca --- /dev/null +++ b/python/paddle_fl/mobile/clients/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data_client_impl import DataClient +from .scheduler_client_impl import SchedulerClient diff --git a/python/paddle_fl/mobile/clients/data_client_impl.py b/python/paddle_fl/mobile/clients/data_client_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..098c7565bc524c94fd0b19ff73daa32871a6a23a --- /dev/null +++ b/python/paddle_fl/mobile/clients/data_client_impl.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function +import grpc +import servers.data_server_pb2 as data_server_pb2 +import servers.data_server_pb2_grpc as data_server_pb2_grpc +from concurrent import futures +from multiprocessing import Process +from utils.hdfs_utils import HDFSClient, multi_download +import time +import sys +import os +import xxhash +import numpy as np +from utils.logger import logging + + +class DataClient(object): + def __init__(self): + self.stub_list = [] + self.load_data_into_patch = None + + def uid_shard(self, uid): + try: + uid_hash = xxhash.xxh32(str(uid), seed=101).intdigest() + except: + return -1 + shard_idx = uid_hash % len(self.stub_list) + return shard_idx + + # should set all params to numpy array with shape and dtype + # buggy here + def set_param_by_uid(self, uid, param_dict): + shard_idx = self.uid_shard(uid) + if shard_idx == -1: + return -1 + user_param = data_server_pb2.UserParams() + user_param.uid = uid + for key in param_dict: + param = data_server_pb2.Param() + param.name = key + np_var = param_dict[param.name] + param.shape.extend(np_var.shape) + param.weight.extend(np_var.ravel()) + user_param.user_params.extend([param]) + + call_future = self.stub_list[shard_idx].UpdateUserParams.future( + user_param) + err_code = call_future.result().err_code + return err_code + + def get_param_by_uid(self, uid): + shard_idx = self.uid_shard(uid) + if shard_idx == -1: + return -1 + data = data_server_pb2.Data() + data.uid = uid + call_future = self.stub_list[shard_idx].GetUserParams.future(data) + user_params = call_future.result() + param_dict = {} + for param in user_params.user_params: + param_dict[param.name] = np.array( + list(param.weight), dtype=np.float32) + param_dict[param.name].shape = list(param.shape) + return param_dict + + def clear_user_data(self, date): + def clear(): + for stub in self.stub_list: + data = data_server_pb2.Data() + data.date = date + call_future = stub.ClearUserData.future(data) + res = call_future.result() + + p = Process(target=clear, args=()) + p.start() + p.join() + + def get_data_by_uid(self, uid, date): + shard_idx = self.uid_shard(uid) + if shard_idx == -1: + return -1 + data = data_server_pb2.Data() + data.uid = uid + data.date = date + call_future = self.stub_list[shard_idx].GetUserData.future(data) + user_data_list = [] + for item in call_future.result().line_str: + user_data_list.append(item) + return user_data_list + + def set_data_server_endpoints(self, endpoints): + self.stub_list = [] + for ep in endpoints: + options = [('grpc.max_message_length', 1024 * 1024 * 1024), + ('grpc.max_receive_message_length', 1024 * 1024 * 1024)] + channel = grpc.insecure_channel(ep, options=options) + stub = data_server_pb2_grpc.DataServerStub(channel) + self.stub_list.append(stub) + + def global_shuffle_by_patch(self, data_patch, date, concurrency): + shuffle_time = len(data_patch) / concurrency + 1 + for i in range(shuffle_time): + if i * concurrency >= len(data_patch): + break + pros = [] + end = min((i + 1) * concurrency, len(data_patch)) + patch_list = data_patch[i * concurrency:end] + width = len(patch_list) + for j in range(width): + p = Process( + target=self.send_one_patch, args=(patch_list[j], date)) + pros.append(p) + for p in pros: + p.start() + for p in pros: + p.join() + logging.info("shuffle round {} done.".format(i)) + + def send_one_patch(self, patch, date): + for line in patch: + group = line.strip().split("\t") + if len(group) != 3: + continue + data = data_server_pb2.Data() + data.uid = group[0] + data.date = date + data.line = line.strip() + stub_idx = self.uid_shard(data.uid) + if stub_idx == -1: + logging.info("send_one_patch continue for uid: %s" % data.uid) + continue + call_future = self.stub_list[stub_idx].SendData.future(data) + u_num = call_future.result() + + def global_shuffle_by_file(self, filelist, concurrency): + pass + + def set_load_data_into_patch_func(self, func): + self.load_data_into_patch = func + + def get_local_files(self, + base_path, + date, + node_idx, + node_num, + hdfs_configs=None): + full_path = "{}/{}".format(base_path, date) + if os.path.exists(full_path): + file_list = os.listdir(full_path) + local_files = ["{}/{}".format(full_path, x) for x in file_list] + elif hdfs_configs is not None: + local_files = self.download_from_hdfs(hdfs_configs, base_path, + date, node_idx, node_num) + else: + local_files = [] + return local_files + + def download_from_hdfs(self, hdfs_configs, base_path, date, node_idx, + node_num): + # return local filelist + hdfs_client = HDFSClient("$HADOOP_HOME", hdfs_configs) + multi_download( + hdfs_client, + "{}/{}".format(base_path, date), + date, + node_idx, + node_num, + multi_processes=30) + filelist = os.listdir(date) + files = ["{}/{}".format(date, fn) for fn in filelist] + return files + + +def test_global_shuffle(): + data_client = DataClient() + server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)] + data_client.set_data_server_endpoints(server_endpoints) + date = "0330" + file_name = ["data_with_uid/part-01991"] + with open(file_name[0]) as fin: + for line in fin: + group = line.strip().split("\t") + uid = group[0] + user_data_dict = data_client.get_data_by_uid(uid) + + +def test_set_param(): + data_client = DataClient() + server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)] + data_client.set_data_server_endpoints(server_endpoints) + uid = ["1001", "10001", "100001", "101"] + param_dict = {"w0": [1.0, 1.1, 1.2, 1.3], "b0": [1.1, 1.2, 1.3, 1.5]} + for cur_i in uid: + data_client.set_param_by_uid(cur_i, param_dict) + + +def test_get_param(): + data_client = DataClient() + server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)] + data_client.set_data_server_endpoints(server_endpoints) + uid = ["1001", "10001", "100001", "101"] + for cur_i in uid: + param_dict = data_client.get_param_by_uid(cur_i) + print(param_dict) + + +if __name__ == "__main__": + #load_data_global_shuffle() + #test_global_shuffle() + test_set_param() + test_get_param() diff --git a/python/paddle_fl/mobile/clients/scheduler_client_impl.py b/python/paddle_fl/mobile/clients/scheduler_client_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..71898d0ad43e8e3f17eea954ef67ff6e6f8ffbae --- /dev/null +++ b/python/paddle_fl/mobile/clients/scheduler_client_impl.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function +import grpc +import servers.scheduler_server_pb2 as scheduler_server_pb2 +import servers.scheduler_server_pb2_grpc as scheduler_server_pb2_grpc +import servers.data_server_pb2_grpc as data_server_pb2_grpc +import servers.data_server_pb2 as data_server_pb2 +import numpy as np +from concurrent import futures +from multiprocessing import Process + +import time +import sys +import os + + +class SchedulerClient(object): + def __init__(self): + self.stub = None + self.stub_list = [] + self.global_param_info = {} + + def update_user_inst_num(self, date, user_info_dict): + user_info = scheduler_server_pb2.UserInstInfo() + for key in user_info_dict: + single_user_info = scheduler_server_pb2.UserInstNum() + single_user_info.uid = key + single_user_info.inst_num = user_info_dict[key] + user_info.inst_nums.extend([single_user_info]) + user_info.shard_num = len(self.stub_list) + user_info.date = date + call_future = self.stub.UpdateUserInstNum.future(user_info) + res = call_future.result() + return res.err_code + + def set_scheduler_server_endpoints(self, endpoints): + options = [('grpc.max_message_length', 1024 * 1024 * 1024), + ('grpc.max_receive_message_length', 1024 * 1024 * 1024)] + channel = grpc.insecure_channel(endpoints[0], options=options) + self.stub = scheduler_server_pb2_grpc.SchedulerServerStub(channel) + + def set_data_server_endpoints(self, endpoints): + self.stub_list = [] + for ep in endpoints: + options = [('grpc.max_message_length', 1024 * 1024 * 1024), + ('grpc.max_receive_message_length', 1024 * 1024 * 1024)] + channel = grpc.insecure_channel(ep, options=options) + stub = data_server_pb2_grpc.DataServerStub(channel) + self.stub_list.append(stub) + + def uniform_sample_user_list(self, date, node_id, sample_num, shard_num, + node_num, min_ins_num): + user_info_dict = {} + req = scheduler_server_pb2.Request() + req.node_idx = node_id + req.sample_num = sample_num + req.shard_num = shard_num + req.node_num = node_num + req.date = date + req.min_ins_num = min_ins_num + call_future = self.stub.SampleUsersToTrain.future(req) + user_info = call_future.result() + for user in user_info.inst_nums: + user_info_dict[user.uid] = user.inst_num + return user_info_dict + + def hash_sample_user_list(self, date, node_id, sample_num, shard_num, + node_num): + user_info_dict = {} + req = scheduler_server_pb2.Request() + req.node_idx = node_id + req.sample_num = sample_num + req.shard_num = shard_num + req.node_num = node_num + req.date = date + call_future = self.stub.SampleUsersWithHash.future(req) + user_info = call_future.result() + for user in user_info.inst_nums: + user_info_dict[user.uid] = user.inst_num + return user_info_dict + + def sample_test_user_list(self, date, node_id, shard_num, node_num): + user_info_dict = {} + req = scheduler_server_pb2.Request() + req.node_idx = node_id + req.shard_num = shard_num + req.node_num = node_num + req.date = date + call_future = self.stub.SampleUsersToTest.future(req) + user_info = call_future.result() + for user in user_info.inst_nums: + user_info_dict[user.uid] = user.inst_num + return user_info_dict + + def fixed_sample_user_list(self, date, node_id, sample_num, shard_num, + node_num): + user_info_dict = {} + req = scheduler_server_pb2.Request() + req.node_idx = node_id + req.sample_num = sample_num + req.shard_num = shard_num + req.node_num = node_num + req.date = date + call_future = self.stub.FixedUsersToTrain.future(req) + user_info = call_future.result() + for user in user_info.inst_nums: + user_info_dict[user.uid] = user.inst_num + return user_info_dict + + def get_global_params(self): + req = scheduler_server_pb2.Request() + req.node_idx = 0 + req.sample_num = 0 + req.shard_num = 0 + req.node_num = 0 + call_future = self.stub.GetGlobalParams.future(req) + global_p = call_future.result() + result_dict = {} + for param in global_p.global_params: + result_dict[param.name] = np.array( + list(param.weight), dtype=np.float32) + result_dict[param.name].shape = param.shape + return result_dict + + def update_global_params(self, global_params): + global_p = scheduler_server_pb2.GlobalParams() + for key in global_params: + param = scheduler_server_pb2.Param() + param.name = key + var, shape = global_params[key], global_params[key].shape + self.global_param_info[key] = shape + param.weight.extend(var.ravel()) + param.shape.extend(shape) + global_p.global_params.extend([param]) + call_future = self.stub.UpdateGlobalParams.future(global_p) + res = call_future.result() + return res.err_code + + def fedavg_update(self, global_param_delta_dict): + global_p = scheduler_server_pb2.GlobalParams() + for key in global_param_delta_dict: + param = scheduler_server_pb2.Param() + param.name = key + parameter_delta, shape = global_param_delta_dict[ + param.name], global_param_delta_dict[param.name].shape + param.weight.extend(parameter_delta.ravel()) + param.shape.extend(shape) + global_p.global_params.extend([param]) + call_future = self.stub.FedAvgUpdate.future(global_p) + res = call_future.result() + return res.err_code + + def stop_scheduler_server(self): + empty_input = scheduler_server_pb2.SchedulerServerEmptyInput() + call_future = self.stub.Exit.future(empty_input) + res = call_future.result() + + +def test_uniform_sample_user_list(): + client = SchedulerClient() + client.set_scheduler_server_endpoints(["127.0.0.1:60001"]) + # buggy + #user_list = client.get_user_list() + global_user_list = [] + for i in range(10000): + global_user_list.append((str(i), 100)) + client.update_user_inst_num(global_user_list) + user_info = client.uniform_sample_user_list(0) + user_list = [("101", 100), ("102", 100), ("103", 10000)] + global_param = {"w0": [1.0, 1.0, 1.0], "w1": [2.0, 2.0, 2.0]} + client.update_global_params(global_param) + fetched_params = client.get_global_params() + + +def test_get_global_params(): + client = SchedulerClient() + client.set_scheduler_server_endpoints(["127.0.0.1:60001"]) + global_param = {"w0": [1.0, 1.0, 1.0], "w1": [2.0, 2.0, 2.0]} + client.update_global_params(global_param) + fetched_params = client.get_global_params() + print(fetched_params) + + +def test_update_global_params(): + client = SchedulerClient() + client.set_scheduler_server_endpoints(["127.0.0.1:60001"]) + global_param = {"w0": [1.0, 1.0, 1.0], "w1": [3.0, 3.0, 3.0]} + client.update_global_params(global_param) + fetched_params = client.get_global_params() + print(fetched_params) + + +if __name__ == "__main__": + test_uniform_sample_user_list() + test_update_global_params() + test_get_global_params() diff --git a/python/paddle_fl/mobile/framework.py b/python/paddle_fl/mobile/framework.py new file mode 100644 index 0000000000000000000000000000000000000000..8b56a8baa8a624b3bbff5f98cb45ec9b6bd3eff4 --- /dev/null +++ b/python/paddle_fl/mobile/framework.py @@ -0,0 +1,354 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from utils.role_maker import FLSimRoleMaker +from clients import DataClient +from clients import SchedulerClient +from servers import DataServer +from servers import SchedulerServer +from multiprocessing import Process, Pool, Manager, Pipe, Lock +from utils.logger import logging +import pickle +import time +import numpy as np +import sys +import os + + +class SimulationFramework(object): + def __init__(self, role_maker): + self.data_client = None + self.scheduler_client = None + self.role_maker = role_maker + # we suppose currently we train homogeneous model + self.trainer = None + # for sampling users + self.sampler = None + # for update global weights + self.fl_optimizer = None + # for samping users to test + self.test_sampler = None + self.profile_file = open("profile", "w") + self.do_profile = True + # for data downloading + self.hdfs_configs = None + + def set_hdfs_configs(self, configs): + self.hdfs_configs = configs + + def set_trainer(self, trainer): + self.trainer = trainer + + def set_sampler(self, sampler): + self.sampler = sampler + + def set_test_sampler(self, sampler): + self.test_sampler = sampler + + def set_fl_optimizer(self, optimizer): + self.fl_optimizer = optimizer + + def is_scheduler(self): + return self.role_maker.is_global_scheduler() + + def is_simulator(self): + return self.role_maker.is_simulator() + + def run_scheduler_service(self): + if self.role_maker.is_global_scheduler(): + self._run_global_scheduler() + + def _barrier_simulators(self): + self.role_maker.barrier_simulator() + + def _start_data_server(self, endpoint): + data_server = DataServer() + port = endpoint.split(":")[1] + data_server.start(endpoint=port) + + def _run_global_scheduler(self): + scheduler_server = SchedulerServer() + endpoint = self.role_maker.get_global_scheduler_endpoint() + port = endpoint.split(":")[1] + scheduler_server.start(endpoint=port) + + def _get_data_services(self): + data_server_endpoints = \ + self.role_maker.get_local_data_server_endpoint() + data_server_pros = [] + for i, ep in enumerate(data_server_endpoints): + p = Process(target=self._start_data_server, args=(ep, )) + data_server_pros.append(p) + + return data_server_pros + + def _profile(self, func, *args, **kwargs): + if self.do_profile: + start = time.time() + res = func(*args, **kwargs) + end = time.time() + self.profile_file.write("%s\t\t%f s\n" % + (func.__name__, end - start)) + return res + else: + return func(*args, **kwargs) + + def _run_sim(self, date, sim_num_everyday=1): + sim_idx = self.role_maker.simulator_idx() + sim_num = self.role_maker.simulator_num() + sim_all_trainer_run_time = 0 + sim_read_praram_and_optimize = 0 + for sim in range(sim_num_everyday): + logging.info("sim id: %d" % sim) + # sampler algorithm + user_info_dict = self._profile( + self.sampler.sample_user_list, self.scheduler_client, date, + sim_idx, len(self.data_client.stub_list), sim_num) + + if self.do_profile: + print("sim_idx: ", sim_idx) + print("shard num: ", len(self.data_client.stub_list)) + print("sim_num: ", sim_num) + print("user_info_dict: ", user_info_dict) + + global_param_dict = self._profile( + self.scheduler_client.get_global_params) + processes = [] + os.system("rm -rf _global_param") + os.system("mkdir _global_param") + start = time.time() + for idx, user in enumerate(user_info_dict): + arg_dict = { + "uid": str(user), + "date": date, + "data_endpoints": + self.role_maker.get_data_server_endpoints(), + "global_params": global_param_dict, + "user_param_names": self.trainer.get_user_param_names(), + "global_param_names": + self.trainer.get_global_param_names(), + "write_global_param_file": + "_global_param/process_%d" % idx, + } + p = Process( + target=self.trainer.train_one_user_func, + args=(arg_dict, self.trainer.trainer_config)) + p.start() + processes.append(p) + if self.do_profile: + logging.info("wait processes to close") + for i, p in enumerate(processes): + processes[i].join() + end = time.time() + sim_all_trainer_run_time += (end - start) + + start = time.time() + train_result = [] + new_global_param_by_user = {} + training_sample_by_user = {} + for i, p in enumerate(processes): + param_dir = "_global_param/process_%d/" % i + with open(param_dir + "/_info", "r") as f: + user, train_sample_num = pickle.load(f) + param_dict = {} + for f_name in os.listdir(os.path.join(param_dir, "params")): + f_path = os.path.join(param_dir, "params", f_name) + if os.path.isdir(f_path): # layer + for layer_param in os.listdir(f_path): + layer_param_path = os.path.join(f_path, + layer_param) + with open(layer_param_path) as f: + param_dict["{}/{}".format( + f_name, layer_param)] = np.load(f) + else: + with open(f_path) as f: + param_dict[f_name] = np.load(f) + new_global_param_by_user[user] = param_dict + training_sample_by_user[user] = train_sample_num + + self.fl_optimizer.update(training_sample_by_user, + new_global_param_by_user, + global_param_dict, self.scheduler_client) + end = time.time() + sim_read_praram_and_optimize += (end - start) + if self.do_profile: + self.profile_file.write("sim_all_trainer_run_time\t\t%f s\n" % + sim_all_trainer_run_time) + self.profile_file.write("sim_read_praram_and_optimize\t\t%f s\n" % + sim_read_praram_and_optimize) + + logging.info("training done for date %s." % date) + + def _test(self, date): + if self.trainer.infer_one_user_func is None: + pass + logging.info("doing test...") + if self.test_sampler is None: + logging.error("self.test_sampler should not be None when testing") + + sim_idx = self.role_maker.simulator_idx() + sim_num = self.role_maker.simulator_num() + user_info_dict = self.test_sampler.sample_user_list( + self.scheduler_client, + date, + sim_idx, + len(self.data_client.stub_list), + sim_num, ) + if self.do_profile: + print("test user info_dict: ", user_info_dict) + global_param_dict = self.scheduler_client.get_global_params() + + def divide_chunks(l, n): + for i in range(0, len(l), n): + yield l[i:i + n] + + # at most 50 process for testing + chunk_size = 50 + # at most 100 uid for testing + max_test_uids = 100 + uid_chunks = divide_chunks(user_info_dict.keys(), chunk_size) + os.system("rm -rf _test_result") + os.system("mkdir _test_result") + + tested_uids = 0 + for uids in uid_chunks: + if tested_uids >= max_test_uids: + break + processes = [] + for user in uids: + arg_dict = { + "uid": str(user), + "date": date, + "data_endpoints": + self.role_maker.get_data_server_endpoints(), + "global_params": global_param_dict, + "user_param_names": self.trainer.get_user_param_names(), + "global_param_names": + self.trainer.get_global_param_names(), + "infer_result_dir": "_test_result/uid-%s" % user, + } + p = Process( + target=self.trainer.infer_one_user_func, + args=(arg_dict, self.trainer.trainer_config)) + p.start() + processes.append(p) + if self.do_profile: + logging.info("wait test processes to close") + for i, p in enumerate(processes): + processes[i].join() + tested_uids += chunk_size + + infer_results = [] + # only support one test metric now + for uid in os.listdir("_test_result"): + with open("_test_result/" + uid + "/res", 'r') as f: + sample_cout, metric = f.readlines()[0].strip('\n').split('\t') + infer_results.append((int(sample_cout), float(metric))) + if sum([x[0] for x in infer_results]) == 0: + logging.info("infer results: 0.0") + else: + count = sum([x[0] for x in infer_results]) + metric = sum([x[0] * x[1] for x in infer_results]) / count + logging.info("infer results: %f" % metric) + + def _save_and_upload(self, date, fs_upload_path): + if self.trainer.save_and_upload_func is None: + return + if fs_upload_path is None: + return + dfs_upload_path = fs_upload_path + date + "_" + str( + self.role_maker.simulator_idx()) + global_param_dict = self.scheduler_client.get_global_params() + arg_dict = { + "date": date, + "global_params": global_param_dict, + "user_param_names": self.trainer.get_user_param_names(), + "global_param_names": self.trainer.get_global_param_names(), + } + self.trainer.save_and_upload_func( + arg_dict, self.trainer.trainer_config, dfs_upload_path) + + def run_simulation(self, + base_path, + dates, + fs_upload_path=None, + sim_num_everyday=1, + do_test=False, + test_skip_day=6): + if not self.role_maker.is_simulator(): + pass + data_services = self._get_data_services() + for service in data_services: + service.start() + self._barrier_simulators() + self.data_client = DataClient() + self.data_client.set_load_data_into_patch_func( + self.trainer.get_load_data_into_patch_func()) + self.data_client.set_data_server_endpoints( + self.role_maker.get_data_server_endpoints()) + self.scheduler_client = SchedulerClient() + self.scheduler_client.set_data_server_endpoints( + self.role_maker.get_data_server_endpoints()) + self.scheduler_client.set_scheduler_server_endpoints( + [self.role_maker.get_global_scheduler_endpoint()]) + logging.info("trainer config: ", self.trainer.trainer_config) + self.trainer.prepare(do_test=do_test) + + if self.role_maker.simulator_idx() == 0: + self.trainer.init_global_model(self.scheduler_client) + self._barrier_simulators() + + for date_idx, date in enumerate(dates): + if date_idx > 0: + self.do_profile = False + self.profile_file.close() + logging.info("reading data for date: %s" % date) + local_files = self._profile( + self.data_client.get_local_files, + base_path, + date, + self.role_maker.simulator_idx(), + self.role_maker.simulator_num(), + hdfs_configs=self.hdfs_configs) + + logging.info("loading data into patch for date: %s" % date) + data_patch, local_user_dict = self._profile( + self.data_client.load_data_into_patch, local_files, 10000) + logging.info("shuffling data for date: %s" % date) + self._profile(self.data_client.global_shuffle_by_patch, data_patch, + date, 30) + + logging.info("updating user inst num for date: %s" % date) + self._profile(self.scheduler_client.update_user_inst_num, date, + local_user_dict) + self.role_maker.barrier_simulator() + + if do_test and date_idx != 0 and date_idx % test_skip_day == 0: + self._barrier_simulators() + self._profile(self._test, date) + self._barrier_simulators() + self._profile(self._save_and_upload, date, fs_upload_path) + + self._run_sim(date, sim_num_everyday=sim_num_everyday) + self.role_maker.barrier_simulator() + logging.info("clear user data for date: %s" % date) + self.data_client.clear_user_data(date) + + self._barrier_simulators() + logging.info("training done all date.") + logging.info("stoping scheduler") + self.scheduler_client.stop_scheduler_server() + for pro in data_services: + pro.terminate() + logging.info("after terminate for all server.") diff --git a/python/paddle_fl/mobile/get_data.sh b/python/paddle_fl/mobile/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..6b157d6b440d52053b8f4e45822001ac6929043a --- /dev/null +++ b/python/paddle_fl/mobile/get_data.sh @@ -0,0 +1,3 @@ +# reddit_subsampled +wget https://paddle-serving.bj.bcebos.com/temporary_files_for_docker/reddit_subsampled.zip --no-check-certificate +unzip reddit_subsampled.zip diff --git a/python/paddle_fl/mobile/layer/__init__.py b/python/paddle_fl/mobile/layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5e3c4c77d2ba7713538a21d98f519a4caaa89e --- /dev/null +++ b/python/paddle_fl/mobile/layer/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .layer_base import Embedding +from .layer_base import SequenceConv +from .layer_base import Concat +from .layer_base import Pooling +from .layer_base import FC +from .layer_base import CrossEntropySum +from .layer_base import ACC +from .layer_base import AUC +from .data_base import TextData diff --git a/python/paddle_fl/mobile/layer/data_base.py b/python/paddle_fl/mobile/layer/data_base.py new file mode 100644 index 0000000000000000000000000000000000000000..c211d36e063c4d9658d60e851b6cb053918b914c --- /dev/null +++ b/python/paddle_fl/mobile/layer/data_base.py @@ -0,0 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid + + +class DataBase(object): + def __init__(self): + pass + + +class TextData(object): + def __init__(self): + pass + + def create(self, name=None): + assert name is not None + return fluid.layers.data( + name=name, shape=[1], dtype='int64', lod_level=1) diff --git a/python/paddle_fl/mobile/layer/layer_base.py b/python/paddle_fl/mobile/layer/layer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..59d4d9961cbdbe995430c61675dc089ef337418f --- /dev/null +++ b/python/paddle_fl/mobile/layer/layer_base.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid + + +class LayerBase(object): + def __init__(self): + pass + + +class Embedding(LayerBase): + def __init__(self): + pass + + def attr(self, shapes, name=None): + self.vocab_size = shapes[0] + self.emb_dim = shapes[1] + self.name = name + + def forward(self, i): + is_sparse = True + if self.name is not None: + param_attr = fluid.ParamAttr(name=self.name) + else: + param_attr = None + + results = [] + emb = fluid.layers.embedding( + input=i, + is_sparse=is_sparse, + size=[self.vocab_size, self.emb_dim], + param_attr=param_attr, + padding_idx=0) + return emb + + +class SequenceConv(LayerBase): + def __init__(self): + pass + + def attr(self, name=None): + self.num_filters = 64 + self.win_size = 3 + self.name = name + + def forward(self, i): + if self.name is not None: + param_attr = fluid.ParamAttr(name=self.name) + else: + param_attr = None + conv = fluid.nets.sequence_conv_pool( + input=i, + num_filters=self.num_filters, + filter_size=self.win_size, + param_attr=param_attr, + act="tanh", + pool_type="max") + return conv + + +class Concat(LayerBase): + def forward(self, inputs): + concat = fluid.layers.concat(inputs, axis=1) + return concat + + +class Pooling(LayerBase): + def __init__(self): + self.pool_type = 'sum' + + def attr(self, pool_type='sum'): + self.pool_type = pool_type + + def forward(self, i): + pool = fluid.layers.sequence_pool(input=i, pool_type='sum') + return pool + + +class FC(LayerBase): + def __init__(self): + return + + def attr(self, shapes, act='relu', name=None): + self.name = name + self.size = shapes[0] + self.act = act + + def forward(self, i): + if self.name is not None: + param_attr = fluid.ParamAttr(name=self.name) + else: + param_attr = None + fc = fluid.layers.fc(input=i, + size=self.size, + act=self.act, + param_attr=param_attr) + return fc + + +class CrossEntropySum(LayerBase): + def __init__(self): + pass + + def forward(self, prediction, label): + cost = fluid.layers.cross_entropy(input=prediction, label=label) + sum_cost = fluid.layers.reduce_sum(cost) + return sum_cost + + +class ACC(LayerBase): + def __init__(self): + pass + + def forward(self, prediction, label): + return fluid.layers.accuracy(input=prediction, label=label) + + +class AUC(LayerBase): + def forward(self, prediction, label): + auc, batch_auc_var, auc_states = \ + fluid.layers.auc(input=prediction, label=label, + num_thresholds=2 ** 12, + slide_steps=20) + return auc, batch_auc_var diff --git a/python/paddle_fl/mobile/model/__init__.py b/python/paddle_fl/mobile/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75d9d3acc966f16e0bbe3423f617e8a10b43cb08 --- /dev/null +++ b/python/paddle_fl/mobile/model/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .language_model import LanguageModel diff --git a/python/paddle_fl/mobile/model/language_model.py b/python/paddle_fl/mobile/model/language_model.py new file mode 100644 index 0000000000000000000000000000000000000000..454c973e56073e5aa78c5fa6333fa16f7ec5965e --- /dev/null +++ b/python/paddle_fl/mobile/model/language_model.py @@ -0,0 +1,443 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from model import * +from layer import * +from .model_base import ModelBase +import paddle.fluid as fluid + +import paddle.fluid.layers as layers +import paddle.fluid as fluid +from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN +import numpy as np +from paddle.fluid import ParamAttr +from paddle.fluid.contrib.layers import basic_lstm + + +class LanguageModel(ModelBase): + def __init__(self): + # model args + self.hidden_size_ = 200 + self.vocab_size_ = 10000 + self.num_layers_ = 2 + self.num_steps_ = 10 # fix + self.init_scale_ = 0.1 + self.dropout_ = 0.0 + self.rnn_model_ = 'basic_lstm' + self.pad_symbol_ = 0 + self.unk_symbol_ = 1 + + # results + self.correct_ = None + self.prediction_ = None + self.loss_ = None + + # private vars + self.user_params_ = [] + self.program_ = None + self.startup_program_ = None + self.input_name_list_ = None + self.target_var_names_ = [] + + def get_model_input_names(self): + return self.input_name_list_ + + def get_model_loss(self): + return self.loss_ + + def get_model_loss_name(self): + return self.loss_.name + + def get_model_metrics(self): + return {"correct": self.correct_.name} + + def get_target_names(self): + return self.target_var_names_ + + def build_model(self, model_configs): + hidden_size = self.hidden_size_ + init_scale = self.init_scale_ + dropout = self.dropout_ + num_layers = self.num_layers_ + num_steps = self.num_steps_ + pad_symbol = self.pad_symbol_ + unk_symbol = self.unk_symbol_ + vocab_size = self.vocab_size_ + rnn_model = self.rnn_model_ + x = fluid.data(name="x", shape=[None, num_steps], dtype='int64') + y = fluid.data(name="y", shape=[None, num_steps], dtype='int64') + x = layers.reshape(x, shape=[-1, num_steps, 1]) + y = layers.reshape(y, shape=[-1, 1]) + self.input_name_list_ = ['x', 'y'] + + init_hidden = layers.fill_constant_batch_size_like( + input=x, + shape=[-1, num_layers, hidden_size], + value=0, + dtype="float32") + init_cell = layers.fill_constant_batch_size_like( + input=x, + shape=[-1, num_layers, hidden_size], + value=0, + dtype="float32") + + init_hidden = layers.transpose(init_hidden, perm=[1, 0, 2]) + init_cell = layers.transpose(init_cell, perm=[1, 0, 2]) + + init_hidden_reshape = layers.reshape( + init_hidden, shape=[num_layers, -1, hidden_size]) + init_cell_reshape = layers.reshape( + init_cell, shape=[num_layers, -1, hidden_size]) + + x_emb = layers.embedding( + input=x, + size=[vocab_size, hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='embedding_para', + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale))) + + x_emb = layers.reshape( + x_emb, shape=[-1, num_steps, hidden_size], inplace=True) + if dropout != None and dropout > 0.0: + x_emb = layers.dropout( + x_emb, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + if rnn_model == "padding": + rnn_out, last_hidden, last_cell = self._padding_rnn( + x_emb, + len=num_steps, + init_hidden=init_hidden_reshape, + init_cell=init_cell_reshape) + elif rnn_model == "static": + rnn_out, last_hidden, last_cell = self._encoder_static( + x_emb, + len=num_steps, + init_hidden=init_hidden_reshape, + init_cell=init_cell_reshape) + elif rnn_model == "cudnn": + x_emb = layers.transpose(x_emb, perm=[1, 0, 2]) + rnn_out, last_hidden, last_cell = layers.lstm( + x_emb, + init_hidden_reshape, + init_cell_reshape, + num_steps, + hidden_size, + num_layers, + is_bidirec=False, + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2]) + elif rnn_model == "basic_lstm": + rnn_out, last_hidden, last_cell = basic_lstm( + x_emb, + init_hidden, + init_cell, + hidden_size, + num_layers=num_layers, + batch_first=True, + dropout_prob=dropout, + param_attr=ParamAttr( + initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0)), + forget_bias=0.0) + else: + raise Exception("type not support") + + rnn_out = layers.reshape( + rnn_out, shape=[-1, num_steps, hidden_size], inplace=True) + + softmax_weight = layers.create_parameter( + [hidden_size, vocab_size], + dtype="float32", + name="softmax_weight", + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + softmax_bias = layers.create_parameter( + [vocab_size], + dtype="float32", + name='softmax_bias', + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + + projection = layers.matmul(rnn_out, softmax_weight) + projection = layers.elementwise_add(projection, softmax_bias) + projection = layers.reshape( + projection, shape=[-1, vocab_size], inplace=True) + + # correct predictions + labels_reshaped = fluid.layers.reshape(y, [-1]) + pred = fluid.layers.cast( + fluid.layers.argmax(projection, 1), dtype="int64") + correct_pred = fluid.layers.cast( + fluid.layers.equal(pred, labels_reshaped), dtype="int64") + self.prediction_ = pred + self.target_var_names_.append(pred) + + # predicting unknown is always considered wrong + unk_tensor = fluid.layers.fill_constant( + fluid.layers.shape(labels_reshaped), + value=unk_symbol, + dtype='int64') + pred_unk = fluid.layers.cast( + fluid.layers.equal(pred, unk_tensor), dtype="int64") + correct_unk = fluid.layers.elementwise_mul(pred_unk, correct_pred) + + # predicting padding is always considered wrong + pad_tensor = fluid.layers.fill_constant( + fluid.layers.shape(labels_reshaped), value=0, dtype='int64') + pred_pad = fluid.layers.cast( + fluid.layers.equal(pred, pad_tensor), dtype="int64") + correct_pad = fluid.layers.elementwise_mul(pred_pad, correct_pred) + + # acc + correct_count = fluid.layers.reduce_sum(correct_pred) \ + - fluid.layers.reduce_sum(correct_unk) \ + - fluid.layers.reduce_sum(correct_pad) + self.correct_ = correct_count + self.target_var_names_.append(correct_count) + + loss = layers.softmax_with_cross_entropy( + logits=projection, label=y, soft_label=False) + + loss = layers.reshape(loss, shape=[-1, num_steps], inplace=True) + loss = layers.reduce_mean(loss, dim=[0]) + loss = layers.reduce_sum(loss) + self.loss_ = loss + self.target_var_names_.append(loss) + + loss.persistable = True + + # This will feed last_hidden, last_cell to init_hidden, init_cell, which + # can be used directly in next batch. This can avoid the fetching of + # last_hidden and last_cell and feeding of init_hidden and init_cell in + # each training step. + #last_hidden = layers.transpose(last_hidden, perm=[1, 0, 2]) + #last_cell = layers.transpose(last_cell, perm=[1, 0, 2]) + #self.input_name_list_ = ['x', 'y', 'init_hidden', 'init_cell'] + + self.program_ = fluid.default_main_program() + self.startup_program_ = fluid.default_startup_program() + + def _padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None): + weight_1_arr = [] + weight_2_arr = [] + bias_arr = [] + hidden_array = [] + cell_array = [] + mask_array = [] + hidden_size = self.hidden_size_ + init_scale = self.init_scale_ + dropout = slef.dropout_ + num_layers = self.num_layers_ + num_steps = self._num_steps_ + for i in range(num_layers): + weight_1 = layers.create_parameter( + [hidden_size * 2, hidden_size * 4], + dtype="float32", + name="fc_weight1_" + str(i), + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + weight_1_arr.append(weight_1) + bias_1 = layers.create_parameter( + [hidden_size * 4], + dtype="float32", + name="fc_bias1_" + str(i), + default_initializer=fluid.initializer.Constant(0.0)) + bias_arr.append(bias_1) + + pre_hidden = layers.slice( + init_hidden, axes=[0], starts=[i], ends=[i + 1]) + pre_cell = layers.slice( + init_cell, axes=[0], starts=[i], ends=[i + 1]) + pre_hidden = layers.reshape(pre_hidden, shape=[-1, hidden_size]) + pre_cell = layers.reshape(pre_cell, shape=[-1, hidden_size]) + hidden_array.append(pre_hidden) + cell_array.append(pre_cell) + + input_embedding = layers.transpose(input_embedding, perm=[1, 0, 2]) + rnn = PaddingRNN() + + with rnn.step(): + input = rnn.step_input(input_embedding) + for k in range(num_layers): + pre_hidden = rnn.memory(init=hidden_array[k]) + pre_cell = rnn.memory(init=cell_array[k]) + weight_1 = weight_1_arr[k] + bias = bias_arr[k] + + nn = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=nn, y=weight_1) + + gate_input = layers.elementwise_add(gate_input, bias) + i = layers.slice( + gate_input, axes=[1], starts=[0], ends=[hidden_size]) + j = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size], + ends=[hidden_size * 2]) + f = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size * 2], + ends=[hidden_size * 3]) + o = layers.slice( + gate_input, + axes=[1], + starts=[hidden_size * 3], + ends=[hidden_size * 4]) + + c = pre_cell * layers.sigmoid(f) + layers.sigmoid( + i) * layers.tanh(j) + m = layers.tanh(c) * layers.sigmoid(o) + + rnn.update_memory(pre_hidden, m) + rnn.update_memory(pre_cell, c) + + rnn.step_output(m) + rnn.step_output(c) + + input = m + + if dropout != None and dropout > 0.0: + input = layers.dropout( + input, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + rnn.step_output(input) + rnnout = rnn() + + last_hidden_array = [] + last_cell_array = [] + real_res = rnnout[-1] + for i in range(num_layers): + m = rnnout[i * 2] + c = rnnout[i * 2 + 1] + m.stop_gradient = True + c.stop_gradient = True + last_h = layers.slice( + m, axes=[0], starts=[num_steps - 1], ends=[num_steps]) + last_hidden_array.append(last_h) + last_c = layers.slice( + c, axes=[0], starts=[num_steps - 1], ends=[num_steps]) + last_cell_array.append(last_c) + real_res = layers.transpose(x=real_res, perm=[1, 0, 2]) + last_hidden = layers.concat(last_hidden_array, 0) + last_cell = layers.concat(last_cell_array, 0) + + return real_res, last_hidden, last_cell + + def _encoder_static(input_embedding, + len=3, + init_hidden=None, + init_cell=None): + weight_1_arr = [] + weight_2_arr = [] + bias_arr = [] + hidden_array = [] + cell_array = [] + mask_array = [] + hidden_size = self.hidden_size_ + init_scale = self.init_scale_ + dropout = slef.dropout_ + num_layers = self.num_layers_ + for i in range(num_layers): + weight_1 = layers.create_parameter( + [hidden_size * 2, hidden_size * 4], + dtype="float32", + name="fc_weight1_" + str(i), + default_initializer=fluid.initializer.UniformInitializer( + low=-init_scale, high=init_scale)) + weight_1_arr.append(weight_1) + bias_1 = layers.create_parameter( + [hidden_size * 4], + dtype="float32", + name="fc_bias1_" + str(i), + default_initializer=fluid.initializer.Constant(0.0)) + bias_arr.append(bias_1) + + pre_hidden = layers.slice( + init_hidden, axes=[0], starts=[i], ends=[i + 1]) + pre_cell = layers.slice( + init_cell, axes=[0], starts=[i], ends=[i + 1]) + pre_hidden = layers.reshape( + pre_hidden, shape=[-1, hidden_size], inplace=True) + pre_cell = layers.reshape( + pre_cell, shape=[-1, hidden_size], inplace=True) + hidden_array.append(pre_hidden) + cell_array.append(pre_cell) + + res = [] + sliced_inputs = layers.split( + input_embedding, num_or_sections=len, dim=1) + + for index in range(len): + input = sliced_inputs[index] + input = layers.reshape( + input, shape=[-1, hidden_size], inplace=True) + for k in range(num_layers): + pre_hidden = hidden_array[k] + pre_cell = cell_array[k] + weight_1 = weight_1_arr[k] + bias = bias_arr[k] + nn = layers.concat([input, pre_hidden], 1) + gate_input = layers.matmul(x=nn, y=weight_1) + + gate_input = layers.elementwise_add(gate_input, bias) + i, j, f, o = layers.split( + gate_input, num_or_sections=4, dim=-1) + + c = pre_cell * layers.sigmoid(f) + layers.sigmoid( + i) * layers.tanh(j) + m = layers.tanh(c) * layers.sigmoid(o) + + hidden_array[k] = m + cell_array[k] = c + input = m + + if dropout != None and dropout > 0.0: + input = layers.dropout( + input, + dropout_prob=dropout, + dropout_implementation='upscale_in_train') + + res.append(input) + + last_hidden = layers.concat(hidden_array, 1) + last_hidden = layers.reshape( + last_hidden, shape=[-1, num_layers, hidden_size], inplace=True) + last_hidden = layers.transpose(x=last_hidden, perm=[1, 0, 2]) + + last_cell = layers.concat(cell_array, 1) + last_cell = layers.reshape( + last_cell, shape=[-1, num_layers, hidden_size]) + last_cell = layers.transpose(x=last_cell, perm=[1, 0, 2]) + + real_res = layers.concat(res, 0) + real_res = layers.reshape( + real_res, shape=[len, -1, hidden_size], inplace=True) + real_res = layers.transpose(x=real_res, perm=[1, 0, 2]) + + return real_res, last_hidden, last_cell diff --git a/python/paddle_fl/mobile/model/model_base.py b/python/paddle_fl/mobile/model/model_base.py new file mode 100644 index 0000000000000000000000000000000000000000..787ff4ffb0c287695f2865408341acacc3822e8b --- /dev/null +++ b/python/paddle_fl/mobile/model/model_base.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import numpy as np + + +def set_user_param_dict(param_names, param_dict, scope): + place = fluid.CPUPlace() + for var_name in param_names: + param = scope.find_var(var_name) + if param is None: + print("var name: {} does not exist in memory".format(var_name)) + continue + param.get_tensor().set(param_dict[var_name], place) + return + + +def set_global_param_dict(param_names, param_dict, scope): + place = fluid.CPUPlace() + for var_name in param_names: + param = scope.find_var(var_name) + if param is None: + print("var name: {} does not exist in memory".format(var_name)) + continue + if var_name not in param_dict: + print("var name: {} does not exist in global param dict".format( + var_name)) + exit() + var_numpy = param_dict[var_name] + param.get_tensor().set(var_numpy, place) + return + + +class ModelBase(object): + def __init__(self): + pass + + def init_model(self): + pass + + def build_model(self, model_configs): + pass + + def get_model_inputs(self): + pass + + def get_model_loss(self): + pass + + def get_model_metrics(self): + pass + + def get_startup_program(self): + pass + + def get_main_program(self): + pass + + def get_user_param_dict(self): + param_dict = {} + scope = fluid.global_scope() + for var_pair in self.get_user_param_names(): + param = scope.find_var(var_pair[0]) + if param is None: + print("var name: {} does not exist in memory".format(var_pair[ + 0])) + continue + var = param.get_tensor().__array__() + param_dict[var_pair[0]] = [var, var_pair[1].shape] + return param_dict + + def get_global_param_dict(self): + param_dict = {} + scope = fluid.global_scope() + for var_pair in self.get_global_param_names(): + param = scope.find_var(var_pair[0]) + if param is None: + print("var name: {} does not exist in memory".format(var_pair[ + 0])) + continue + var = param.get_tensor().__array__() + param_dict[var_pair[0]] = var + return param_dict + + def get_user_param_names(self): + user_params = [] + for var_name, var in self.startup_program_.global_block().vars.items(): + if var.persistable and "@USER" in var_name and \ + "learning_rate" not in var_name: + user_params.append((var_name, var)) + return user_params + + def get_global_param_names(self): + global_params = [] + for var_name, var in self.startup_program_.global_block().vars.items(): + if var.persistable and "@USER" not in var_name and \ + "learning_rate" not in var_name and "generated_var" not in var_name: + global_params.append((var_name, var)) + return global_params diff --git a/python/paddle_fl/mobile/optimizer/__init__.py b/python/paddle_fl/mobile/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9337dbec4bc6f1c19b9a8df6d7a2473170cdcf2 --- /dev/null +++ b/python/paddle_fl/mobile/optimizer/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .fed_avg_optimizer import FedAvgOptimizer +from .fed_avg_optimizer import SumOptimizer diff --git a/python/paddle_fl/mobile/optimizer/fed_avg_optimizer.py b/python/paddle_fl/mobile/optimizer/fed_avg_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fd7545eef4e6e14b9743cecf36701a202753c9 --- /dev/null +++ b/python/paddle_fl/mobile/optimizer/fed_avg_optimizer.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .optimizer_base import OptimizerBase + + +class FedAvgOptimizer(OptimizerBase): + def __init__(self, learning_rate=1.0): + self.learning_rate = learning_rate + pass + + def update(self, user_info, new_global_param_by_user, old_global_param, + scheduler_client): + total_weight = 0.0 + for uid in user_info: + total_weight += float(user_info[uid]) + update_dict = {} + for key in new_global_param_by_user: + uid = key + uid_global_w = new_global_param_by_user[key] + weight = float(user_info[key]) / total_weight + for param_name in uid_global_w: + if param_name in update_dict: + update_dict[param_name] += \ + self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name]) + else: + update_dict[param_name] = \ + self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name]) + + scheduler_client.fedavg_update(update_dict) + + +class SumOptimizer(OptimizerBase): + def __init__(self, learning_rate=1.0): + self.learning_rate = learning_rate + pass + + def update(self, user_info, new_global_param_by_user, old_global_param, + scheduler_client): + update_dict = {} + for key in new_global_param_by_user: + uid = key + uid_global_w = new_global_param_by_user[key] + weight = 1.0 + for param_name in uid_global_w: + if param_name in update_dict: + update_dict[param_name] += \ + self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name]) + else: + update_dict[param_name] = \ + self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name]) + + scheduler_client.fedavg_update(update_dict) diff --git a/python/paddle_fl/mobile/optimizer/optimizer_base.py b/python/paddle_fl/mobile/optimizer/optimizer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3dcc823f5207043feccf870f3d8875138e4c1b --- /dev/null +++ b/python/paddle_fl/mobile/optimizer/optimizer_base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class OptimizerBase(object): + def __init__(self, role_maker): + pass + + def update(self, new_global_param_by_user, old_global_param, + scheduler_client): + pass diff --git a/python/paddle_fl/mobile/protos/data_server.proto b/python/paddle_fl/mobile/protos/data_server.proto new file mode 100644 index 0000000000000000000000000000000000000000..fdb2a8351edef0efac07f915aa580d5bfe616ea4 --- /dev/null +++ b/python/paddle_fl/mobile/protos/data_server.proto @@ -0,0 +1,53 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package server; + +message Data { + string date = 1; + string uid = 2; + string line = 3; +} + +message Res { + int32 err_code = 1; + int32 user_num = 2; +} + +message UserData { + repeated string line_str = 1; + int32 err_code = 2; +} + +message Param { + string name = 1; + repeated float weight = 2; + repeated int32 shape = 3; +} + +message UserParams { + repeated Param user_params = 1; + int32 err_code = 2; + string uid = 3; +} + +service DataServer { + // (Method definitions not shown) + rpc SendData(Data) returns (Res) {} + rpc GetUserData(Data) returns (UserData) {} + rpc ClearUserData(Data) returns (Res) {} + rpc GetUserParams(Data) returns (UserParams) {} + rpc UpdateUserParams(UserParams) returns (Res) {} +} diff --git a/python/paddle_fl/mobile/protos/run_codegen.py b/python/paddle_fl/mobile/protos/run_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..706f2d263352cd2d6e839cfd9ed3843b47c5e4d6 --- /dev/null +++ b/python/paddle_fl/mobile/protos/run_codegen.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2015 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Runs protoc with the gRPC plugin to generate messages and gRPC stubs.""" + +from grpc_tools import protoc + +#python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. data_server.proto +protoc.main(( + '', + '-I.', + '--python_out=.', + '--grpc_python_out=.', + 'data_server.proto', )) + +protoc.main(( + '', + '-I.', + '--python_out=.', + '--grpc_python_out=.', + 'scheduler_server.proto', )) + +import os +os.system("mv *pb2.py ../servers/") +os.system("mv *pb2_grpc.py ../servers/") diff --git a/python/paddle_fl/mobile/protos/scheduler_server.proto b/python/paddle_fl/mobile/protos/scheduler_server.proto new file mode 100644 index 0000000000000000000000000000000000000000..39da4e36a9b8e1b1918a2d7419ff0cc21bfa531e --- /dev/null +++ b/python/paddle_fl/mobile/protos/scheduler_server.proto @@ -0,0 +1,69 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; +package scheduler; + +message Request { + int32 node_idx = 1; + int32 sample_num = 2; + int32 shard_num = 3; + int32 node_num = 4; + int32 min_ins_num = 5; + string date = 6; +} + +message UserList { + repeated string uids = 1; + repeated float weight = 2; +} + +message Param { + string name = 1; + repeated float weight = 2; + repeated int32 shape = 3; +} + +message GlobalParams { + repeated Param global_params = 1; + int32 err_code = 2; +} + +message UserInstNum { + string uid = 1; + int32 inst_num = 2; +} + +message UserInstInfo { + repeated UserInstNum inst_nums = 1; + int32 shard_num = 2; + string date = 3; +} + +message Res { int32 err_code = 1; } + +message SchedulerServerEmptyInput {} + +service SchedulerServer { + // (Method definitions not shown) + rpc SampleUsersToTrain(Request) returns (UserInstInfo) {} + rpc FixedUsersToTrain(Request) returns (UserInstInfo) {} + rpc GetGlobalParams(Request) returns (GlobalParams) {} + rpc UpdateGlobalParams(GlobalParams) returns (Res) {} + rpc FedAvgUpdate(GlobalParams) returns (Res) {} + rpc UpdateUserInstNum(UserInstInfo) returns (Res) {} + rpc Exit(SchedulerServerEmptyInput) returns (Res) {} + rpc SampleUsersToTest(Request) returns (UserInstInfo) {} + rpc SampleUsersWithHash(Request) returns (UserInstInfo) {} +} diff --git a/python/paddle_fl/mobile/reader/__init__.py b/python/paddle_fl/mobile/reader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4085427caed2dd8ebe15eda343456ca38b8011f6 --- /dev/null +++ b/python/paddle_fl/mobile/reader/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle_fl/mobile/reader/leaf_reddit_reader.py b/python/paddle_fl/mobile/reader/leaf_reddit_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..ee2f926ce6b1da5041e528347f15341b74ad6c7e --- /dev/null +++ b/python/paddle_fl/mobile/reader/leaf_reddit_reader.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import sys +import numpy as np +from utils.logger import logging +import json + +PAD_SYMBOL, UNK_SYMBOL = 0, 1 +DATA_PATH = "lm_data" +VOCAB_PATH = os.path.join(DATA_PATH, "vocab.json") +TRAIN_DATA_PATH = os.path.join(DATA_PATH, "20200101", "train_data.json") +VOCAB = None + + +def build_counter(train_data): + train_tokens = [] + for u in train_data: + for c in train_data[u]['x']: + train_tokens.extend([s for s in c]) + + all_tokens = [] + for i in train_tokens: + all_tokens.extend(i) + train_tokens = [] + + counter = collections.Counter() + counter.update(all_tokens) + all_tokens = [] + return counter + + +def build_vocab(filename, vocab_size=10000): + train_data = {} + with open(filename) as json_file: + data = json.load(json_file) + train_data = data['user_data'] + counter = build_counter(train_data) + + count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) + # -2 to account for the unknown and pad symbols + count_pairs = count_pairs[:(vocab_size - 2)] + words, _ = list(zip(*count_pairs)) + + vocab = {} + vocab[''] = PAD_SYMBOL + vocab[''] = UNK_SYMBOL + + for i, w in enumerate(words): + if w != '': + vocab[w] = i + 1 + + return { + 'vocab': vocab, + 'size': vocab_size, + 'unk_symbol': vocab[''], + 'pad_symbol': vocab[''] + } + + +def save_vocab(filename, vocab): + with open(filename, "w") as f: + f.write(json.dumps(vocab)) + + +def load_vocab(filename): + with open(filename) as f: + return json.loads(f.read()) + + +if os.path.exists(VOCAB_PATH): + logging.info("load vocab form: {}".format(VOCAB_PATH)) + VOCAB = load_vocab(VOCAB_PATH) +else: + #TODO: singleton + logging.info("build vocab form: {}".format(TRAIN_DATA_PATH)) + VOCAB = build_vocab(TRAIN_DATA_PATH) + logging.info("save vocab into: {}".format(VOCAB_PATH)) + save_vocab(VOCAB_PATH, VOCAB) +if VOCAB is None: + logging.error("load vocab error") + raise Exception("load vocab error") + + +def train_reader(lines): + def local_iter(): + seg_id = 0 + for line in lines: + assert (len(line.split("\t")) == 3) + uid, _, input_str = line.split("\t") + data = json.loads(input_str) + data_x = data["x"] + data_y = data["y"] + data_mask = data["mask"] + + input_data, input_length = process_x(data_x, VOCAB) + target_data = process_y(data_y, VOCAB) + yield [input_data] + [target_data] + + return local_iter + + +def infer_reader(lines): + return train_reader(lines) + + +def load_data_into_patch(filelist, patch_size): + data_patch = [] + idx = 0 + local_user_dict = {} + for fn in filelist: + tmp_list = [] + with open(fn) as fin: + raw_data = json.loads(fin.read())["user_data"] + local_user_dict = {k: 0 for k in raw_data.keys()} + for user, data in raw_data.items(): + data_x = data["x"] + data_y = data["y"] + for c, l in zip(data_x, data_y): + for inst_i in range(len(c)): + local_user_dict[user] += 1 + idx += 1 + inst = { + "x": c[inst_i], + "y": l["target_tokens"][inst_i], + "mask": l["count_tokens"][inst_i] + } + line = "{}\t\t{}".format(user, json.dumps(inst)) + if idx % patch_size == 0: + data_patch.append(tmp_list) + tmp_list = [line] + else: + tmp_list.append(line) + if len(tmp_list) > 0: + data_patch.append(tmp_list) + return data_patch, local_user_dict + + +def tokens_to_ids(tokens, vocab): + to_ret = [vocab.get(word, vocab[""]) for word in tokens] + return np.array(to_ret, dtype="int64") + + +def process_x(raw_x, vocab): + tokens = tokens_to_ids(raw_x, vocab["vocab"]) + lengths = np.sum(tokens != vocab["pad_symbol"]) + return tokens, lengths + + +def process_y(raw_y, vocab): + tokens = tokens_to_ids(raw_y, vocab["vocab"]) + return tokens diff --git a/python/paddle_fl/mobile/sampler/__init__.py b/python/paddle_fl/mobile/sampler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa842b81a6c3ec06835aca48020d67391e357927 --- /dev/null +++ b/python/paddle_fl/mobile/sampler/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .uniform_sampler import UniformSampler +from .fixed_sampler import FixedSampler +from .sampler_test_user import Test1wSampler +from .sampler_test_user import Test1percentSampler +from .hash_sampler import HashSampler diff --git a/python/paddle_fl/mobile/sampler/fixed_sampler.py b/python/paddle_fl/mobile/sampler/fixed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..85544c7727eb04b3d578d31757a2e5c578a17a37 --- /dev/null +++ b/python/paddle_fl/mobile/sampler/fixed_sampler.py @@ -0,0 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .sampler_base import SamplerBase + + +class FixedSampler(SamplerBase): + def __init__(self): + self.sample_num = 100 + + def set_sample_num(self, sample_num): + self.sample_num = sample_num + + def sample_user_list(self, scheduler_client, date, sim_idx, shard_num, + sim_num): + user_info = scheduler_client.fixed_sample_user_list( + date, sim_idx, self.sample_num, shard_num, sim_num) + return user_info diff --git a/python/paddle_fl/mobile/sampler/hash_sampler.py b/python/paddle_fl/mobile/sampler/hash_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c2248c3767cea246071740c020a25275afa77ca2 --- /dev/null +++ b/python/paddle_fl/mobile/sampler/hash_sampler.py @@ -0,0 +1,28 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .sampler_base import SamplerBase + + +class HashSampler(SamplerBase): + def __init__(self): + self.sample_num = 100 + + def set_sample_num(self, sample_num): + self.sample_num = sample_num + + def sample_user_list(self, scheduler_client, date, sim_idx, shard_num, + sim_num): + user_info = scheduler_client.hash_sample_user_list( + date, sim_idx, self.sample_num, shard_num, sim_num) + return user_info diff --git a/python/paddle_fl/mobile/sampler/sampler_base.py b/python/paddle_fl/mobile/sampler/sampler_base.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd6f2842cb522c4ba95917e261330bb568d4571 --- /dev/null +++ b/python/paddle_fl/mobile/sampler/sampler_base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class SamplerBase(object): + def __init__(self, role_maker): + pass + + def sample_user_list(self, scheduler_client, sim_idx, shard_num, sim_num, + min_ins_num): + pass diff --git a/python/paddle_fl/mobile/sampler/sampler_test_user.py b/python/paddle_fl/mobile/sampler/sampler_test_user.py new file mode 100644 index 0000000000000000000000000000000000000000..f2242f3fdd7ace8530c986c9fb31a69e3b650b3e --- /dev/null +++ b/python/paddle_fl/mobile/sampler/sampler_test_user.py @@ -0,0 +1,36 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .sampler_base import SamplerBase + + +class Test1wSampler(SamplerBase): + def __init__(self): + pass + + def sample_user_list(self, scheduler_client, date, sim_idx, shard_num, + sim_num): + user_info = scheduler_client.sample_test_user_list(date, sim_idx, + shard_num, sim_num) + return user_info + + +class Test1percentSampler(SamplerBase): + def __init__(self): + pass + + def sample_user_list(self, scheduler_client, date, sim_idx, shard_num, + sim_num): + user_info = scheduler_client.sample_test_user_list(date, sim_idx, + shard_num, sim_num) + return user_info diff --git a/python/paddle_fl/mobile/sampler/uniform_sampler.py b/python/paddle_fl/mobile/sampler/uniform_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..be73164d7fb786d1598f88cd11814b1b656792ec --- /dev/null +++ b/python/paddle_fl/mobile/sampler/uniform_sampler.py @@ -0,0 +1,33 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .sampler_base import SamplerBase + + +class UniformSampler(SamplerBase): + def __init__(self): + self.sample_num = 100 + self.min_ins_num = 1 + + def set_sample_num(self, sample_num): + self.sample_num = sample_num + + def set_min_ins_num(self, min_ins_num): + self.min_ins_num = min_ins_num + + def sample_user_list(self, scheduler_client, date, sim_idx, shard_num, + sim_num): + user_info = scheduler_client.uniform_sample_user_list( + date, sim_idx, self.sample_num, shard_num, sim_num, + self.min_ins_num) + return user_info diff --git a/python/paddle_fl/mobile/servers/__init__.py b/python/paddle_fl/mobile/servers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b2f42b0901356f5f8dcf24b8dceb4c11a62d17 --- /dev/null +++ b/python/paddle_fl/mobile/servers/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data_server_impl import DataServerServicer +from .data_server_impl import DataServer +from .scheduler_server_impl import SchedulerServerServicer +from .scheduler_server_impl import SchedulerServer diff --git a/python/paddle_fl/mobile/servers/data_server_impl.py b/python/paddle_fl/mobile/servers/data_server_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b13bb2bb1a77002998c96313cbeb2828b27a5f --- /dev/null +++ b/python/paddle_fl/mobile/servers/data_server_impl.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from concurrent import futures +import data_server_pb2 +import data_server_pb2_grpc +import grpc +import sys +from utils.logger import logging + + +class DataServerServicer(object): + def __init__(self): + self.data_dict = {} + self.param_dict = {} + self.request_num = 0 + + def GetUserParams(self, request, context): + uid = unicode(request.uid) + if uid in self.param_dict: + return self.param_dict[uid] + else: + user_param = data_server_pb2.UserParams() + user_param.err_code = -1 + return user_param + + def UpdateUserParams(self, request, context): + self.param_dict[request.uid] = request + res = data_server_pb2.Res() + res.err_code = 0 + return res + + def ClearUserData(self, request, context): + date = request.date + self.data_dict[date].clear() + res = data_server_pb2.Res() + res.err_code = 0 + return res + + def GetUserData(self, request, context): + uid = unicode(request.uid) + date = unicode(request.date) + if date not in self.data_dict: + user_data = data_server_pb2.UserData() + user_data.err_code = -1 + return user_data + if uid in self.data_dict[date]: + self.data_dict[date][uid].err_code = 0 + return self.data_dict[date][uid] + else: + user_data = data_server_pb2.UserData() + user_data.err_code = -1 + self.data_dict[date][uid] = user_data + return self.data_dict[date][uid] + user_data = data_server_pb2.UserData() + user_data.err_code = -1 + return user_data + + def SendData(self, request, context): + date = unicode(request.date) + uid = unicode(request.uid) + if date in self.data_dict: + if uid in self.data_dict[request.date]: + self.data_dict[date][uid].line_str.extend([request.line]) + else: + user_data = data_server_pb2.UserData() + self.data_dict[date][uid] = user_data + self.data_dict[date][uid].line_str.extend([request.line]) + else: + user_data = data_server_pb2.UserData() + self.data_dict[date] = {} + self.data_dict[date][uid] = user_data + self.data_dict[date][uid].line_str.extend([request.line]) + + res = data_server_pb2.Res() + res.err_code = 0 + return res + + +class DataServer(object): + def __init__(self): + pass + + def start(self, max_workers=1000, concurrency=100, endpoint=""): + if endpoint == "": + logging.error("You should specify endpoint in start function") + return + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=[('grpc.max_send_message_length', 1024 * 1024 * 1024), + ('grpc.max_receive_message_length', 1024 * 1024 * 1024)], + maximum_concurrent_rpcs=concurrency) + data_server_pb2_grpc.add_DataServerServicer_to_server( + DataServerServicer(), server) + server.add_insecure_port('[::]:{}'.format(endpoint)) + server.start() + server.wait_for_termination() + + +if __name__ == "__main__": + data_server = DataServer() + data_server.start(endpoint=sys.argv[1]) diff --git a/python/paddle_fl/mobile/servers/scheduler_server_impl.py b/python/paddle_fl/mobile/servers/scheduler_server_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2b84fa7f0f7f787d89bdd6271b6dacda0f3f1e --- /dev/null +++ b/python/paddle_fl/mobile/servers/scheduler_server_impl.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function +from concurrent import futures +import scheduler_server_pb2 +import scheduler_server_pb2_grpc +import grpc +import numpy as np +import random +import sys, os +import time +import xxhash +from utils.logger import logging + + +class SchedulerServerServicer(object): + def __init__(self): + self.global_param_dict = {} + self.uid_inst_num_dict = {} + self.shard_id_dict = {} + + def uid_shard(self, uid, shard_num): + try: + # print("uid_shard uid: %s" % uid) + uid_hash = xxhash.xxh32(str(uid), seed=101).intdigest() + except: + return -1 + shard_idx = uid_hash % shard_num + return shard_idx + + def is_test_uid(self, uid): + return xxhash.xxh32(str(uid), seed=222).intdigest() % 100 == 3 + + # we suppose shard num will not be changed during one training job + # but can be changed with another job + # so we send shard num every time we update user inst num + def UpdateUserInstNum(self, request, context): + shard_num = request.shard_num + date = request.date + if date not in self.uid_inst_num_dict: + self.uid_inst_num_dict[date] = {} + if date not in self.shard_id_dict: + self.shard_id_dict[date] = {} + for user in request.inst_nums: + shard_id = self.uid_shard(user.uid, shard_num) + if shard_id == -1: + logging.info("UpdateUserInstNum continue") + continue + if user.uid in self.uid_inst_num_dict[date]: + self.uid_inst_num_dict[date][user.uid] += user.inst_num + else: + self.uid_inst_num_dict[date][user.uid] = user.inst_num + if shard_id not in self.shard_id_dict[date]: + self.shard_id_dict[date][shard_id] = [user.uid] + else: + self.shard_id_dict[date][shard_id].append(user.uid) + res = scheduler_server_pb2.Res() + res.err_code = 0 + return res + + ''' + SampleUsersToTrain: + request.node_idx: from which worker node the request is from + request.sample_num: how many users do we need to sample + request.shard_num: total shard number of this task + request.node_num: total number of training node + ''' + + def SampleUsersToTrain(self, request, context): + node_idx = request.node_idx + sample_num = request.sample_num + shard_num = request.shard_num + node_num = request.node_num + date = request.date + shard_per_node = shard_num / node_num + begin_idx = node_idx * shard_per_node + min_ins_num = request.min_ins_num + + uid_list = [] + i = 0 + while i < sample_num: + shard_idx = begin_idx + random.randint(0, shard_per_node) + if shard_idx not in self.shard_id_dict[date]: + continue + sample_idx = random.randint( + 0, len(self.shard_id_dict[date][shard_idx]) - 1) + uid = self.shard_id_dict[date][shard_idx][sample_idx] + if self.uid_inst_num_dict[date][uid] < min_ins_num: + continue + uid_list.append(uid) + i += 1 + + info = scheduler_server_pb2.UserInstInfo() + for uid in uid_list: + inst_num = scheduler_server_pb2.UserInstNum() + inst_num.uid = uid + inst_num.inst_num = self.uid_inst_num_dict[date][uid] + info.inst_nums.extend([inst_num]) + return info + + def SampleUsersWithHash(self, request, context): + node_idx = request.node_idx + sample_num = request.sample_num + shard_num = request.shard_num + node_num = request.node_num + date = request.date + shard_per_node = shard_num / node_num + begin_idx = node_idx * shard_per_node + + uid_list = [] + i = 0 + while i < sample_num: + shard_idx = begin_idx + random.randint(0, shard_per_node) + if shard_idx not in self.shard_id_dict[date]: + continue + sample_idx = random.randint( + 0, len(self.shard_id_dict[date][shard_idx]) - 1) + uid = self.shard_id_dict[date][shard_idx][sample_idx] + if not self.is_test_uid(uid): + continue + uid_list.append(uid) + i += 1 + + info = scheduler_server_pb2.UserInstInfo() + for uid in uid_list: + inst_num = scheduler_server_pb2.UserInstNum() + inst_num.uid = uid + inst_num.inst_num = self.uid_inst_num_dict[date][uid] + info.inst_nums.extend([inst_num]) + return info + + def SampleUsersToTest(self, request, context): + node_idx = request.node_idx + shard_num = request.shard_num + node_num = request.node_num + date = request.date + shard_per_node = shard_num / node_num + shard_begin_idx = node_idx * shard_per_node + shard_end_idx = (node_idx + 1) * shard_per_node + uid_list = [] + + for shard_idx in range(shard_begin_idx, shard_end_idx): + for uid in self.shard_id_dict[date][shard_idx]: + if self.is_test_uid(uid): + uid_list.append(uid) + info = scheduler_server_pb2.UserInstInfo() + for uid in uid_list: + inst_num = scheduler_server_pb2.UserInstNum() + inst_num.uid = uid + inst_num.inst_num = self.uid_inst_num_dict[date][uid] + info.inst_nums.extend([inst_num]) + return info + + def FixedUsersToTrain(self, request, context): + node_idx = request.node_idx + sample_num = request.sample_num + shard_num = request.shard_num + node_num = request.node_num + date = request.date + begin_idx = node_idx * shard_num + + shard_per_node = shard_num / node_num + + uid_list = [] + i = 0 + assert (sample_num <= 100) + with open("data/test_user_100.txt") as f: + for line in f.readlines(): + uid_list.append(line.strip()) + # uid_list.extend(["somebody", "nobody"]) + info = scheduler_server_pb2.UserInstInfo() + for uid in uid_list[:sample_num]: + inst_num = scheduler_server_pb2.UserInstNum() + inst_num.uid = uid + inst_num.inst_num = self.uid_inst_num_dict[date][uid] + info.inst_nums.extend([inst_num]) + return info + + def GetGlobalParams(self, request, context): + if self.global_param_dict == {}: + logging.debug("global param has not been initialized") + return + # logging.info("node {} is asking for global params".format(request.node_idx)) + global_param_pb = scheduler_server_pb2.GlobalParams() + for key in self.global_param_dict: + param = scheduler_server_pb2.Param() + param.name = key + var, shape = self.global_param_dict[key] + param.weight.extend(var) + # print("GetGlobalParams before param.shape") + param.shape.extend(shape) + # print("GetGlobalParams after param.shape") + global_param_pb.global_params.extend([param]) + # print("finish GetGlobalParams") + return global_param_pb + + def UpdateGlobalParams(self, request, context): + for param in request.global_params: + self.global_param_dict[param.name] = [param.weight, param.shape] + res = scheduler_server_pb2.Res() + res.err_code = 0 + return res + + def FedAvgUpdate(self, request, context): + for param in request.global_params: + old_param, shape = self.global_param_dict[param.name] + for idx, item in enumerate(old_param): + old_param[idx] += param.weight[idx] + res = scheduler_server_pb2.Res() + res.err_code = 0 + return res + + def Exit(self, request, context): + with open("_shutdown_scheduler", "w") as f: + f.write("_shutdown_scheduler\n") + res = scheduler_server_pb2.Res() + res.err_code = 0 + return res + + +class SchedulerServer(object): + def __init__(self): + pass + + def start(self, max_workers=1000, concurrency=100, endpoint=""): + if endpoint == "": + logging.info("You should specify endpoint in start function") + return + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=max_workers), + options=[('grpc.max_send_message_length', 1024 * 1024 * 1024), + ('grpc.max_receive_message_length', 1024 * 1024 * 1024)], + maximum_concurrent_rpcs=concurrency) + scheduler_server_pb2_grpc.add_SchedulerServerServicer_to_server( + SchedulerServerServicer(), server) + # print("SchedulerServer add endpoint: ", '[::]:{}'.format(endpoint)) + server.add_insecure_port('[::]:{}'.format(endpoint)) + server.start() + logging.info("server started") + os.system("rm _shutdown_scheduler") + while (not os.path.isfile("_shutdown_scheduler")): + time.sleep(10) + + +if __name__ == "__main__": + scheduler_server = SchedulerServer() + scheduler_server.start(endpoint=60001) diff --git a/python/paddle_fl/mobile/trainer/__init__.py b/python/paddle_fl/mobile/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99d91d344ca7e99f8d0af9803fa812787d1c6928 --- /dev/null +++ b/python/paddle_fl/mobile/trainer/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .language_model_trainer import LanguageModelTrainer diff --git a/python/paddle_fl/mobile/trainer/language_model_trainer.py b/python/paddle_fl/mobile/trainer/language_model_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee382b03820566ef760b4f3ab0fb472d25ced1e --- /dev/null +++ b/python/paddle_fl/mobile/trainer/language_model_trainer.py @@ -0,0 +1,325 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .trainer_base import TrainerBase +from model import LanguageModel +from clients import DataClient +import paddle.fluid as fluid +from utils.hdfs_utils import multi_upload, HDFSClient +import reader.leaf_reddit_reader as reader +from utils.logger import logging +from itertools import groupby +import numpy as np +import random +import paddle +import pickle +import os +from model.model_base import set_user_param_dict +from model.model_base import set_global_param_dict + + +def train_one_user(arg_dict, trainer_config): + show_metric = trainer_config["show_metric"] + shuffle = trainer_config["shuffle"] + max_training_steps = trainer_config["max_training_steps"] + batch_size = trainer_config["batch_size"] + # logging.info("training one user...") + main_program = fluid.Program.parse_from_string(trainer_config[ + "main_program_desc"]) + startup_program = fluid.Program.parse_from_string(trainer_config[ + "startup_program_desc"]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.global_scope() + if (startup_program is None): + logging.error("startup_program is None") + exit() + exe.run(startup_program) + + feeder = fluid.DataFeeder( + feed_list=trainer_config["input_names"], + place=place, + program=main_program) + data_server_endpoints = arg_dict["data_endpoints"] + # create data clients + data_client = DataClient() + data_client.set_data_server_endpoints(data_server_endpoints) + uid = arg_dict["uid"] + date = arg_dict["date"] + global_param_dict = arg_dict["global_params"] + user_data = data_client.get_data_by_uid(uid, date) + train_reader = reader.train_reader(user_data) + if shuffle == True: + train_reader = paddle.reader.shuffle(train_reader, buf_size=10000) + train_reader = paddle.batch(train_reader, batch_size=batch_size) + + # get user param + # logging.debug("do not need to get user params") + + set_global_param_dict(arg_dict["global_param_names"], + arg_dict["global_params"], scope) + + if (main_program is None): + logging.error("main_program is None") + exit() + + epoch = trainer_config["epoch"] + max_steps_in_epoch = trainer_config.get("max_steps_in_epoch", -1) + metrics = trainer_config["metrics"] + metric_keys = metrics.keys() + fetch_list = [main_program.global_block().var(trainer_config["loss_name"])] + for key in metric_keys: + fetch_list.append(main_program.global_block().var(metrics[key])) + + seq_len = 10 + for ei in range(epoch): + trained_sample_num = 0 + step = 0 + fetch_res_list = [] + total_loss = 0.0 + total_correct = 0 + for data in train_reader(): + fetch_res = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=fetch_list) + step += 1 + trained_sample_num += len(data) + fetch_res_list.append([x[0] for x in fetch_res]) + if max_steps_in_epoch != -1 and step >= max_steps_in_epoch: + break + + if show_metric and trained_sample_num > 0: + loss = sum([x[0] for x in fetch_res_list]) / trained_sample_num + print("loss: {}, ppl: {}".format(loss, np.exp(loss))) + for i, key in enumerate(metric_keys): + if key == "correct": + value = float(sum([x[i + 1] for x in fetch_res_list + ])) / trained_sample_num + print("correct: {}".format(value / seq_len)) + + local_updated_param_dict = {} + # update user param + # logging.debug("do not need to update user params") + + data_client.set_param_by_uid(uid, local_updated_param_dict) + # global_updated_param_dict = {} + write_global_param_file = arg_dict["write_global_param_file"] + #os.makedirs("%s/params" % write_global_param_file) + for var_name in arg_dict["global_param_names"]: + var = scope.var(var_name).get_tensor().__array__().astype(np.float32) + filename = os.path.join(write_global_param_file, "params", var_name) + #logging.info("filename: {}".format(filename)) + dirname = os.path.dirname(filename) + if not os.path.exists(dirname): + os.makedirs(dirname) + with open(filename, "w") as f: + np.save(f, var) + with open("%s/_info" % write_global_param_file, "w") as f: + pickle.dump([uid, trained_sample_num], file=f) + + +def infer_one_user(arg_dict, trainer_config): + """ + infer a model with global_param and user params + input: + global_param + user_params + infer_program + user_data + output: + [sample_cout, top1] + """ + # run startup program, set params + uid = arg_dict["uid"] + batch_size = trainer_config["batch_size"] + startup_program = fluid.Program.parse_from_string(trainer_config[ + "startup_program_desc"]) + infer_program = fluid.Program.parse_from_string(trainer_config[ + "infer_program_desc"]) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.global_scope() + + if (startup_program is None): + logging.error("startup_program is None") + exit() + if (infer_program is None): + logging.error("infer_program is None") + exit() + + exe.run(startup_program) + + data_client = DataClient() + data_client.set_data_server_endpoints(arg_dict["data_endpoints"]) + + # get user param + # logging.debug("do not need to get user params") + + set_global_param_dict(arg_dict["global_param_names"], + arg_dict["global_params"], scope) + + # reader + + date = arg_dict["date"] + global_param_dict = arg_dict["global_params"] + user_data = data_client.get_data_by_uid(uid, date) + infer_reader = reader.infer_reader(user_data) + infer_reader = paddle.batch(infer_reader, batch_size=batch_size) + + # run infer program + os.mkdir(arg_dict["infer_result_dir"]) + #pred_file = open(arg_dict["infer_result_dir"] + '/' + "pred_file", "w") + feeder = fluid.DataFeeder( + feed_list=trainer_config["input_names"], + place=place, + program=infer_program) + + fetch_list = trainer_config["target_names"] + #logging.info("fetch_list: {}".format(fetch_list)) + fetch_res = [] + sample_count = 0 + + total_loss = 0.0 + total_correct = 0 + iters = 0 + steps = 0 + seq_len = 10 + for data in infer_reader(): + # feed_data = [x["features"] + [x["label"]] for x in data] + # prediction, acc_val= exe.run(infer_program, + pred, correct_count, loss = exe.run(infer_program, + feed=feeder.feed(data), + fetch_list=fetch_list) + total_loss += loss + total_correct += correct_count + steps += 1 + sample_count += len(data) + + correct = float(total_correct) / (seq_len * sample_count) + # logging.info("correct: {}".format(correct)) + with open(arg_dict["infer_result_dir"] + "/res", "w") as f: + f.write("%d\t%f\n" % (1, correct)) + + +def save_and_upload(arg_dict, trainer_config, dfs_upload_path): + logging.info("do not save and upload...") + return + + +def evaluate_a_group(group): + group_list = [] + for label, pred, _ in group: + # print("%s\t%s\n" % (label, pred)) + group_list.append((int(label), float(pred))) + random.shuffle(group_list) + labels = [x[0] for x in group_list] + preds = [x[1] for x in group_list] + true_res = labels.index(1) if 1 in labels else -1 + pred_res = preds.index(max(preds)) + if pred_res == true_res: + return 1 + else: + return 0 + + +class LanguageModelTrainer(TrainerBase): + """ + LanguageModelTrainer only support training with PaddlePaddle + """ + + def __init__(self): + super(LanguageModelTrainer, self).__init__() + self.main_program_ = fluid.Program() + self.startup_program_ = fluid.Program() + self.infer_program_ = fluid.Program() + self.main_program_desc_ = "" + self.startup_program_desc_ = "" + self.infer_program_desc_ = "" + self.train_one_user_func = train_one_user + self.infer_one_user_func = infer_one_user + self.save_and_upload_func = save_and_upload + self.input_model_ = None + + def get_load_data_into_patch_func(self): + return reader.load_data_into_patch + + def prepare(self, do_test=False): + self.generate_program_desc(do_test) + pass + + def get_user_param_names(self): + # return [x[0] for x in self.input_model_.get_user_param_names()] + pass + + def get_global_param_names(self): + return [x[0] for x in self.input_model_.get_global_param_names()] + + def generate_program_desc(self, do_test=False): + """ + generate the paddle program desc + """ + with fluid.program_guard(self.main_program_, self.startup_program_): + self.input_model_ = LanguageModel() + model_configs = {} + self.input_model_.build_model(model_configs) + optimizer = fluid.optimizer.SGD( + learning_rate=self.trainer_config["lr"]) + optimizer.minimize(self.input_model_.get_model_loss()) + + self.main_program_desc_ = self.main_program_.desc.serialize_to_string() + self.startup_program_desc_ = self.startup_program_.desc.serialize_to_string( + ) + self.update_trainer_configs("loss_name", + self.input_model_.get_model_loss_name()) + self.update_trainer_configs( + "input_names", + self.input_model_.get_model_input_names(), ) + self.update_trainer_configs( + "target_names", + self.input_model_.get_target_names(), ) + self.update_trainer_configs( + "metrics", + self.input_model_.get_model_metrics(), ) + self.update_trainer_configs("show_metric", True) + self.update_trainer_configs("max_training_steps", "inf") + self.update_trainer_configs("shuffle", False) + self.update_trainer_configs("main_program_desc", + self.main_program_desc_) + self.update_trainer_configs("startup_program_desc", + self.startup_program_desc_) + + if do_test: + input_names = self.input_model_.get_model_input_names() + target_var_names = self.input_model_.get_target_names() + self.infer_program_ = self.main_program_._prune_with_input( + feeded_var_names=input_names, targets=target_var_names) + self.infer_program_ = self.infer_program_._inference_optimize( + prune_read_op=True) + fluid.io.prepend_feed_ops(self.infer_program_, input_names) + fluid.io.append_fetch_ops(self.infer_program_, target_var_names) + self.infer_program_.desc._set_version() + fluid.core.save_op_compatible_info(self.infer_program_.desc) + self.infer_program_desc_ = self.infer_program_.desc.serialize_to_string( + ) + self.update_trainer_configs("infer_program_desc", + self.infer_program_desc_) + + def init_global_model(self, scheduler_client): + logging.info("initializing global model") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(self.startup_program_) + logging.info("finish initializing global model") + + global_param_dict = self.input_model_.get_global_param_dict() + scheduler_client.update_global_params(global_param_dict) diff --git a/python/paddle_fl/mobile/trainer/trainer_base.py b/python/paddle_fl/mobile/trainer/trainer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0ddd20845303491e26ac3ae790abb3651c1815 --- /dev/null +++ b/python/paddle_fl/mobile/trainer/trainer_base.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TrainerBase(object): + def __init__(self): + self.trainer_config = None + self.train_one_user_func = None + self.infer_one_user_func = None + self.save_and_upload_func = None + + def get_user_param_names(self): + pass + + def get_global_param_names(self): + pass + + def set_trainer_configs(self, trainer_configs): + """ + config training parameter, only support the basic types of python + """ + self.trainer_config = trainer_configs + + def update_trainer_configs(self, key, val): + self.trainer_config[key] = val + + def prepare(self): + """ + generate network description string; + """ + pass + + def init_global_model(self, scheduler_client): + """ + initialize the network parameters, which will be broadcasted to all simulator + """ + pass diff --git a/python/paddle_fl/mobile/utils/__init__.py b/python/paddle_fl/mobile/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e9b1f50d87cc5fb3c52a0cfd90c7f60d0f0d8a --- /dev/null +++ b/python/paddle_fl/mobile/utils/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .role_maker import FLSimRoleMaker diff --git a/python/paddle_fl/mobile/utils/hdfs_utils.py b/python/paddle_fl/mobile/utils/hdfs_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ef950c685f380a6f5aa542e4a00b8b6763cee87c --- /dev/null +++ b/python/paddle_fl/mobile/utils/hdfs_utils.py @@ -0,0 +1,621 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""hdfs_utils.py will move to fluid/incubate/fleet/utils/hdfs.py""" + +import os +import sys +import subprocess +import multiprocessing +from datetime import datetime + +import re +import copy +import errno + +import logging + +__all__ = ["HDFSClient", "multi_download", "multi_upload"] + + +def get_logger(name, level, fmt): + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler(sys.stderr) + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +# DO NOT LOG +_logger = get_logger( + __name__, logging.CRITICAL, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class HDFSClient(object): + """ + A tool of HDFS + + Args: + hadoop_home (string): hadoop_home + configs (dict): hadoop config, it is a dict, please contain \ + key "fs.default.name" and "hadoop.job.ugi" + Can be a float value + Examples: + hadoop_home = "/home/client/hadoop-client/hadoop/" + + configs = { + "fs.default.name": "hdfs://xxx.hadoop.com:54310", + "hadoop.job.ugi": "hello,hello123" + } + + client = HDFSClient(hadoop_home, configs) + + client.ls("/user/com/train-25") + files = client.lsr("/user/com/train-25/models") + """ + + def __init__(self, hadoop_home, configs): + self.pre_commands = [] + hadoop_bin = '%s/bin/hadoop' % hadoop_home + self.pre_commands.append(hadoop_bin) + dfs = 'fs' + self.pre_commands.append(dfs) + + for k, v in configs.items(): + config_command = '-D%s=%s' % (k, v) + self.pre_commands.append(config_command) + + def __run_hdfs_cmd(self, commands, retry_times=5): + whole_commands = copy.deepcopy(self.pre_commands) + whole_commands.extend(commands) + + _logger.info('Running system command: {0}'.format(' '.join( + whole_commands))) + + ret_code = 0 + ret_out = None + ret_err = None + whole_commands = " ".join(whole_commands) + for x in range(retry_times + 1): + proc = subprocess.Popen( + whole_commands, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + (output, errors) = proc.communicate() + ret_code, ret_out, ret_err = proc.returncode, output, errors + if ret_code: + _logger.warn( + 'Times: %d, Error running command: %s. Return code: %d, Error: %s' + % (x, ' '.join(whole_commands), proc.returncode, errors)) + else: + break + return ret_code, ret_out, ret_err + + def upload(self, hdfs_path, local_path, overwrite=False, retry_times=5): + """ + upload the local file to hdfs + + Args: + hdfs_path(str): the hdfs file path + local_path(str): the local file path + overwrite(bool|None): will overwrite the file on HDFS or not + retry_times(int|5): retry times + + Returns: + True or False + """ + assert hdfs_path is not None + assert local_path is not None and os.path.exists(local_path) + + if os.path.isdir(local_path): + _logger.warn( + "The Local path: {} is dir and I will support it later, return". + format(local_path)) + return False + + base = os.path.basename(local_path) + if not self.is_exist(hdfs_path): + self.makedirs(hdfs_path) + else: + if self.is_exist(os.path.join(hdfs_path, base)): + if overwrite: + _logger.error( + "The HDFS path: {} is exist and overwrite is True, delete it". + format(hdfs_path)) + self.delete(hdfs_path) + else: + _logger.error( + "The HDFS path: {} is exist and overwrite is False, return". + format(hdfs_path)) + return False + + put_commands = ["-put", local_path, hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd(put_commands, + retry_times) + if returncode: + _logger.error("Put local path: {} to HDFS path: {} failed".format( + local_path, hdfs_path)) + return False + else: + _logger.info("Put local path: {} to HDFS path: {} successfully". + format(local_path, hdfs_path)) + return True + + def download(self, hdfs_path, local_path, overwrite=False, unzip=False): + """ + download file from HDFS + + Args: + hdfs_path(str): the hdfs file path + local_path(str): the local file path + overwrite(bool|None): will overwrite the file on HDFS or not + unzip(bool|False): if the download file is compressed by zip, unzip it or not. + + Returns: + True or False + """ + _logger.info('Downloading %r to %r.', hdfs_path, local_path) + _logger.info('Download of %s to %r complete.', hdfs_path, local_path) + + if not self.is_exist(hdfs_path): + print("HDFS path: {} do not exist".format(hdfs_path)) + return False + if self.is_dir(hdfs_path): + _logger.error( + "The HDFS path: {} is dir and I will support it later, return". + format(hdfs_path)) + + if os.path.exists(local_path): + base = os.path.basename(hdfs_path) + local_file = os.path.join(local_path, base) + if os.path.exists(local_file): + if overwrite: + os.remove(local_file) + else: + _logger.error( + "The Local path: {} is exist and overwrite is False, return". + format(local_file)) + return False + + self.make_local_dirs(local_path) + + download_commands = ["-get", hdfs_path, local_path] + returncode, output, errors = self.__run_hdfs_cmd(download_commands) + if returncode: + _logger.error("Get local path: {} from HDFS path: {} failed". + format(local_path, hdfs_path)) + return False + else: + _logger.info("Get local path: {} from HDFS path: {} successfully". + format(local_path, hdfs_path)) + return True + + def is_exist(self, hdfs_path=None): + """ + whether the remote HDFS path exists + + Args: + hdfs_path(str): the hdfs file path + + Returns: + True or False + """ + exist_cmd = ['-test', '-e', hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd( + exist_cmd, retry_times=1) + + if returncode: + _logger.error("HDFS is_exist HDFS path: {} failed".format( + hdfs_path)) + return False + else: + _logger.info("HDFS is_exist HDFS path: {} successfully".format( + hdfs_path)) + return True + + def is_dir(self, hdfs_path=None): + """ + whether the remote HDFS path is directory + + Args: + hdfs_path(str): the hdfs file path + + Returns: + True or False + """ + + if not self.is_exist(hdfs_path): + return False + + dir_cmd = ['-test', '-d', hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd( + dir_cmd, retry_times=1) + + if returncode: + _logger.error("HDFS path: {} failed is not a directory".format( + hdfs_path)) + return False + else: + _logger.info("HDFS path: {} successfully is a directory".format( + hdfs_path)) + return True + + def delete(self, hdfs_path): + """ + Remove a file or directory from HDFS. + + whether the remote HDFS path exists + + Args: + hdfs_path: HDFS path. + + Returns: + True or False + This function returns `True` if the deletion was successful and `False` if + no file or directory previously existed at `hdfs_path`. + """ + _logger.info('Deleting %r.', hdfs_path) + + if not self.is_exist(hdfs_path): + _logger.warn("HDFS path: {} do not exist".format(hdfs_path)) + return True + + if self.is_dir(hdfs_path): + del_cmd = ['-rmr', hdfs_path] + else: + del_cmd = ['-rm', hdfs_path] + + returncode, output, errors = self.__run_hdfs_cmd( + del_cmd, retry_times=0) + + if returncode: + _logger.error("HDFS path: {} delete files failure".format( + hdfs_path)) + return False + else: + _logger.info("HDFS path: {} delete files successfully".format( + hdfs_path)) + return True + + def rename(self, hdfs_src_path, hdfs_dst_path, overwrite=False): + """ + Move a file or folder on HDFS. + + Args: + hdfs_path(str): HDFS path. + overwrite(bool|False): If the path already exists and overwrite is False, will return False. + + Returns: + True or False + """ + assert hdfs_src_path is not None + assert hdfs_dst_path is not None + + if not self.is_exist(hdfs_src_path): + _logger.info("HDFS path do not exist: {}".format(hdfs_src_path)) + if self.is_exist(hdfs_dst_path) and not overwrite: + _logger.error("HDFS path is exist: {} and overwrite=False".format( + hdfs_dst_path)) + + rename_command = ['-mv', hdfs_src_path, hdfs_dst_path] + returncode, output, errors = self.__run_hdfs_cmd( + rename_command, retry_times=1) + + if returncode: + _logger.error("HDFS rename path: {} to {} failed".format( + hdfs_src_path, hdfs_dst_path)) + return False + else: + _logger.info("HDFS rename path: {} to {} successfully".format( + hdfs_src_path, hdfs_dst_path)) + return True + + @staticmethod + def make_local_dirs(local_path): + """ + create a directiory local, is same to mkdir + Args: + local_path: local path that wants to create a directiory. + """ + try: + os.makedirs(local_path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + def makedirs(self, hdfs_path): + """ + Create a remote directory, recursively if necessary. + + Args: + hdfs_path(str): Remote path. Intermediate directories will be created appropriately. + + Returns: + True or False + """ + _logger.info('Creating directories to %r.', hdfs_path) + assert hdfs_path is not None + + if self.is_exist(hdfs_path): + _logger.error("HDFS path is exist: {}".format(hdfs_path)) + return + + mkdirs_commands = ['-mkdir', hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd( + mkdirs_commands, retry_times=1) + + if returncode: + _logger.error("HDFS mkdir path: {} failed".format(hdfs_path)) + return False + else: + _logger.error("HDFS mkdir path: {} successfully".format(hdfs_path)) + return True + + def ls(self, hdfs_path): + """ + ls directory contents about HDFS hdfs_path + + Args: + hdfs_path(str): Remote HDFS path will be ls. + + Returns: + List: a contents list about hdfs_path. + """ + assert hdfs_path is not None + + if not self.is_exist(hdfs_path): + return [] + + ls_commands = ['-ls', hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd( + ls_commands, retry_times=1) + + if returncode: + _logger.error("HDFS list path: {} failed".format(hdfs_path)) + return [] + else: + _logger.info("HDFS list path: {} successfully".format(hdfs_path)) + + ret_lines = [] + regex = re.compile('\s+') + out_lines = output.strip().split("\n") + for line in out_lines: + re_line = regex.split(line) + if len(re_line) == 8: + ret_lines.append(re_line[7]) + return ret_lines + + def lsr(self, hdfs_path, only_file=True, sort=True): + """ + list directory contents about HDFS hdfs_path recursively + + Args: + hdfs_path(str): Remote HDFS path. + only_file(bool|True): will discard folders. + sort(bool|True): will be sorted by create time. + + Returns: + List: a contents list about hdfs_path. + """ + + def sort_by_time(v1, v2): + v1_time = datetime.strptime(v1[1], '%Y-%m-%d %H:%M') + v2_time = datetime.strptime(v2[1], '%Y-%m-%d %H:%M') + return v1_time > v2_time + + assert hdfs_path is not None + + if not self.is_exist(hdfs_path): + return [] + + ls_commands = ['-lsr', hdfs_path] + returncode, output, errors = self.__run_hdfs_cmd( + ls_commands, retry_times=1) + + if returncode: + _logger.error("HDFS list all files: {} failed".format(hdfs_path)) + return [] + else: + _logger.info("HDFS list all files: {} successfully".format( + hdfs_path)) + lines = [] + regex = re.compile('\s+') + out_lines = output.strip().split("\n") + for line in out_lines: + re_line = regex.split(line) + if len(re_line) == 8: + if only_file and re_line[0][0] == "d": + continue + else: + lines.append( + (re_line[7], re_line[5] + " " + re_line[6])) + if sort: + sorted(lines, cmp=sort_by_time) + ret_lines = [ret[0] for ret in lines] + return ret_lines + + +def multi_download(client, + hdfs_path, + local_path, + trainer_id, + trainers, + multi_processes=5): + """ + Download files from HDFS using multi process. + + Args: + client(HDFSClient): instance of HDFSClient + hdfs_path(str): path on hdfs + local_path(str): path on local + trainer_id(int): current trainer id + trainers(int): all trainers number + multi_processes(int|5): the download data process at the same time, default=5 + + Returns: + List: + Download files in local folder. + """ + print("multi-downloading from %s to %s" % (hdfs_path, local_path)) + + def __subprocess_download(datas): + for data in datas: + re_path = os.path.relpath(os.path.dirname(data), hdfs_path) + if re_path == os.curdir: + sub_local_re_path = local_path + else: + sub_local_re_path = os.path.join(local_path, re_path) + client.download(data, sub_local_re_path) + + assert isinstance(client, HDFSClient) + + client.make_local_dirs(local_path) + _logger.info("Make local dir {} successfully".format(local_path)) + + all_need_download = client.lsr(hdfs_path, sort=True) + need_download = all_need_download[trainer_id::trainers] + _logger.info( + "Get {} files From all {} files need to be download from {}".format( + len(need_download), len(all_need_download), hdfs_path)) + + _logger.info("Start {} multi process to download datas".format( + multi_processes)) + procs = [] + for i in range(multi_processes): + process_datas = need_download[i::multi_processes] + p = multiprocessing.Process( + target=__subprocess_download, args=(process_datas, )) + procs.append(p) + p.start() + + # complete the processes + for proc in procs: + proc.join() + + _logger.info("Finish {} multi process to download datas".format( + multi_processes)) + + local_downloads = [] + for data in need_download: + data_name = os.path.basename(data) + re_path = os.path.relpath(os.path.dirname(data), hdfs_path) + if re_path == os.curdir: + local_re_path = os.path.join(local_path, data_name) + else: + local_re_path = os.path.join(local_path, re_path, data_name) + local_downloads.append(local_re_path) + + return local_downloads + + +def getfilelist(path): + rlist = [] + for dir, folder, file in os.walk(path): + for i in file: + t = os.path.join(dir, i) + rlist.append(t) + for r in rlist: + print(r) + + +def multi_upload(client, + hdfs_path, + local_path, + multi_processes=5, + overwrite=False, + sync=True): + """ + Upload files to HDFS using multi process. + + Args: + client(HDFSClient): instance of HDFSClient + hdfs_path(str): path on hdfs + local_path(str): path on local + multi_processes(int|5): the upload data process at the same time, default=5 + overwrite(bool|False): will overwrite file on HDFS or not + sync(bool|True): upload files sync or not. + + Returns: + None + """ + print("multi-uploading from %s to %s" % (local_path, hdfs_path)) + + def __subprocess_upload(datas): + for data in datas: + re_path = os.path.relpath(os.path.dirname(data), local_path) + hdfs_re_path = os.path.join(hdfs_path, re_path) + client.upload(hdfs_re_path, data, overwrite, retry_times=5) + + def get_local_files(path): + rlist = [] + + if not os.path.isdir(path): + return rlist + + for dirname, folder, files in os.walk(path): + for i in files: + t = os.path.join(dirname, i) + rlist.append(t) + return rlist + + assert isinstance(client, HDFSClient) + + all_files = get_local_files(local_path) + if not all_files: + _logger.info("there are nothing need to upload, exit") + return + _logger.info("Start {} multi process to upload datas".format( + multi_processes)) + procs = [] + for i in range(multi_processes): + process_datas = all_files[i::multi_processes] + p = multiprocessing.Process( + target=__subprocess_upload, args=(process_datas, )) + procs.append(p) + p.start() + + # complete the processes + for proc in procs: + proc.join() + + _logger.info("Finish {} multi process to upload datas".format( + multi_processes)) + + +if __name__ == "__main__": + hadoop_home = "/home/client/hadoop-client/hadoop/" + + configs = { + "fs.default.name": "hdfs://xxx.hadoop.com:54310", + "hadoop.job.ugi": "hello,hello123" + } + + client = HDFSClient(hadoop_home, configs) + + client.ls("/user/com/train-25") + files = client.lsr("/user/com/train-25/models") + + downloads = multi_download( + client, + "/user/com/train-25/model", + "/home/xx/data1", + 1, + 5, + 100, + multi_processes=5) + + multi_upload(client, "/user/com/train-25/model", "/home/xx/data1") diff --git a/python/paddle_fl/mobile/utils/logger.py b/python/paddle_fl/mobile/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1234a58bee6eaf6e2226e94b1071359b75562d94 --- /dev/null +++ b/python/paddle_fl/mobile/utils/logger.py @@ -0,0 +1,22 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import logging + +logging.basicConfig( + stream=sys.stdout, + format='%(asctime)s %(filename)s : %(levelname)s %(message)s', + level=logging.DEBUG) +logger = logging.getLogger(__name__) diff --git a/python/paddle_fl/mobile/utils/role_maker.py b/python/paddle_fl/mobile/utils/role_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd37d157b9611521a6c371c86266e57f1092b20 --- /dev/null +++ b/python/paddle_fl/mobile/utils/role_maker.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from utils.logger import logging + + +class FLSimRoleMaker(object): + def __init__(self): + from mpi4py import MPI + self.MPI = MPI + self._comm = MPI.COMM_WORLD + + def init_env(self, local_shard_num=1): + if (local_shard_num > 250): + logging.error("shard num must be less than or equal to 250") + exit() + self.free_ports = self._get_free_endpoints(local_shard_num, 40000) + import socket + ip = socket.gethostbyname(socket.gethostname()) + self.free_endpoints = ["{}:{}".format(ip, x) for x in self.free_ports] + self._comm.barrier() + self.node_type = 1 + if self._comm.Get_rank() == 0: + self.node_type = 0 + self.group_comm = self._comm.Split(self.node_type) + self.all_endpoints = self._comm.allgather(self.free_endpoints) + + def simulator_num(self): + return self._comm.Get_size() - 1 + + def simulator_idx(self): + return self._comm.Get_rank() - 1 + + def get_global_scheduler_endpoint(self): + return self.all_endpoints[0][0] + + def get_data_server_endpoints(self): + #return self.all_endpoints[2:][::2] + all_endpoints = [] + for eps in self.all_endpoints[1:]: + all_endpoints.extend(eps[:len(eps) / 2]) + return all_endpoints + + def get_local_data_server_endpoint(self): + if self._comm.Get_rank() < 1: + return None + local_endpoints = self.all_endpoints[self._comm.Get_rank()] + return local_endpoints[:len(local_endpoints) / 2] + + def get_local_param_server_endpoint(self): + if self._comm.Get_rank() < 1: + return None + local_endpoints = self.all_endpoints[self._comm.Get_rank()] + return local_endpoints[len(local_endpoints) / 2:] + + def is_global_scheduler(self): + rank = self._comm.Get_rank() + return rank == 0 + + def is_simulator(self): + return self._comm.Get_rank() > 0 + + def barrier_simulator(self): + if self._comm.Get_rank() > 0: + self.group_comm.barrier() + + def barrier(self): + self.group_comm.barrier() + + def _get_free_endpoints(self, local_shard_num, start_port): + import psutil + conns = psutil.net_connections() + x = [conn.laddr.port for conn in conns] + free_endpoints = [] + step = 500 + start_range = start_port + self._comm.Get_rank() * step + for i in range(start_range, start_range + step, 1): + if i in x: + continue + if i > 65535: + continue + else: + free_endpoints.append(i) + if len(free_endpoints) == local_shard_num * 2: + break + return free_endpoints diff --git a/python/paddle_fl/mobile/utils/run_servers.sh b/python/paddle_fl/mobile/utils/run_servers.sh new file mode 100644 index 0000000000000000000000000000000000000000..3db7d5d69a273ae73e790cf03b8b89f4a0302904 --- /dev/null +++ b/python/paddle_fl/mobile/utils/run_servers.sh @@ -0,0 +1,6 @@ +endpoint=("50050" "50051" "50052" "50053" "50054" "50055" "50056" "50057" "50058" "50059") + +for i in {0..9} +do + python data_server_impl.py ${endpoint[$i]} & +done diff --git a/python/paddle_fl/mobile/utils/test_role_maker.py b/python/paddle_fl/mobile/utils/test_role_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..f336761bc56573c5d43fe31620da8fa30daf6dd4 --- /dev/null +++ b/python/paddle_fl/mobile/utils/test_role_maker.py @@ -0,0 +1,29 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from role_maker import FLSimRoleMaker + +role_maker = FLSimRoleMaker() +role_maker.init_env(local_shard_num=30) +print("simulator num: {}".format(role_maker.simulator_num())) +print("simulator idx: {}".format(role_maker.simulator_idx())) +print("global scheduler endpoint: {}".format( + role_maker.get_global_scheduler_endpoint())) +print("data server endpoints") +print(role_maker.get_data_server_endpoints()) +print("local data server") +print(role_maker.get_local_data_server_endpoint()) +print("local param server") +print(role_maker.get_local_param_server_endpoint()) +role_maker.barrier_simulator()