提交 e7ac709b 编写于 作者: T typhoonzero

done

上级 a131c73f
...@@ -208,9 +208,9 @@ void AsyncGRPCServer::WaitClientGet(int count) { ...@@ -208,9 +208,9 @@ void AsyncGRPCServer::WaitClientGet(int count) {
} }
} }
bool AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [&] { return this->ready_ == 1; }); condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
} }
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
......
...@@ -47,7 +47,7 @@ class AsyncGRPCServer final { ...@@ -47,7 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {} : address_(address), sync_mode_(sync_mode), ready_(0) {}
bool WaitServerReady(); void WaitServerReady();
void RunSyncUpdate(); void RunSyncUpdate();
// functions to sync server barrier status. // functions to sync server barrier status.
...@@ -120,7 +120,7 @@ class AsyncGRPCServer final { ...@@ -120,7 +120,7 @@ class AsyncGRPCServer final {
framework::Executor *executor_; framework::Executor *executor_;
int selected_port_; int selected_port_;
std::mutext mutex_ready_; std::mutex mutex_ready_;
std::condition_variable condition_ready_; std::condition_variable condition_ready_;
int ready_; int ready_;
}; };
......
...@@ -80,12 +80,7 @@ static void ParallelExecuteBlocks( ...@@ -80,12 +80,7 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) { std::atomic_int ListenAndServOp::selected_port_{0};
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service->GetSelectedPort();
port_file.close();
}
ListenAndServOp::ListenAndServOp(const std::string &type, ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
...@@ -93,15 +88,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type, ...@@ -93,15 +88,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
int ListenAndServOp::GetSelectedPort() const {
return rpc_service_->GetSelectedPort();
}
void ListenAndServOp::Stop() { void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE); rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
server_thread_->join(); server_thread_->join();
} }
void ListenAndServOp::SavePort(const std::string &file_path) const {
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort();
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
...@@ -265,23 +272,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -265,23 +272,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true) } // while(true)
} }
void ListenAndServOp::StartServerThread() {
server_thread_.reset(new std::thread(
std::bind(&ListenAndServOp::ServerThreadEntry, this, rpc_service_)));
}
void ListenAndServOp::ServerThreadEntry(
std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
{
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
barrier_condition_.notify_all();
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -315,9 +305,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -315,9 +305,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready..."; VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(rpc_service_); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else { } else {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <atomic>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs); const framework::AttributeMap& attrs);
int GetSelectedPort() const;
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
...@@ -51,20 +50,25 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -51,20 +50,25 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope, framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const; framework::BlockDesc* prefetch_block) const;
void StartServerThread(); void SavePort(
const std::string& file_path = "/tmp/paddle.selected_port") const;
void WaitServerReady();
void ServerThreadEntry(); int GetSelectedPort() { return selected_port_; }
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected: protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
std::mutext server_ready_mutex_; // FIXME(wuyi): it's static so that the operator can be cloned.
std::condition_variable server_ready_; static std::atomic_int selected_port_;
}; };
} // namespace operators } // namespace operators
......
...@@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
void StartServerNet(bool is_sparse) { void StartServerNet(bool is_sparse) {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
VLOG(4) << "before init tensor";
if (is_sparse) { if (is_sparse) {
InitSelectedRowsInScope(place, &scope); InitSelectedRowsInScope(place, &scope);
} else { } else {
...@@ -129,6 +130,7 @@ void StartServerNet(bool is_sparse) { ...@@ -129,6 +130,7 @@ void StartServerNet(bool is_sparse) {
auto *prefetch_block = program.AppendBlock(root_block); auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape. // X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block); AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
VLOG(4) << "before attr";
f::AttributeMap attrs; f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")}); attrs.insert({"endpoint", std::string("127.0.0.1:0")});
...@@ -139,15 +141,19 @@ void StartServerNet(bool is_sparse) { ...@@ -139,15 +141,19 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"PrefetchBlock", prefetch_block}); attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})}); attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true}); attrs.insert({"sync_mode", true});
VLOG(4) << "before init op";
listen_and_serv_op = listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs); f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
VLOG(4) << "before run op";
listen_and_serv_op->Run(scope, place); listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit"; LOG(INFO) << "server exit";
} }
TEST(SendRecvOp, CPUDense) { TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false); std::thread server_thread(StartServerNet, false);
sleep(5); // wait server to start // wait server to start
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
...@@ -181,11 +187,13 @@ TEST(SendRecvOp, CPUDense) { ...@@ -181,11 +187,13 @@ TEST(SendRecvOp, CPUDense) {
listen_and_serv_op->Stop(); listen_and_serv_op->Stop();
server_thread.join(); server_thread.join();
listen_and_serv_op.reset(nullptr); listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
} }
TEST(SendRecvOp, CPUSparse) { TEST(SendRecvOp, CPUSparse) {
std::thread server_thread(StartServerNet, true); std::thread server_thread(StartServerNet, true);
sleep(3); // wait server to start static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net // local net
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
...@@ -226,4 +234,5 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -226,4 +234,5 @@ TEST(SendRecvOp, CPUSparse) {
listen_and_serv_op->Stop(); listen_and_serv_op->Stop();
server_thread.join(); server_thread.join();
listen_and_serv_op.reset(); listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册