diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index d21e0f7b96c80edf171c476ab8c21a45240a30e7..d8e711994c5dba15ce0a1c237558b121888902e3 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( checker(op.InputArgumentNames(), recv_vars); } -bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const { - for (auto &name : op.OutputNames()) { - if (name == "RPCClient") { - return true; - } - } - return false; -} - std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { std::unordered_map var_types; @@ -184,7 +175,9 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { - if (IsRPCOp(*op)) { + if (boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kRPC)) { // append rpc op if program is distributed trainer main program. // always use the first device CreateRPCOp(&result, *op); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index be17c2a92ef2f365933235717e437d2c900af8f8..e07597dbd80889c366babe79455beb12c9eb80d9 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { std::vector FindDistTrainRecvVars( const ProgramDesc &program) const; - bool IsRPCOp(const OpDesc &op) const; - void ConnectOp(SSAGraph *result, OpHandleBase *op, const std::string &prev_op_name) const; diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 5a4380a83a2e5bf492098032cd9de7bf274fe47e..ae9f4efd44acdcdff2806deea6826e4089459a78 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, .InEnum( {static_cast(OpRole::kForward), static_cast(OpRole::kBackward), - static_cast(OpRole::kOptimize), + static_cast(OpRole::kOptimize), static_cast(OpRole::kRPC), static_cast(OpRole::kLoss) | static_cast(OpRole::kForward), static_cast(OpRole::kLoss) | static_cast(OpRole::kBackward), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 9bd6ca6ea32734707a5c37b3ecfe449436c04c8c..8493b9d8b326c71a33b95bf95e5fc1743c686eb7 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -24,6 +24,7 @@ enum class OpRole { kForward = 0x0000, kBackward = 0x0001, kOptimize = 0x0002, + kRPC = 0x0003, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc index 51d38d6251d853fa8a02a4e22f819cfc44294453..9d7cceeb65888b8ba3fdf39e88fc2877abd82d11 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc @@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { GraphTraits trait(&dfg); auto nodes = trait.nodes(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; @@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { dfg.Build(); GraphTraits trait(&dfg); auto nodes = trait.nodes_in_DFS(); - int count = 0; + size_t count = 0; for (auto it = nodes.begin(); it != nodes.end(); ++it) { LOG(INFO) << "visiting " << it->name(); ++count; diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc index 4c9c7be40c143c748c12dc08c22a09ea590366a2..f7ce7786874285795878b655365974f082c00b44 100644 --- a/paddle/fluid/operators/detail/grpc_client.cc +++ b/paddle/fluid/operators/detail/grpc_client.cc @@ -25,6 +25,21 @@ namespace paddle { namespace operators { namespace detail { +std::once_flag RPCClient::init_flag_; + +std::unique_ptr RPCClient::rpc_client_(nullptr); + +RPCClient* RPCClient::GetInstance() { + std::call_once(init_flag_, &RPCClient::Init); + return rpc_client_.get(); +} + +void RPCClient::Init() { + if (rpc_client_.get() == nullptr) { + rpc_client_.reset(new RPCClient()); + } +} + bool RPCClient::AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h index e5007b509a30a5251e78e8636d53d81022dae0d3..449d5105afb8c02294a0ef57610e7de1b1631b35 100644 --- a/paddle/fluid/operators/detail/grpc_client.h +++ b/paddle/fluid/operators/detail/grpc_client.h @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h" +#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN namespace paddle { namespace operators { @@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { class RPCClient { public: + RPCClient() {} + + static RPCClient* GetInstance(); + bool AsyncSendVariable(const std::string& ep, const platform::DeviceContext& ctx, const framework::Scope& scope, @@ -192,12 +197,17 @@ class RPCClient { private: bool Proceed(); std::shared_ptr GetChannel(const std::string& ep); + // Init is called by GetInstance. + static void Init(); private: grpc::CompletionQueue cq_; std::map> channels_; std::atomic req_count_{0}; std::mutex mutex_; + static std::unique_ptr rpc_client_; + static std::once_flag init_flag_; + DISABLE_COPY_AND_ASSIGN(RPCClient); }; } // namespace detail diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index 73e75c9087fef756840c76db249f8996253ced64..264e3c6671f31f055e3520f82919ef74102cd0d7 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) { std::string in_var_name("ids"); std::string out_var_name("out"); - detail::RPCClient client; - client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, - out_var_name); - client.Wait(); + detail::RPCClient::GetInstance(); + + // detail::RPCClient::GetInstance(); + // client->Wait(); + // client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, + // out_var_name); + // client->Wait(); auto var = scope.Var(out_var_name); auto value = var->GetMutable()->value(); diff --git a/paddle/fluid/operators/fetch_barrier_op.cc b/paddle/fluid/operators/fetch_barrier_op.cc index 5d2e558699116c270dc603b0137dee7f16bee1f8..79ec02f52094121d01c6bda2a5d99d2211893e89 100644 --- a/paddle/fluid/operators/fetch_barrier_op.cc +++ b/paddle/fluid/operators/fetch_barrier_op.cc @@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); PADDLE_ENFORCE(rpc_client->Wait()); @@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( SendBarrier operator @@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent. } }; -class FetchBarrierOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class FetchBarrierOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -103,5 +84,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, - ops::FetchBarrierOpVarTypeInference, ops::FetchBarrierOpShapeInference); diff --git a/paddle/fluid/operators/prefetch_op.cc b/paddle/fluid/operators/prefetch_op.cc index 4cfea958e8e50156c90af8806414b043e15f8a9c..e0a9b24ac8978418a1a4ece62286e022bec8b834 100644 --- a/paddle/fluid/operators/prefetch_op.cc +++ b/paddle/fluid/operators/prefetch_op.cc @@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddOutput("Out", "(LoDTensor) result " "to be fetched from parameter server") @@ -87,17 +79,6 @@ the parameter server and fetch result back. } }; -class PrefetchOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class PrefetchOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -110,5 +91,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(prefetch, ops::PrefetchOp, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, - ops::PrefetchOpVarTypeInference, ops::PrefetchOpShapeInference); diff --git a/paddle/fluid/operators/recv_op.cc b/paddle/fluid/operators/recv_op.cc index d416ba1e1fda4e7803d0b5a00cb2f7b26ce215b8..d8ddb7b448910b5e0e6e71742eb2fdc6a225c919 100644 --- a/paddle/fluid/operators/recv_op.cc +++ b/paddle/fluid/operators/recv_op.cc @@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase { const platform::Place& place) const override { auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); - auto client_var_name = Output("RPCClient"); int sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); @@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; @@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( Recv operator diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 354eb4fa13913eb6ec01885cf411627bf8cfa61c..2c77ee2e2792d6fdd76bacd68b6c3b4a296b2e3a 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase { auto& ctx = *pool.Get(place); // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + + auto rpc_client = detail::RPCClient::GetInstance(); // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); @@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() { - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( SendBarrier operator @@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent. } }; -class SendBarrierOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendBarrierOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -106,5 +88,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, - ops::SendBarrierOpVarTypeInference, ops::SendBarrierOpShapeInference); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 95bb1f3c695297e6d8134a647925310207118a9b..a5150f242ca3b0befafa2443f0bc466e2aea85e4 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddOutput("Out", "(Tensor) Output tensor to be received from server") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which is" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. } }; -class SendOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, - ops::SendOpMaker, ops::SendOpVarTypeInference, - ops::SendOpShapeInference); + ops::SendOpMaker, ops::SendOpShapeInference); diff --git a/paddle/fluid/operators/send_recv_op_test.cc b/paddle/fluid/operators/send_recv_op_test.cc index d5303eaf50722234d205264e56892b1723104d53..2b3dc81676f8f2518b02b892d2da841a58ea76e4 100644 --- a/paddle/fluid/operators/send_recv_op_test.cc +++ b/paddle/fluid/operators/send_recv_op_test.cc @@ -177,75 +177,75 @@ TEST(SendRecvOp, CPUDense) { attrs.insert({"epmap", std::vector({endpoint})}); auto send_op = f::OpRegistry::CreateOp( "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); - send_op->Run(scope, place); - - auto in_var = scope.Var("x1"); - auto tensor = in_var->GetMutable(); - float *expected = tensor->data(); - auto out_var = scope.Var("Out"); - auto target = out_var->GetMutable(); - // x1 * 2 == x0 - EXPECT_NE(target->memory_size(), size_t(0)); - float *actual = target->data(); - for (int64_t i = 0; i < target->numel(); ++i) { - EXPECT_EQ(expected[i] * 2, actual[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(nullptr); - paddle::operators::ListenAndServOp::ResetPort(); + {{"Out", {"Out"}}, attrs); + send_op->Run(scope, place); + + auto in_var = scope.Var("x1"); + auto tensor = in_var->GetMutable(); + float *expected = tensor->data(); + auto out_var = scope.Var("Out"); + auto target = out_var->GetMutable(); + // x1 * 2 == x0 + EXPECT_NE(target->memory_size(), size_t(0)); + float *actual = target->data(); + for (int64_t i = 0; i < target->numel(); ++i) { + EXPECT_EQ(expected[i] * 2, actual[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(nullptr); + paddle::operators::ListenAndServOp::ResetPort(); } TEST(SendRecvOp, CPUSparse) { - std::atomic initialized; - initialized = false; - std::thread server_thread(StartServerNet, true, &initialized); - while (!initialized) { - } - auto *listen_and_serv_op_ptr = - static_cast( - listen_and_serv_op.get()); - ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); - listen_and_serv_op_ptr->WaitServerReady(); - - // local net - f::Scope scope; - p::CPUPlace place; - p::CPUDeviceContext ctx(place); - InitSelectedRowsInScope(place, &scope); - scope.Var("RPC_CLIENT_VAR"); - f::AttributeMap attrs; - selected_port = listen_and_serv_op_ptr->GetSelectedPort(); - std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); - attrs.insert({"endpoints", std::vector({endpoint})}); - attrs.insert({"epmap", std::vector({endpoint})}); - auto send_op = f::OpRegistry::CreateOp( - "send", {{"X", {"x1"}}}, - {{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); - send_op->Run(scope, place); - - auto x0 = scope.Var("x0")->GetMutable(); - auto x1 = scope.Var("x1")->GetMutable(); - auto out = scope.Var("Out")->GetMutable(); - auto actual = out->mutable_value(); - - std::unique_ptr expect{new f::SelectedRows()}; - auto expect_value = expect->mutable_value(); - expect_value->mutable_data(f::make_ddim({5, 10}), place); - - m::SelectedRowsAdd add_functor; - add_functor(ctx, *x0, *x1, expect.get()); - - EXPECT_EQ(actual->numel(), expect_value->numel()); - EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); - - for (int64_t i = 0; i < expect_value->numel(); ++i) { - EXPECT_EQ(expect_value->mutable_data(place)[i], - actual->mutable_data(place)[i]); - } - listen_and_serv_op->Stop(); - server_thread.join(); - listen_and_serv_op.reset(); - paddle::operators::ListenAndServOp::ResetPort(); + std::atomic initialized; + initialized = false; + std::thread server_thread(StartServerNet, true, &initialized); + while (!initialized) { + } + auto *listen_and_serv_op_ptr = + static_cast( + listen_and_serv_op.get()); + ASSERT_TRUE(listen_and_serv_op_ptr != nullptr); + listen_and_serv_op_ptr->WaitServerReady(); + + // local net + f::Scope scope; + p::CPUPlace place; + p::CPUDeviceContext ctx(place); + InitSelectedRowsInScope(place, &scope); + scope.Var("RPC_CLIENT_VAR"); + f::AttributeMap attrs; + selected_port = listen_and_serv_op_ptr->GetSelectedPort(); + std::string endpoint = + paddle::string::Sprintf("127.0.0.1:%d", selected_port); + attrs.insert({"endpoints", std::vector({endpoint})}); + attrs.insert({"epmap", std::vector({endpoint})}); + auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}}, + {{"Out", {"Out"}}}, attrs); + send_op->Run(scope, place); + + auto x0 = scope.Var("x0")->GetMutable(); + auto x1 = scope.Var("x1")->GetMutable(); + auto out = scope.Var("Out")->GetMutable(); + auto actual = out->mutable_value(); + + std::unique_ptr expect{new f::SelectedRows()}; + auto expect_value = expect->mutable_value(); + expect_value->mutable_data(f::make_ddim({5, 10}), place); + + m::SelectedRowsAdd add_functor; + add_functor(ctx, *x0, *x1, expect.get()); + + EXPECT_EQ(actual->numel(), expect_value->numel()); + EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size()); + + for (int64_t i = 0; i < expect_value->numel(); ++i) { + EXPECT_EQ(expect_value->mutable_data(place)[i], + actual->mutable_data(place)[i]); + } + listen_and_serv_op->Stop(); + server_thread.join(); + listen_and_serv_op.reset(); + paddle::operators::ListenAndServOp::ResetPort(); } diff --git a/paddle/fluid/operators/send_vars_op.cc b/paddle/fluid/operators/send_vars_op.cc index 8d5b5f4292a73407ea55c2811d8020b6e89cd262..fe839dab6924618c8a4c39868d9bf86056a0be40 100644 --- a/paddle/fluid/operators/send_vars_op.cc +++ b/paddle/fluid/operators/send_vars_op.cc @@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase { // For profiling platform::RecordEvent record_event(Type(), &ctx); - auto client_var_name = Output("RPCClient"); - PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), - "Can not find variable '%s' in the scope.", - client_var_name); - auto* client_var = scope.FindVar(client_var_name); - detail::RPCClient* rpc_client = client_var->GetMutable(); + auto rpc_client = detail::RPCClient::GetInstance(); for (size_t i = 0; i < ins.size(); i++) { if (NeedSend(scope, ins[i])) { @@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { void Make() { AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") .AsDuplicable(); - AddOutput("RPCClient", - "(RPCClient) The RPC client object which will be" - "initialized at most once."); AddComment(R"DOC( Send operator @@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server } }; -class SendVarsOpVarTypeInference : public framework::VarTypeInference { - public: - void operator()(const framework::OpDesc& op_desc, - framework::BlockDesc* block) const override { - auto out_var_name = op_desc.Output("RPCClient").front(); - auto& out_var = block->FindRecursiveOrCreateVar(out_var_name); - auto var_type = framework::proto::VarType::RAW; - out_var.SetType(var_type); - } -}; - class SendVarsOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext* ctx) const override {} @@ -116,5 +97,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(send_vars, ops::SendVarsOp, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, - ops::SendVarsOpVarTypeInference, ops::SendVarsOpShapeInference); diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 9111abca5aac97e9d5c7b00ce5173f08e49cda12..76aa7d2010682416f68e982e9b89da9813abb078 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { .value("Forward", framework::OpRole::kForward) .value("Backward", framework::OpRole::kBackward) .value("Optimize", framework::OpRole::kOptimize) - .value("Loss", framework::OpRole::kLoss); + .value("Loss", framework::OpRole::kLoss) + .value("RPC", framework::OpRole::kRPC); op_proto_and_checker_maker.def( "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 03d4602f7a99dc335260cffdcdc30a839f3988cd..8758ac9f94ab91b5be5fc70917c64db38997d1c1 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None): endpoints = list(set(epmap)) helper = LayerHelper("Send", **locals()) - rpc_client_var = default_main_program().global_block().create_var( - name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW) if not get_vars: get_vars = [] for s in send_vars: v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) get_vars.append(v) + rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName() helper.append_op( type="send", inputs={"X": send_vars}, - outputs={"Out": get_vars, - "RPCClient": rpc_client_var}, - attrs={"endpoints": endpoints, - "epmap": epmap}) + outputs={"Out": get_vars}, + attrs={ + "endpoints": endpoints, + "epmap": epmap, + rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC + }) + return get_vars diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index a9de5419faadba82f92913526999d22dd4c64f3e..4e17fdb16b6c2eb9846fd27ccde36e532d600a7e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \ LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" -RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" +RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName( +) +RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC class VarBlock: @@ -297,11 +299,6 @@ class DistributeTranspiler: grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ param_var_mapping[p_name][int(p_bid)] - rpc_client_var = program.global_block().create_var( - name=RPC_CLIENT_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - # step 3: transpile trainer side program, insert recv op and send op. # create mapping of endpoint -> split var to create pserver side program @@ -338,8 +335,11 @@ class DistributeTranspiler: index=index + 1, type="send_vars", inputs={"X": splited_vars}, - outputs={"RPCClient": rpc_client_var}, - attrs={"epmap": eplist}) + outputs={}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for _, var in enumerate(splited_vars): send_vars.append(var) @@ -347,10 +347,11 @@ class DistributeTranspiler: program.global_block().append_op( type="send_barrier", inputs={}, - outputs={"RPCClient": rpc_client_var}, + outputs={}, attrs={ "endpoints": pserver_endpoints, - "sync_mode": self.sync_mode + "sync_mode": self.sync_mode, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE }) # step 3.2: insert recv op to receive parameters from parameter server @@ -373,15 +374,20 @@ class DistributeTranspiler: program.global_block().append_op( type="recv", inputs={}, - outputs={"Out": splited_var, - "RPCClient": rpc_client_var}, - attrs={"epmap": eps}) + outputs={"Out": splited_var}, + attrs={ + "epmap": eps, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) program.global_block().append_op( type="fetch_barrier", inputs={}, - outputs={"RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints}) + outputs={}, + attrs={ + "endpoints": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -394,10 +400,8 @@ class DistributeTranspiler: attrs={"axis": 0}) if self.has_distributed_lookup_table: - self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, - eplist) - self._split_table_grad_and_add_send_vars(program, rpc_client_var, - pserver_endpoints) + self._replace_lookup_table_op_with_prefetch(program, eplist) + self._split_table_grad_and_add_send_vars(program, pserver_endpoints) def get_trainer_program(self): # remove optimize ops and add a send op to main_program @@ -617,8 +621,7 @@ class DistributeTranspiler: return s_prog # transpiler function for dis lookup_table - def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, - eplist): + def _replace_lookup_table_op_with_prefetch(self, program, eplist): # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op self.prefetch_input_vars = None self.prefetch_output_vars = None @@ -665,11 +668,11 @@ class DistributeTranspiler: index=op_index + 1, type="prefetch", inputs={'X': self.prefetch_input_vars}, - outputs={ - "Out": self.prefetch_output_vars, - "RPCClient": rpc_client_var - }, - attrs={"epmap": eplist}) + outputs={"Out": self.prefetch_output_vars}, + attrs={ + "epmap": eplist, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) # insert concat_op program.global_block().insert_op( @@ -689,8 +692,7 @@ class DistributeTranspiler: # break for loop break - def _split_table_grad_and_add_send_vars(self, program, rpc_client_var, - pserver_endpoints): + def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints): # 2. add split_ids_op and send_vars_op to send gradient to pservers # there should only be one table_name all_ops = program.global_block().ops @@ -710,9 +712,12 @@ class DistributeTranspiler: index=op_index + 2, type="send_vars", inputs={'X': self.table_grad_list}, - outputs={"RPCClient": rpc_client_var}, - attrs={"sync_send": True, - "epmap": pserver_endpoints}) + outputs={}, + attrs={ + "sync_send": True, + "epmap": pserver_endpoints, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) break def _create_prefetch_block(self, pserver_index, pserver_program,