diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 83aa927c293676c3800ed945c175e4f3dc5629d6..cc3916e7bb6d5d0ca906222e22023ba8a25cce4a 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -85,7 +85,7 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) { } void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, - bool create_local_scope) { + bool create_local_scope, bool create_vars) { // TODO(tonyyang-svail): // - only runs on the first device (i.e. no interdevice communication) // - will change to use multiple blocks for RNN op and Cond Op @@ -94,33 +94,35 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, auto& device = device_contexts_[0]; Scope* local_scope = scope; - if (create_local_scope) { - local_scope = &scope->NewScope(); - for (auto& var : block.AllVars()) { - if (var->Name() == framework::kEmptyVarName) { - continue; + if (create_vars) { + if (create_local_scope) { + local_scope = &scope->NewScope(); + for (auto& var : block.AllVars()) { + if (var->Name() == framework::kEmptyVarName) { + continue; + } + + if (var->Persistable()) { + auto* ptr = scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + VLOG(3) << "Create Variable " << var->Name() + << " global, which pointer is " << ptr; + } else { + auto* ptr = local_scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + VLOG(3) << "Create Variable " << var->Name() + << " locally, which pointer is " << ptr; + } } - - if (var->Persistable()) { - auto* ptr = scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create Variable " << var->Name() - << " global, which pointer is " << ptr; - } else { + } else { + for (auto& var : block.AllVars()) { auto* ptr = local_scope->Var(var->Name()); CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create Variable " << var->Name() - << " locally, which pointer is " << ptr; + VLOG(3) << "Create variable " << var->Name() << ", which pointer is " + << ptr; } - } - } else { - for (auto& var : block.AllVars()) { - auto* ptr = local_scope->Var(var->Name()); - CreateTensor(ptr, var->GetType()); - VLOG(3) << "Create variable " << var->Name() << ", which pointer is " - << ptr; - } - } + } // if (create_local_scope) + } // if (create_vars) for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index b745f4f6474ef688774f4c833a3958942e9aa8cb..28da0608300ca18b48a84974c29458d7e2f49918 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -35,7 +35,8 @@ class Executor { * ProgramDesc * Scope */ - void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true); + void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true, + bool create_vars = true); private: std::vector device_contexts_; diff --git a/paddle/operators/detail/recv_impl.cc b/paddle/operators/detail/recv_impl.cc index dab3d1e14c81c9f0f91258b99e51ba1428ac2ed0..bc930cbb007b73b6bbf9d4e999ee6c4388c8d0f8 100644 --- a/paddle/operators/detail/recv_impl.cc +++ b/paddle/operators/detail/recv_impl.cc @@ -20,7 +20,7 @@ namespace detail { Status SendRecvServerImpl::SendVariable(ServerContext *context, const VariableMessage *in_var, - VariableMessage *out_var) { + VoidMessage *out_var) { // TODO(typhoonzero): support different variable types. std::istringstream iss(in_var->serialized()); framework::LoDTensor t; @@ -29,6 +29,12 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, std::make_pair(in_var->varname(), std::move(t)); var_recv_queue_.Push(std::move(tensor_with_name)); + return Status::OK; +} + +Status SendRecvServerImpl::GetVariable(ServerContext *context, + const VoidMessage *in_var, + VariableMessage *out_var) { // Block util the sub graph is done. auto out_tensor_with_name = var_return_queue_.Pop(); std::ostringstream oss; @@ -36,10 +42,9 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, platform::CPUDeviceContext()); std::string *varname = out_var->mutable_varname(); - *varname = in_var->varname(); + *varname = out_tensor_with_name.first; std::string *serialized = out_var->mutable_serialized(); *serialized = oss.str(); - return Status::OK; } diff --git a/paddle/operators/detail/send_impl.cc b/paddle/operators/detail/send_impl.cc index 2313255dcba3f0209ce0559d804b8c594086d5a5..bf22d3df818358de5362c82c60e955b976238483 100644 --- a/paddle/operators/detail/send_impl.cc +++ b/paddle/operators/detail/send_impl.cc @@ -19,10 +19,10 @@ namespace operators { namespace detail { bool RPCClient::SendVariable(const framework::Scope& scope, - const std::string& inname, - const std::string& outname) { + const std::string& inname) { ClientContext context; - VariableMessage msg, out_msg; + VariableMessage msg; + VoidMessage out_msg; // FIXME(typhoonzero): pass device context to here. auto ctx = platform::CPUDeviceContext(); auto* var = scope.FindVar(inname); @@ -40,7 +40,22 @@ bool RPCClient::SendVariable(const framework::Scope& scope, LOG(ERROR) << "gRPC error: " << status.error_message(); return false; } - std::istringstream iss(out_msg.serialized()); + return true; +} + +bool RPCClient::GetVariable(const framework::Scope& scope) { + ClientContext context; + VariableMessage msg; + VoidMessage void_msg; + auto ctx = platform::CPUDeviceContext(); + Status status = stub_->GetVariable(&context, void_msg, &msg); + if (!status.ok()) { + LOG(ERROR) << "gRPC error: " << status.error_message(); + return false; + } + + std::istringstream iss(msg.serialized()); + auto outname = msg.varname(); framework::LoDTensor ret_tensor; framework::DeserializeFromStream(iss, &ret_tensor); auto* outvar = scope.FindVar(outname); diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index 9b4058fd6172b1b9d6a66d4a13a4cf9d23663112..d00c33fe42af1c63435db8c730a1d7b789420d12 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -20,7 +20,9 @@ service SendRecvService { // For parameter server round-robin like hashing, do not split tensors. // Send and recv only one tensor // TODO(typhoonzero): add streaming API - rpc SendVariable(VariableMessage) returns (VariableMessage) {} + rpc SendVariable(VariableMessage) returns (VoidMessage) {} + // Argument VariableMessage for GetVariable should only contain varname. + rpc GetVariable(VoidMessage) returns (VariableMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/detail/send_recv_impl.h b/paddle/operators/detail/send_recv_impl.h index b6b9919c609bcf564ee5df49932e28a028a178c6..df01345e342789d5816944f2e3637ea64f0c6960 100644 --- a/paddle/operators/detail/send_recv_impl.h +++ b/paddle/operators/detail/send_recv_impl.h @@ -55,7 +55,9 @@ class SendRecvServerImpl final : public SendRecvService::Service { explicit SendRecvServerImpl() {} Status SendVariable(ServerContext *context, const VariableMessage *in_var, - VariableMessage *out_var) override; + VoidMessage *out_var) override; + Status GetVariable(ServerContext *context, const VoidMessage *in_var, + VariableMessage *out_var) override; const TensorWithName Get() { return this->var_recv_queue_.Pop(); } @@ -75,8 +77,8 @@ class RPCClient { RPCClient(std::shared_ptr channel) : stub_(SendRecvService::NewStub(channel)) {} - bool SendVariable(const framework::Scope &scope, const std::string &inname, - const std::string &outname); + bool SendVariable(const framework::Scope &scope, const std::string &inname); + bool GetVariable(const framework::Scope &scope); private: std::unique_ptr stub_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 94cb39391f9d40d500ba01985bb58dd96572cf24..754338ec6bd880852974c9765b7c1f72bc2440f5 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase { const platform::DeviceContext &dev_ctx) const override { // FIXME(typhoonzero): no new scopes for every run. framework::Scope &recv_scope = scope.NewScope(); - // blocking get one var from client. - const detail::TensorWithName &v = rpc_service_->Get(); - auto grad_var_name = v.first; - auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); - auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; + size_t param_count = param_list.size(); + for (size_t i = 0; i < param_count; ++i) { + // blocking get one var from client. + const detail::TensorWithName &v = rpc_service_->Get(); + auto grad_var_name = v.first; + auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } + VLOG(10) << "recved grad: " << grad_var_name + << " updating param: " << param_var_name; + auto *var = recv_scope.Var(grad_var_name); + auto *tensor = var->GetMutable(); + // FIXME(typhoonzero): do not copy + framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); } - // find input by "grad_var_name" - // auto inputs = Inputs("RX"); - - // FIXME(typhoonzero): Find the parameter name from input grad name - // rename X -> Param - // rename RX -> Grad - - LOG(ERROR) << "recved grad: " << grad_var_name - << " param: " << param_var_name; - auto *var = recv_scope.Var(grad_var_name); - auto *tensor = var->GetMutable(); - - // Param is in parent scope, put it in current scope. - auto *param_var = recv_scope.FindVar(param_var_name); - auto param_scope = recv_scope.FindScope(param_var); - param_scope->Rename(param_var_name, "Param"); - recv_scope.Rename(grad_var_name, "Grad"); - - // FIXME(typhoonzero): do not copy - framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); std::string program_str = Attr("OptimizeProgram"); framework::ProgramDesc program_desc; @@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase { framework::ProgramDescBind program(program_desc); framework::Executor executor(dev_ctx); // Run sub graph to get optimized tensor - executor.Run(program, &recv_scope, 0, /*global_block*/ - false /*create_local_scope*/); - - auto *out_var = recv_scope.FindVar("ParamOut"); - detail::TensorWithName out; - out.first = param_var_name; - out.second = out_var->Get(); - rpc_service_->Push(out); - // rename back the params - param_scope.Rename("Param", param_var_name); - recv_scope.Rename("Grad", grad_var_name); + try { + executor.Run(program, &recv_scope, 0, /*global_block*/ + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + + for (size_t i = 0; i < param_count; ++i) { + auto *out_var = recv_scope.FindVar(param_list[i]); + detail::TensorWithName out; + out.first = param_list[i]; + out.second = out_var->Get(); + rpc_service_->Push(out); + } } protected: diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 648905743c8b075d250b727dfaf5541bbc75869e..ab1ae5b31dd6c51432de8bd9b80d295c74e88b45 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -48,11 +48,18 @@ class SendOp : public framework::OperatorBase { // should block until server responds. for (auto in : ins) { LOG(ERROR) << "sending grad: " << in; - bool ret = client_->SendVariable(scope, in, in); + bool ret = client_->SendVariable(scope, in); if (!ret) { LOG(ERROR) << "send variable error"; } } + for (auto in : ins) { + LOG(ERROR) << "updating from server..."; + bool ret = client_->GetVariable(scope); + if (!ret) { + LOG(ERROR) << "GetVariable error"; + } + } } protected: diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index ba699442ce60f6636593ad1262d903e89def5a7f..c8c9a4ef366869e369f54961849cf306b1d0c264 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -138,7 +138,6 @@ class Executor(object): inputs=opt_op.inputs, outputs=opt_op.outputs, attrs=opt_op.attrs) - print("optimize program: ", optimize_sub_program) pserver_program.global_block().append_op( type="recv", @@ -248,7 +247,7 @@ class Executor(object): outputs={'Out': [fetch_var]}, attrs={'col': i}) - self.executor.run(program.desc, scope, 0, True) + self.executor.run(program.desc, scope, 0, True, True) outs = [ core.get_fetch_variable(scope, fetch_var_name, i) for i in xrange(len(fetch_list)) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py index 208002c8d6cbbc073eaf225a08ed61e3e672afb5..5178131ea771dd2fcfadae5a2366d32c2e2141de 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv_dist.py @@ -44,6 +44,7 @@ exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1) pserver_endpoint = os.getenv("PSERVER") if pserver_endpoint: pserver_prog = exe.get_pserver_program(pserver_endpoint, optimize_ops) + print("pserver startup: ", fluid.default_startup_program()) exe.run(fluid.default_startup_program()) while True: exe.run(pserver_prog)