提交 94eea16e 编写于 作者: T typhoonzero

fix sendrecv port bind

上级 3fd92662
...@@ -186,7 +186,8 @@ void AsyncGRPCServer::WaitClientGet(int count) { ...@@ -186,7 +186,8 @@ void AsyncGRPCServer::WaitClientGet(int count) {
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
::grpc::ServerBuilder builder; ::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials()); builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
&selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_); builder.RegisterService(&service_);
...@@ -196,7 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() { ...@@ -196,7 +197,8 @@ void AsyncGRPCServer::RunSyncUpdate() {
cq_prefetch_ = builder.AddCompletionQueue(); cq_prefetch_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl; LOG(INFO) << "Server listening on " << address_
<< " selected port: " << selected_port_;
std::function<void()> send_register = std::function<void()> send_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this);
...@@ -242,6 +244,9 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { ...@@ -242,6 +244,9 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
while (scope_ == nullptr) {
sleep(0.01);
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
&var_recv_queue_, dev_ctx_); &var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status(); VLOG(4) << "Create RequestSend status:" << send->Status();
......
...@@ -62,6 +62,8 @@ class AsyncGRPCServer final { ...@@ -62,6 +62,8 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; } void SetExecutor(framework::Executor *executor) { executor_ = executor; }
int GetSelectedPort() { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const std::string &msg_name) { void Push(const std::string &msg_name) {
...@@ -109,6 +111,7 @@ class AsyncGRPCServer final { ...@@ -109,6 +111,7 @@ class AsyncGRPCServer final {
int prefetch_blk_id_; int prefetch_blk_id_;
framework::ProgramDesc *program_; framework::ProgramDesc *program_;
framework::Executor *executor_; framework::Executor *executor_;
int selected_port_;
}; };
}; // namespace detail }; // namespace detail
......
...@@ -12,100 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,100 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdint.h>
#include <ostream> #include <ostream>
#include <thread>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate(); service->RunSyncUpdate();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void CreateTensorFromMessageType(framework::Variable *var, ListenAndServOp::ListenAndServOp(const std::string &type,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
executor->Run(*program, scope, run_block, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
class ListenAndServOp : public framework::OperatorBase {
public:
ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) { : OperatorBase(type, inputs, outputs, attrs) {}
if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint"); int ListenAndServOp::GetSelectedPort() {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); return rpc_service_->GetSelectedPort();
server_thread_.reset(new std::thread(RunServer, rpc_service_)); }
}
}
void Stop() override { void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
server_thread_->join(); server_thread_->join();
} }
void RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
LOG(INFO) << "created recv scope: " << &recv_scope;
if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
}
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto ins = Inputs("X"); auto ins = Inputs("X");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program(); auto *program = block->Program();
int num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// FIXME(Yancey1989): initialize rpc server with lazy mode.
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// TODO(qiao) set proper fields for table lookup and update // TODO(qiao) set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor); rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0); rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program); rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
// FIXME(typhoonzero): do we need to wait until the server port is ready?
sleep(5);
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false; bool exit_flag = false;
...@@ -153,15 +120,14 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -153,15 +120,14 @@ class ListenAndServOp : public framework::OperatorBase {
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
size_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = program->Block(1).Parent();
std::vector<size_t> parallel_blkids; std::vector<size_t> parallel_blkids;
parallel_blkids.push_back(1); parallel_blkids.push_back(1);
double ts = detail::GetTimestamp(); double ts = detail::GetTimestamp();
for (size_t blkid = 2; blkid < num_blocks; ++blkid) { for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
if (program->Block(blkid).Parent() != last_parent_blkid) { if (program->Block(blkid).Parent() != last_parent_blkid) {
for (size_t idx : parallel_blkids) VLOG(3) << idx; for (size_t idx : parallel_blkids) VLOG(3) << idx;
ParallelExecuteBlocks(parallel_blkids, &executor, program, ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
&recv_scope);
parallel_blkids.clear(); parallel_blkids.clear();
last_parent_blkid = program->Block(blkid).Parent(); last_parent_blkid = program->Block(blkid).Parent();
} }
...@@ -169,8 +135,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -169,8 +135,7 @@ class ListenAndServOp : public framework::OperatorBase {
} }
ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope); ParallelExecuteBlocks(parallel_blkids, &executor, program, &recv_scope);
VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts VLOG(3) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
<< "(ms)";
// Reset the received sparse variables, the sum operator would not // Reset the received sparse variables, the sum operator would not
// sum the input sparse variables which rows is empty at the next // sum the input sparse variables which rows is empty at the next
...@@ -185,12 +150,7 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -185,12 +150,7 @@ class ListenAndServOp : public framework::OperatorBase {
rpc_service_->WaitClientGet(fan_in); rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear(); sparse_vars.clear();
} // while(true) } // while(true)
} }
protected:
std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
};
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <ostream>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
static void CreateTensorFromMessageType(framework::Variable *var,
sendrecv::VarType var_type) {
if (var_type == sendrecv::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>();
} else if (var_type == sendrecv::VarType::SELECTED_ROWS) {
var->GetMutable<framework::SelectedRows>();
} else {
PADDLE_THROW(
"VariableMessage type %d is not in "
"[LoDTensor, SelectedRows]",
var_type);
}
}
static void ParallelExecuteBlocks(const std::vector<size_t> &parallel_blkids,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *scope) {
std::vector<std::future<void>> fs;
for (size_t idx : parallel_blkids) {
fs.push_back(framework::Async([&executor, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
executor->Run(*program, scope, run_block, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
}));
}
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
class ListenAndServOp : public framework::OperatorBase {
public:
ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs);
int GetSelectedPort();
void Stop() override;
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override;
protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_;
};
} // namespace operators
} // namespace paddle
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -34,6 +35,7 @@ namespace m = paddle::operators::math; ...@@ -34,6 +35,7 @@ namespace m = paddle::operators::math;
// global for simplicity. // global for simplicity.
std::unique_ptr<f::OperatorBase> listen_and_serv_op; std::unique_ptr<f::OperatorBase> listen_and_serv_op;
int selected_port;
void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) {
p::CPUDeviceContext ctx(place); p::CPUDeviceContext ctx(place);
...@@ -128,14 +130,16 @@ void StartServerNet(bool is_sparse) { ...@@ -128,14 +130,16 @@ void StartServerNet(bool is_sparse) {
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"endpoint", std::string("127.0.0.1:0")});
attrs.insert({"Fanin", 1}); attrs.insert({"Fanin", 1});
attrs.insert({"ParamList", std::vector<std::string>({"Out"})}); attrs.insert({"ParamList", std::vector<std::string>({"Out"})});
attrs.insert({"GradList", std::vector<std::string>({"x1"})}); attrs.insert({"GradList", std::vector<std::string>({"x1"})});
attrs.insert({"OptimizeBlock", optimize_block}); attrs.insert({"OptimizeBlock", optimize_block});
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
LOG(INFO) << "selected port before run " << selected_port;
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit";
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
...@@ -149,12 +153,19 @@ TEST(SendRecvOp, CPUDense) { ...@@ -149,12 +153,19 @@ TEST(SendRecvOp, CPUDense) {
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); selected_port = static_cast<paddle::operators::ListenAndServOp *>(
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); listen_and_serv_op.get())
->GetSelectedPort();
LOG(INFO) << "selected port " << selected_port;
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}}, "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
LOG(INFO) << "before run " << endpoint;
send_op->Run(scope, place); send_op->Run(scope, place);
LOG(INFO) << "end run";
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
auto tensor = in_var->GetMutable<f::LoDTensor>(); auto tensor = in_var->GetMutable<f::LoDTensor>();
...@@ -167,6 +178,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -167,6 +178,7 @@ TEST(SendRecvOp, CPUDense) {
for (int64_t i = 0; i < target->numel(); ++i) { for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]); EXPECT_EQ(expected[i] * 2, actual[i]);
} }
LOG(INFO) << "before stop";
listen_and_serv_op->Stop(); listen_and_serv_op->Stop();
server_thread.join(); server_thread.join();
listen_and_serv_op.reset(nullptr); listen_and_serv_op.reset(nullptr);
...@@ -182,8 +194,13 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -182,8 +194,13 @@ TEST(SendRecvOp, CPUSparse) {
InitSelectedRowsInScope(scope, place); InitSelectedRowsInScope(scope, place);
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoints", std::vector<std::string>({"127.0.0.1:6174"})}); selected_port = static_cast<paddle::operators::ListenAndServOp *>(
attrs.insert({"epmap", std::vector<std::string>({"127.0.0.1:6174"})}); listen_and_serv_op.get())
->GetSelectedPort();
LOG(INFO) << "selected port " << selected_port;
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}}, "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册