未验证 提交 be853853 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #9593 from Yancey1989/prefech_prog_on_server

run prefetch prog on server
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
......@@ -39,6 +38,7 @@ Scope::~Scope() {
}
Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
......@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}
void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
......@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}
void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <list>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
......@@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names);
void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
......@@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
};
} // namespace framework
} // namespace paddle
......@@ -5,5 +5,5 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
cares zlib protobuf sendrecvop_grpc)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
cc_test(grpc_server_test SRCS grpc_server_test.cc DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_table_op)
endif()
......@@ -138,7 +138,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req);
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val);
// var handle
VarHandle var_h;
......
......@@ -138,39 +138,48 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope,
const platform::DeviceContext* dev_ctx,
framework::Executor* executor,
framework::ProgramDesc* program, int blkid)
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx),
responder_(&ctx_),
scope_(scope),
executor_(executor),
program_(program),
blkid_(blkid) {
prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}
virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this);
status_ = FINISH;
}
protected:
sendrecv::VariableMessage request_;
std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_;
};
......@@ -268,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_);
executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
......
......@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}
int GetSelectedPort() { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
......@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_prefetch_;
int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
......
......@@ -20,43 +20,121 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;
USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_table");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::SELECTED_ROWS);
out.SetShape({10, 10});
return block;
}
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
w_var->GetMutable<framework::SelectedRows>();
auto out_var = scope->Var("out");
out_var->GetMutable<framework::SelectedRows>();
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::SelectedRows>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto ids_var = scope->Var("ids")->GetMutable<framework::SelectedRows>();
auto rows = ids_var->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i * 2);
ids_var->mutable_value()->Resize({rows_numel, 1});
ids_var->mutable_value()->mutable_data<float>(*place);
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
int64_t rows_numel) {
CreateVarsOnScope(scope, place);
auto w = scope->Var("w")->GetMutable<framework::SelectedRows>();
auto rows = w->mutable_rows();
for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i);
auto w_value = w->mutable_value();
w_value->Resize({rows_numel, 10});
auto ptr = w_value->mutable_data<float>(*place);
for (int64_t i = 0; i < w_value->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(prepared.get());
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
rpc_service_->RunSyncUpdate();
}
TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string in_var_name("in");
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
}
......@@ -21,7 +21,7 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
}
......@@ -67,6 +67,8 @@ message VariableMessage {
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
}
message VoidMessage {}
......@@ -30,7 +30,8 @@ namespace detail {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg) {
::grpc::ByteBuffer* msg,
const std::string& out_name) {
using VarMsg = sendrecv::VariableMessage;
sendrecv::VariableMessage request;
std::string header;
......@@ -52,6 +53,9 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteUint64(VarMsg::kTypeFieldNumber, 1);
}
if (!out_name.empty()) {
e.WriteString(VarMsg::kOutVarnameFieldNumber, out_name);
}
switch (framework::ToVarType(var->Type())) {
case framework::proto::VarType_Type_LOD_TENSOR: {
auto tensor = var->Get<framework::LoDTensor>();
......
......@@ -46,7 +46,8 @@ typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg);
::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string());
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
......
......@@ -416,6 +416,20 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
default: {
// Unknown tag, return unknown error.
......
......@@ -55,6 +55,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册