未验证 提交 539c8707 编写于 作者: Z zhang wenhui 提交者: GitHub

add fl_listen_and_serv &fl_transpiler,test=develop (#19091)

add fl_listen_and_serv op for Federated_learning and fl_distribute_transpiler add this op to pserver program . This op just listen the endpoint and sum&scale.
上级 5368b365
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> // for removing the port file
#include <csignal>
#include <cstdlib>
#include <fstream>
#include <thread> // NOLINT
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(flrpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(flrpc_get_thread_num, 12, "number of threads for rpc get");
namespace paddle {
namespace operators {
void FlRunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
}
static void flsplit(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
static void FlParallelExecuteBlocks(
const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
&prepared,
framework::ProgramDesc *program, framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
int run_block = idx; // thread local
try {
VLOG(3) << "running server block: " << run_block
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
LOG(FATAL) << "run sub program:" << idx << " error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
FlListenAndServOp::FlListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
FlListenAndServOp::~FlListenAndServOp() {}
void FlListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port
rpc_service_->SavePort();
}
static int64_t GetTimestamp() {
struct timeval tp;
gettimeofday(&tp, NULL);
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
void FlListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
platform::DeviceContext *dev_ctx) const {
VLOG(2) << "RunSyncLoop";
size_t num_blocks = program->Size();
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
// Prepare all the server block
std::vector<int> optimize_blocks_list;
for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i);
}
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
optimize_prepared.insert(
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to get pserver parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
if (rpc_service_->IsExit()) {
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
VLOG(3) << "wait all clients to send after_optimizer parameters";
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = optimize_blocks[0]->Parent();
std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(optimize_blocks[0]->ID());
double ts = GetTimestamp();
for (size_t i = 1; i < optimize_blocks.size(); ++i) {
// skip the first optimize block because it is already in the
// parallel_blkids.
int blkid = optimize_blocks[i]->ID();
if (program->Block(blkid).Parent() != last_parent_blkid) {
FlParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent();
}
parallel_blkids.push_back(blkid);
}
FlParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
program, recv_scope);
VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
} // while(true)
}
static void FillRequestCtx(distributed::RequestHandler *h,
framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
distributed::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetRPCServer(rpc_server);
}
void FlListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
platform::SetProfileListener();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X");
PADDLE_ENFORCE_EQ(!rpc_service_, true, "rpc_service_ must null");
std::string endpoint = Attr<std::string>("endpoint");
VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(
new distributed::RequestSendHandler(sync_mode, false));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, false));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
FLAGS_flrpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestGet,
request_get_handler_.get(),
FLAGS_flrpc_get_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
PADDLE_ENFORCE_GE(
optimize_blocks.size(), 1,
"optimize blocks should be 1 at least on the pserver side.");
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(FlRunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, FlSignalHandler::StopAndExit);
signal(SIGTERM, FlSignalHandler::StopAndExit);
// Cache the type of the received vars as `sparse_vars_` and `dense_vars_`
// so that we can reset them at the end of each iteration.
// NOTE: only used in sync update
// Write to a file of server selected port for python use.
SavePort();
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx);
}
class FlListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
" will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({});
}
};
void FlSignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0);
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fl_listen_and_serv, ops::FlListenAndServOp,
ops::FlListenAndServOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks";
void FlRunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class FlListenAndServOp : public framework::OperatorBase {
public:
FlListenAndServOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~FlListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
platform::DeviceContext* dev_ctx) const;
void SavePort() const;
int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
protected:
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
mutable std::vector<std::string> dense_vars_;
};
class FlSignalHandler {
public:
static void StopAndExit(int signal_num);
private:
DISABLE_COPY_AND_ASSIGN(FlSignalHandler);
};
} // namespace operators
} // namespace paddle
......@@ -1035,8 +1035,8 @@ class Operator(object):
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_nccl_id', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream',
'c_sync_comm_stream'
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program
import os
import signal
import subprocess
import time
import unittest
from multiprocessing import Process
from op_test import OpTest
import numpy
import urllib
import sys
def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
with open("trainer_recv_program.dms", "rb") as f:
trainer_recv_program_desc_str = f.read()
with open("trainer_main_program.dms", "rb") as f:
trainer_main_program_desc_str = f.read()
with open("trainer_send_program.dms", "rb") as f:
trainer_send_program_desc_str = f.read()
recv_program = Program.parse_from_string(trainer_recv_program_desc_str)
main_program = Program.parse_from_string(trainer_main_program_desc_str)
send_program = Program.parse_from_string(trainer_send_program_desc_str)
trainer_startup_program = fluid.default_startup_program()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(trainer_startup_program)
for i in range(5):
exe.run(recv_program)
exe.run(main_program,
feed={
"x": numpy.array([1, 2]).astype('float32').reshape(2, 1),
"y": numpy.array([2, 3]).astype('float32').reshape(2, 1)
})
exe.run(send_program)
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
with open("pserver_startup_program.dms", "rb") as f:
pserver_startup_program_desc_str = f.read()
with open("pserver_main_program.dms", "rb") as f:
pserver_main_program_desc_str = f.read()
startup_program = Program.parse_from_string(
pserver_startup_program_desc_str)
main_program = Program.parse_from_string(pserver_main_program_desc_str)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
exe.run(main_program)
class TestFlListenAndServOp(OpTest):
def setUp(self):
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6000"
self.trainers = 2
self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode, pserver_func):
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.daemon = True
p.start()
return p
def _start_trainer0(self, use_cuda, sync_mode, pserver_func):
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 0))
p.daemon = True
p.start()
return p
def _start_trainer1(self, use_cuda, sync_mode, pserver_func):
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 1))
p.daemon = True
p.start()
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def test_rpc_interfaces(self):
# TODO(Yancey1989): need to make sure the rpc interface correctly.
pass
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass
else:
print(sys.platform)
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/pserver_startup_program.dms"
os.system(cmd)
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/pserver_main_program.dms"
os.system(cmd)
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/trainer_recv_program.dms"
os.system(cmd)
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/trainer_main_program.dms"
os.system(cmd)
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/trainer_send_program.dms"
os.system(cmd)
p1 = self._start_pserver(False, True, run_pserver)
self._wait_ps_ready(p1.pid)
time.sleep(5)
t1 = self._start_trainer0(False, True, run_trainer)
time.sleep(2)
t2 = self._start_trainer1(False, True, run_trainer)
# raise SIGTERM to pserver
time.sleep(2)
cmd_del = "rm trainer*dms* pserver*dms*"
os.system(cmd_del)
os.kill(p1.pid, signal.SIGINT)
p1.join()
os.kill(t1.pid, signal.SIGINT)
t1.join()
os.kill(t2.pid, signal.SIGINT)
t2.join()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册