提交 20c24c05 编写于 作者: Y Yancey1989

singleton rpc_client

上级 28596a33
......@@ -146,15 +146,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
checker(op.InputArgumentNames(), recv_vars);
}
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
for (auto &name : op.OutputNames()) {
if (name == "RPCClient") {
return true;
}
}
return false;
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unordered_map<std::string, proto::VarType::Type> var_types;
......@@ -184,7 +175,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
if (IsRPCOp(*op)) {
if (boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device
CreateRPCOp(&result, *op);
......
......@@ -80,8 +80,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
bool IsRPCOp(const OpDesc &op) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
......
......@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum(
{static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward),
......
......@@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000,
kBackward = 0x0001,
kOptimize = 0x0002,
kRPC = 0x0003,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
......@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS();
int count = 0;
size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name();
++count;
......
......@@ -25,6 +25,21 @@ namespace paddle {
namespace operators {
namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
}
bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
......
......@@ -36,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle {
namespace operators {
......@@ -162,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient {
public:
RPCClient() {}
static RPCClient* GetInstance();
bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
......@@ -192,12 +197,17 @@ class RPCClient {
private:
bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private:
grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
};
} // namespace detail
......
......@@ -121,10 +121,13 @@ TEST(PREFETCH, DISABLED_CPU) {
std::string in_var_name("ids");
std::string out_var_name("out");
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
detail::RPCClient::GetInstance();
// detail::RPCClient::GetInstance();
// client->Wait();
// client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
// out_var_name);
// client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
......
......@@ -43,12 +43,7 @@ class FetchBarrierOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait());
......@@ -63,9 +58,6 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
SendBarrier operator
......@@ -80,17 +72,6 @@ the Parameter Server would knew all variables have been sent.
}
};
class FetchBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -103,5 +84,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpVarTypeInference,
ops::FetchBarrierOpShapeInference);
......@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out",
"(LoDTensor) result "
"to be fetched from parameter server")
......@@ -87,17 +79,6 @@ the parameter server and fetch result back.
}
};
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
ops::PrefetchOpVarTypeInference,
ops::PrefetchOpShapeInference);
......@@ -37,7 +37,6 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
......@@ -45,11 +44,7 @@ class RecvOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
......@@ -65,9 +60,6 @@ class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("Out", "(Tensor) Variables to get from server.").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
Recv operator
......
......@@ -43,12 +43,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
......@@ -65,9 +61,6 @@ class SendBarrierOp : public framework::OperatorBase {
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
SendBarrier operator
......@@ -83,17 +76,6 @@ the Parameter Server would knew all variables have been sent.
}
};
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -106,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference);
......@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
Send operator
......@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
}
};
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference,
ops::SendOpShapeInference);
ops::SendOpMaker, ops::SendOpShapeInference);
......@@ -177,75 +177,75 @@ TEST(SendRecvOp, CPUDense) {
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place);
auto in_var = scope.Var("x1");
auto tensor = in_var->GetMutable<f::LoDTensor>();
float *expected = tensor->data<float>();
auto out_var = scope.Var("Out");
auto target = out_var->GetMutable<f::LoDTensor>();
// x1 * 2 == x0
EXPECT_NE(target->memory_size(), size_t(0));
float *actual = target->data<float>();
for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
{{"Out", {"Out"}}, attrs);
send_op->Run(scope, place);
auto in_var = scope.Var("x1");
auto tensor = in_var->GetMutable<f::LoDTensor>();
float *expected = tensor->data<float>();
auto out_var = scope.Var("Out");
auto target = out_var->GetMutable<f::LoDTensor>();
// x1 * 2 == x0
EXPECT_NE(target->memory_size(), size_t(0));
float *actual = target->data<float>();
for (int64_t i = 0; i < target->numel(); ++i) {
EXPECT_EQ(expected[i] * 2, actual[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset(nullptr);
paddle::operators::ListenAndServOp::ResetPort();
}
TEST(SendRecvOp, CPUSparse) {
std::atomic<bool> initialized;
initialized = false;
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
listen_and_serv_op_ptr->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp(
"send", {{"X", {"x1"}}},
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
auto x1 = scope.Var("x1")->GetMutable<f::SelectedRows>();
auto out = scope.Var("Out")->GetMutable<f::SelectedRows>();
auto actual = out->mutable_value();
std::unique_ptr<f::SelectedRows> expect{new f::SelectedRows()};
auto expect_value = expect->mutable_value();
expect_value->mutable_data<float>(f::make_ddim({5, 10}), place);
m::SelectedRowsAdd<p::CPUDeviceContext, float> add_functor;
add_functor(ctx, *x0, *x1, expect.get());
EXPECT_EQ(actual->numel(), expect_value->numel());
EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size());
for (int64_t i = 0; i < expect_value->numel(); ++i) {
EXPECT_EQ(expect_value->mutable_data<float>(place)[i],
actual->mutable_data<float>(place)[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
std::atomic<bool> initialized;
initialized = false;
std::thread server_thread(StartServerNet, true, &initialized);
while (!initialized) {
}
auto *listen_and_serv_op_ptr =
static_cast<paddle::operators::ListenAndServOp *>(
listen_and_serv_op.get());
ASSERT_TRUE(listen_and_serv_op_ptr != nullptr);
listen_and_serv_op_ptr->WaitServerReady();
// local net
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
InitSelectedRowsInScope(place, &scope);
scope.Var("RPC_CLIENT_VAR");
f::AttributeMap attrs;
selected_port = listen_and_serv_op_ptr->GetSelectedPort();
std::string endpoint =
paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
{{"Out", {"Out"}}}, attrs);
send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
auto x1 = scope.Var("x1")->GetMutable<f::SelectedRows>();
auto out = scope.Var("Out")->GetMutable<f::SelectedRows>();
auto actual = out->mutable_value();
std::unique_ptr<f::SelectedRows> expect{new f::SelectedRows()};
auto expect_value = expect->mutable_value();
expect_value->mutable_data<float>(f::make_ddim({5, 10}), place);
m::SelectedRowsAdd<p::CPUDeviceContext, float> add_functor;
add_functor(ctx, *x0, *x1, expect.get());
EXPECT_EQ(actual->numel(), expect_value->numel());
EXPECT_EQ(out->rows().size(), x0->rows().size() + x1->rows().size());
for (int64_t i = 0; i < expect_value->numel(); ++i) {
EXPECT_EQ(expect_value->mutable_data<float>(place)[i],
actual->mutable_data<float>(place)[i]);
}
listen_and_serv_op->Stop();
server_thread.join();
listen_and_serv_op.reset();
paddle::operators::ListenAndServOp::ResetPort();
}
......@@ -45,12 +45,7 @@ class SendVarsOp : public framework::OperatorBase {
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
......@@ -73,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC(
Send operator
......@@ -93,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
}
};
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
......@@ -116,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference);
......@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss);
.value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
......@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars,
"RPCClient": rpc_client_var},
attrs={"endpoints": endpoints,
"epmap": epmap})
outputs={"Out": get_vars},
attrs={
"endpoints": endpoints,
"epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars
......
......@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class VarBlock:
......@@ -297,11 +299,6 @@ class DistributeTranspiler:
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)]
rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
# step 3: transpile trainer side program, insert recv op and send op.
# create mapping of endpoint -> split var to create pserver side program
......@@ -338,8 +335,11 @@ class DistributeTranspiler:
index=index + 1,
type="send_vars",
inputs={"X": splited_vars},
outputs={"RPCClient": rpc_client_var},
attrs={"epmap": eplist})
outputs={},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
......@@ -347,10 +347,11 @@ class DistributeTranspiler:
program.global_block().append_op(
type="send_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
outputs={},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode
"sync_mode": self.sync_mode,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
# step 3.2: insert recv op to receive parameters from parameter server
......@@ -373,15 +374,20 @@ class DistributeTranspiler:
program.global_block().append_op(
type="recv",
inputs={},
outputs={"Out": splited_var,
"RPCClient": rpc_client_var},
attrs={"epmap": eps})
outputs={"Out": splited_var},
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
outputs={},
attrs={
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......@@ -394,10 +400,8 @@ class DistributeTranspiler:
attrs={"axis": 0})
if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var,
eplist)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
self._replace_lookup_table_op_with_prefetch(program, eplist)
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
......@@ -617,8 +621,7 @@ class DistributeTranspiler:
return s_prog
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
def _replace_lookup_table_op_with_prefetch(self, program, eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None
self.prefetch_output_vars = None
......@@ -665,11 +668,11 @@ class DistributeTranspiler:
index=op_index + 1,
type="prefetch",
inputs={'X': self.prefetch_input_vars},
outputs={
"Out": self.prefetch_output_vars,
"RPCClient": rpc_client_var
},
attrs={"epmap": eplist})
outputs={"Out": self.prefetch_output_vars},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
# insert concat_op
program.global_block().insert_op(
......@@ -689,8 +692,7 @@ class DistributeTranspiler:
# break for loop
break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var,
pserver_endpoints):
def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name
all_ops = program.global_block().ops
......@@ -710,9 +712,12 @@ class DistributeTranspiler:
index=op_index + 2,
type="send_vars",
inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var},
attrs={"sync_send": True,
"epmap": pserver_endpoints})
outputs={},
attrs={
"sync_send": True,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
break
def _create_prefetch_block(self, pserver_index, pserver_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册