diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 593c35879ae2b3680b93ac5d8443110e61cb99fe..f8af4a290b470f31980245dfd54e600d26da576d 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -12,179 +12,67 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include -#include -#include - -#include "paddle/framework/executor.h" +#include "paddle/framework/data_type.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/proto_desc.h" -#include "paddle/operators/detail/grpc_server.h" -#include "paddle/operators/detail/sendrecvop_utils.h" -#include "paddle/operators/detail/simple_block_queue.h" -#include "paddle/string/printf.h" -#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" +#include +#include "paddle/operators/detail/grpc_client.h" namespace paddle { namespace operators { -constexpr char kOptimizeBlock[] = "OptimizeBlock"; - -void RunServer(std::shared_ptr service) { - service->RunSyncUpdate(); - VLOG(4) << "RunServer thread end"; -} - -static void CreateTensorFromMessageType(framework::Variable *var, - sendrecv::VarType var_type) { - if (var_type == sendrecv::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == sendrecv::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else { - PADDLE_THROW( - "VariableMessage type %d is not in " - "[LoDTensor, SelectedRows]", - var_type); - } -} - -class RecvOp : public framework::OperatorBase { +class SendOp : public framework::OperatorBase { public: - RecvOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorBase(type, inputs, outputs, attrs) { - if (!rpc_service_) { - std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); - server_thread_.reset(new std::thread(RunServer, rpc_service_)); + SendOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope& scope, + const platform::Place& place) const override { + auto outs = Outputs("Out"); + std::vector epmap = Attr>("epmap"); + + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& ctx = *pool.Get(place); + + for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i]; + client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } - } - - void Stop() override { - detail::MessageWithName term_msg; - term_msg.first = LISTEN_TERMINATE_MESSAGE; - rpc_service_->Push(term_msg); - rpc_service_->ShutDown(); - server_thread_->join(); - } - - std::string GetGradVarNameForTrainer(const std::string &varname) const { - if (grads_counter_.find(varname) == grads_counter_.end()) { - grads_counter_[varname] = 0; - } - return string::Sprintf("%s.trainer_%d", varname, grads_counter_[varname]++); - } - void Run(const framework::Scope &scope, - const platform::Place &dev_place) const override { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(dev_place); - framework::Scope &recv_scope = scope.NewScope(); - - // FIXME(Yancey1989): initialize rpc server with laze mode. - rpc_service_->SetScope(&recv_scope); - rpc_service_->SetDevCtx(&dev_ctx); - auto param_list = Attr>("ParamList"); - auto grad_list = Attr>("GradList"); - auto fan_in = Attr("Fanin"); - size_t param_count = param_list.size(); - - auto *block = Attr(kOptimizeBlock); - auto *program = block->Program(); - framework::Executor executor(dev_place); - - // TODO(typhoonzero): change this to a while_op for every cluster-batch. - bool exit_flag = false; - size_t barrier_size = param_count * fan_in; - while (!exit_flag) { - // Get from multiple trainers, we don't care about the order in which - // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(0); - for (size_t i = 0; i < barrier_size; ++i) { - const detail::MessageWithName &v = rpc_service_->Get(); - auto grad_var_name = v.first; - if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { - LOG(INFO) << "received terminate message and exit"; - exit_flag = true; - break; - } - auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; - } else { - LOG(ERROR) << "grad has no paired param:" << grad_var_name; - } - VLOG(3) << "received grad: " << grad_var_name - << " updating param: " << param_var_name; - if (fan_in > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } - auto *var = recv_scope.FindVar(grad_var_name); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << grad_var_name; - PADDLE_THROW("Can not find server side var"); - } - detail::DeserializeFromMessage(v.second, dev_ctx, var); - } - if (exit_flag) { - break; - } - - try { - executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ - false /*create_local_scope*/, false /*create_vars*/); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - rpc_service_->SetCond(1); - rpc_service_->WaitClientGet(barrier_size); - grads_counter_.clear(); - } // while(true) + PADDLE_ENFORCE(client_.Wait()); } - protected: - std::shared_ptr rpc_service_; - std::shared_ptr server_thread_; - mutable std::unordered_map grads_counter_; + private: + mutable detail::RPCClient client_; }; -class RecvOpMaker : public framework::OpProtoAndCheckerMaker { +class SendOpMaker : public framework::OpProtoAndCheckerMaker { public: - RecvOpMaker(OpProto *proto, OpAttrChecker *op_checker) + SendOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("RX", "(Tensor) Input tensor to be optimized").AsDuplicable(); + AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); + AddOutput("Out", "(Tensor) Output tensor to be received from server") + .AsDuplicable(); AddComment(R"DOC( -Recv operator +Send operator -This operator will recieve tensor from send_op +This operator will send tensor to recv_op at the parameter server. )DOC"); - AddAttr("endpoint", - "(string, default 127.0.0.1:6164)" - "IP address to listen on.") - .SetDefault("127.0.0.1:6164") - .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr( - kOptimizeBlock, "Serialized ProgramDesc string for recv to run."); - AddAttr>( - "ParamList", "type list of string", - "grad->param name mapping to find which parameters to optimize.") + AddAttr>("endpoints", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints to send variables to.") .SetDefault({}); - AddAttr>( - "GradList", "type list of string", - "grad->param name mapping to find which parameters to optimize.") + AddAttr>("epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input " + "variables for mapping") .SetDefault({}); - AddAttr("Fanin", "type int", - "Number of trainers in the current cluster job") - .SetDefault(1); } }; @@ -193,4 +81,4 @@ This operator will recieve tensor from send_op namespace ops = paddle::operators; -REGISTER_OPERATOR(recv, ops::RecvOp, ops::RecvOpMaker); +REGISTER_OPERATOR(send, ops::SendOp, ops::SendOpMaker); diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 5aa66c20eaf77959089100f8dcee55f2bc83a71a..c90e4d8ef06581e2fb7f99dcbe70df32c1ad80c0 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -37,6 +37,7 @@ class SendOp : public framework::OperatorBase { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + bool do_get = Attr("DoGet"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -46,9 +47,11 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(client_.Wait()); - for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; - client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + if (do_get) { + for (size_t i = 0; i < outs.size(); i++) { + VLOG(3) << "getting " << outs[i]; + client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); + } } PADDLE_ENFORCE(client_.Wait()); @@ -79,6 +82,10 @@ This operator will send tensor to recv_op at the parameter server. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("DoGet", + "(bool, default true)" + "Whether do GetVariable call after send") + .SetDefault(true); } }; diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index 045a0f5434f339bab345d14881ed05450ce6588d..874eac9711377d5d8d3eb022511fcc6535a5098f 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/string/printf.h" USE_NO_KERNEL_OP(send); -USE_NO_KERNEL_OP(recv); +USE_NO_KERNEL_OP(listen_and_serv); USE_OP(sum); namespace f = paddle::framework; @@ -33,7 +33,7 @@ namespace p = paddle::platform; namespace m = paddle::operators::math; // global for simplicity. -std::unique_ptr recv_op; +std::unique_ptr listen_and_serv_op; void InitTensorsInScope(f::Scope &scope, p::CPUPlace &place) { p::CPUDeviceContext ctx(place); @@ -120,7 +120,7 @@ void StartServerNet(bool is_sparse) { InitTensorsInScope(scope, place); } - // sub program run in recv_op, for simple test we use sum + // sub program run in listen_and_serv_op, for simple test we use sum f::ProgramDesc program; f::BlockDesc *block = program.MutableBlock(0); // X for server side tensors, RX for received tensers, must be of same shape. @@ -131,8 +131,9 @@ void StartServerNet(bool is_sparse) { attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); attrs.insert({"OptimizeBlock", block}); - recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); - recv_op->Run(scope, place); + listen_and_serv_op = + f::OpRegistry::CreateOp("listen_and_serv", {{"RX", {"x1"}}}, {}, attrs); + listen_and_serv_op->Run(scope, place); } TEST(SendRecvOp, CPUDense) { @@ -161,9 +162,9 @@ TEST(SendRecvOp, CPUDense) { for (int64_t i = 0; i < target->numel(); ++i) { EXPECT_EQ(expected[i] * 2, actual[i]); } - recv_op->Stop(); + listen_and_serv_op->Stop(); server_thread.join(); - recv_op.reset(nullptr); + listen_and_serv_op.reset(nullptr); } TEST(SendRecvOp, CPUSparse) { @@ -200,7 +201,7 @@ TEST(SendRecvOp, CPUSparse) { EXPECT_EQ(expect_value->mutable_data(place)[i], actual->mutable_data(place)[i]); } - recv_op->Stop(); + listen_and_serv_op->Stop(); server_thread.join(); - recv_op.reset(); + listen_and_serv_op.reset(); }