提交 9de18095 编写于 作者: Y Yancey1989

fluid distributed on CUDA place

上级 cb6b468e
......@@ -315,9 +315,8 @@ inline void DeserializeFromStream(std::istream& is, Tensor* tensor,
desc.data_type(),
DeserializedDataFunctor(&buf, &cpu_tensor, ctx.GetPlace()));
is.read(static_cast<char*>(buf), cpu_tensor.memory_size());
auto cpu_place = new platform::CPUPlace();
framework::Copy(cpu_tensor, *cpu_place, dev_ctx, tensor);
delete cpu_place;
auto dst_place = dev_ctx.GetPlace();
framework::Copy(cpu_tensor, dst_place, dev_ctx, tensor);
#else
PADDLE_THROW("Unexpected branch");
#endif
......
......@@ -74,8 +74,12 @@ class RequestSend final : public RequestBase {
class RequestGet final : public RequestBase {
public:
explicit RequestGet(sendrecv::SendRecvService::AsyncService* service,
grpc::ServerCompletionQueue* cq, framework::Scope* scope)
: RequestBase(service, cq), responder_(&ctx_), scope_(scope) {
grpc::ServerCompletionQueue* cq, framework::Scope* scope,
const platform::DeviceContext* dev_ctx)
: RequestBase(service, cq),
responder_(&ctx_),
scope_(scope),
dev_ctx_(dev_ctx) {
service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this);
}
......@@ -85,7 +89,7 @@ class RequestGet final : public RequestBase {
// proc request.
std::string var_name = request_.varname();
auto* var = scope_->FindVar(var_name);
SerializeToMessage(var_name, var, platform::CPUDeviceContext(), &reply_);
SerializeToMessage(var_name, var, *dev_ctx_, &reply_);
// TODO(gongwb): check var's info.
responder_.Finish(reply_, grpc::Status::OK, this);
}
......@@ -95,6 +99,7 @@ class RequestGet final : public RequestBase {
sendrecv::VariableMessage reply_;
ServerAsyncResponseWriter<sendrecv::VariableMessage> responder_;
framework::Scope* scope_;
const platform::DeviceContext* dev_ctx_;
};
void AsyncGRPCServer::RunSyncUpdate() {
......@@ -155,7 +160,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
if (is_shut_down_) {
return;
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_);
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_);
VLOG(4) << "create Requestget status:" << get->Status();
}
......
......@@ -37,7 +37,7 @@ class RequestBase;
class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
public:
explicit AsyncGRPCServer(std::string address) { address_ = address; }
explicit AsyncGRPCServer(const std::string &address) : address_(address) {}
void RunSyncUpdate();
......@@ -47,6 +47,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
const MessageWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }
......@@ -74,6 +76,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
std::string address_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<MessageWithName> var_recv_queue_;
......
......@@ -87,7 +87,11 @@ class RecvOp : public framework::OperatorBase {
const platform::Place &dev_place) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
......@@ -134,9 +138,6 @@ class RecvOp : public framework::OperatorBase {
}
auto *var = recv_scope.Var(grad_var_name);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
......
......@@ -33,13 +33,15 @@ class SendOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const platform::Place& place) const override {
auto ins = Inputs("X");
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
// FIXME(gongwb): DeviceContext?
auto ctx = platform::CPUDeviceContext();
// auto ctx = platform::CPUDeviceContext();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册