提交 e7ac709b 编写于 作者: T typhoonzero

done

上级 a131c73f
......@@ -208,9 +208,9 @@ void AsyncGRPCServer::WaitClientGet(int count) {
}
}
bool AsyncGRPCServer::WaitServerReady() {
void AsyncGRPCServer::WaitServerReady() {
std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [&] { return this->ready_ == 1; });
condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
}
void AsyncGRPCServer::RunSyncUpdate() {
......
......@@ -47,7 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
bool WaitServerReady();
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
......@@ -120,7 +120,7 @@ class AsyncGRPCServer final {
framework::Executor *executor_;
int selected_port_;
std::mutext mutex_ready_;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
};
......
......@@ -80,12 +80,7 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
static void SavePort(std::shared_ptr<detail::AsyncGRPCServer> rpc_service) {
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service->GetSelectedPort();
port_file.close();
}
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
......@@ -93,15 +88,27 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
int ListenAndServOp::GetSelectedPort() const {
return rpc_service_->GetSelectedPort();
}
void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
server_thread_->join();
}
void ListenAndServOp::SavePort(const std::string &file_path) const {
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort();
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
......@@ -265,23 +272,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
} // while(true)
}
void ListenAndServOp::StartServerThread() {
server_thread_.reset(new std::thread(
std::bind(&ListenAndServOp::ServerThreadEntry, this, rpc_service_)));
}
void ListenAndServOp::ServerThreadEntry(
std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
VLOG(4) << "RunServer thread end";
{
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
barrier_condition_.notify_all();
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -315,9 +305,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use.
SavePort(rpc_service_);
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
} else {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <stdint.h>
#include <atomic>
#include <ostream>
#include <string>
......@@ -39,8 +40,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
int GetSelectedPort() const;
void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program,
framework::Scope* recv_scope,
......@@ -51,20 +50,25 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Scope* recv_scope,
framework::BlockDesc* prefetch_block) const;
void StartServerThread();
void SavePort(
const std::string& file_path = "/tmp/paddle.selected_port") const;
void WaitServerReady();
void ServerThreadEntry();
int GetSelectedPort() { return selected_port_; }
void Stop() override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_;
std::mutext server_ready_mutex_;
std::condition_variable server_ready_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
};
} // namespace operators
......
......@@ -116,6 +116,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
void StartServerNet(bool is_sparse) {
f::Scope scope;
p::CPUPlace place;
VLOG(4) << "before init tensor";
if (is_sparse) {
InitSelectedRowsInScope(place, &scope);
} else {
......@@ -129,6 +130,7 @@ void StartServerNet(bool is_sparse) {
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
VLOG(4) << "before attr";
f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")});
......@@ -139,15 +141,19 @@ void StartServerNet(bool is_sparse) {
attrs.insert({"PrefetchBlock", prefetch_block});
attrs.insert({"grad_to_block_id", std::vector<std::string>({""})});
attrs.insert({"sync_mode", true});
VLOG(4) << "before init op";
listen_and_serv_op =
f::OpRegistry::CreateOp("listen_and_serv", {{"X", {"x1"}}}, {}, attrs);
VLOG(4) << "before run op";
listen_and_serv_op->Run(scope, place);
LOG(INFO) << "server exit";
}
TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false);
sleep(5); // wait server to start
// wait server to start
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
......@@ -181,11 +187,13 @@ TEST(SendRecvOp, CPUDense) {
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
}
TEST(SendRecvOp, CPUSparse) {
std::thread server_thread(StartServerNet, true);
sleep(3); // wait server to start
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
......@@ -226,4 +234,5 @@ TEST(SendRecvOp, CPUSparse) {
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册