提交 20c24c05 编写于 作者: Y Yancey1989

singleton rpc_client

上级 28596a33
...@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars); checker(op.InputArgumentNames(), recv_vars);
} }
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
for (auto &name : op.OutputNames()) {
if (name == "RPCClient") {
return true;
}
}
return false;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types; std::unordered_map<std::string, proto::VarType::Type> var_types;
...@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (IsRPCOp(*op)) { if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program. // append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateRPCOp(&result, *op); CreateRPCOp(&result, *op);
......
...@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const; const ProgramDesc &program) const;
bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op, void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
......
...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum( .InEnum(
{static_cast<int>(OpRole::kForward), {static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
......
...@@ -24,6 +24,7 @@ enum class OpRole { ...@@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
kRPC = 0x0003,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { ...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes(); auto nodes = trait.nodes();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { ...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build(); dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS(); auto nodes = trait.nodes_in_DFS();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
......
...@@ -25,6 +25,21 @@ namespace paddle { ...@@ -25,6 +25,21 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
}
bool RPCClient::AsyncSendVariable(const std::string& ep, bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
......
...@@ -36,6 +36,7 @@ limitations under the License. */ ...@@ -36,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient { class RPCClient {
public: public:
RPCClient() {}
static RPCClient* GetInstance();
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -192,12 +197,17 @@ class RPCClient { ...@@ -192,12 +197,17 @@ class RPCClient {
private: private:
bool Proceed(); bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::atomic<int64_t> req_count_{0}; std::atomic<int64_t> req_count_{0};
std::mutex mutex_; std::mutex mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
}; };
} // namespace detail } // namespace detail
......
...@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) { ...@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
detail::RPCClient client; detail::RPCClient::GetInstance();
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name); // detail::RPCClient::GetInstance();
client.Wait(); // client->Wait();
// client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
// out_var_name);
// client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
......
...@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
...@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
...@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent. ...@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent.
} }
}; };
class FetchBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase { class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -103,5 +84,4 @@ namespace ops = paddle::operators; ...@@ -103,5 +84,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp, REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker, paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpVarTypeInference,
ops::FetchBarrierOpShapeInference); ops::FetchBarrierOpShapeInference);
...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
...@@ -87,17 +79,6 @@ the parameter server and fetch result back. ...@@ -87,17 +79,6 @@ the parameter server and fetch result back.
} }
}; };
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase { class PrefetchOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -110,5 +91,4 @@ namespace ops = paddle::operators; ...@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp, REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
ops::PrefetchOpVarTypeInference,
ops::PrefetchOpShapeInference); ops::PrefetchOpShapeInference);
...@@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
int sync_mode = Attr<int>("sync_mode"); int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), auto rpc_client = detail::RPCClient::GetInstance();
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
...@@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable(); AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Recv operator Recv operator
......
...@@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), auto rpc_client = detail::RPCClient::GetInstance();
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
...@@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
...@@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent. ...@@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent.
} }
}; };
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendBarrierOpShapeInference : public framework::InferShapeBase { class SendBarrierOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -106,5 +88,4 @@ namespace ops = paddle::operators; ...@@ -106,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference); ops::SendBarrierOpShapeInference);
...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { ...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server") AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
} }
}; };
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase { class SendOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { ...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference, ops::SendOpMaker, ops::SendOpShapeInference);
ops::SendOpShapeInference);
...@@ -177,7 +177,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -177,7 +177,7 @@ TEST(SendRecvOp, CPUDense) {
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}}, "send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs); {{"Out", {"Out"}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
...@@ -217,12 +217,12 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -217,12 +217,12 @@ TEST(SendRecvOp, CPUSparse) {
scope.Var("RPC_CLIENT_VAR"); scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs; f::AttributeMap attrs;
selected_port = listen_and_serv_op_ptr->GetSelectedPort(); selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint =
paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
"send", {{"X", {"x1"}}}, {{"Out", {"Out"}}}, attrs);
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>(); auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
...@@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() { void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
} }
}; };
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase { class SendVarsOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -116,5 +97,4 @@ namespace ops = paddle::operators; ...@@ -116,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp, REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference); ops::SendVarsOpShapeInference);
...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { ...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.value("Forward", framework::OpRole::kForward) .value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward) .value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize) .value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss); .value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
...@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None): ...@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap)) endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals()) helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
if not get_vars: if not get_vars:
get_vars = [] get_vars = []
for s in send_vars: for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v) get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op( helper.append_op(
type="send", type="send",
inputs={"X": send_vars}, inputs={"X": send_vars},
outputs={"Out": get_vars, outputs={"Out": get_vars},
"RPCClient": rpc_client_var}, attrs={
attrs={"endpoints": endpoints, "endpoints": endpoints,
"epmap": epmap}) "epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars return get_vars
......
...@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \ ...@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class VarBlock: class VarBlock:
...@@ -297,11 +299,6 @@ class DistributeTranspiler: ...@@ -297,11 +299,6 @@ class DistributeTranspiler:
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \ grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)] param_var_mapping[p_name][int(p_bid)]
rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
# step 3: transpile trainer side program, insert recv op and send op. # step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
...@@ -338,8 +335,11 @@ class DistributeTranspiler: ...@@ -338,8 +335,11 @@ class DistributeTranspiler:
index=index + 1, index=index + 1,
type="send_vars", type="send_vars",
inputs={"X": splited_vars}, inputs={"X": splited_vars},
outputs={"RPCClient": rpc_client_var}, outputs={},
attrs={"epmap": eplist}) attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
send_vars.append(var) send_vars.append(var)
...@@ -347,10 +347,11 @@ class DistributeTranspiler: ...@@ -347,10 +347,11 @@ class DistributeTranspiler:
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
inputs={}, inputs={},
outputs={"RPCClient": rpc_client_var}, outputs={},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"sync_mode": self.sync_mode "sync_mode": self.sync_mode,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
# step 3.2: insert recv op to receive parameters from parameter server # step 3.2: insert recv op to receive parameters from parameter server
...@@ -373,15 +374,20 @@ class DistributeTranspiler: ...@@ -373,15 +374,20 @@ class DistributeTranspiler:
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={}, inputs={},
outputs={"Out": splited_var, outputs={"Out": splited_var},
"RPCClient": rpc_client_var}, attrs={
attrs={"epmap": eps}) "epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
inputs={}, inputs={},
outputs={"RPCClient": rpc_client_var}, outputs={},
attrs={"endpoints": pserver_endpoints}) attrs={
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
...@@ -394,10 +400,8 @@ class DistributeTranspiler: ...@@ -394,10 +400,8 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, eplist)
eplist) self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
...@@ -617,8 +621,7 @@ class DistributeTranspiler: ...@@ -617,8 +621,7 @@ class DistributeTranspiler:
return s_prog return s_prog
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, def _replace_lookup_table_op_with_prefetch(self, program, eplist):
eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None self.prefetch_input_vars = None
self.prefetch_output_vars = None self.prefetch_output_vars = None
...@@ -665,11 +668,11 @@ class DistributeTranspiler: ...@@ -665,11 +668,11 @@ class DistributeTranspiler:
index=op_index + 1, index=op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': self.prefetch_input_vars},
outputs={ outputs={"Out": self.prefetch_output_vars},
"Out": self.prefetch_output_vars, attrs={
"RPCClient": rpc_client_var "epmap": eplist,
}, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
attrs={"epmap": eplist}) })
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -689,8 +692,7 @@ class DistributeTranspiler: ...@@ -689,8 +692,7 @@ class DistributeTranspiler:
# break for loop # break for loop
break break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var, def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers # 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name # there should only be one table_name
all_ops = program.global_block().ops all_ops = program.global_block().ops
...@@ -710,9 +712,12 @@ class DistributeTranspiler: ...@@ -710,9 +712,12 @@ class DistributeTranspiler:
index=op_index + 2, index=op_index + 2,
type="send_vars", type="send_vars",
inputs={'X': self.table_grad_list}, inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var}, outputs={},
attrs={"sync_send": True, attrs={
"epmap": pserver_endpoints}) "sync_send": True,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
break break
def _create_prefetch_block(self, pserver_index, pserver_program, def _create_prefetch_block(self, pserver_index, pserver_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册