提交 0763ae9a 编写于 作者: Q qiaolongfei

remove unused file

上级 dc3d2dc8
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <ostream>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/async_listen_and_serv_op.h"
#include "paddle/utils/StringUtil.h"
namespace paddle {
namespace operators {
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
if (str.empty()) {
return;
}
size_t pos = 0;
size_t next = str.find(sep, pos);
while (next != std::string::npos) {
pieces->push_back(str.substr(pos, next - pos));
pos = next + 1;
next = str.find(sep, pos);
}
if (!str.substr(pos).empty()) {
pieces->push_back(str.substr(pos));
}
}
void RunServer(std::shared_ptr<detail::SyncGRPCServer> service) {
service->RunAsyncUpdate();
VLOG(4) << "RunServer thread end";
}
static void AsyncExecuteBlock(framework::Executor *executor,
framework::ExecutorPrepareContext *prepared,
framework::Scope *scope) {
framework::Async([&executor, &prepared, &scope]() {
try {
executor->RunPreparedContext(prepared, scope, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
}
AsyncListenAndServOp::AsyncListenAndServOp(
const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
int AsyncListenAndServOp::GetSelectedPort() const {
return rpc_service_->GetSelectedPort();
}
void AsyncListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
server_thread_->join();
}
void AsyncListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
framework::Scope &recv_scope = scope.NewScope();
if (!rpc_service_) {
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::SyncGRPCServer(endpoint));
}
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_map_str = Attr<std::vector<std::string>>("grad_to_id");
for (auto &grad_and_id : grad_map_str) {
std::vector<std::string> pieces;
split(grad_and_id, ' ', &pieces);
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_id[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0];
}
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
framework::Executor executor(dev_place);
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
if (blkid != static_cast<size_t>(prefetch_block->ID())) {
block_list.push_back(blkid);
}
}
PADDLE_ENFORCE_EQ(grad_map_str.size(), block_list.size(),
"grad num should be equal to optimize block num");
auto optimize_prepared = executor.Prepare(*program, block_list);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared;
for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
// set proper fields for table lookup and update
rpc_service_->SetExecutor(&executor);
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
prefetch_prepared.release();
rpc_service_->SetProgram(program);
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
sleep(5);
// Write to a file of server selected port for python use.
std::ofstream port_file;
port_file.open("/tmp/paddle.selected_port");
port_file << rpc_service_->GetSelectedPort();
port_file.close();
bool exit_flag = false;
while (!exit_flag) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
AsyncExecuteBlock(&executor, grad_to_prepared[recv_var_name].get(),
&recv_scope);
// TODO(qiao): explain why
if (var->IsType<framework::SelectedRows>()) {
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
} // while(true)
}
class AsyncListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AsyncListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
AddComment(R"DOC(
ListenAndServ operator
This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
AddAttr<std::string>("endpoint",
"(string, default 127.0.0.1:6164)"
"IP address to listen on.")
.SetDefault("127.0.0.1:6164")
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::vector<std::string>>(
"grad_to_id(['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'])",
"a map from grad name to it's optimize block id")
.SetDefault({});
AddAttr<framework::BlockDesc *>(kOptimizeBlock,
"BlockID to run on server side.");
AddAttr<framework::BlockDesc *>(kPrefetchBlock,
"prefetch block to run on server side.");
AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(async_listen_and_serv, ops::AsyncListenAndServOp,
ops::AsyncListenAndServOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/async_grpc_server.h"
namespace paddle {
namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::SyncGRPCServer> service);
class AsyncListenAndServOp : public framework::OperatorBase {
public:
AsyncListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs);
int GetSelectedPort() const;
void Stop() override;
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override;
protected:
mutable std::shared_ptr<detail::SyncGRPCServer> rpc_service_;
mutable std::shared_ptr<std::thread> server_thread_;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/async_grpc_server.h"
#include <limits>
#include <string>
using ::grpc::ServerAsyncResponseWriter;
namespace paddle {
namespace operators {
namespace detail {
enum CallStatus { PROCESS = 0, FINISH };
// reference:
// https://stackoverflow.com/questions/41732884/grpc-multiple-services-in-cpp-async-server
class RequestBase {
public:
explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
const platform::DeviceContext* dev_ctx)
: service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE(cq_);
}
virtual ~RequestBase() {}
virtual void Process() { assert(false); }
CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() {
assert(false);
return "";
}
protected:
::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_;
CallStatus status_;
const platform::DeviceContext* dev_ctx_;
};
class RequestSend final : public RequestBase {
public:
explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, ReceivedQueue* queue,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
request_.reset(new VariableResponse(true, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}
virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
queue_->Push(std::make_pair(request_->Varname(), request_));
sendrecv::VoidMessage reply;
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestGet final : public RequestBase {
public:
explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq, dev_ctx), responder_(&ctx_), scope_(scope) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
}
virtual ~RequestGet() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual void Process() {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
::grpc::ByteBuffer reply;
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
sendrecv::VariableMessage request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
};
class RequestPrefetch final : public RequestBase {
public:
explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(true, scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}
virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
std::string var_name = request_->OutVarname();
VLOG(3) << "prefetch var " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
};
void SyncGRPCServer::RunAsyncUpdate() {
::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(),
&selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();
cq_prefetch_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_
<< " selected port: " << selected_port_;
std::function<void()> send_register =
std::bind(&SyncGRPCServer::TryToRegisterNewSendOne, this);
std::function<void()> get_register =
std::bind(&SyncGRPCServer::TryToRegisterNewGetOne, this);
std::function<void()> prefetch_register =
std::bind(&SyncGRPCServer::TryToRegisterNewPrefetchOne, this);
// TODO(wuyi): Run these "HandleRequest" in thread pool
t_send_.reset(
new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
new std::thread(std::bind(&SyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
t_prefetch_.reset(new std::thread(
std::bind(&SyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
// wait server
server_->Wait();
t_send_->join();
t_get_->join();
t_prefetch_->join();
}
void SyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_);
cq_send_->Shutdown();
cq_get_->Shutdown();
cq_prefetch_->Shutdown();
}
// This URL explains why shutdown is complicate:
void SyncGRPCServer::ShutDown() {
is_shut_down_ = true;
ShutdownQueue();
server_->Shutdown();
}
void SyncGRPCServer::TryToRegisterNewSendOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return;
}
RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_,
&var_recv_queue_, dev_ctx_);
VLOG(4) << "Create RequestSend status:" << send->Status();
}
void SyncGRPCServer::TryToRegisterNewGetOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne";
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
VLOG(4) << "Create RequestGet status:" << get->Status();
}
void SyncGRPCServer::TryToRegisterNewPrefetchOne() {
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return;
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
void SyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
const std::string& cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
void* tag = NULL;
bool ok = false;
while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " while in";
if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!";
break;
}
VLOG(3) << "HandleRequest for " << cq_name << " while after Next";
PADDLE_ENFORCE(tag);
RequestBase* base = reinterpret_cast<RequestBase*>(tag);
// reference:
// https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]";
TryToRegisterNewOne();
delete base;
continue;
}
switch (base->Status()) {
case PROCESS: {
VLOG(4) << cq_name << " status:" << base->Status();
TryToRegisterNewOne();
base->Process();
break;
}
case FINISH: {
VLOG(4) << cq_name << " status:" << base->Status();
delete base;
break;
}
default: { assert(false); }
}
}
}
} // namespace detail
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <thread> // NOLINT
#include <utility>
#include "grpc++/grpc++.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef SimpleBlockQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
class SyncGRPCServer final {
public:
explicit SyncGRPCServer(const std::string &address) : address_(address) {}
void RunAsyncUpdate();
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}
int GetSelectedPort() { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const std::string &msg_name) {
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr));
}
void ShutDown();
protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
void TryToRegisterNewPrefetchOne();
void ShutdownQueue();
private:
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_;
std::string address_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
std::unique_ptr<std::thread> t_send_;
std::unique_ptr<std::thread> t_get_;
std::unique_ptr<std::thread> t_prefetch_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册