From ccefde203adb1a5af99fb3ce30ba553f0aec1680 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Tue, 30 Jan 2018 11:30:31 +0800 Subject: [PATCH] follow comments --- paddle/operators/listen_and_serv_op.cc | 207 +++++++++++++++++++ paddle/operators/recv_op.cc | 9 +- paddle/operators/send_op.cc | 12 +- python/paddle/v2/fluid/tests/test_recv_op.py | 2 +- 4 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 paddle/operators/listen_and_serv_op.cc diff --git a/paddle/operators/listen_and_serv_op.cc b/paddle/operators/listen_and_serv_op.cc new file mode 100644 index 0000000000..5745938ed9 --- /dev/null +++ b/paddle/operators/listen_and_serv_op.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include + +#include "paddle/framework/executor.h" +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/proto_desc.h" +#include "paddle/operators/detail/grpc_server.h" +#include "paddle/operators/detail/sendrecvop_utils.h" +#include "paddle/operators/detail/simple_block_queue.h" +#include "paddle/string/printf.h" + +namespace paddle { +namespace operators { + +constexpr char kOptimizeBlock[] = "OptimizeBlock"; + +void RunServer(std::shared_ptr service) { + service->RunSyncUpdate(); + VLOG(4) << "RunServer thread end"; +} + +static void CreateTensorFromMessageType(framework::Variable *var, + sendrecv::VarType var_type) { + if (var_type == sendrecv::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else { + PADDLE_THROW( + "VariableMessage type %d is not in " + "[LoDTensor, SelectedRows]", + var_type); + } +} + +class ListenAndServOp : public framework::OperatorBase { + public: + ListenAndServOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) { + if (!rpc_service_) { + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + server_thread_.reset(new std::thread(RunServer, rpc_service_)); + } + } + + void Stop() override { + detail::MessageWithName term_msg; + term_msg.first = LISTEN_TERMINATE_MESSAGE; + rpc_service_->Push(term_msg); + rpc_service_->ShutDown(); + server_thread_->join(); + } + + std::string GetGradVarNameForTrainer(const std::string &varname) const { + if (grads_counter_.find(varname) == grads_counter_.end()) { + grads_counter_[varname] = 0; + } + return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); + } + + void Run(const framework::Scope &scope, + const platform::Place &dev_place) const override { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(dev_place); + framework::Scope &recv_scope = scope.NewScope(); + + // FIXME(Yancey1989): initialize rpc server with lazy mode. + rpc_service_->SetScope(&recv_scope); + rpc_service_->SetDevCtx(&dev_ctx); + auto param_list = Attr>("ParamList"); + auto grad_list = Attr>("GradList"); + auto fan_in = Attr("Fanin"); + + auto *block = Attr(kOptimizeBlock); + auto *program = block->Program(); + framework::Executor executor(dev_place); + + // TODO(typhoonzero): change this to a while_op for every cluster-batch. + bool exit_flag = false; + while (!exit_flag) { + // Get from multiple trainers, we don't care about the order in which + // the gradients arrives, just add suffix 0~n and merge the gradient. + rpc_service_->SetCond(0); + size_t recv_var_cnt = 0; + int batch_barrier = 0; + while (batch_barrier != fan_in) { + const detail::MessageWithName &v = rpc_service_->Get(); + auto grad_var_name = v.first; + if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { + LOG(INFO) << "received terminate message and exit"; + exit_flag = true; + break; + } else if (grad_var_name == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "recv batch barrier message"; + batch_barrier++; + continue; + } else { + // receive a variable + recv_var_cnt++; + auto it = + std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad has no paired param:" << grad_var_name; + } + VLOG(3) << "received grad: " << grad_var_name + << " updating param: " << param_var_name; + + if (fan_in > 1) { + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + } + auto *var = recv_scope.FindVar(grad_var_name); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << grad_var_name; + PADDLE_THROW("Can not find server side var"); + } + detail::DeserializeFromMessage(v.second, dev_ctx, var); + } + } + VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; + // TODO(Yancey1989): merge SelectedRows variables here + if (exit_flag) { + rpc_service_->ShutDown(); + } + + try { + executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ + false /*create_local_scope*/, false /*create_vars*/); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + rpc_service_->SetCond(1); + rpc_service_->WaitClientGet(recv_var_cnt); + grads_counter_.clear(); + } // while(true) + } + + protected: + std::shared_ptr rpc_service_; + std::shared_ptr server_thread_; + mutable std::unordered_map grads_counter_; +}; + +class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ListenAndServOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddComment(R"DOC( +ListenAndServ operator + +This operator will start a RPC server which can receive variables +from send_op and send back variables to recv_op. +)DOC"); + AddAttr("endpoint", + "(string, default 127.0.0.1:6164)" + "IP address to listen on.") + .SetDefault("127.0.0.1:6164") + .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); + AddAttr(kOptimizeBlock, + "BlockID to run on server side."); + AddAttr>( + "ParamList", "type list of string", + "grad->param name mapping to find which parameters to optimize.") + .SetDefault({}); + AddAttr>( + "GradList", "type list of string", + "grad->param name mapping to find which parameters to optimize.") + .SetDefault({}); + AddAttr("Fanin", "type int", + "Number of trainers in the current cluster job") + .SetDefault(1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp, + ops::ListenAndServOpMaker); \ No newline at end of file diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 1e64d5c65c..ba71094219 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -55,19 +55,12 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { public: RecvOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); - AddOutput("Out", "(Tensor) Output tensor to be received from server") - .AsDuplicable(); + AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); AddComment(R"DOC( Recv operator This operator can get variables from server side. )DOC"); - AddAttr>("endpoints", - "(string vector, default 127.0.0.1:6164)" - "Server endpoints to recv variables" - "from.") - .SetDefault({}); AddAttr>("epmap", "(string vector, default 127.0.0.1:6164)" "Server endpoints in the order of input " diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 9c180a7244..51e6b9de58 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -37,7 +37,6 @@ class SendOp : public framework::OperatorBase { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - bool do_get = Attr("DoGet"); std::vector endpoints = Attr>("endpoints"); @@ -55,7 +54,7 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(client_.Wait()); - if (do_get) { + if (outs.size() > 0) { for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); @@ -65,7 +64,8 @@ class SendOp : public framework::OperatorBase { } private: - // TODO(typhoonzero): put RPCClient in a Variable. + // TODO(typhoonzero): put RPCClient in a Variable, so that + // send and recv can use the same connection. mutable detail::RPCClient client_; }; @@ -81,6 +81,8 @@ Send operator This operator will send tensor to recv_op at the parameter server. )DOC"); + // TODO(typhoonzero): remove this attr generate de-duplicated vector from + // epmap when initializing. AddAttr>("endpoints", "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") @@ -90,10 +92,6 @@ This operator will send tensor to recv_op at the parameter server. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); - AddAttr("DoGet", - "(bool, default true)" - "Whether do GetVariable call after send") - .SetDefault(true); } }; diff --git a/python/paddle/v2/fluid/tests/test_recv_op.py b/python/paddle/v2/fluid/tests/test_recv_op.py index def5ca9442..3a02b88241 100644 --- a/python/paddle/v2/fluid/tests/test_recv_op.py +++ b/python/paddle/v2/fluid/tests/test_recv_op.py @@ -29,7 +29,7 @@ class TestRecvOp(unittest.TestCase): p = Process(target=self.init_serv, args=(place, )) p.daemon = True p.start() - time.sleep(5) + time.sleep(1) self.init_client(place) # FIXME(typhoonzero): find a way to gracefully shutdown the server. os.system("kill -9 %d" % p.pid) -- GitLab