提交 f132f51e 编写于 作者: Y Yancey1989

prepare prefetch context

上级 46989663
...@@ -138,13 +138,14 @@ class RequestPrefetch final : public RequestBase { ...@@ -138,13 +138,14 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope, framework::Scope* scope,
const platform::DeviceContext* dev_ctx, const platform::DeviceContext* dev_ctx,
framework::Executor* executor, framework::Executor* executor,
framework::ProgramDesc* program, int blkid) framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx)
: RequestBase(service, cq, dev_ctx), : RequestBase(service, cq, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), scope_(scope),
executor_(executor), executor_(executor),
program_(program), program_(program),
blkid_(blkid) { prefetch_ctx_(prefetch_ctx) {
request_.reset(new VariableResponse(scope, dev_ctx_)); request_.reset(new VariableResponse(scope, dev_ctx_));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
...@@ -164,8 +165,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -164,8 +165,7 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope = &scope_->NewScope(); framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name); auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType()); InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
executor_->Run(*program_, local_scope, blkid_, false, false);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
...@@ -179,6 +179,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -179,6 +179,7 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* scope_; framework::Scope* scope_;
framework::Executor* executor_; framework::Executor* executor_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int blkid_; int blkid_;
}; };
...@@ -276,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { ...@@ -276,7 +277,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
} }
RequestPrefetch* prefetch = RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_,
executor_, program_, prefetch_blk_id_); executor_, program_, prefetch_ctx_);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
} }
......
...@@ -63,6 +63,10 @@ class AsyncGRPCServer final { ...@@ -63,6 +63,10 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; } void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
prefetch_ctx_ = prepared;
}
int GetSelectedPort() { return selected_port_; } int GetSelectedPort() { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); } const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
...@@ -111,6 +115,7 @@ class AsyncGRPCServer final { ...@@ -111,6 +115,7 @@ class AsyncGRPCServer final {
std::unique_ptr<std::thread> t_prefetch_; std::unique_ptr<std::thread> t_prefetch_;
int prefetch_blk_id_; int prefetch_blk_id_;
framework::ExecutorPrepareContext *prefetch_ctx_;
framework::ProgramDesc *program_; framework::ProgramDesc *program_;
framework::Executor *executor_; framework::Executor *executor_;
int selected_port_; int selected_port_;
......
...@@ -96,10 +96,11 @@ void StartServer(const std::string& endpoint) { ...@@ -96,10 +96,11 @@ void StartServer(const std::string& endpoint) {
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program); auto* block = AppendPrefetchBlcok(&program);
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program); rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchBlkdId(block->ID()); rpc_service_->SetPrefetchPreparedCtx(prepared.get());
rpc_service_->SetDevCtx(&ctx); rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope); rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe); rpc_service_->SetExecutor(&exe);
...@@ -125,7 +126,6 @@ TEST(PREFETCH, CPU) { ...@@ -125,7 +126,6 @@ TEST(PREFETCH, CPU) {
out_var_name); out_var_name);
client.Wait(); client.Wait();
// auto out_var = scope.Var(out_var_name);
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place); auto ptr = value.mutable_data<float>(place);
......
...@@ -21,7 +21,7 @@ service SendRecvService { ...@@ -21,7 +21,7 @@ service SendRecvService {
rpc SendVariable(VariableMessage) returns (VoidMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname. // Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// Prefetch variable by Ids // Look up table block execution output variable name.
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册