未验证 提交 be853853 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9593 from Yancey1989/prefech_prog_on_server

run prefetch prog on server
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set> #include <set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
...@@ -39,6 +38,7 @@ Scope::~Scope() { ...@@ -39,6 +38,7 @@ Scope::~Scope() {
} }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
...@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) { void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it); this->kids_.erase(it);
...@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) { ...@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
} }
} }
void Scope::EraseVars(std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <list> #include <list>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -51,7 +52,7 @@ class Scope { ...@@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr); Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
...@@ -88,6 +89,9 @@ class Scope { ...@@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr}; Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE) ...@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc) cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
endif() endif()
...@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle // var handle
VarHandle var_h; VarHandle var_h;
......
...@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase { ...@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, int blkid) framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
blkid_(blkid) { prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, this); cq_, cq_, this);
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_.varname(); } virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() { virtual void Process() {
// prefetch process... // prefetch process...
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in"; std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH; status_ = FINISH;
} }
protected: protected:
sendrecv::VariableMessage request_; std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_; int blkid_;
}; };
...@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { ...@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_); executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
......
...@@ -63,6 +63,10 @@ class AsyncGRPCServer final { ...@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; } void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}
int GetSelectedPort() { return selected_port_; } int GetSelectedPort() { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
...@@ -111,6 +115,7 @@ class AsyncGRPCServer final { ...@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
int prefetch_blk_id_; int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_; framework::ProgramDesc *program_;
framework::Executor *executor_; framework::Executor *executor_;
int selected_port_; int selected_port_;
......
...@@ -20,43 +20,121 @@ limitations under the License. */ ...@@ -20,43 +20,121 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS);
out.SetShape({10, 10});
return block;
}
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
auto rows = ids_var->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
ids_var->mutable_value()->Resize({rows_numel, 1});
ids_var->mutable_value()->mutable_data<float>(*place);
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto rows = w->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
rpc_service_->RunSyncUpdate(); rpc_service_->RunSyncUpdate();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
// start up a server instance backend // start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889"); std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
std::string in_var_name("in"); int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client; detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name); out_var_name);
client.Wait(); client.Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
rpc_service_->ShutDown(); rpc_service_->ShutDown();
server_thread.join(); server_thread.join();
rpc_service_.reset(nullptr); rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
} }
...@@ -21,7 +21,7 @@ service SendRecvService { ...@@ -21,7 +21,7 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname. // Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids // pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
} }
...@@ -67,6 +67,8 @@ message VariableMessage { ...@@ -67,6 +67,8 @@ message VariableMessage {
bytes serialized = 8; bytes serialized = 8;
// selected_rows data // selected_rows data
bytes rows = 9; bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
} }
message VoidMessage {} message VoidMessage {}
...@@ -30,7 +30,8 @@ namespace detail { ...@@ -30,7 +30,8 @@ namespace detail {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) { ::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request; sendrecv::VariableMessage request;
std::string header; std::string header;
...@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1); e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
} }
if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) { switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: { case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
......
...@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*); ...@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg); ::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
......
...@@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) { ...@@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
default: { default: {
// Unknown tag, return unknown error. // Unknown tag, return unknown error.
......
...@@ -55,6 +55,7 @@ class VariableResponse { ...@@ -55,6 +55,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer); int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
// should call parse first. // should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册