diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 091b63bf0f907a5449f08f0e36abb6577fa5e43e..b49c61449984f51d65963958c87191b0799bcf5b 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, desc.data_type(), DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); is.read(static_cast(buf), cpu_tensor.memory_size()); - auto cpu_place = new platform::CPUPlace(); - framework::Copy(cpu_tensor, *cpu_place, dev_ctx, tensor); - delete cpu_place; + auto dst_place = dev_ctx.GetPlace(); + framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor); #else PADDLE_THROW("Unexpected branch"); #endif diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index ac4bb5cb8227b89a2f387801e4f876b09eabbff3..c0b94746a0b7f6ffb657bbf5af18360426933858 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -79,8 +79,12 @@ class RequestSend final : public RequestBase { class RequestGet final : public RequestBase { public: explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq, framework::Scope* scope) - : RequestBase(service, cq), responder_(&ctx_), scope_(scope) { + grpc::ServerCompletionQueue* cq, framework::Scope* scope, + const platform::DeviceContext* dev_ctx) + : RequestBase(service, cq), + responder_(&ctx_), + scope_(scope), + dev_ctx_(dev_ctx) { service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); } @@ -92,7 +96,7 @@ class RequestGet final : public RequestBase { // proc request. std::string var_name = request_.varname(); auto* var = scope_->FindVar(var_name); - SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); + SerializeToMessage(var_name, var, *dev_ctx_, &reply_); // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; @@ -103,6 +107,7 @@ class RequestGet final : public RequestBase { sendrecv::VariableMessage reply_; ServerAsyncResponseWriter responder_; framework::Scope* scope_; + const platform::DeviceContext* dev_ctx_; }; void AsyncGRPCServer::RunSyncUpdate() { @@ -165,7 +170,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { if (is_shut_down_) { return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); VLOG(4) << "create Requestget status:" << get->Status(); } diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 694e18ef49818f2d22789748d930d041de4f3586..2c078b77771656dc7fc0342ecf21b8d33dc11817 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -37,7 +37,7 @@ class RequestBase; class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { public: - explicit AsyncGRPCServer(std::string address) { address_ = address; } + explicit AsyncGRPCServer(const std::string &address) : address_(address) {} void RunSyncUpdate(); @@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void SetScope(framework::Scope *scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } + const MessageWithName Get() { return this->var_recv_queue_.Pop(); } void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } @@ -73,6 +75,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { std::string address_; framework::Scope *scope_; + const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index cf69c12b6864b4aaea82ecf7e9eb38e44d61a215..f9ed7516826319da422fbb0af4e5c277afa7ae40 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -87,7 +87,12 @@ class RecvOp : public framework::OperatorBase { const platform::Place &dev_place) const override { // FIXME(typhoonzero): no new scopes for every run. framework::Scope &recv_scope = scope.NewScope(); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + + // FIXME(Yancey1989): initialize rpc server with laze mode. rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); auto trainer_count = Attr("Trainers"); @@ -136,9 +141,6 @@ class RecvOp : public framework::OperatorBase { } auto *var = recv_scope.Var(grad_var_name); - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); detail::DeserializeFromMessage(v.second, dev_ctx, var); } diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 203000f5aaff1a89ac6119b5a9b774f8d48c7c76..7c81a9524d6609a65b3167d95053bf4e85eef0db 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -33,13 +33,13 @@ class SendOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void Run(const framework::Scope& scope, - const platform::Place& dev_place) const override { + const platform::Place& place) const override { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - // FIXME(gongwb): DeviceContext? - auto ctx = platform::CPUDeviceContext(); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); for (size_t i = 0; i < ins.size(); i++) { client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); }