提交 9de18095 编写于 作者: Y Yancey1989

fluid distributed on CUDA place

上级 cb6b468e
...@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor, ...@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
desc.data_type(), desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace())); DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size()); is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
auto cpu_place = new platform::CPUPlace(); auto dst_place = dev_ctx.GetPlace();
framework::Copy(cpu_tensor, *cpu_place, dev_ctx, tensor); framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor);
delete cpu_place;
#else #else
PADDLE_THROW("Unexpected branch"); PADDLE_THROW("Unexpected branch");
#endif #endif
......
...@@ -74,8 +74,12 @@ class RequestSend final : public RequestBase { ...@@ -74,8 +74,12 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope) grpc::ServerCompletionQueue* cq, framework::Scope* scope,
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) { const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
} }
...@@ -85,7 +89,7 @@ class RequestGet final : public RequestBase { ...@@ -85,7 +89,7 @@ class RequestGet final : public RequestBase {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name); auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_); SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
// TODO(gongwb): check var's info. // TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this); responder_.Finish(reply_, grpc::Status::OK, this);
} }
...@@ -95,6 +99,7 @@ class RequestGet final : public RequestBase { ...@@ -95,6 +99,7 @@ class RequestGet final : public RequestBase {
sendrecv::VariableMessage reply_; sendrecv::VariableMessage reply_;
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_; ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_; framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
}; };
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::RunSyncUpdate() {
...@@ -155,7 +160,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { ...@@ -155,7 +160,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) { if (is_shut_down_) {
return; return;
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_); RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
VLOG(4) << "create Requestget status:" << get->Status(); VLOG(4) << "create Requestget status:" << get->Status();
} }
......
...@@ -37,7 +37,7 @@ class RequestBase; ...@@ -37,7 +37,7 @@ class RequestBase;
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
public: public:
explicit AsyncGRPCServer(std::string address) { address_ = address; } explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
void RunSyncUpdate(); void RunSyncUpdate();
...@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void SetScope(framework::Scope *scope) { scope_ = scope; } void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
const MessageWithName Get() { return this->var_recv_queue_.Pop(); } const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
...@@ -74,6 +76,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { ...@@ -74,6 +76,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::string address_; std::string address_;
framework::Scope *scope_; framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_; SimpleBlockQueue<MessageWithName> var_recv_queue_;
......
...@@ -87,7 +87,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -87,7 +87,11 @@ class RecvOp : public framework::OperatorBase {
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run. // FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
rpc_service_->SetScope(&recv_scope); rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto trainer_count = Attr<int>("Trainers");
...@@ -134,9 +138,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -134,9 +138,6 @@ class RecvOp : public framework::OperatorBase {
} }
auto *var = recv_scope.Var(grad_var_name); auto *var = recv_scope.Var(grad_var_name);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
detail::DeserializeFromMessage(v.second, dev_ctx, var); detail::DeserializeFromMessage(v.second, dev_ctx, var);
} }
......
...@@ -33,13 +33,15 @@ class SendOp : public framework::OperatorBase { ...@@ -33,13 +33,15 @@ class SendOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope, void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override { const platform::Place& place) const override {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
// FIXME(gongwb): DeviceContext? // FIXME(gongwb): DeviceContext?
auto ctx = platform::CPUDeviceContext(); // auto ctx = platform::CPUDeviceContext();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册