未验证 提交 e2d65fef 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #74 from mapingshuo/fl-mobile

add mobile
clean:
rm -rf *~ *pyc */*~ */*.pyc
## 联邦算法模拟器 (fl-mobile simulator)
FL-mobile是一个集移动端算法模拟调研、训练和部署为一体的框架。算法模拟器 (simulator) 是FL-mobile的一部分。
该模拟器的设计目的,是为了模拟实际线上多个移动端设备配合训练的场景。框架的设计思想在服务器上模拟数个端上设备,快速验证算法效果。模拟器的优势为:
- 支持单机和分布式训练
- 支持常见开源数据集的训练
- 支持模型中的私有参数和共享参数,私有参数不参与全局更新
## 准备工作
- 安装mpirun
- python安装grpc
```shell
pip install grpcio==1.28.1
```
- 安装Paddle
## 快速开始
我们以Leaf数据集中的[reddit数据](https://github.com/TalwalkarLab/leaf/tree/master/data/reddit)为例,用LSTM建模,在simulator
中给出一个单机训练的例子,通过这个例子,您能了解simulator的基础用法。
### 准备数据
```
wget https://paddle-serving.bj.bcebos.com/temporary_files_for_docker/reddit_subsampled.zip --no-check-certificate
unzip reddit_subsampled.zip
```
在模拟器中,我们假设用户的数据都是天级别的,因此我们将下载的数据重新归置如下
```
tree lm_data
lm_data
|-- 20200101
| `-- train_data.json
|-- 20200102
| `-- test_data.json
`-- vocab.json
```
可以看到,我们将训练数据作为20200101的数据,测试数据作为20200102的数据。
### 生成server代码
```
cd protos
python run_codegen.py
cd ..
```
### 开始训练
```shell
export PYTHONPATH=$PWD:$PYTHONPATH
mpirun -np 2 python application.py lm_data
```
### 训练结果
```shell
framework.py : INFO infer results: 0.085723
```
即:在测试集上的,测试Top1为 8.6%
## 添加自己的数据集和Trainer
如果您想要训练自己的联邦模型,您需要做四件事:
1. 创建reader,参考`reader/leaf_reddit_reader.py`
2. 创建trainer,参考`trainer/language_model_trainer.py`
3. 创建model,即组网,参考`model/language_model.py`
4. 创建application,参考`application.py`
## 模拟器(simulator) 介绍
框架主要由scheduler和simulator构成;其中scheduler负责统筹规划数据和全局参数;simulator负责做实际的训练和私有参数更新。
- scheduler
在一次训练流程中,只会有一个global scheduler, 而每个机器上都会有一个scheduler client,负责和global scheduler做参数、数据的通信。
- simulator
每个机器上都会有1个Simulator,每个Simulator又会有多个shard,shard是用于本机并行训练的。
作为一个分布式框架,FL-mobile simulator 也是包含模型初始化、模型分发、模型训练、模型更新四部分,下面通过一次实际的训练流程,来了解一下各个模型的工作吧:
- Step1 模型初始化
1. 全局参数初始化:由编号为0的simulator来做模型初始化工作,初始化之后,它会通过UpdateGlobalParams()接口将参数传递给Scheduler;
2. 个性化参数初始化
- Step2 模型分发
1. 全局参数分发:每个simulator开始训练之前,都需要先找SchedulerServer拿全局参数,即通过scheduler_client.get_global_params()获得全局参数;
2. 个性化参数分发:个性化参数由每个local trainer训练前,向data server获取,获取接口为:get_param_by_uid;
- Step3 模型训练
模型训练是整个流程的核心,多个trainer并行训练,训练中参数不做同步;每个trainer的训练步数由这个用户的数据量决定。
- Step4 模型更新
1. 全局参数更新:在所有trainer训练结束后,会做一次同步,并且通过FedAvg算法计算参数梯度;之后上传梯度至scheduler,再拉取新的全局参数,回到第二步;
2. 个性化参数更新:个性化参数更新简单一些,每个trainer调用set_param_by_uid就可完成自己的个性化参数更新;
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
For large scale mobile devices simulation, we need to consider:
- model structure and input features
- trainer, i.e. how to design training strategy for each simulation
- user sampler, i.e. users that need to participate training
- optimizer, i.e. how to update global weights
Currently, we couple trainer and model together for simplicity
'''
from utils import FLSimRoleMaker
from framework import SimulationFramework
from trainer import LanguageModelTrainer
from optimizer import FedAvgOptimizer
from sampler import UniformSampler, Test1percentSampler
from datetime import date, timedelta
import sys
import time
role_maker = FLSimRoleMaker()
role_maker.init_env(local_shard_num=30)
simulator = SimulationFramework(role_maker)
language_model_trainer = LanguageModelTrainer()
language_model_trainer.set_trainer_configs({
"epoch": 3,
"max_steps_in_epoch": -1,
"lr": 0.1,
"batch_size": 5,
})
sampler = UniformSampler()
sampler.set_sample_num(30)
sampler.set_min_ins_num(1)
test_sampler = Test1percentSampler()
fed_avg_optimizer = FedAvgOptimizer(learning_rate=2.0)
simulator.set_trainer(language_model_trainer)
simulator.set_sampler(sampler)
simulator.set_test_sampler(test_sampler)
simulator.set_fl_optimizer(fed_avg_optimizer)
if simulator.is_scheduler():
simulator.run_scheduler_service()
elif simulator.is_simulator():
base_path = sys.argv[1]
dates = []
start_date = date(2020, 1, 1)
end_date = date(2020, 1, 2)
delta = timedelta(days=1)
while start_date <= end_date:
dates.append(start_date.strftime("%Y%m%d"))
start_date += delta
print("dates: {}".format(dates))
time.sleep(10)
simulator.run_simulation(
base_path, dates, sim_num_everyday=100, do_test=True, test_skip_day=1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data_client_impl import DataClient
from .scheduler_client_impl import SchedulerClient
# -*- coding: utf-8 -*-
from __future__ import print_function
import grpc
import servers.data_server_pb2 as data_server_pb2
import servers.data_server_pb2_grpc as data_server_pb2_grpc
from concurrent import futures
from multiprocessing import Process
from utils.hdfs_utils import HDFSClient, multi_download
import time
import sys
import os
import xxhash
import numpy as np
from utils.logger import logging
class DataClient(object):
def __init__(self):
self.stub_list = []
self.load_data_into_patch = None
def uid_shard(self, uid):
try:
uid_hash = xxhash.xxh32(str(uid), seed=101).intdigest()
except:
return -1
shard_idx = uid_hash % len(self.stub_list)
return shard_idx
# should set all params to numpy array with shape and dtype
# buggy here
def set_param_by_uid(self, uid, param_dict):
shard_idx = self.uid_shard(uid)
if shard_idx == -1:
return -1
user_param = data_server_pb2.UserParams()
user_param.uid = uid
for key in param_dict:
param = data_server_pb2.Param()
param.name = key
np_var = param_dict[param.name]
param.shape.extend(np_var.shape)
param.weight.extend(np_var.ravel())
user_param.user_params.extend([param])
call_future = self.stub_list[shard_idx].UpdateUserParams.future(
user_param)
err_code = call_future.result().err_code
return err_code
def get_param_by_uid(self, uid):
shard_idx = self.uid_shard(uid)
if shard_idx == -1:
return -1
data = data_server_pb2.Data()
data.uid = uid
call_future = self.stub_list[shard_idx].GetUserParams.future(data)
user_params = call_future.result()
param_dict = {}
for param in user_params.user_params:
param_dict[param.name] = np.array(
list(param.weight), dtype=np.float32)
param_dict[param.name].shape = list(param.shape)
return param_dict
def clear_user_data(self, date):
def clear():
for stub in self.stub_list:
data = data_server_pb2.Data()
data.date = date
call_future = stub.ClearUserData.future(data)
res = call_future.result()
p = Process(target=clear, args=())
p.start()
p.join()
def get_data_by_uid(self, uid, date):
shard_idx = self.uid_shard(uid)
if shard_idx == -1:
return -1
data = data_server_pb2.Data()
data.uid = uid
data.date = date
call_future = self.stub_list[shard_idx].GetUserData.future(data)
user_data_list = []
for item in call_future.result().line_str:
user_data_list.append(item)
return user_data_list
def set_data_server_endpoints(self, endpoints):
self.stub_list = []
for ep in endpoints:
options = [('grpc.max_message_length', 1024 * 1024 * 1024),
('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
channel = grpc.insecure_channel(ep, options=options)
stub = data_server_pb2_grpc.DataServerStub(channel)
self.stub_list.append(stub)
def global_shuffle_by_patch(self, data_patch, date, concurrency):
shuffle_time = len(data_patch) / concurrency + 1
for i in range(shuffle_time):
if i * concurrency >= len(data_patch):
break
pros = []
end = min((i + 1) * concurrency, len(data_patch))
patch_list = data_patch[i * concurrency:end]
width = len(patch_list)
for j in range(width):
p = Process(
target=self.send_one_patch, args=(patch_list[j], date))
pros.append(p)
for p in pros:
p.start()
for p in pros:
p.join()
logging.info("shuffle round {} done.".format(i))
def send_one_patch(self, patch, date):
for line in patch:
group = line.strip().split("\t")
if len(group) != 3:
continue
data = data_server_pb2.Data()
data.uid = group[0]
data.date = date
data.line = line.strip()
stub_idx = self.uid_shard(data.uid)
if stub_idx == -1:
logging.info("send_one_patch continue for uid: %s" % data.uid)
continue
call_future = self.stub_list[stub_idx].SendData.future(data)
u_num = call_future.result()
def global_shuffle_by_file(self, filelist, concurrency):
pass
def set_load_data_into_patch_func(self, func):
self.load_data_into_patch = func
def get_local_files(self,
base_path,
date,
node_idx,
node_num,
hdfs_configs=None):
full_path = "{}/{}".format(base_path, date)
if os.path.exists(full_path):
file_list = os.listdir(full_path)
local_files = ["{}/{}".format(full_path, x) for x in file_list]
elif hdfs_configs is not None:
local_files = self.download_from_hdfs(hdfs_configs, base_path,
date, node_idx, node_num)
else:
local_files = []
return local_files
def download_from_hdfs(self, hdfs_configs, base_path, date, node_idx,
node_num):
# return local filelist
hdfs_client = HDFSClient("$HADOOP_HOME", hdfs_configs)
multi_download(
hdfs_client,
"{}/{}".format(base_path, date),
date,
node_idx,
node_num,
multi_processes=30)
filelist = os.listdir(date)
files = ["{}/{}".format(date, fn) for fn in filelist]
return files
def test_global_shuffle():
data_client = DataClient()
server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)]
data_client.set_data_server_endpoints(server_endpoints)
date = "0330"
file_name = ["data_with_uid/part-01991"]
with open(file_name[0]) as fin:
for line in fin:
group = line.strip().split("\t")
uid = group[0]
user_data_dict = data_client.get_data_by_uid(uid)
def test_set_param():
data_client = DataClient()
server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)]
data_client.set_data_server_endpoints(server_endpoints)
uid = ["1001", "10001", "100001", "101"]
param_dict = {"w0": [1.0, 1.1, 1.2, 1.3], "b0": [1.1, 1.2, 1.3, 1.5]}
for cur_i in uid:
data_client.set_param_by_uid(cur_i, param_dict)
def test_get_param():
data_client = DataClient()
server_endpoints = ["127.0.0.1:{}".format(50050 + i) for i in range(10)]
data_client.set_data_server_endpoints(server_endpoints)
uid = ["1001", "10001", "100001", "101"]
for cur_i in uid:
param_dict = data_client.get_param_by_uid(cur_i)
print(param_dict)
if __name__ == "__main__":
#load_data_global_shuffle()
#test_global_shuffle()
test_set_param()
test_get_param()
# -*- coding: utf-8 -*-
from __future__ import print_function
import grpc
import servers.scheduler_server_pb2 as scheduler_server_pb2
import servers.scheduler_server_pb2_grpc as scheduler_server_pb2_grpc
import servers.data_server_pb2_grpc as data_server_pb2_grpc
import servers.data_server_pb2 as data_server_pb2
import numpy as np
from concurrent import futures
from multiprocessing import Process
import time
import sys
import os
class SchedulerClient(object):
def __init__(self):
self.stub = None
self.stub_list = []
self.global_param_info = {}
def update_user_inst_num(self, date, user_info_dict):
user_info = scheduler_server_pb2.UserInstInfo()
for key in user_info_dict:
single_user_info = scheduler_server_pb2.UserInstNum()
single_user_info.uid = key
single_user_info.inst_num = user_info_dict[key]
user_info.inst_nums.extend([single_user_info])
user_info.shard_num = len(self.stub_list)
user_info.date = date
call_future = self.stub.UpdateUserInstNum.future(user_info)
res = call_future.result()
return res.err_code
def set_scheduler_server_endpoints(self, endpoints):
options = [('grpc.max_message_length', 1024 * 1024 * 1024),
('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
channel = grpc.insecure_channel(endpoints[0], options=options)
self.stub = scheduler_server_pb2_grpc.SchedulerServerStub(channel)
def set_data_server_endpoints(self, endpoints):
self.stub_list = []
for ep in endpoints:
options = [('grpc.max_message_length', 1024 * 1024 * 1024),
('grpc.max_receive_message_length', 1024 * 1024 * 1024)]
channel = grpc.insecure_channel(ep, options=options)
stub = data_server_pb2_grpc.DataServerStub(channel)
self.stub_list.append(stub)
def uniform_sample_user_list(self, date, node_id, sample_num, shard_num,
node_num, min_ins_num):
user_info_dict = {}
req = scheduler_server_pb2.Request()
req.node_idx = node_id
req.sample_num = sample_num
req.shard_num = shard_num
req.node_num = node_num
req.date = date
req.min_ins_num = min_ins_num
call_future = self.stub.SampleUsersToTrain.future(req)
user_info = call_future.result()
for user in user_info.inst_nums:
user_info_dict[user.uid] = user.inst_num
return user_info_dict
def hash_sample_user_list(self, date, node_id, sample_num, shard_num,
node_num):
user_info_dict = {}
req = scheduler_server_pb2.Request()
req.node_idx = node_id
req.sample_num = sample_num
req.shard_num = shard_num
req.node_num = node_num
req.date = date
call_future = self.stub.SampleUsersWithHash.future(req)
user_info = call_future.result()
for user in user_info.inst_nums:
user_info_dict[user.uid] = user.inst_num
return user_info_dict
def sample_test_user_list(self, date, node_id, shard_num, node_num):
user_info_dict = {}
req = scheduler_server_pb2.Request()
req.node_idx = node_id
req.shard_num = shard_num
req.node_num = node_num
req.date = date
call_future = self.stub.SampleUsersToTest.future(req)
user_info = call_future.result()
for user in user_info.inst_nums:
user_info_dict[user.uid] = user.inst_num
return user_info_dict
def fixed_sample_user_list(self, date, node_id, sample_num, shard_num,
node_num):
user_info_dict = {}
req = scheduler_server_pb2.Request()
req.node_idx = node_id
req.sample_num = sample_num
req.shard_num = shard_num
req.node_num = node_num
req.date = date
call_future = self.stub.FixedUsersToTrain.future(req)
user_info = call_future.result()
for user in user_info.inst_nums:
user_info_dict[user.uid] = user.inst_num
return user_info_dict
def get_global_params(self):
req = scheduler_server_pb2.Request()
req.node_idx = 0
req.sample_num = 0
req.shard_num = 0
req.node_num = 0
call_future = self.stub.GetGlobalParams.future(req)
global_p = call_future.result()
result_dict = {}
for param in global_p.global_params:
result_dict[param.name] = np.array(
list(param.weight), dtype=np.float32)
result_dict[param.name].shape = param.shape
return result_dict
def update_global_params(self, global_params):
global_p = scheduler_server_pb2.GlobalParams()
for key in global_params:
param = scheduler_server_pb2.Param()
param.name = key
var, shape = global_params[key], global_params[key].shape
self.global_param_info[key] = shape
param.weight.extend(var.ravel())
param.shape.extend(shape)
global_p.global_params.extend([param])
call_future = self.stub.UpdateGlobalParams.future(global_p)
res = call_future.result()
return res.err_code
def fedavg_update(self, global_param_delta_dict):
global_p = scheduler_server_pb2.GlobalParams()
for key in global_param_delta_dict:
param = scheduler_server_pb2.Param()
param.name = key
parameter_delta, shape = global_param_delta_dict[
param.name], global_param_delta_dict[param.name].shape
param.weight.extend(parameter_delta.ravel())
param.shape.extend(shape)
global_p.global_params.extend([param])
call_future = self.stub.FedAvgUpdate.future(global_p)
res = call_future.result()
return res.err_code
def stop_scheduler_server(self):
empty_input = scheduler_server_pb2.SchedulerServerEmptyInput()
call_future = self.stub.Exit.future(empty_input)
res = call_future.result()
def test_uniform_sample_user_list():
client = SchedulerClient()
client.set_scheduler_server_endpoints(["127.0.0.1:60001"])
# buggy
#user_list = client.get_user_list()
global_user_list = []
for i in range(10000):
global_user_list.append((str(i), 100))
client.update_user_inst_num(global_user_list)
user_info = client.uniform_sample_user_list(0)
user_list = [("101", 100), ("102", 100), ("103", 10000)]
global_param = {"w0": [1.0, 1.0, 1.0], "w1": [2.0, 2.0, 2.0]}
client.update_global_params(global_param)
fetched_params = client.get_global_params()
def test_get_global_params():
client = SchedulerClient()
client.set_scheduler_server_endpoints(["127.0.0.1:60001"])
global_param = {"w0": [1.0, 1.0, 1.0], "w1": [2.0, 2.0, 2.0]}
client.update_global_params(global_param)
fetched_params = client.get_global_params()
print(fetched_params)
def test_update_global_params():
client = SchedulerClient()
client.set_scheduler_server_endpoints(["127.0.0.1:60001"])
global_param = {"w0": [1.0, 1.0, 1.0], "w1": [3.0, 3.0, 3.0]}
client.update_global_params(global_param)
fetched_params = client.get_global_params()
print(fetched_params)
if __name__ == "__main__":
test_uniform_sample_user_list()
test_update_global_params()
test_get_global_params()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.role_maker import FLSimRoleMaker
from clients import DataClient
from clients import SchedulerClient
from servers import DataServer
from servers import SchedulerServer
from multiprocessing import Process, Pool, Manager, Pipe, Lock
from utils.logger import logging
import pickle
import time
import numpy as np
import sys
import os
class SimulationFramework(object):
def __init__(self, role_maker):
self.data_client = None
self.scheduler_client = None
self.role_maker = role_maker
# we suppose currently we train homogeneous model
self.trainer = None
# for sampling users
self.sampler = None
# for update global weights
self.fl_optimizer = None
# for samping users to test
self.test_sampler = None
self.profile_file = open("profile", "w")
self.do_profile = True
# for data downloading
self.hdfs_configs = None
def set_hdfs_configs(self, configs):
self.hdfs_configs = configs
def set_trainer(self, trainer):
self.trainer = trainer
def set_sampler(self, sampler):
self.sampler = sampler
def set_test_sampler(self, sampler):
self.test_sampler = sampler
def set_fl_optimizer(self, optimizer):
self.fl_optimizer = optimizer
def is_scheduler(self):
return self.role_maker.is_global_scheduler()
def is_simulator(self):
return self.role_maker.is_simulator()
def run_scheduler_service(self):
if self.role_maker.is_global_scheduler():
self._run_global_scheduler()
def _barrier_simulators(self):
self.role_maker.barrier_simulator()
def _start_data_server(self, endpoint):
data_server = DataServer()
port = endpoint.split(":")[1]
data_server.start(endpoint=port)
def _run_global_scheduler(self):
scheduler_server = SchedulerServer()
endpoint = self.role_maker.get_global_scheduler_endpoint()
port = endpoint.split(":")[1]
scheduler_server.start(endpoint=port)
def _get_data_services(self):
data_server_endpoints = \
self.role_maker.get_local_data_server_endpoint()
data_server_pros = []
for i, ep in enumerate(data_server_endpoints):
p = Process(target=self._start_data_server, args=(ep, ))
data_server_pros.append(p)
return data_server_pros
def _profile(self, func, *args, **kwargs):
if self.do_profile:
start = time.time()
res = func(*args, **kwargs)
end = time.time()
self.profile_file.write("%s\t\t%f s\n" %
(func.__name__, end - start))
return res
else:
return func(*args, **kwargs)
def _run_sim(self, date, sim_num_everyday=1):
sim_idx = self.role_maker.simulator_idx()
sim_num = self.role_maker.simulator_num()
sim_all_trainer_run_time = 0
sim_read_praram_and_optimize = 0
for sim in range(sim_num_everyday):
logging.info("sim id: %d" % sim)
# sampler algorithm
user_info_dict = self._profile(
self.sampler.sample_user_list, self.scheduler_client, date,
sim_idx, len(self.data_client.stub_list), sim_num)
if self.do_profile:
print("sim_idx: ", sim_idx)
print("shard num: ", len(self.data_client.stub_list))
print("sim_num: ", sim_num)
print("user_info_dict: ", user_info_dict)
global_param_dict = self._profile(
self.scheduler_client.get_global_params)
processes = []
os.system("rm -rf _global_param")
os.system("mkdir _global_param")
start = time.time()
for idx, user in enumerate(user_info_dict):
arg_dict = {
"uid": str(user),
"date": date,
"data_endpoints":
self.role_maker.get_data_server_endpoints(),
"global_params": global_param_dict,
"user_param_names": self.trainer.get_user_param_names(),
"global_param_names":
self.trainer.get_global_param_names(),
"write_global_param_file":
"_global_param/process_%d" % idx,
}
p = Process(
target=self.trainer.train_one_user_func,
args=(arg_dict, self.trainer.trainer_config))
p.start()
processes.append(p)
if self.do_profile:
logging.info("wait processes to close")
for i, p in enumerate(processes):
processes[i].join()
end = time.time()
sim_all_trainer_run_time += (end - start)
start = time.time()
train_result = []
new_global_param_by_user = {}
training_sample_by_user = {}
for i, p in enumerate(processes):
param_dir = "_global_param/process_%d/" % i
with open(param_dir + "/_info", "r") as f:
user, train_sample_num = pickle.load(f)
param_dict = {}
for f_name in os.listdir(os.path.join(param_dir, "params")):
f_path = os.path.join(param_dir, "params", f_name)
if os.path.isdir(f_path): # layer
for layer_param in os.listdir(f_path):
layer_param_path = os.path.join(f_path,
layer_param)
with open(layer_param_path) as f:
param_dict["{}/{}".format(
f_name, layer_param)] = np.load(f)
else:
with open(f_path) as f:
param_dict[f_name] = np.load(f)
new_global_param_by_user[user] = param_dict
training_sample_by_user[user] = train_sample_num
self.fl_optimizer.update(training_sample_by_user,
new_global_param_by_user,
global_param_dict, self.scheduler_client)
end = time.time()
sim_read_praram_and_optimize += (end - start)
if self.do_profile:
self.profile_file.write("sim_all_trainer_run_time\t\t%f s\n" %
sim_all_trainer_run_time)
self.profile_file.write("sim_read_praram_and_optimize\t\t%f s\n" %
sim_read_praram_and_optimize)
logging.info("training done for date %s." % date)
def _test(self, date):
if self.trainer.infer_one_user_func is None:
pass
logging.info("doing test...")
if self.test_sampler is None:
logging.error("self.test_sampler should not be None when testing")
sim_idx = self.role_maker.simulator_idx()
sim_num = self.role_maker.simulator_num()
user_info_dict = self.test_sampler.sample_user_list(
self.scheduler_client,
date,
sim_idx,
len(self.data_client.stub_list),
sim_num, )
if self.do_profile:
print("test user info_dict: ", user_info_dict)
global_param_dict = self.scheduler_client.get_global_params()
def divide_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i + n]
# at most 50 process for testing
chunk_size = 50
# at most 100 uid for testing
max_test_uids = 100
uid_chunks = divide_chunks(user_info_dict.keys(), chunk_size)
os.system("rm -rf _test_result")
os.system("mkdir _test_result")
tested_uids = 0
for uids in uid_chunks:
if tested_uids >= max_test_uids:
break
processes = []
for user in uids:
arg_dict = {
"uid": str(user),
"date": date,
"data_endpoints":
self.role_maker.get_data_server_endpoints(),
"global_params": global_param_dict,
"user_param_names": self.trainer.get_user_param_names(),
"global_param_names":
self.trainer.get_global_param_names(),
"infer_result_dir": "_test_result/uid-%s" % user,
}
p = Process(
target=self.trainer.infer_one_user_func,
args=(arg_dict, self.trainer.trainer_config))
p.start()
processes.append(p)
if self.do_profile:
logging.info("wait test processes to close")
for i, p in enumerate(processes):
processes[i].join()
tested_uids += chunk_size
infer_results = []
# only support one test metric now
for uid in os.listdir("_test_result"):
with open("_test_result/" + uid + "/res", 'r') as f:
sample_cout, metric = f.readlines()[0].strip('\n').split('\t')
infer_results.append((int(sample_cout), float(metric)))
if sum([x[0] for x in infer_results]) == 0:
logging.info("infer results: 0.0")
else:
count = sum([x[0] for x in infer_results])
metric = sum([x[0] * x[1] for x in infer_results]) / count
logging.info("infer results: %f" % metric)
def _save_and_upload(self, date, fs_upload_path):
if self.trainer.save_and_upload_func is None:
return
if fs_upload_path is None:
return
dfs_upload_path = fs_upload_path + date + "_" + str(
self.role_maker.simulator_idx())
global_param_dict = self.scheduler_client.get_global_params()
arg_dict = {
"date": date,
"global_params": global_param_dict,
"user_param_names": self.trainer.get_user_param_names(),
"global_param_names": self.trainer.get_global_param_names(),
}
self.trainer.save_and_upload_func(
arg_dict, self.trainer.trainer_config, dfs_upload_path)
def run_simulation(self,
base_path,
dates,
fs_upload_path=None,
sim_num_everyday=1,
do_test=False,
test_skip_day=6):
if not self.role_maker.is_simulator():
pass
data_services = self._get_data_services()
for service in data_services:
service.start()
self._barrier_simulators()
self.data_client = DataClient()
self.data_client.set_load_data_into_patch_func(
self.trainer.get_load_data_into_patch_func())
self.data_client.set_data_server_endpoints(
self.role_maker.get_data_server_endpoints())
self.scheduler_client = SchedulerClient()
self.scheduler_client.set_data_server_endpoints(
self.role_maker.get_data_server_endpoints())
self.scheduler_client.set_scheduler_server_endpoints(
[self.role_maker.get_global_scheduler_endpoint()])
logging.info("trainer config: ", self.trainer.trainer_config)
self.trainer.prepare(do_test=do_test)
if self.role_maker.simulator_idx() == 0:
self.trainer.init_global_model(self.scheduler_client)
self._barrier_simulators()
for date_idx, date in enumerate(dates):
if date_idx > 0:
self.do_profile = False
self.profile_file.close()
logging.info("reading data for date: %s" % date)
local_files = self._profile(
self.data_client.get_local_files,
base_path,
date,
self.role_maker.simulator_idx(),
self.role_maker.simulator_num(),
hdfs_configs=self.hdfs_configs)
logging.info("loading data into patch for date: %s" % date)
data_patch, local_user_dict = self._profile(
self.data_client.load_data_into_patch, local_files, 10000)
logging.info("shuffling data for date: %s" % date)
self._profile(self.data_client.global_shuffle_by_patch, data_patch,
date, 30)
logging.info("updating user inst num for date: %s" % date)
self._profile(self.scheduler_client.update_user_inst_num, date,
local_user_dict)
self.role_maker.barrier_simulator()
if do_test and date_idx != 0 and date_idx % test_skip_day == 0:
self._barrier_simulators()
self._profile(self._test, date)
self._barrier_simulators()
self._profile(self._save_and_upload, date, fs_upload_path)
self._run_sim(date, sim_num_everyday=sim_num_everyday)
self.role_maker.barrier_simulator()
logging.info("clear user data for date: %s" % date)
self.data_client.clear_user_data(date)
self._barrier_simulators()
logging.info("training done all date.")
logging.info("stoping scheduler")
self.scheduler_client.stop_scheduler_server()
for pro in data_services:
pro.terminate()
logging.info("after terminate for all server.")
# reddit_subsampled
wget https://paddle-serving.bj.bcebos.com/temporary_files_for_docker/reddit_subsampled.zip --no-check-certificate
unzip reddit_subsampled.zip
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .layer_base import Embedding
from .layer_base import SequenceConv
from .layer_base import Concat
from .layer_base import Pooling
from .layer_base import FC
from .layer_base import CrossEntropySum
from .layer_base import ACC
from .layer_base import AUC
from .data_base import TextData
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
class DataBase(object):
def __init__(self):
pass
class TextData(object):
def __init__(self):
pass
def create(self, name=None):
assert name is not None
return fluid.layers.data(
name=name, shape=[1], dtype='int64', lod_level=1)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
class LayerBase(object):
def __init__(self):
pass
class Embedding(LayerBase):
def __init__(self):
pass
def attr(self, shapes, name=None):
self.vocab_size = shapes[0]
self.emb_dim = shapes[1]
self.name = name
def forward(self, i):
is_sparse = True
if self.name is not None:
param_attr = fluid.ParamAttr(name=self.name)
else:
param_attr = None
results = []
emb = fluid.layers.embedding(
input=i,
is_sparse=is_sparse,
size=[self.vocab_size, self.emb_dim],
param_attr=param_attr,
padding_idx=0)
return emb
class SequenceConv(LayerBase):
def __init__(self):
pass
def attr(self, name=None):
self.num_filters = 64
self.win_size = 3
self.name = name
def forward(self, i):
if self.name is not None:
param_attr = fluid.ParamAttr(name=self.name)
else:
param_attr = None
conv = fluid.nets.sequence_conv_pool(
input=i,
num_filters=self.num_filters,
filter_size=self.win_size,
param_attr=param_attr,
act="tanh",
pool_type="max")
return conv
class Concat(LayerBase):
def forward(self, inputs):
concat = fluid.layers.concat(inputs, axis=1)
return concat
class Pooling(LayerBase):
def __init__(self):
self.pool_type = 'sum'
def attr(self, pool_type='sum'):
self.pool_type = pool_type
def forward(self, i):
pool = fluid.layers.sequence_pool(input=i, pool_type='sum')
return pool
class FC(LayerBase):
def __init__(self):
return
def attr(self, shapes, act='relu', name=None):
self.name = name
self.size = shapes[0]
self.act = act
def forward(self, i):
if self.name is not None:
param_attr = fluid.ParamAttr(name=self.name)
else:
param_attr = None
fc = fluid.layers.fc(input=i,
size=self.size,
act=self.act,
param_attr=param_attr)
return fc
class CrossEntropySum(LayerBase):
def __init__(self):
pass
def forward(self, prediction, label):
cost = fluid.layers.cross_entropy(input=prediction, label=label)
sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost
class ACC(LayerBase):
def __init__(self):
pass
def forward(self, prediction, label):
return fluid.layers.accuracy(input=prediction, label=label)
class AUC(LayerBase):
def forward(self, prediction, label):
auc, batch_auc_var, auc_states = \
fluid.layers.auc(input=prediction, label=label,
num_thresholds=2 ** 12,
slide_steps=20)
return auc, batch_auc_var
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .language_model import LanguageModel
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model import *
from layer import *
from .model_base import ModelBase
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid as fluid
from paddle.fluid.layers.control_flow import StaticRNN as PaddingRNN
import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.contrib.layers import basic_lstm
class LanguageModel(ModelBase):
def __init__(self):
# model args
self.hidden_size_ = 200
self.vocab_size_ = 10000
self.num_layers_ = 2
self.num_steps_ = 10 # fix
self.init_scale_ = 0.1
self.dropout_ = 0.0
self.rnn_model_ = 'basic_lstm'
self.pad_symbol_ = 0
self.unk_symbol_ = 1
# results
self.correct_ = None
self.prediction_ = None
self.loss_ = None
# private vars
self.user_params_ = []
self.program_ = None
self.startup_program_ = None
self.input_name_list_ = None
self.target_var_names_ = []
def get_model_input_names(self):
return self.input_name_list_
def get_model_loss(self):
return self.loss_
def get_model_loss_name(self):
return self.loss_.name
def get_model_metrics(self):
return {"correct": self.correct_.name}
def get_target_names(self):
return self.target_var_names_
def build_model(self, model_configs):
hidden_size = self.hidden_size_
init_scale = self.init_scale_
dropout = self.dropout_
num_layers = self.num_layers_
num_steps = self.num_steps_
pad_symbol = self.pad_symbol_
unk_symbol = self.unk_symbol_
vocab_size = self.vocab_size_
rnn_model = self.rnn_model_
x = fluid.data(name="x", shape=[None, num_steps], dtype='int64')
y = fluid.data(name="y", shape=[None, num_steps], dtype='int64')
x = layers.reshape(x, shape=[-1, num_steps, 1])
y = layers.reshape(y, shape=[-1, 1])
self.input_name_list_ = ['x', 'y']
init_hidden = layers.fill_constant_batch_size_like(
input=x,
shape=[-1, num_layers, hidden_size],
value=0,
dtype="float32")
init_cell = layers.fill_constant_batch_size_like(
input=x,
shape=[-1, num_layers, hidden_size],
value=0,
dtype="float32")
init_hidden = layers.transpose(init_hidden, perm=[1, 0, 2])
init_cell = layers.transpose(init_cell, perm=[1, 0, 2])
init_hidden_reshape = layers.reshape(
init_hidden, shape=[num_layers, -1, hidden_size])
init_cell_reshape = layers.reshape(
init_cell, shape=[num_layers, -1, hidden_size])
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='embedding_para',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
x_emb = layers.reshape(
x_emb, shape=[-1, num_steps, hidden_size], inplace=True)
if dropout != None and dropout > 0.0:
x_emb = layers.dropout(
x_emb,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
if rnn_model == "padding":
rnn_out, last_hidden, last_cell = self._padding_rnn(
x_emb,
len=num_steps,
init_hidden=init_hidden_reshape,
init_cell=init_cell_reshape)
elif rnn_model == "static":
rnn_out, last_hidden, last_cell = self._encoder_static(
x_emb,
len=num_steps,
init_hidden=init_hidden_reshape,
init_cell=init_cell_reshape)
elif rnn_model == "cudnn":
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn_out, last_hidden, last_cell = layers.lstm(
x_emb,
init_hidden_reshape,
init_cell_reshape,
num_steps,
hidden_size,
num_layers,
is_bidirec=False,
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
rnn_out = layers.transpose(rnn_out, perm=[1, 0, 2])
elif rnn_model == "basic_lstm":
rnn_out, last_hidden, last_cell = basic_lstm(
x_emb,
init_hidden,
init_cell,
hidden_size,
num_layers=num_layers,
batch_first=True,
dropout_prob=dropout,
param_attr=ParamAttr(
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(0.0)),
forget_bias=0.0)
else:
raise Exception("type not support")
rnn_out = layers.reshape(
rnn_out, shape=[-1, num_steps, hidden_size], inplace=True)
softmax_weight = layers.create_parameter(
[hidden_size, vocab_size],
dtype="float32",
name="softmax_weight",
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
softmax_bias = layers.create_parameter(
[vocab_size],
dtype="float32",
name='softmax_bias',
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
projection = layers.matmul(rnn_out, softmax_weight)
projection = layers.elementwise_add(projection, softmax_bias)
projection = layers.reshape(
projection, shape=[-1, vocab_size], inplace=True)
# correct predictions
labels_reshaped = fluid.layers.reshape(y, [-1])
pred = fluid.layers.cast(
fluid.layers.argmax(projection, 1), dtype="int64")
correct_pred = fluid.layers.cast(
fluid.layers.equal(pred, labels_reshaped), dtype="int64")
self.prediction_ = pred
self.target_var_names_.append(pred)
# predicting unknown is always considered wrong
unk_tensor = fluid.layers.fill_constant(
fluid.layers.shape(labels_reshaped),
value=unk_symbol,
dtype='int64')
pred_unk = fluid.layers.cast(
fluid.layers.equal(pred, unk_tensor), dtype="int64")
correct_unk = fluid.layers.elementwise_mul(pred_unk, correct_pred)
# predicting padding is always considered wrong
pad_tensor = fluid.layers.fill_constant(
fluid.layers.shape(labels_reshaped), value=0, dtype='int64')
pred_pad = fluid.layers.cast(
fluid.layers.equal(pred, pad_tensor), dtype="int64")
correct_pad = fluid.layers.elementwise_mul(pred_pad, correct_pred)
# acc
correct_count = fluid.layers.reduce_sum(correct_pred) \
- fluid.layers.reduce_sum(correct_unk) \
- fluid.layers.reduce_sum(correct_pad)
self.correct_ = correct_count
self.target_var_names_.append(correct_count)
loss = layers.softmax_with_cross_entropy(
logits=projection, label=y, soft_label=False)
loss = layers.reshape(loss, shape=[-1, num_steps], inplace=True)
loss = layers.reduce_mean(loss, dim=[0])
loss = layers.reduce_sum(loss)
self.loss_ = loss
self.target_var_names_.append(loss)
loss.persistable = True
# This will feed last_hidden, last_cell to init_hidden, init_cell, which
# can be used directly in next batch. This can avoid the fetching of
# last_hidden and last_cell and feeding of init_hidden and init_cell in
# each training step.
#last_hidden = layers.transpose(last_hidden, perm=[1, 0, 2])
#last_cell = layers.transpose(last_cell, perm=[1, 0, 2])
#self.input_name_list_ = ['x', 'y', 'init_hidden', 'init_cell']
self.program_ = fluid.default_main_program()
self.startup_program_ = fluid.default_startup_program()
def _padding_rnn(input_embedding, len=3, init_hidden=None, init_cell=None):
weight_1_arr = []
weight_2_arr = []
bias_arr = []
hidden_array = []
cell_array = []
mask_array = []
hidden_size = self.hidden_size_
init_scale = self.init_scale_
dropout = slef.dropout_
num_layers = self.num_layers_
num_steps = self._num_steps_
for i in range(num_layers):
weight_1 = layers.create_parameter(
[hidden_size * 2, hidden_size * 4],
dtype="float32",
name="fc_weight1_" + str(i),
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
weight_1_arr.append(weight_1)
bias_1 = layers.create_parameter(
[hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
bias_arr.append(bias_1)
pre_hidden = layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1])
pre_cell = layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = layers.reshape(pre_hidden, shape=[-1, hidden_size])
pre_cell = layers.reshape(pre_cell, shape=[-1, hidden_size])
hidden_array.append(pre_hidden)
cell_array.append(pre_cell)
input_embedding = layers.transpose(input_embedding, perm=[1, 0, 2])
rnn = PaddingRNN()
with rnn.step():
input = rnn.step_input(input_embedding)
for k in range(num_layers):
pre_hidden = rnn.memory(init=hidden_array[k])
pre_cell = rnn.memory(init=cell_array[k])
weight_1 = weight_1_arr[k]
bias = bias_arr[k]
nn = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=nn, y=weight_1)
gate_input = layers.elementwise_add(gate_input, bias)
i = layers.slice(
gate_input, axes=[1], starts=[0], ends=[hidden_size])
j = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size],
ends=[hidden_size * 2])
f = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size * 2],
ends=[hidden_size * 3])
o = layers.slice(
gate_input,
axes=[1],
starts=[hidden_size * 3],
ends=[hidden_size * 4])
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
rnn.update_memory(pre_hidden, m)
rnn.update_memory(pre_cell, c)
rnn.step_output(m)
rnn.step_output(c)
input = m
if dropout != None and dropout > 0.0:
input = layers.dropout(
input,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
rnn.step_output(input)
rnnout = rnn()
last_hidden_array = []
last_cell_array = []
real_res = rnnout[-1]
for i in range(num_layers):
m = rnnout[i * 2]
c = rnnout[i * 2 + 1]
m.stop_gradient = True
c.stop_gradient = True
last_h = layers.slice(
m, axes=[0], starts=[num_steps - 1], ends=[num_steps])
last_hidden_array.append(last_h)
last_c = layers.slice(
c, axes=[0], starts=[num_steps - 1], ends=[num_steps])
last_cell_array.append(last_c)
real_res = layers.transpose(x=real_res, perm=[1, 0, 2])
last_hidden = layers.concat(last_hidden_array, 0)
last_cell = layers.concat(last_cell_array, 0)
return real_res, last_hidden, last_cell
def _encoder_static(input_embedding,
len=3,
init_hidden=None,
init_cell=None):
weight_1_arr = []
weight_2_arr = []
bias_arr = []
hidden_array = []
cell_array = []
mask_array = []
hidden_size = self.hidden_size_
init_scale = self.init_scale_
dropout = slef.dropout_
num_layers = self.num_layers_
for i in range(num_layers):
weight_1 = layers.create_parameter(
[hidden_size * 2, hidden_size * 4],
dtype="float32",
name="fc_weight1_" + str(i),
default_initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale))
weight_1_arr.append(weight_1)
bias_1 = layers.create_parameter(
[hidden_size * 4],
dtype="float32",
name="fc_bias1_" + str(i),
default_initializer=fluid.initializer.Constant(0.0))
bias_arr.append(bias_1)
pre_hidden = layers.slice(
init_hidden, axes=[0], starts=[i], ends=[i + 1])
pre_cell = layers.slice(
init_cell, axes=[0], starts=[i], ends=[i + 1])
pre_hidden = layers.reshape(
pre_hidden, shape=[-1, hidden_size], inplace=True)
pre_cell = layers.reshape(
pre_cell, shape=[-1, hidden_size], inplace=True)
hidden_array.append(pre_hidden)
cell_array.append(pre_cell)
res = []
sliced_inputs = layers.split(
input_embedding, num_or_sections=len, dim=1)
for index in range(len):
input = sliced_inputs[index]
input = layers.reshape(
input, shape=[-1, hidden_size], inplace=True)
for k in range(num_layers):
pre_hidden = hidden_array[k]
pre_cell = cell_array[k]
weight_1 = weight_1_arr[k]
bias = bias_arr[k]
nn = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=nn, y=weight_1)
gate_input = layers.elementwise_add(gate_input, bias)
i, j, f, o = layers.split(
gate_input, num_or_sections=4, dim=-1)
c = pre_cell * layers.sigmoid(f) + layers.sigmoid(
i) * layers.tanh(j)
m = layers.tanh(c) * layers.sigmoid(o)
hidden_array[k] = m
cell_array[k] = c
input = m
if dropout != None and dropout > 0.0:
input = layers.dropout(
input,
dropout_prob=dropout,
dropout_implementation='upscale_in_train')
res.append(input)
last_hidden = layers.concat(hidden_array, 1)
last_hidden = layers.reshape(
last_hidden, shape=[-1, num_layers, hidden_size], inplace=True)
last_hidden = layers.transpose(x=last_hidden, perm=[1, 0, 2])
last_cell = layers.concat(cell_array, 1)
last_cell = layers.reshape(
last_cell, shape=[-1, num_layers, hidden_size])
last_cell = layers.transpose(x=last_cell, perm=[1, 0, 2])
real_res = layers.concat(res, 0)
real_res = layers.reshape(
real_res, shape=[len, -1, hidden_size], inplace=True)
real_res = layers.transpose(x=real_res, perm=[1, 0, 2])
return real_res, last_hidden, last_cell
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import numpy as np
def set_user_param_dict(param_names, param_dict, scope):
place = fluid.CPUPlace()
for var_name in param_names:
param = scope.find_var(var_name)
if param is None:
print("var name: {} does not exist in memory".format(var_name))
continue
param.get_tensor().set(param_dict[var_name], place)
return
def set_global_param_dict(param_names, param_dict, scope):
place = fluid.CPUPlace()
for var_name in param_names:
param = scope.find_var(var_name)
if param is None:
print("var name: {} does not exist in memory".format(var_name))
continue
if var_name not in param_dict:
print("var name: {} does not exist in global param dict".format(
var_name))
exit()
var_numpy = param_dict[var_name]
param.get_tensor().set(var_numpy, place)
return
class ModelBase(object):
def __init__(self):
pass
def init_model(self):
pass
def build_model(self, model_configs):
pass
def get_model_inputs(self):
pass
def get_model_loss(self):
pass
def get_model_metrics(self):
pass
def get_startup_program(self):
pass
def get_main_program(self):
pass
def get_user_param_dict(self):
param_dict = {}
scope = fluid.global_scope()
for var_pair in self.get_user_param_names():
param = scope.find_var(var_pair[0])
if param is None:
print("var name: {} does not exist in memory".format(var_pair[
0]))
continue
var = param.get_tensor().__array__()
param_dict[var_pair[0]] = [var, var_pair[1].shape]
return param_dict
def get_global_param_dict(self):
param_dict = {}
scope = fluid.global_scope()
for var_pair in self.get_global_param_names():
param = scope.find_var(var_pair[0])
if param is None:
print("var name: {} does not exist in memory".format(var_pair[
0]))
continue
var = param.get_tensor().__array__()
param_dict[var_pair[0]] = var
return param_dict
def get_user_param_names(self):
user_params = []
for var_name, var in self.startup_program_.global_block().vars.items():
if var.persistable and "@USER" in var_name and \
"learning_rate" not in var_name:
user_params.append((var_name, var))
return user_params
def get_global_param_names(self):
global_params = []
for var_name, var in self.startup_program_.global_block().vars.items():
if var.persistable and "@USER" not in var_name and \
"learning_rate" not in var_name and "generated_var" not in var_name:
global_params.append((var_name, var))
return global_params
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fed_avg_optimizer import FedAvgOptimizer
from .fed_avg_optimizer import SumOptimizer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optimizer_base import OptimizerBase
class FedAvgOptimizer(OptimizerBase):
def __init__(self, learning_rate=1.0):
self.learning_rate = learning_rate
pass
def update(self, user_info, new_global_param_by_user, old_global_param,
scheduler_client):
total_weight = 0.0
for uid in user_info:
total_weight += float(user_info[uid])
update_dict = {}
for key in new_global_param_by_user:
uid = key
uid_global_w = new_global_param_by_user[key]
weight = float(user_info[key]) / total_weight
for param_name in uid_global_w:
if param_name in update_dict:
update_dict[param_name] += \
self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name])
else:
update_dict[param_name] = \
self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name])
scheduler_client.fedavg_update(update_dict)
class SumOptimizer(OptimizerBase):
def __init__(self, learning_rate=1.0):
self.learning_rate = learning_rate
pass
def update(self, user_info, new_global_param_by_user, old_global_param,
scheduler_client):
update_dict = {}
for key in new_global_param_by_user:
uid = key
uid_global_w = new_global_param_by_user[key]
weight = 1.0
for param_name in uid_global_w:
if param_name in update_dict:
update_dict[param_name] += \
self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name])
else:
update_dict[param_name] = \
self.learning_rate * weight * (uid_global_w[param_name] - old_global_param[param_name])
scheduler_client.fedavg_update(update_dict)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class OptimizerBase(object):
def __init__(self, role_maker):
pass
def update(self, new_global_param_by_user, old_global_param,
scheduler_client):
pass
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package server;
message Data {
string date = 1;
string uid = 2;
string line = 3;
}
message Res {
int32 err_code = 1;
int32 user_num = 2;
}
message UserData {
repeated string line_str = 1;
int32 err_code = 2;
}
message Param {
string name = 1;
repeated float weight = 2;
repeated int32 shape = 3;
}
message UserParams {
repeated Param user_params = 1;
int32 err_code = 2;
string uid = 3;
}
service DataServer {
// (Method definitions not shown)
rpc SendData(Data) returns (Res) {}
rpc GetUserData(Data) returns (UserData) {}
rpc ClearUserData(Data) returns (Res) {}
rpc GetUserParams(Data) returns (UserParams) {}
rpc UpdateUserParams(UserParams) returns (Res) {}
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2015 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs protoc with the gRPC plugin to generate messages and gRPC stubs."""
from grpc_tools import protoc
#python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. data_server.proto
protoc.main((
'',
'-I.',
'--python_out=.',
'--grpc_python_out=.',
'data_server.proto', ))
protoc.main((
'',
'-I.',
'--python_out=.',
'--grpc_python_out=.',
'scheduler_server.proto', ))
import os
os.system("mv *pb2.py ../servers/")
os.system("mv *pb2_grpc.py ../servers/")
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package scheduler;
message Request {
int32 node_idx = 1;
int32 sample_num = 2;
int32 shard_num = 3;
int32 node_num = 4;
int32 min_ins_num = 5;
string date = 6;
}
message UserList {
repeated string uids = 1;
repeated float weight = 2;
}
message Param {
string name = 1;
repeated float weight = 2;
repeated int32 shape = 3;
}
message GlobalParams {
repeated Param global_params = 1;
int32 err_code = 2;
}
message UserInstNum {
string uid = 1;
int32 inst_num = 2;
}
message UserInstInfo {
repeated UserInstNum inst_nums = 1;
int32 shard_num = 2;
string date = 3;
}
message Res { int32 err_code = 1; }
message SchedulerServerEmptyInput {}
service SchedulerServer {
// (Method definitions not shown)
rpc SampleUsersToTrain(Request) returns (UserInstInfo) {}
rpc FixedUsersToTrain(Request) returns (UserInstInfo) {}
rpc GetGlobalParams(Request) returns (GlobalParams) {}
rpc UpdateGlobalParams(GlobalParams) returns (Res) {}
rpc FedAvgUpdate(GlobalParams) returns (Res) {}
rpc UpdateUserInstNum(UserInstInfo) returns (Res) {}
rpc Exit(SchedulerServerEmptyInput) returns (Res) {}
rpc SampleUsersToTest(Request) returns (UserInstInfo) {}
rpc SampleUsersWithHash(Request) returns (UserInstInfo) {}
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import os
import sys
import numpy as np
from utils.logger import logging
import json
PAD_SYMBOL, UNK_SYMBOL = 0, 1
DATA_PATH = "lm_data"
VOCAB_PATH = os.path.join(DATA_PATH, "vocab.json")
TRAIN_DATA_PATH = os.path.join(DATA_PATH, "20200101", "train_data.json")
VOCAB = None
def build_counter(train_data):
train_tokens = []
for u in train_data:
for c in train_data[u]['x']:
train_tokens.extend([s for s in c])
all_tokens = []
for i in train_tokens:
all_tokens.extend(i)
train_tokens = []
counter = collections.Counter()
counter.update(all_tokens)
all_tokens = []
return counter
def build_vocab(filename, vocab_size=10000):
train_data = {}
with open(filename) as json_file:
data = json.load(json_file)
train_data = data['user_data']
counter = build_counter(train_data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
# -2 to account for the unknown and pad symbols
count_pairs = count_pairs[:(vocab_size - 2)]
words, _ = list(zip(*count_pairs))
vocab = {}
vocab['<PAD>'] = PAD_SYMBOL
vocab['<UNK>'] = UNK_SYMBOL
for i, w in enumerate(words):
if w != '<PAD>':
vocab[w] = i + 1
return {
'vocab': vocab,
'size': vocab_size,
'unk_symbol': vocab['<UNK>'],
'pad_symbol': vocab['<PAD>']
}
def save_vocab(filename, vocab):
with open(filename, "w") as f:
f.write(json.dumps(vocab))
def load_vocab(filename):
with open(filename) as f:
return json.loads(f.read())
if os.path.exists(VOCAB_PATH):
logging.info("load vocab form: {}".format(VOCAB_PATH))
VOCAB = load_vocab(VOCAB_PATH)
else:
#TODO: singleton
logging.info("build vocab form: {}".format(TRAIN_DATA_PATH))
VOCAB = build_vocab(TRAIN_DATA_PATH)
logging.info("save vocab into: {}".format(VOCAB_PATH))
save_vocab(VOCAB_PATH, VOCAB)
if VOCAB is None:
logging.error("load vocab error")
raise Exception("load vocab error")
def train_reader(lines):
def local_iter():
seg_id = 0
for line in lines:
assert (len(line.split("\t")) == 3)
uid, _, input_str = line.split("\t")
data = json.loads(input_str)
data_x = data["x"]
data_y = data["y"]
data_mask = data["mask"]
input_data, input_length = process_x(data_x, VOCAB)
target_data = process_y(data_y, VOCAB)
yield [input_data] + [target_data]
return local_iter
def infer_reader(lines):
return train_reader(lines)
def load_data_into_patch(filelist, patch_size):
data_patch = []
idx = 0
local_user_dict = {}
for fn in filelist:
tmp_list = []
with open(fn) as fin:
raw_data = json.loads(fin.read())["user_data"]
local_user_dict = {k: 0 for k in raw_data.keys()}
for user, data in raw_data.items():
data_x = data["x"]
data_y = data["y"]
for c, l in zip(data_x, data_y):
for inst_i in range(len(c)):
local_user_dict[user] += 1
idx += 1
inst = {
"x": c[inst_i],
"y": l["target_tokens"][inst_i],
"mask": l["count_tokens"][inst_i]
}
line = "{}\t\t{}".format(user, json.dumps(inst))
if idx % patch_size == 0:
data_patch.append(tmp_list)
tmp_list = [line]
else:
tmp_list.append(line)
if len(tmp_list) > 0:
data_patch.append(tmp_list)
return data_patch, local_user_dict
def tokens_to_ids(tokens, vocab):
to_ret = [vocab.get(word, vocab["<UNK>"]) for word in tokens]
return np.array(to_ret, dtype="int64")
def process_x(raw_x, vocab):
tokens = tokens_to_ids(raw_x, vocab["vocab"])
lengths = np.sum(tokens != vocab["pad_symbol"])
return tokens, lengths
def process_y(raw_y, vocab):
tokens = tokens_to_ids(raw_y, vocab["vocab"])
return tokens
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .uniform_sampler import UniformSampler
from .fixed_sampler import FixedSampler
from .sampler_test_user import Test1wSampler
from .sampler_test_user import Test1percentSampler
from .hash_sampler import HashSampler
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sampler_base import SamplerBase
class FixedSampler(SamplerBase):
def __init__(self):
self.sample_num = 100
def set_sample_num(self, sample_num):
self.sample_num = sample_num
def sample_user_list(self, scheduler_client, date, sim_idx, shard_num,
sim_num):
user_info = scheduler_client.fixed_sample_user_list(
date, sim_idx, self.sample_num, shard_num, sim_num)
return user_info
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sampler_base import SamplerBase
class HashSampler(SamplerBase):
def __init__(self):
self.sample_num = 100
def set_sample_num(self, sample_num):
self.sample_num = sample_num
def sample_user_list(self, scheduler_client, date, sim_idx, shard_num,
sim_num):
user_info = scheduler_client.hash_sample_user_list(
date, sim_idx, self.sample_num, shard_num, sim_num)
return user_info
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class SamplerBase(object):
def __init__(self, role_maker):
pass
def sample_user_list(self, scheduler_client, sim_idx, shard_num, sim_num,
min_ins_num):
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sampler_base import SamplerBase
class Test1wSampler(SamplerBase):
def __init__(self):
pass
def sample_user_list(self, scheduler_client, date, sim_idx, shard_num,
sim_num):
user_info = scheduler_client.sample_test_user_list(date, sim_idx,
shard_num, sim_num)
return user_info
class Test1percentSampler(SamplerBase):
def __init__(self):
pass
def sample_user_list(self, scheduler_client, date, sim_idx, shard_num,
sim_num):
user_info = scheduler_client.sample_test_user_list(date, sim_idx,
shard_num, sim_num)
return user_info
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .sampler_base import SamplerBase
class UniformSampler(SamplerBase):
def __init__(self):
self.sample_num = 100
self.min_ins_num = 1
def set_sample_num(self, sample_num):
self.sample_num = sample_num
def set_min_ins_num(self, min_ins_num):
self.min_ins_num = min_ins_num
def sample_user_list(self, scheduler_client, date, sim_idx, shard_num,
sim_num):
user_info = scheduler_client.uniform_sample_user_list(
date, sim_idx, self.sample_num, shard_num, sim_num,
self.min_ins_num)
return user_info
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .data_server_impl import DataServerServicer
from .data_server_impl import DataServer
from .scheduler_server_impl import SchedulerServerServicer
from .scheduler_server_impl import SchedulerServer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from concurrent import futures
import data_server_pb2
import data_server_pb2_grpc
import grpc
import sys
from utils.logger import logging
class DataServerServicer(object):
def __init__(self):
self.data_dict = {}
self.param_dict = {}
self.request_num = 0
def GetUserParams(self, request, context):
uid = unicode(request.uid)
if uid in self.param_dict:
return self.param_dict[uid]
else:
user_param = data_server_pb2.UserParams()
user_param.err_code = -1
return user_param
def UpdateUserParams(self, request, context):
self.param_dict[request.uid] = request
res = data_server_pb2.Res()
res.err_code = 0
return res
def ClearUserData(self, request, context):
date = request.date
self.data_dict[date].clear()
res = data_server_pb2.Res()
res.err_code = 0
return res
def GetUserData(self, request, context):
uid = unicode(request.uid)
date = unicode(request.date)
if date not in self.data_dict:
user_data = data_server_pb2.UserData()
user_data.err_code = -1
return user_data
if uid in self.data_dict[date]:
self.data_dict[date][uid].err_code = 0
return self.data_dict[date][uid]
else:
user_data = data_server_pb2.UserData()
user_data.err_code = -1
self.data_dict[date][uid] = user_data
return self.data_dict[date][uid]
user_data = data_server_pb2.UserData()
user_data.err_code = -1
return user_data
def SendData(self, request, context):
date = unicode(request.date)
uid = unicode(request.uid)
if date in self.data_dict:
if uid in self.data_dict[request.date]:
self.data_dict[date][uid].line_str.extend([request.line])
else:
user_data = data_server_pb2.UserData()
self.data_dict[date][uid] = user_data
self.data_dict[date][uid].line_str.extend([request.line])
else:
user_data = data_server_pb2.UserData()
self.data_dict[date] = {}
self.data_dict[date][uid] = user_data
self.data_dict[date][uid].line_str.extend([request.line])
res = data_server_pb2.Res()
res.err_code = 0
return res
class DataServer(object):
def __init__(self):
pass
def start(self, max_workers=1000, concurrency=100, endpoint=""):
if endpoint == "":
logging.error("You should specify endpoint in start function")
return
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
options=[('grpc.max_send_message_length', 1024 * 1024 * 1024),
('grpc.max_receive_message_length', 1024 * 1024 * 1024)],
maximum_concurrent_rpcs=concurrency)
data_server_pb2_grpc.add_DataServerServicer_to_server(
DataServerServicer(), server)
server.add_insecure_port('[::]:{}'.format(endpoint))
server.start()
server.wait_for_termination()
if __name__ == "__main__":
data_server = DataServer()
data_server.start(endpoint=sys.argv[1])
# -*- coding: utf-8 -*-
from __future__ import print_function
from concurrent import futures
import scheduler_server_pb2
import scheduler_server_pb2_grpc
import grpc
import numpy as np
import random
import sys, os
import time
import xxhash
from utils.logger import logging
class SchedulerServerServicer(object):
def __init__(self):
self.global_param_dict = {}
self.uid_inst_num_dict = {}
self.shard_id_dict = {}
def uid_shard(self, uid, shard_num):
try:
# print("uid_shard uid: %s" % uid)
uid_hash = xxhash.xxh32(str(uid), seed=101).intdigest()
except:
return -1
shard_idx = uid_hash % shard_num
return shard_idx
def is_test_uid(self, uid):
return xxhash.xxh32(str(uid), seed=222).intdigest() % 100 == 3
# we suppose shard num will not be changed during one training job
# but can be changed with another job
# so we send shard num every time we update user inst num
def UpdateUserInstNum(self, request, context):
shard_num = request.shard_num
date = request.date
if date not in self.uid_inst_num_dict:
self.uid_inst_num_dict[date] = {}
if date not in self.shard_id_dict:
self.shard_id_dict[date] = {}
for user in request.inst_nums:
shard_id = self.uid_shard(user.uid, shard_num)
if shard_id == -1:
logging.info("UpdateUserInstNum continue")
continue
if user.uid in self.uid_inst_num_dict[date]:
self.uid_inst_num_dict[date][user.uid] += user.inst_num
else:
self.uid_inst_num_dict[date][user.uid] = user.inst_num
if shard_id not in self.shard_id_dict[date]:
self.shard_id_dict[date][shard_id] = [user.uid]
else:
self.shard_id_dict[date][shard_id].append(user.uid)
res = scheduler_server_pb2.Res()
res.err_code = 0
return res
'''
SampleUsersToTrain:
request.node_idx: from which worker node the request is from
request.sample_num: how many users do we need to sample
request.shard_num: total shard number of this task
request.node_num: total number of training node
'''
def SampleUsersToTrain(self, request, context):
node_idx = request.node_idx
sample_num = request.sample_num
shard_num = request.shard_num
node_num = request.node_num
date = request.date
shard_per_node = shard_num / node_num
begin_idx = node_idx * shard_per_node
min_ins_num = request.min_ins_num
uid_list = []
i = 0
while i < sample_num:
shard_idx = begin_idx + random.randint(0, shard_per_node)
if shard_idx not in self.shard_id_dict[date]:
continue
sample_idx = random.randint(
0, len(self.shard_id_dict[date][shard_idx]) - 1)
uid = self.shard_id_dict[date][shard_idx][sample_idx]
if self.uid_inst_num_dict[date][uid] < min_ins_num:
continue
uid_list.append(uid)
i += 1
info = scheduler_server_pb2.UserInstInfo()
for uid in uid_list:
inst_num = scheduler_server_pb2.UserInstNum()
inst_num.uid = uid
inst_num.inst_num = self.uid_inst_num_dict[date][uid]
info.inst_nums.extend([inst_num])
return info
def SampleUsersWithHash(self, request, context):
node_idx = request.node_idx
sample_num = request.sample_num
shard_num = request.shard_num
node_num = request.node_num
date = request.date
shard_per_node = shard_num / node_num
begin_idx = node_idx * shard_per_node
uid_list = []
i = 0
while i < sample_num:
shard_idx = begin_idx + random.randint(0, shard_per_node)
if shard_idx not in self.shard_id_dict[date]:
continue
sample_idx = random.randint(
0, len(self.shard_id_dict[date][shard_idx]) - 1)
uid = self.shard_id_dict[date][shard_idx][sample_idx]
if not self.is_test_uid(uid):
continue
uid_list.append(uid)
i += 1
info = scheduler_server_pb2.UserInstInfo()
for uid in uid_list:
inst_num = scheduler_server_pb2.UserInstNum()
inst_num.uid = uid
inst_num.inst_num = self.uid_inst_num_dict[date][uid]
info.inst_nums.extend([inst_num])
return info
def SampleUsersToTest(self, request, context):
node_idx = request.node_idx
shard_num = request.shard_num
node_num = request.node_num
date = request.date
shard_per_node = shard_num / node_num
shard_begin_idx = node_idx * shard_per_node
shard_end_idx = (node_idx + 1) * shard_per_node
uid_list = []
for shard_idx in range(shard_begin_idx, shard_end_idx):
for uid in self.shard_id_dict[date][shard_idx]:
if self.is_test_uid(uid):
uid_list.append(uid)
info = scheduler_server_pb2.UserInstInfo()
for uid in uid_list:
inst_num = scheduler_server_pb2.UserInstNum()
inst_num.uid = uid
inst_num.inst_num = self.uid_inst_num_dict[date][uid]
info.inst_nums.extend([inst_num])
return info
def FixedUsersToTrain(self, request, context):
node_idx = request.node_idx
sample_num = request.sample_num
shard_num = request.shard_num
node_num = request.node_num
date = request.date
begin_idx = node_idx * shard_num
shard_per_node = shard_num / node_num
uid_list = []
i = 0
assert (sample_num <= 100)
with open("data/test_user_100.txt") as f:
for line in f.readlines():
uid_list.append(line.strip())
# uid_list.extend(["somebody", "nobody"])
info = scheduler_server_pb2.UserInstInfo()
for uid in uid_list[:sample_num]:
inst_num = scheduler_server_pb2.UserInstNum()
inst_num.uid = uid
inst_num.inst_num = self.uid_inst_num_dict[date][uid]
info.inst_nums.extend([inst_num])
return info
def GetGlobalParams(self, request, context):
if self.global_param_dict == {}:
logging.debug("global param has not been initialized")
return
# logging.info("node {} is asking for global params".format(request.node_idx))
global_param_pb = scheduler_server_pb2.GlobalParams()
for key in self.global_param_dict:
param = scheduler_server_pb2.Param()
param.name = key
var, shape = self.global_param_dict[key]
param.weight.extend(var)
# print("GetGlobalParams before param.shape")
param.shape.extend(shape)
# print("GetGlobalParams after param.shape")
global_param_pb.global_params.extend([param])
# print("finish GetGlobalParams")
return global_param_pb
def UpdateGlobalParams(self, request, context):
for param in request.global_params:
self.global_param_dict[param.name] = [param.weight, param.shape]
res = scheduler_server_pb2.Res()
res.err_code = 0
return res
def FedAvgUpdate(self, request, context):
for param in request.global_params:
old_param, shape = self.global_param_dict[param.name]
for idx, item in enumerate(old_param):
old_param[idx] += param.weight[idx]
res = scheduler_server_pb2.Res()
res.err_code = 0
return res
def Exit(self, request, context):
with open("_shutdown_scheduler", "w") as f:
f.write("_shutdown_scheduler\n")
res = scheduler_server_pb2.Res()
res.err_code = 0
return res
class SchedulerServer(object):
def __init__(self):
pass
def start(self, max_workers=1000, concurrency=100, endpoint=""):
if endpoint == "":
logging.info("You should specify endpoint in start function")
return
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=max_workers),
options=[('grpc.max_send_message_length', 1024 * 1024 * 1024),
('grpc.max_receive_message_length', 1024 * 1024 * 1024)],
maximum_concurrent_rpcs=concurrency)
scheduler_server_pb2_grpc.add_SchedulerServerServicer_to_server(
SchedulerServerServicer(), server)
# print("SchedulerServer add endpoint: ", '[::]:{}'.format(endpoint))
server.add_insecure_port('[::]:{}'.format(endpoint))
server.start()
logging.info("server started")
os.system("rm _shutdown_scheduler")
while (not os.path.isfile("_shutdown_scheduler")):
time.sleep(10)
if __name__ == "__main__":
scheduler_server = SchedulerServer()
scheduler_server.start(endpoint=60001)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .language_model_trainer import LanguageModelTrainer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .trainer_base import TrainerBase
from model import LanguageModel
from clients import DataClient
import paddle.fluid as fluid
from utils.hdfs_utils import multi_upload, HDFSClient
import reader.leaf_reddit_reader as reader
from utils.logger import logging
from itertools import groupby
import numpy as np
import random
import paddle
import pickle
import os
from model.model_base import set_user_param_dict
from model.model_base import set_global_param_dict
def train_one_user(arg_dict, trainer_config):
show_metric = trainer_config["show_metric"]
shuffle = trainer_config["shuffle"]
max_training_steps = trainer_config["max_training_steps"]
batch_size = trainer_config["batch_size"]
# logging.info("training one user...")
main_program = fluid.Program.parse_from_string(trainer_config[
"main_program_desc"])
startup_program = fluid.Program.parse_from_string(trainer_config[
"startup_program_desc"])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
if (startup_program is None):
logging.error("startup_program is None")
exit()
exe.run(startup_program)
feeder = fluid.DataFeeder(
feed_list=trainer_config["input_names"],
place=place,
program=main_program)
data_server_endpoints = arg_dict["data_endpoints"]
# create data clients
data_client = DataClient()
data_client.set_data_server_endpoints(data_server_endpoints)
uid = arg_dict["uid"]
date = arg_dict["date"]
global_param_dict = arg_dict["global_params"]
user_data = data_client.get_data_by_uid(uid, date)
train_reader = reader.train_reader(user_data)
if shuffle == True:
train_reader = paddle.reader.shuffle(train_reader, buf_size=10000)
train_reader = paddle.batch(train_reader, batch_size=batch_size)
# get user param
# logging.debug("do not need to get user params")
set_global_param_dict(arg_dict["global_param_names"],
arg_dict["global_params"], scope)
if (main_program is None):
logging.error("main_program is None")
exit()
epoch = trainer_config["epoch"]
max_steps_in_epoch = trainer_config.get("max_steps_in_epoch", -1)
metrics = trainer_config["metrics"]
metric_keys = metrics.keys()
fetch_list = [main_program.global_block().var(trainer_config["loss_name"])]
for key in metric_keys:
fetch_list.append(main_program.global_block().var(metrics[key]))
seq_len = 10
for ei in range(epoch):
trained_sample_num = 0
step = 0
fetch_res_list = []
total_loss = 0.0
total_correct = 0
for data in train_reader():
fetch_res = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
step += 1
trained_sample_num += len(data)
fetch_res_list.append([x[0] for x in fetch_res])
if max_steps_in_epoch != -1 and step >= max_steps_in_epoch:
break
if show_metric and trained_sample_num > 0:
loss = sum([x[0] for x in fetch_res_list]) / trained_sample_num
print("loss: {}, ppl: {}".format(loss, np.exp(loss)))
for i, key in enumerate(metric_keys):
if key == "correct":
value = float(sum([x[i + 1] for x in fetch_res_list
])) / trained_sample_num
print("correct: {}".format(value / seq_len))
local_updated_param_dict = {}
# update user param
# logging.debug("do not need to update user params")
data_client.set_param_by_uid(uid, local_updated_param_dict)
# global_updated_param_dict = {}
write_global_param_file = arg_dict["write_global_param_file"]
#os.makedirs("%s/params" % write_global_param_file)
for var_name in arg_dict["global_param_names"]:
var = scope.var(var_name).get_tensor().__array__().astype(np.float32)
filename = os.path.join(write_global_param_file, "params", var_name)
#logging.info("filename: {}".format(filename))
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(filename, "w") as f:
np.save(f, var)
with open("%s/_info" % write_global_param_file, "w") as f:
pickle.dump([uid, trained_sample_num], file=f)
def infer_one_user(arg_dict, trainer_config):
"""
infer a model with global_param and user params
input:
global_param
user_params
infer_program
user_data
output:
[sample_cout, top1]
"""
# run startup program, set params
uid = arg_dict["uid"]
batch_size = trainer_config["batch_size"]
startup_program = fluid.Program.parse_from_string(trainer_config[
"startup_program_desc"])
infer_program = fluid.Program.parse_from_string(trainer_config[
"infer_program_desc"])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
if (startup_program is None):
logging.error("startup_program is None")
exit()
if (infer_program is None):
logging.error("infer_program is None")
exit()
exe.run(startup_program)
data_client = DataClient()
data_client.set_data_server_endpoints(arg_dict["data_endpoints"])
# get user param
# logging.debug("do not need to get user params")
set_global_param_dict(arg_dict["global_param_names"],
arg_dict["global_params"], scope)
# reader
date = arg_dict["date"]
global_param_dict = arg_dict["global_params"]
user_data = data_client.get_data_by_uid(uid, date)
infer_reader = reader.infer_reader(user_data)
infer_reader = paddle.batch(infer_reader, batch_size=batch_size)
# run infer program
os.mkdir(arg_dict["infer_result_dir"])
#pred_file = open(arg_dict["infer_result_dir"] + '/' + "pred_file", "w")
feeder = fluid.DataFeeder(
feed_list=trainer_config["input_names"],
place=place,
program=infer_program)
fetch_list = trainer_config["target_names"]
#logging.info("fetch_list: {}".format(fetch_list))
fetch_res = []
sample_count = 0
total_loss = 0.0
total_correct = 0
iters = 0
steps = 0
seq_len = 10
for data in infer_reader():
# feed_data = [x["features"] + [x["label"]] for x in data]
# prediction, acc_val= exe.run(infer_program,
pred, correct_count, loss = exe.run(infer_program,
feed=feeder.feed(data),
fetch_list=fetch_list)
total_loss += loss
total_correct += correct_count
steps += 1
sample_count += len(data)
correct = float(total_correct) / (seq_len * sample_count)
# logging.info("correct: {}".format(correct))
with open(arg_dict["infer_result_dir"] + "/res", "w") as f:
f.write("%d\t%f\n" % (1, correct))
def save_and_upload(arg_dict, trainer_config, dfs_upload_path):
logging.info("do not save and upload...")
return
def evaluate_a_group(group):
group_list = []
for label, pred, _ in group:
# print("%s\t%s\n" % (label, pred))
group_list.append((int(label), float(pred)))
random.shuffle(group_list)
labels = [x[0] for x in group_list]
preds = [x[1] for x in group_list]
true_res = labels.index(1) if 1 in labels else -1
pred_res = preds.index(max(preds))
if pred_res == true_res:
return 1
else:
return 0
class LanguageModelTrainer(TrainerBase):
"""
LanguageModelTrainer only support training with PaddlePaddle
"""
def __init__(self):
super(LanguageModelTrainer, self).__init__()
self.main_program_ = fluid.Program()
self.startup_program_ = fluid.Program()
self.infer_program_ = fluid.Program()
self.main_program_desc_ = ""
self.startup_program_desc_ = ""
self.infer_program_desc_ = ""
self.train_one_user_func = train_one_user
self.infer_one_user_func = infer_one_user
self.save_and_upload_func = save_and_upload
self.input_model_ = None
def get_load_data_into_patch_func(self):
return reader.load_data_into_patch
def prepare(self, do_test=False):
self.generate_program_desc(do_test)
pass
def get_user_param_names(self):
# return [x[0] for x in self.input_model_.get_user_param_names()]
pass
def get_global_param_names(self):
return [x[0] for x in self.input_model_.get_global_param_names()]
def generate_program_desc(self, do_test=False):
"""
generate the paddle program desc
"""
with fluid.program_guard(self.main_program_, self.startup_program_):
self.input_model_ = LanguageModel()
model_configs = {}
self.input_model_.build_model(model_configs)
optimizer = fluid.optimizer.SGD(
learning_rate=self.trainer_config["lr"])
optimizer.minimize(self.input_model_.get_model_loss())
self.main_program_desc_ = self.main_program_.desc.serialize_to_string()
self.startup_program_desc_ = self.startup_program_.desc.serialize_to_string(
)
self.update_trainer_configs("loss_name",
self.input_model_.get_model_loss_name())
self.update_trainer_configs(
"input_names",
self.input_model_.get_model_input_names(), )
self.update_trainer_configs(
"target_names",
self.input_model_.get_target_names(), )
self.update_trainer_configs(
"metrics",
self.input_model_.get_model_metrics(), )
self.update_trainer_configs("show_metric", True)
self.update_trainer_configs("max_training_steps", "inf")
self.update_trainer_configs("shuffle", False)
self.update_trainer_configs("main_program_desc",
self.main_program_desc_)
self.update_trainer_configs("startup_program_desc",
self.startup_program_desc_)
if do_test:
input_names = self.input_model_.get_model_input_names()
target_var_names = self.input_model_.get_target_names()
self.infer_program_ = self.main_program_._prune_with_input(
feeded_var_names=input_names, targets=target_var_names)
self.infer_program_ = self.infer_program_._inference_optimize(
prune_read_op=True)
fluid.io.prepend_feed_ops(self.infer_program_, input_names)
fluid.io.append_fetch_ops(self.infer_program_, target_var_names)
self.infer_program_.desc._set_version()
fluid.core.save_op_compatible_info(self.infer_program_.desc)
self.infer_program_desc_ = self.infer_program_.desc.serialize_to_string(
)
self.update_trainer_configs("infer_program_desc",
self.infer_program_desc_)
def init_global_model(self, scheduler_client):
logging.info("initializing global model")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(self.startup_program_)
logging.info("finish initializing global model")
global_param_dict = self.input_model_.get_global_param_dict()
scheduler_client.update_global_params(global_param_dict)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainerBase(object):
def __init__(self):
self.trainer_config = None
self.train_one_user_func = None
self.infer_one_user_func = None
self.save_and_upload_func = None
def get_user_param_names(self):
pass
def get_global_param_names(self):
pass
def set_trainer_configs(self, trainer_configs):
"""
config training parameter, only support the basic types of python
"""
self.trainer_config = trainer_configs
def update_trainer_configs(self, key, val):
self.trainer_config[key] = val
def prepare(self):
"""
generate network description string;
"""
pass
def init_global_model(self, scheduler_client):
"""
initialize the network parameters, which will be broadcasted to all simulator
"""
pass
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .role_maker import FLSimRoleMaker
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import logging
logging.basicConfig(
stream=sys.stdout,
format='%(asctime)s %(filename)s : %(levelname)s %(message)s',
level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from utils.logger import logging
class FLSimRoleMaker(object):
def __init__(self):
from mpi4py import MPI
self.MPI = MPI
self._comm = MPI.COMM_WORLD
def init_env(self, local_shard_num=1):
if (local_shard_num > 250):
logging.error("shard num must be less than or equal to 250")
exit()
self.free_ports = self._get_free_endpoints(local_shard_num, 40000)
import socket
ip = socket.gethostbyname(socket.gethostname())
self.free_endpoints = ["{}:{}".format(ip, x) for x in self.free_ports]
self._comm.barrier()
self.node_type = 1
if self._comm.Get_rank() == 0:
self.node_type = 0
self.group_comm = self._comm.Split(self.node_type)
self.all_endpoints = self._comm.allgather(self.free_endpoints)
def simulator_num(self):
return self._comm.Get_size() - 1
def simulator_idx(self):
return self._comm.Get_rank() - 1
def get_global_scheduler_endpoint(self):
return self.all_endpoints[0][0]
def get_data_server_endpoints(self):
#return self.all_endpoints[2:][::2]
all_endpoints = []
for eps in self.all_endpoints[1:]:
all_endpoints.extend(eps[:len(eps) / 2])
return all_endpoints
def get_local_data_server_endpoint(self):
if self._comm.Get_rank() < 1:
return None
local_endpoints = self.all_endpoints[self._comm.Get_rank()]
return local_endpoints[:len(local_endpoints) / 2]
def get_local_param_server_endpoint(self):
if self._comm.Get_rank() < 1:
return None
local_endpoints = self.all_endpoints[self._comm.Get_rank()]
return local_endpoints[len(local_endpoints) / 2:]
def is_global_scheduler(self):
rank = self._comm.Get_rank()
return rank == 0
def is_simulator(self):
return self._comm.Get_rank() > 0
def barrier_simulator(self):
if self._comm.Get_rank() > 0:
self.group_comm.barrier()
def barrier(self):
self.group_comm.barrier()
def _get_free_endpoints(self, local_shard_num, start_port):
import psutil
conns = psutil.net_connections()
x = [conn.laddr.port for conn in conns]
free_endpoints = []
step = 500
start_range = start_port + self._comm.Get_rank() * step
for i in range(start_range, start_range + step, 1):
if i in x:
continue
if i > 65535:
continue
else:
free_endpoints.append(i)
if len(free_endpoints) == local_shard_num * 2:
break
return free_endpoints
endpoint=("50050" "50051" "50052" "50053" "50054" "50055" "50056" "50057" "50058" "50059")
for i in {0..9}
do
python data_server_impl.py ${endpoint[$i]} &
done
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from role_maker import FLSimRoleMaker
role_maker = FLSimRoleMaker()
role_maker.init_env(local_shard_num=30)
print("simulator num: {}".format(role_maker.simulator_num()))
print("simulator idx: {}".format(role_maker.simulator_idx()))
print("global scheduler endpoint: {}".format(
role_maker.get_global_scheduler_endpoint()))
print("data server endpoints")
print(role_maker.get_data_server_endpoints())
print("local data server")
print(role_maker.get_local_data_server_endpoint())
print("local param server")
print(role_maker.get_local_param_server_endpoint())
role_maker.barrier_simulator()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册