提交 18427581 编写于 作者: Y Yancey1989

prefetch prog run on new scope

上级 0cafe390
......@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include <set>
#include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h"
......@@ -39,6 +38,7 @@ Scope::~Scope() {
}
Scope& Scope::NewScope() const {
std::unique_lock<std::mutex> lock(mutex_);
kids_.push_back(new Scope(this));
return *kids_.back();
}
......@@ -92,6 +92,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
}
void Scope::DeleteScope(Scope* scope) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);
this->kids_.erase(it);
......@@ -103,7 +104,7 @@ void Scope::DeleteScope(Scope* scope) {
}
}
void Scope::EraseVars(std::vector<std::string>& var_names) {
void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <list>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <vector>
......@@ -51,7 +52,7 @@ class Scope {
/// Create a variable with a scope-unique name.
Variable* Var(std::string* name = nullptr);
void EraseVars(std::vector<std::string>& var_names);
void EraseVars(const std::vector<std::string>& var_names);
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
......@@ -88,6 +89,9 @@ class Scope {
Scope const* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope);
private:
mutable std::mutex mutex_;
};
} // namespace framework
} // namespace paddle
......@@ -145,23 +145,28 @@ class RequestPrefetch final : public RequestBase {
executor_(executor),
program_(program),
blkid_(blkid) {
request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_,
cq_, this);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
cq_, cq_, this);
}
virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_.varname(); }
virtual std::string GetReqName() { return request_->Varname(); }
virtual void Process() {
// prefetch process...
::grpc::ByteBuffer reply;
executor_->Run(*program_, scope_, blkid_, false, false);
std::string var_name = request_->OutVarname();
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->Run(*program_, local_scope, blkid_, false, false);
std::string var_name = request_.out_varname();
auto* var = scope_->FindVar(var_name);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
responder_.Finish(reply, ::grpc::Status::OK, this);
......@@ -169,7 +174,7 @@ class RequestPrefetch final : public RequestBase {
}
protected:
sendrecv::VariableMessage request_;
std::shared_ptr<VariableResponse> request_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::Executor* executor_;
......
......@@ -14,12 +14,13 @@ limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include <thread> // NOLINT
#include "gtest/gtest.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -31,9 +32,9 @@ USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) {
const auto &root_block = program.Block(0);
auto *block= program.AppendBlock(root_block);
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
......@@ -42,32 +43,48 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc& program) {
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10});
return block;
}
void InitTensorsInScope(framework::Scope &scope, platform::CPUPlace &place) {
auto w_var = scope.Var("w");
void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto w_var = scope->Var("w");
auto w = w_var->GetMutable<framework::LoDTensor>();
w->Resize({10, 10});
float *ptr = w->mutable_data<float>(place);
for (int64_t i = 0; i < w->numel(); ++i) {
ptr[i] = static_cast<float>(i/10);
}
w->mutable_data<float>(*place);
auto out_var = scope.Var("out");
auto out_var = scope->Var("out");
auto out = out_var->GetMutable<framework::LoDTensor>();
out->Resize({5, 10});
out->mutable_data<float>(place);
out->mutable_data<float>(*place);
auto ids_var = scope.Var("ids");
auto ids_var = scope->Var("ids");
auto ids = ids_var->GetMutable<framework::LoDTensor>();
ids->Resize({5, 1});
auto ids_ptr = ids->mutable_data<int64_t>(place);
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place) {
CreateVarsOnScope(scope, place);
auto ids = scope->Var("ids")->GetMutable<framework::LoDTensor>();
auto ptr = ids->mutable_data<int64_t>(*place);
for (int64_t i = 0; i < ids->numel(); ++i) {
ids_ptr[i] = i * 2;
ptr[i] = i * 2;
}
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place) {
CreateVarsOnScope(scope, place);
auto w_var = scope->Var("w");
auto w = w_var->GetMutable<framework::LoDTensor>();
auto ptr = w->mutable_data<float>(*place);
for (int64_t i = 0; i < w->numel(); ++i) {
ptr[i] = static_cast<float>(i / 10);
}
}
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
......@@ -76,8 +93,8 @@ void StartServer(const std::string& endpoint) {
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(program);
InitTensorsInScope(scope, place);
auto* block = AppendPrefetchBlcok(&program);
InitTensorsOnServer(&scope, &place);
rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchBlkdId(block->ID());
......@@ -88,22 +105,20 @@ void StartServer(const std::string& endpoint) {
rpc_service_->RunSyncUpdate();
}
TEST(PREFETCH, CPU) {
// start up a server instance backend
// TODO(Yancey1989): Need to start a server with optimize blocks and
// prefetch blocks.
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(3);
sleep(2);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
InitTensorsInScope(scope, place);
InitTensorsOnClient(&scope, &place);
std::string in_var_name("ids");
std::string out_var_name("out");
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
......@@ -111,6 +126,7 @@ TEST(PREFETCH, CPU) {
auto out_var = scope.Var(out_var_name);
auto out = out_var->Get<framework::LoDTensor>();
auto out_ptr = out.data<float>();
rpc_service_->ShutDown();
server_thread.join();
......
......@@ -108,7 +108,8 @@ bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
bool VariableResponse::CopyLodTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(dims);
......@@ -144,7 +145,8 @@ inline framework::DDim GetDims(
bool VariableResponse::CopySelectRowsTensorData(
::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, framework::DDim& dims, int length) {
const platform::DeviceContext& ctx, const framework::DDim& dims,
int length) {
auto var = scope_->FindVar(meta_.varname());
auto* slr = var->GetMutable<framework::SelectedRows>();
slr->set_height(meta_.slr_height());
......@@ -410,6 +412,20 @@ int VariableResponse::Parse(Source* source) {
}
break;
}
case sendrecv::VariableMessage::kOutVarnameFieldNumber: {
uint32_t length;
if ((wt != WIRETYPE_LENGTH_DELIMITED) || !input.ReadVarint32(&length)) {
return tag;
}
std::string temp;
if (!input.ReadString(&temp, length)) {
return tag;
}
meta_.set_out_varname(temp);
break;
}
default: {
// Unknown tag, return unknown error.
......
......@@ -14,6 +14,10 @@
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -53,6 +57,7 @@ class VariableResponse {
int Parse(const ::grpc::ByteBuffer& byte_buffer);
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
// should call parse first.
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
......@@ -60,14 +65,14 @@ class VariableResponse {
private:
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
const framework::DDim& dims, int length);
bool CopySelectRowsData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx, int length);
bool CopyLodTensorData(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& ctx,
framework::DDim& dims, int length);
const framework::DDim& dims, int length);
private:
const framework::Scope* scope_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册