提交 18427581 编写于 作者: Y Yancey1989

prefetch prog run on new scope

上级 0cafe390
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set> #include <set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
...@@ -39,6 +38,7 @@ Scope::~Scope() { ...@@ -39,6 +38,7 @@ Scope::~Scope() {
} }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
...@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) { void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it); this->kids_.erase(it);
...@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) { ...@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
} }
} }
void Scope::EraseVars(std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <list> #include <list>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -51,7 +52,7 @@ class Scope { ...@@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name. /// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr); Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names); void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns /// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find. /// nullptr if cannot find.
...@@ -88,6 +89,9 @@ class Scope { ...@@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr}; Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -145,23 +145,28 @@ class RequestPrefetch final : public RequestBase { ...@@ -145,23 +145,28 @@ class RequestPrefetch final : public RequestBase {
executor_(executor), executor_(executor),
program_(program), program_(program),
blkid_(blkid) { blkid_(blkid) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, this); cq_, cq_, this);
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_.varname(); } virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() { virtual void Process() {
// prefetch process... // prefetch process...
::grpc::ByteBuffer reply; ::grpc::ByteBuffer reply;
executor_->Run(*program_, scope_, blkid_, false, false); std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->Run(*program_, local_scope, blkid_, false, false);
std::string var_name = request_.out_varname();
auto* var = scope_->FindVar(var_name);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this); responder_.Finish(reply, ::grpc::Status::OK, this);
...@@ -169,7 +174,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -169,7 +174,7 @@ class RequestPrefetch final : public RequestBase {
} }
protected: protected:
sendrecv::VariableMessage request_; std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* scope_;
framework::Executor* executor_; framework::Executor* executor_;
......
...@@ -14,12 +14,13 @@ limitations under the License. */ ...@@ -14,12 +14,13 @@ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <string> #include <string>
#include <thread> #include <thread> // NOLINT
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -31,9 +32,9 @@ USE_OP(lookup_table); ...@@ -31,9 +32,9 @@ USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
const auto &root_block = program.Block(0); auto root_block = program->MutableBlock(0);
auto *block= program.AppendBlock(root_block); auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}}); framework::VariableNameMap output({{"Output", {"out"}}});
...@@ -42,32 +43,48 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) { ...@@ -42,32 +43,48 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) {
op->SetInput("W", {"w"}); op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"}); op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10});
return block; return block;
} }
void InitTensorsInScope(framework::Scope &scope, platform::CPUPlace &place) { void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope.Var("w"); auto w_var = scope->Var("w");
auto w = w_var->GetMutable<framework::LoDTensor>(); auto w = w_var->GetMutable<framework::LoDTensor>();
w->Resize({10, 10}); w->Resize({10, 10});
float *ptr = w->mutable_data<float>(place); w->mutable_data<float>(*place);
for (int64_t i = 0; i < w->numel(); ++i) {
ptr[i] = static_cast<float>(i/10);
}
auto out_var = scope.Var("out"); auto out_var = scope->Var("out");
auto out = out_var->GetMutable<framework::LoDTensor>(); auto out = out_var->GetMutable<framework::LoDTensor>();
out->Resize({5, 10}); out->Resize({5, 10});
out->mutable_data<float>(place); out->mutable_data<float>(*place);
auto ids_var = scope.Var("ids"); auto ids_var = scope->Var("ids");
auto ids = ids_var->GetMutable<framework::LoDTensor>(); auto ids = ids_var->GetMutable<framework::LoDTensor>();
ids->Resize({5, 1}); ids->Resize({5, 1});
auto ids_ptr = ids->mutable_data<int64_t>(place); }
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place) {
CreateVarsOnScope(scope, place);
auto ids = scope->Var("ids")->GetMutable<framework::LoDTensor>();
auto ptr = ids->mutable_data<int64_t>(*place);
for (int64_t i = 0; i < ids->numel(); ++i) { for (int64_t i = 0; i < ids->numel(); ++i) {
ids_ptr[i] = i * 2; ptr[i] = i * 2;
} }
} }
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place) {
CreateVarsOnScope(scope, place);
auto w_var = scope->Var("w");
auto w = w_var->GetMutable<framework::LoDTensor>();
auto ptr = w->mutable_data<float>(*place);
for (int64_t i = 0; i < w->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void StartServer(const std::string& endpoint) { void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
...@@ -76,8 +93,8 @@ void StartServer(const std::string& endpoint) { ...@@ -76,8 +93,8 @@ void StartServer(const std::string& endpoint) {
platform::CPUPlace place; platform::CPUPlace place;
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(program); auto* block = AppendPrefetchBlcok(&program);
InitTensorsInScope(scope, place); InitTensorsOnServer(&scope, &place);
rpc_service_->SetProgram(&program); rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchBlkdId(block->ID()); rpc_service_->SetPrefetchBlkdId(block->ID());
...@@ -88,22 +105,20 @@ void StartServer(const std::string& endpoint) { ...@@ -88,22 +105,20 @@ void StartServer(const std::string& endpoint) {
rpc_service_->RunSyncUpdate(); rpc_service_->RunSyncUpdate();
} }
TEST(PREFETCH, CPU) { TEST(PREFETCH, CPU) {
// start up a server instance backend // start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and // TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks. // prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889"); std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(3); sleep(2);
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope // create var on local scope
InitTensorsInScope(scope, place); InitTensorsOnClient(&scope, &place);
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
detail::RPCClient client; detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name); out_var_name);
...@@ -111,6 +126,7 @@ TEST(PREFETCH, CPU) { ...@@ -111,6 +126,7 @@ TEST(PREFETCH, CPU) {
auto out_var = scope.Var(out_var_name); auto out_var = scope.Var(out_var_name);
auto out = out_var->Get<framework::LoDTensor>(); auto out = out_var->Get<framework::LoDTensor>();
auto out_ptr = out.data<float>(); auto out_ptr = out.data<float>();
rpc_service_->ShutDown(); rpc_service_->ShutDown();
server_thread.join(); server_thread.join();
......
...@@ -108,7 +108,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input, ...@@ -108,7 +108,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
bool VariableResponse::CopyLodTensorData( bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims); tensor->Resize(dims);
...@@ -144,7 +145,8 @@ inline framework::DDim GetDims( ...@@ -144,7 +145,8 @@ inline framework::DDim GetDims(
bool VariableResponse::CopySelectRowsTensorData( bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input, ::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) { const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname()); auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height()); slr->set_height(meta_.slr_height());
...@@ -410,6 +412,20 @@ int VariableResponse::Parse(Source* source) { ...@@ -410,6 +412,20 @@ int VariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
default: { default: {
// Unknown tag, return unknown error. // Unknown tag, return unknown error.
......
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#pragma once #pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -53,6 +57,7 @@ class VariableResponse { ...@@ -53,6 +57,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer); int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
// should call parse first. // should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); } framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
...@@ -60,14 +65,14 @@ class VariableResponse { ...@@ -60,14 +65,14 @@ class VariableResponse {
private: private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::DDim& dims, int length); const framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input, bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length); const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input, bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
framework::DDim& dims, int length); const framework::DDim& dims, int length);
private: private:
const framework::Scope* scope_; const framework::Scope* scope_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册