未验证 提交 d92a75be 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #10550 from Yancey1989/overlap_send_op

overlap send ops and backward ops
...@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context ...@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
...@@ -26,7 +26,7 @@ endif() ...@@ -26,7 +26,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -28,6 +29,10 @@ ...@@ -28,6 +29,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
} }
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
OpDesc *send_op) const { const ProgramDesc &program) const {
if (send_op == nullptr) { std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false; return false;
} }
...@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars * Check any of opvars contains `.block` and in sendvars
*/ */
auto checker = [](const std::vector<std::string> &opvars, auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool { const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) { for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos && if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true; return true;
} }
} }
return false; return false;
}; };
if (op.Type() == "split" || op.Type() == "split_byref") { return checker(op.OutputArgumentNames(), send_vars) ||
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); checker(op.InputArgumentNames(), recv_vars);
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
} }
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
// Find "send" op first for split is in front of send. // find send/recv vars so that we can place the distributed training
OpDesc *send_op = GetSendOpDesc(program); // realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
...@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (boost::get<int>(
// append send op if program is distributed trainer main program. op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateSendOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateComputationalOps(&result, *op, 1); CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::ostringstream sout; std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, sout); PrintGraphviz(*graph, fout);
VLOG(10) << sout.str();
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
...@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, ...@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var; return var;
} }
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
op->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const { const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0 result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output if (op.Type() == "send_barrier") {
// is created for send op. ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0); CreateOpHandleIOs(result, op, 0);
} }
......
...@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
...@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const; const std::string &og) const;
......
...@@ -12,24 +12,26 @@ ...@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const Scope *local_scope, const platform::Place &place,
const platform::Place &place) const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
place_(place) {} place_(place),
name_(name) {}
void SendOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle. // TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done // Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK if (in->DebugString() == "dummy") { // HACK
continue; continue;
} }
...@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() { ...@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_->Run(*tmp_scope, place_); op_->Run(*tmp_scope, place_);
} }
std::string SendOpHandle::Name() const { return "send"; } std::string RPCOpHandle::Name() const { return name_; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,9 +27,9 @@ namespace paddle { ...@@ -27,9 +27,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct SendOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place); const platform::Place& place, const std::string& name);
std::string Name() const override; std::string Name() const override;
...@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase { ...@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_; const platform::Place& place_;
const std::string name_;
}; };
} // namespace details } // namespace details
......
...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum( .InEnum(
{static_cast<int>(OpRole::kForward), {static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
......
...@@ -24,6 +24,7 @@ enum class OpRole { ...@@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
kRPC = 0x0003,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { ...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes(); auto nodes = trait.nodes();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { ...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build(); dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS(); auto nodes = trait.nodes_in_DFS();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
......
...@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE) ...@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL) # listen_and_serv_op sum_op executor SERIAL)
...@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE) ...@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() endif()
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
...@@ -25,6 +25,21 @@ namespace paddle { ...@@ -25,6 +25,21 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
}
bool RPCClient::AsyncSendVariable(const std::string& ep, bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
}); });
req_count_++; req_count_++;
return true; return true;
...@@ -249,8 +263,9 @@ bool RPCClient::Proceed() { ...@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
delete c; delete c;
return true; return true;
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
return it->second; return it->second;
...@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto ch = auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch; channels_[ep] = ch;
return ch; return ch;
} }
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -35,6 +36,7 @@ limitations under the License. */ ...@@ -35,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient { class RPCClient {
public: public:
RPCClient() {}
static RPCClient* GetInstance();
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -191,11 +197,17 @@ class RPCClient { ...@@ -191,11 +197,17 @@ class RPCClient {
private: private:
bool Proceed(); bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0; std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
}; };
} // namespace detail } // namespace detail
......
...@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) { ...@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
detail::RPCClient client; auto client = detail::RPCClient::GetInstance();
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name); out_var_name);
client.Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
class FetchBarrierOp : public framework::OperatorBase {
public:
FetchBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
};
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpShapeInference);
...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
...@@ -87,17 +79,6 @@ the parameter server and fetch result back. ...@@ -87,17 +79,6 @@ the parameter server and fetch result back.
} }
}; };
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase { class PrefetchOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -110,5 +91,4 @@ namespace ops = paddle::operators; ...@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp, REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
ops::PrefetchOpVarTypeInference,
ops::PrefetchOpShapeInference); ops::PrefetchOpShapeInference);
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase { ...@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait());
} }
PADDLE_ENFORCE(client_.Wait());
} }
private:
mutable detail::RPCClient client_;
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -65,6 +70,10 @@ This operator can get variables from server side. ...@@ -65,6 +70,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
} }
}; };
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
auto client_var_name = Output("RPCClient"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), auto& ctx = *pool.Get(place);
"Can not find variable '%s' in the scope.", // For profiling
client_var_name); platform::RecordEvent record_event(Type(), &ctx);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); auto rpc_client = detail::RPCClient::GetInstance();
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
} }
}
}; };
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
...@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent. ...@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
} AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
};
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
...@@ -98,5 +88,4 @@ namespace ops = paddle::operators; ...@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference); ops::SendBarrierOpShapeInference);
...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { ...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server") AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
} }
}; };
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase { class SendOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { ...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference, ops::SendOpMaker, ops::SendOpShapeInference);
ops::SendOpShapeInference);
...@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false, &initialized); std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) { while (!initialized) {
} }
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get()) static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady(); ->WaitServerReady();
...@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) { ...@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( const f::VariableNameMap &inputs = {{"X", {"x1"}}};
"send", {{"X", {"x1"}}}, const f::VariableNameMap &outputs = {{"Out", {"Out"}}};
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
...@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
"send", {{"X", {"x1"}}}, {{"Out", {"Out"}}}, attrs);
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>(); auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
...@@ -20,6 +20,9 @@ namespace operators { ...@@ -20,6 +20,9 @@ namespace operators {
inline bool NeedSend(const framework::Scope& scope, inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient"); // For profiling
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), platform::RecordEvent record_event(Type(), &ctx);
"Can not find variable '%s' in the scope.",
client_var_name); auto rpc_client = detail::RPCClient::GetInstance();
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() { void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
} }
}; };
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase { class SendVarsOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -112,5 +97,4 @@ namespace ops = paddle::operators; ...@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp, REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference); ops::SendVarsOpShapeInference);
...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { ...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.value("Forward", framework::OpRole::kForward) .value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward) .value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize) .value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss); .value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
...@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None): ...@@ -195,21 +195,23 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap)) endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals()) helper = LayerHelper("Send", **locals())
rpc_client_var = default_main_program().global_block().create_var(
name="RPC_CLIENT_VAR", persistable=True, type=core.VarDesc.VarType.RAW)
if not get_vars: if not get_vars:
get_vars = [] get_vars = []
for s in send_vars: for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True) v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v) get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op( helper.append_op(
type="send", type="send",
inputs={"X": send_vars}, inputs={"X": send_vars},
outputs={"Out": get_vars, outputs={"Out": get_vars},
"RPCClient": rpc_client_var}, attrs={
attrs={"endpoints": endpoints, "endpoints": endpoints,
"epmap": epmap}) "epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars return get_vars
......
...@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): ...@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def test_transpiler(self): def test_transpiler(self):
trainer = self.get_trainer() trainer = self.get_trainer()
pserver, startup = self.get_pserver(self.current_pserver_ep) pserver, startup = self.get_pserver(self.current_pserver_ep)
self.assertEqual([op.type for op in trainer.global_block().ops], self.assertEqual([op.type for op in trainer.global_block().ops],
self.get_expect_trainer_ops()) self.get_expect_trainer_ops())
...@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): ...@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops, params_grads = self.net_conf() optimize_ops, params_grads = self.net_conf()
delete_ops(trainer.global_block(), optimize_ops) delete_ops(trainer.global_block(), optimize_ops)
return [op.type for op in trainer.global_block().ops ops = [op.type for op in trainer.global_block().ops] + [
] + ["split_byref", "send", "concat"] "split_byref", "send_vars", "send_barrier", "recv", "recv",
"fetch_barrier", "concat"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
return ops
def get_trainer(self): def get_trainer(self):
return self._transpiler_instance().get_trainer_program() return self._transpiler_instance().get_trainer_program()
......
...@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler ...@@ -16,8 +16,9 @@ from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory from memory_optimization_transpiler import memory_optimize, release_memory
from distribute_transpiler_simple import SimpleDistributeTranspiler from distribute_transpiler_simple import SimpleDistributeTranspiler
from ps_dispatcher import HashName, RoundRobin
__all__ = [ __all__ = [
"DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler", "DistributeTranspiler", "InferenceTranspiler", "SimpleDistributeTranspiler",
"memory_optimize", "release_memory" "memory_optimize", "release_memory", "HashName", "RoundRobin"
] ]
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import math import math
import distributed_splitter as splitter from ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework from .. import core, framework
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, \ default_startup_program, \
...@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \ ...@@ -24,7 +24,9 @@ from ..framework import Program, default_main_program, \
LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class VarBlock: class VarBlock:
...@@ -149,13 +151,27 @@ def delete_ops(block, ops): ...@@ -149,13 +151,27 @@ def delete_ops(block, ops):
block.program.sync_with_cpp() block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
class DistributeTranspiler: class DistributeTranspiler:
def transpile(self, def transpile(self,
trainer_id, trainer_id,
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=splitter.round_robin, split_method=RoundRobin,
sync_mode=True): sync_mode=True):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
...@@ -196,7 +212,7 @@ class DistributeTranspiler: ...@@ -196,7 +212,7 @@ class DistributeTranspiler:
:param sync_mode: if sync_mode is set True, it means that dist transpiler :param sync_mode: if sync_mode is set True, it means that dist transpiler
will transpile the program into sync_mode pserver and trainer program. will transpile the program into sync_mode pserver and trainer program.
""" """
assert (callable(split_method)) assert (split_method.__bases__[0] == PSDispatcher)
if program is None: if program is None:
program = default_main_program() program = default_main_program()
self.origin_program = program self.origin_program = program
...@@ -209,6 +225,7 @@ class DistributeTranspiler: ...@@ -209,6 +225,7 @@ class DistributeTranspiler:
pserver_endpoints = pservers.split(",") pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass() self.optimize_ops, params_grads = self._get_optimize_pass()
ps_dispatcher = split_method(pserver_endpoints)
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
...@@ -268,54 +285,110 @@ class DistributeTranspiler: ...@@ -268,54 +285,110 @@ class DistributeTranspiler:
grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints)) grad_blocks = split_dense_variable(grad_list, len(pserver_endpoints))
param_blocks = split_dense_variable(param_list, len(pserver_endpoints)) param_blocks = split_dense_variable(param_list, len(pserver_endpoints))
assert (len(grad_blocks) == len(param_blocks))
# step2: Create new vars for the parameters and gradients blocks and # step2: Create new vars for the parameters and gradients blocks and
# add ops to do the split. # add ops to do the split.
grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program, param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks) param_blocks)
grad_var_mapping = self._create_vars_from_blocklist(
program, grad_blocks, add_trainer_suffix=self.trainer_num > 1)
grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
grad_param_mapping[grad_var_mapping[g_name][int(g_bid)]] = \
param_var_mapping[p_name][int(p_bid)]
# step 3: transpile trainer side program, insert recv op and send op.
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs = []
send_outputs = []
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict() self.param_grad_ep_mapping = dict()
for i, ep in enumerate(eplist): [
param = send_outputs[i] self.param_grad_ep_mapping.update({
grad = send_inputs[i] ep: {
if not self.param_grad_ep_mapping.has_key(ep): "params": [],
self.param_grad_ep_mapping[ep] = {"params": [], "grads": []} "grads": []
self.param_grad_ep_mapping[ep]["params"].append(param) }
self.param_grad_ep_mapping[ep]["grads"].append(grad) }) for ep in self.pserver_endpoints
]
rpc_client_var = program.global_block().create_var(
name=RPC_CLIENT_VAR_NAME, # step 3.1: insert send op to send gradient vars to parameter servers
persistable=True, ps_dispatcher.reset()
type=core.VarDesc.VarType.RAW) send_vars = []
for orig_varname, splited_vars in grad_var_mapping.items():
eplist = ps_dispatcher.dispatch(splited_vars)
if len(splited_vars) == 1:
orig_varname = splited_vars[0].name
index = find_op_by_output_arg(program.global_block(),
orig_varname)
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[orig_varname]
index = find_op_by_output_arg(program.global_block(),
orig_varname)
self._insert_split_op(program, orig_var, index, splited_vars)
index += 1
else:
AssertionError("Can not insert the send op by original "
"variable name :", orig_varname)
program.global_block().insert_op(
index=index + 1,
type="send_vars",
inputs={"X": splited_vars},
outputs={},
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
# create send_op if self.sync_mode:
program.global_block().append_op( program.global_block().append_op(
type="send", type="send_barrier",
inputs={"X": send_inputs}, inputs={},
outputs={"Out": send_outputs, outputs={},
"RPCClient": rpc_client_var},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"epmap": eplist, "sync_mode": self.sync_mode,
"sync_mode": self.sync_mode RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars = []
for _, var in enumerate(send_vars):
recv_vars.append(grad_param_mapping[var])
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
eps = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
program.global_block().append_op(
type="recv",
inputs={},
outputs={"Out": splited_var},
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op(
type="fetch_barrier",
inputs={},
outputs={},
attrs={
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
continue continue
...@@ -327,10 +400,8 @@ class DistributeTranspiler: ...@@ -327,10 +400,8 @@ class DistributeTranspiler:
attrs={"axis": 0}) attrs={"axis": 0})
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
self._replace_lookup_table_op_with_prefetch(program, rpc_client_var, self._replace_lookup_table_op_with_prefetch(program, eplist)
eplist) self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
self._split_table_grad_and_add_send_vars(program, rpc_client_var,
pserver_endpoints)
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
...@@ -550,8 +621,7 @@ class DistributeTranspiler: ...@@ -550,8 +621,7 @@ class DistributeTranspiler:
return s_prog return s_prog
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, def _replace_lookup_table_op_with_prefetch(self, program, eplist):
eplist):
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op # 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
self.prefetch_input_vars = None self.prefetch_input_vars = None
self.prefetch_output_vars = None self.prefetch_output_vars = None
...@@ -598,11 +668,11 @@ class DistributeTranspiler: ...@@ -598,11 +668,11 @@ class DistributeTranspiler:
index=op_index + 1, index=op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': self.prefetch_input_vars}, inputs={'X': self.prefetch_input_vars},
outputs={ outputs={"Out": self.prefetch_output_vars},
"Out": self.prefetch_output_vars, attrs={
"RPCClient": rpc_client_var "epmap": eplist,
}, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
attrs={"epmap": eplist}) })
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block().insert_op(
...@@ -622,8 +692,7 @@ class DistributeTranspiler: ...@@ -622,8 +692,7 @@ class DistributeTranspiler:
# break for loop # break for loop
break break
def _split_table_grad_and_add_send_vars(self, program, rpc_client_var, def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
pserver_endpoints):
# 2. add split_ids_op and send_vars_op to send gradient to pservers # 2. add split_ids_op and send_vars_op to send gradient to pservers
# there should only be one table_name # there should only be one table_name
all_ops = program.global_block().ops all_ops = program.global_block().ops
...@@ -643,9 +712,12 @@ class DistributeTranspiler: ...@@ -643,9 +712,12 @@ class DistributeTranspiler:
index=op_index + 2, index=op_index + 2,
type="send_vars", type="send_vars",
inputs={'X': self.table_grad_list}, inputs={'X': self.table_grad_list},
outputs={"RPCClient": rpc_client_var}, outputs={},
attrs={"sync_send": True, attrs={
"epmap": pserver_endpoints}) "sync_send": True,
"epmap": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
break break
def _create_prefetch_block(self, pserver_index, pserver_program, def _create_prefetch_block(self, pserver_index, pserver_program,
...@@ -838,32 +910,13 @@ class DistributeTranspiler: ...@@ -838,32 +910,13 @@ class DistributeTranspiler:
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=persistable) persistable=persistable)
def _append_split_op(self, program, gradblocks): def _insert_split_op(self, program, orig_var, index, splited_vars):
"""
Split variables that need to be split and append respective ops
Args:
program (ProgramDesc): ProgramDesc that gradients blong.
gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
Returns:
var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
from original var name to each var split.
"""
add_suffix = False
if self.trainer_num > 1:
add_suffix = True
var_mapping = self._create_vars_from_blocklist(
program, gradblocks, add_trainer_suffix=add_suffix)
for varname, splited_vars in var_mapping.iteritems():
# variable that don't need to split have empty splited_vars
if len(splited_vars) <= 1:
continue
orig_var = program.global_block().vars[varname]
if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS:
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
program.global_block().append_op( program.global_block().insert_op(
index=index + 1,
type="split_selected_rows", type="split_selected_rows",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
...@@ -872,7 +925,8 @@ class DistributeTranspiler: ...@@ -872,7 +925,8 @@ class DistributeTranspiler:
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
sections.append(v.shape[0]) sections.append(v.shape[0])
program.global_block().append_op( program.global_block().insert_op(
index=index + 1,
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
outputs={"Out": splited_vars}, outputs={"Out": splited_vars},
...@@ -881,7 +935,6 @@ class DistributeTranspiler: ...@@ -881,7 +935,6 @@ class DistributeTranspiler:
else: else:
AssertionError("Variable type should be in set " AssertionError("Variable type should be in set "
"[LOD_TENSOR, SELECTED_ROWS]") "[LOD_TENSOR, SELECTED_ROWS]")
return var_mapping
def _get_optimizer_input_shape(self, op_type, varkey, orig_shape, def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
param_shape): param_shape):
......
...@@ -13,45 +13,66 @@ ...@@ -13,45 +13,66 @@
# limitations under the License. # limitations under the License.
def hash_name(varlist, pserver_endpoints): class PSDispatcher(object):
""" """
hash variable names to several endpoints. PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
def __init__(self, pserver_endpoints):
self._eps = pserver_endpoints
self._step = 0
@property
def eps(self):
return self._eps
def reset(self):
self._step = 0
def dispatch(self, varlist):
"""
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
AssertionError("Interface has not been implemented.")
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname class HashName(PSDispatcher):
"""
Hash variable names to several endpoints
""" """
def _hash_block(block_str, total): def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def _hash_block(self, block_str, total):
return hash(block_str) % total return hash(block_str) % total
def dispatch(self, varlist):
eplist = [] eplist = []
for var in varlist: for var in varlist:
server_id = _hash_block(var.name(), len(pserver_endpoints)) server_id = self._hash_block(var.name(), len(self._eps))
server_for_param = pserver_endpoints[server_id] server_for_param = self._eps[server_id]
eplist.append(server_for_param) eplist.append(server_for_param)
return eplist return eplist
def round_robin(varlist, pserver_endpoints): class RoundRobin(PSDispatcher):
""" """
Distribute variables to several endpoints. Distribute variables to serveral endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
""" """
assert (len(varlist) >= len(pserver_endpoints))
def __init__(self, pserver_endpoints):
super(self.__class__, self).__init__(pserver_endpoints)
def dispatch(self, varlist):
eplist = [] eplist = []
pserver_idx = 0
for var in varlist: for var in varlist:
server_for_param = pserver_endpoints[pserver_idx] server_for_param = self._eps[self._step]
eplist.append(server_for_param) eplist.append(server_for_param)
self._step += 1
pserver_idx += 1 if self._step >= len(self._eps):
if pserver_idx >= len(pserver_endpoints): self._step = 0
pserver_idx = 0
return eplist return eplist
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册