未验证 提交 caa90a65 编写于 作者: T tangwei12 提交者: GitHub

Integrated Trainer of Parameter Server (API add...

Integrated Trainer of Parameter Server (API add `fluid.contrib.layers.sparse_embedding` only) (#22957)

* Integrated Trainer of Parameter Server
上级 af74675b
...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope, ...@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
} }
} }
// get RpcContext and remote send and recv op // get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { auto *instance = operators::distributed::Communicator::GetInstance();
auto *instance = operators::distributed::Communicator::GetInstance(); auto initialized = instance ? true : false;
auto initialized = instance ? true : false; PADDLE_ENFORCE_EQ(initialized, true,
PADDLE_ENFORCE_EQ(initialized, true, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "Communicator is not Initialized, you may use "
"Communicator is not Initialized, you may use " "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" "develop/markdown_doc/transpiler)"));
"develop/markdown_doc/transpiler)"));
}
#endif #endif
} }
......
...@@ -122,7 +122,7 @@ class SelectedRows { ...@@ -122,7 +122,7 @@ class SelectedRows {
/* /*
* @brief Get the index of the key from id_to_index_ map. * @brief Get the index of the key from id_to_index_ map.
*/ */
inline int64_t GetIndexFromId(int64_t key) { inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key); auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) { if (iter == id_to_index_.end()) {
return -1; return -1;
......
...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) { ...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy"); PADDLE_THROW("unknown var type to copy");
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec ...@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
...@@ -26,7 +27,7 @@ if(WITH_GRPC) ...@@ -26,7 +27,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc collective_client.cc collective_server.cc
${GRPC_SRCS} ${GRPC_SRCS}
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
...@@ -50,12 +51,12 @@ else() ...@@ -50,12 +51,12 @@ else()
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op) DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif() endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_op) DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
...@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, ...@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
} }
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dirname,
const std::string& varname,
int64_t time_out) { int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(varname);
req.set_out_varname(dir); req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out); return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
} }
......
...@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient { ...@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient {
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -22,67 +22,69 @@ namespace paddle { ...@@ -22,67 +22,69 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
struct RpcContext { struct CommContext {
RpcContext() = default; CommContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names, CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap, const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id, const std::vector<int64_t> &sections,
bool merge_add_ = true, bool use_send_handler_ = true) const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name), : var_name(name),
splited_var_names(names), splited_varnames(names),
epmap(emap), epmap(emap),
height_sections(sections), height_sections(sections),
origin_varnames(origin_names),
trainer_id(id), trainer_id(id),
merge_add(merge_add_), merge_add(merge_add_),
use_send_handler(use_send_handler_) {} is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
RpcContext(const RpcContext &ctx) { CommContext(const CommContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names; splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap; epmap = ctx.epmap;
height_sections = ctx.height_sections; height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id; trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add; merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler; is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
} }
std::string var_name; std::string print() const {
std::vector<std::string> splited_var_names; std::stringstream ss;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
bool merge_add;
bool use_send_handler;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
os << "splited_var_names: ["; for (size_t i = 0; i < splited_varnames.size(); i++) {
for (auto &name : rpc_ctx.splited_var_names) { ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
os << name << ", "; << " section: " << height_sections[i] << " ";
} }
os << "]\n";
os << "epmap: ["; ss << "origin varnames: ";
for (auto &ep : rpc_ctx.epmap) { for (size_t i = 0; i < origin_varnames.size(); i++) {
os << ep << ", "; ss << origin_varnames[i] << " ";
} }
os << "]\n";
ss << " aggregation->add: " << merge_add << " ";
ss << " is_sparse: " << is_sparse << "\n";
ss << " is_distributed: " << is_distributed << "\n";
os << "height_sections: ["; return ss.str();
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
} }
os << "]\n";
os << "merge add: " << rpc_ctx.merge_add; std::string var_name;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n"; std::vector<std::string> splited_varnames;
os << "}"; std::vector<std::string> epmap;
return os; std::vector<int64_t> height_sections;
} std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
......
...@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
} }
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dirname,
const std::string& varname,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
...@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(varname);
req.set_out_varname(dir); req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
......
...@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient { ...@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify( VarHandlePtr AsyncDistributeNotify(
......
...@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase { ...@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase {
void Process() override { void Process() override {
std::string varname = GetReqName(); std::string varname = GetReqName();
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
...@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase { ...@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase {
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName(); std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name); auto invar = scope->FindVar(in_var_name);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
此差异已折叠。
...@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor; ...@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector,
const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids;
for (auto id : ids_vector) {
all_ids.insert(id);
}
auto abs_sections = ToAbsoluteSection(height_section);
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
static void SplitIdsIntoMultipleVarsBySection( static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names, const std::vector<int64_t> &in_ids,
const std::vector<int64_t>& height_section, const std::vector<std::string> &in_varnames, const int tables,
const std::vector<std::vector<int64_t>>& splited_ids, const int pservers, const bool is_distibuted, framework::Scope *scope,
framework::Scope* scope) { std::vector<std::vector<int64_t>> *splited_ids,
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); std::vector<std::vector<int64_t>> *origin_ids) {
PADDLE_ENFORCE_EQ(
in_varnames.size(), tables,
platform::errors::OutOfRange(
"send varnames size: %d not equal table number: %d, internal error",
in_varnames.size(), tables));
PADDLE_ENFORCE_LE(
tables, pservers,
platform::errors::OutOfRange("table number %d not equal or less than "
"pserver number: %d, internal error",
tables, pservers));
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
for (size_t i = 0; i < in_var_names.size(); ++i) { std::set<int64_t> st(in_ids.begin(), in_ids.end());
auto* id_tensor = std::vector<int64_t> all_ids;
scope->Var(in_var_names[i])->GetMutable<framework::LoDTensor>(); all_ids.assign(st.begin(), st.end());
auto& ids = splited_ids[i];
splited_ids->resize(tables);
origin_ids->resize(tables);
if (is_distibuted) {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*splited_ids)[pserver_id].push_back(id);
(*origin_ids)[pserver_id].push_back(id);
}
} else {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*origin_ids)[pserver_id].push_back(id);
id = id / pservers;
(*splited_ids)[pserver_id].push_back(id);
}
}
for (size_t i = 0; i < in_varnames.size(); ++i) {
auto *id_tensor =
scope->Var(in_varnames[i])->GetMutable<framework::LoDTensor>();
auto &ids = (*splited_ids)[i];
if (!ids.empty()) { if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>( auto *id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place); framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
} }
...@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection( ...@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection(
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints; typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
void prefetch_core( void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables, const std::vector<int64_t> &ids, const TableAndEndpoints &tables,
const std::vector<int64_t>& height_sections, const framework::ExecutionContext &context, const framework::Scope &scope,
const framework::ExecutionContext& context, const framework::Scope& scope, const bool is_distributed,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) { std::unordered_map<int64_t, std::vector<float>> *recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); distributed::RPCClient *rpc_client =
auto& actual_ctx = *pool.Get(context.GetPlace()); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(context.GetPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
...@@ -99,19 +121,17 @@ void prefetch_core( ...@@ -99,19 +121,17 @@ void prefetch_core(
out_var_names.push_back("prefetch_recv@" + tables[i].second); out_var_names.push_back("prefetch_recv@" + tables[i].second);
} }
auto splited_ids = SplitIds(ids, height_sections); std::vector<std::vector<int64_t>> split_ids;
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, std::vector<std::vector<int64_t>> origin_ids;
local_scope.get()); SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers,
is_distributed, local_scope.get(),
&split_ids, &origin_ids);
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto &name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>(); local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) { if (NeedSend(*local_scope.get(), in_var_names[i])) {
...@@ -126,20 +146,18 @@ void prefetch_core( ...@@ -126,20 +146,18 @@ void prefetch_core(
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
} }
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), ""); for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) {
auto &ids_in_this_section = origin_ids[o_idx];
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) { if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx]) auto &prefetch_out_var =
->Get<framework::LoDTensor>(); local_scope->Var(out_var_names[o_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>(); const auto *out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims(); auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, ""); PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]); PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
...@@ -147,8 +165,7 @@ void prefetch_core( ...@@ -147,8 +165,7 @@ void prefetch_core(
auto row_numel = dims[1]; auto row_numel = dims[1];
for (int64_t i = 0; i < dims[0]; ++i) { for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i]; auto origin_id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
std::vector<float> vecs(row_numel); std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin()); std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs; (*recved_vec_map)[origin_id] = vecs;
...@@ -159,38 +176,35 @@ void prefetch_core( ...@@ -159,38 +176,35 @@ void prefetch_core(
} }
} }
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string &id_name, const std::string &out_name,
const std::string& persistable_var_name, const bool backfill, const std::string &persistable_var_name,
const std::vector<std::string>& table_names, const bool is_distributed,
const std::vector<std::string>& endpoints, const std::vector<std::string> &table_names,
const std::vector<int64_t>& height_sections, const std::vector<std::string> &endpoints,
const framework::ExecutionContext& context, const framework::ExecutionContext &context,
const framework::Scope& scope) { const framework::Scope &scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names, prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed,
endpoints, height_sections, context, scope); table_names, endpoints, context, scope);
} }
void prefetchs(const std::vector<std::string>& id_var_names, void prefetchs(const std::vector<std::string> &id_var_names,
const std::vector<std::string>& out_var_names, const std::vector<std::string> &out_var_names,
const std::string& persistable_var_name, const bool backfill, const std::string &persistable_var_name,
const std::vector<std::string>& table_names, const bool is_distributed,
const std::vector<std::string>& endpoints, const std::vector<std::string> &table_names,
const std::vector<int64_t>& height_sections, const std::vector<std::string> &endpoints,
const framework::ExecutionContext& context, const framework::ExecutionContext &context,
const framework::Scope& scope) { const framework::Scope &scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
auto vec_dim_1 = 0; auto vec_dim_1 = 0;
framework::Variable* var = scope.FindVar(persistable_var_name); auto vec_dim_0 = 0;
framework::Variable *var = scope.FindVar(persistable_var_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument( if (var->IsType<SelectedRows>()) {
"prefetch can only support LodTensor only")); vec_dim_1 = var->Get<framework::SelectedRows>().value().dims()[1];
} else {
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1]; vec_dim_0 = var->Get<framework::LoDTensor>().dims()[0];
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
}
PADDLE_ENFORCE_GT(vec_dim_1, 0, PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names,
PADDLE_THROW("multi prefetch only support CPU currently"); PADDLE_THROW("multi prefetch only support CPU currently");
} }
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union; std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables; TableAndEndpoints tables;
for (auto& id_name : id_var_names) { for (auto &id_name : id_var_names) {
auto* id_tensor = auto *in_var = scope.FindVar(id_name);
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>(); auto &id_tensor = in_var->Get<framework::LoDTensor>();
auto id_dims = id_tensor->dims(); std::copy_n(id_tensor.data<int64_t>(), id_tensor.numel(),
id_tensor->Resize(framework::make_ddim( back_inserter(ids_union));
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
ids_group.push_back(ids);
ids_lods.push_back(id_tensor->lod());
} }
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end()); std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end()); ids_union.assign(s.begin(), s.end());
for (auto &i : ids_union) {
PADDLE_ENFORCE_GE(
i, 0, platform::errors::OutOfRange(
"each element in embedding should be larger or equal 0"));
if (!is_distributed) {
PADDLE_ENFORCE_LT(
i, vec_dim_0,
platform::errors::OutOfRange(
"embedding id must in [0, %d) when is_distributed False",
vec_dim_0));
}
}
for (size_t i = 0; i < table_names.size(); i++) { for (size_t i = 0; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i])); tables.push_back(std::make_pair(table_names[i], endpoints[i]));
} }
std::unordered_map<int64_t, std::vector<float>> recved_vec_map; std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, height_sections, context, scope, prefetch_core(ids_union, tables, context, scope, is_distributed,
&recved_vec_map); &recved_vec_map);
auto padding_idx = distributed::kNoPadding; auto padding_idx = distributed::kNoPadding;
...@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
padding_idx = context.Attr<int64_t>("padding_idx"); padding_idx = context.Attr<int64_t>("padding_idx");
} }
// copy vectors to out vars
for (size_t i = 0; i < out_var_names.size(); i++) { for (size_t i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i]; auto *in_var = scope.FindVar(id_var_names[i]);
auto* out_t = auto &id_tensor = in_var->Get<framework::LoDTensor>();
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>(); auto ids_size = id_tensor.dims()[0];
out_t->Resize( const auto *id_data = id_tensor.data<int64_t>();
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto* out_d = out_t->mutable_data<float>(place);
for (size_t idx = 0; idx < ids.size(); idx++) { auto *out_t =
const auto& id = ids[idx]; scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->set_lod(id_tensor.lod());
out_t->Resize(framework::make_ddim({ids_size, vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = id_data[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) { if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else { } else {
......
...@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
const std::string& persistable_var_name, const bool backfill, const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
...@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill, const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
...@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows; ...@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(2) << "ParameterRecv in " << rpc_ctx.var_name; platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
int64_t height = 0;
int64_t ids_num = 0;
int64_t width = 0;
std::vector<int64_t> all_ids;
auto pserver_num = rpc_ctx.splited_varnames.size();
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
height += recv_t.height();
ids_num += recv_t.rows().size();
width = recv_t.value().dims()[1];
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
}
auto *var = scope.FindVar(rpc_ctx.var_name);
auto *t_ = var->GetMutable<framework::SelectedRows>();
T *out_data =
t_->mutable_value()->mutable_data<T>({ids_num, width}, cpu_place);
t_->set_height(height);
t_->set_rows(all_ids);
int64_t cnt = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto rows = recv_t.rows().size();
const T *in_data = recv_t.value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
}
t_->SyncIndex();
}
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name); std::vector<distributed::VarHandlePtr> rets;
// recv all vars to local scope // variable do not spilt
if (recv_var->IsType<framework::LoDTensor>() || if (rpc_ctx.origin_varnames.size() == 1 &&
recv_var->IsType<framework::SelectedRows>()) { rpc_ctx.splited_varnames.size() == 1) {
std::vector<distributed::VarHandlePtr> rets; auto varname = rpc_ctx.origin_varnames[0];
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
auto &recv_var_name = rpc_ctx.splited_var_names[i]; rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
local_scope->Var(recv_var_name); scope, varname, varname));
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
if (recv_var->IsType<framework::LoDTensor>()) {
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(),
recv_var_name, recv_var_name));
} else {
// sparse param in pserver_scope is SelectedRows
rets.push_back(rpc_client->AsyncGetVar(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
}
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
} }
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
return;
} else { } else {
PADDLE_THROW("unsupported var type to recv!"); PADDLE_ENFORCE(false, platform::errors::Unimplemented(
"ParameterRecv can not recv dense with multi "
"parts now, add it soon."));
} }
}
// concat recved tensor into one var template <typename T>
if (recv_var->IsType<framework::LoDTensor>()) { void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
size_t output_offset = 0; const framework::Scope &scope, bool barrier) {
size_t row_offset = 0; VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>(); PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
auto dev_ctx = paddle::platform::CPUDeviceContext(); platform::errors::InvalidArgument(
int64_t recv_numel = 0; "origin_varnames.size() >= 1 is permitted"));
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *recv_var = local_scope->FindVar(recv_var_name); if (rpc_ctx.is_sparse) {
if (recv_var->IsType<framework::LoDTensor>()) { RecvSelectedRows<T>(rpc_ctx, scope);
auto &in = recv_var->Get<framework::LoDTensor>(); } else {
recv_numel += in.numel(); RecvLodTensor<T>(rpc_ctx, scope);
auto in_stride = framework::stride_numel(in.dims());
auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in.data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
<< sstream.str();
}
for (size_t i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
}
row_offset += recv_slr.height();
} else {
PADDLE_THROW("unsupported recieved var type");
}
}
auto numel = recv_tensor->numel();
PADDLE_ENFORCE_EQ(
recv_numel, numel,
platform::errors::InvalidArgument(
"The number of receive tensor's elements are not valid. The "
"recevie tensor numel is %d, the actual tensor numel is %d.",
recv_numel, numel));
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto cpu_place = platform::CPUPlace();
auto *slr = recv_var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
int64_t width = 0;
int64_t height = 0;
std::vector<int64_t> new_rows{};
// trans sparse ids from local to global
std::vector<int64_t> abs_sections =
ToAbsoluteSection(rpc_ctx.height_sections);
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
width = var_slr->mutable_value()->dims()[1];
height += var_slr->height();
auto row_offset = abs_sections[i];
VLOG(4) << "Recv split_var " << recv_var_name << " Row size "
<< var_slr_row->size();
for (size_t j = 0; j < var_slr_row->size(); j++) {
new_rows.push_back(row_offset + (*var_slr_row)[j]);
}
}
slr->set_rows(new_rows);
slr->set_height(height);
slr->mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(slr->mutable_rows()->size()), width}),
cpu_place);
auto *slr_data = slr->mutable_value()->data<float>();
size_t row_offset = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
auto var_slr_row_size = var_slr_row->size();
auto *var_slr_data = var_slr->mutable_value()->data<float>();
memcpy(slr_data + row_offset * width, var_slr_data,
sizeof(float) * width * var_slr_row_size);
row_offset += var_slr_row_size;
}
} }
VLOG(2) << "ParameterRecv out " << rpc_ctx.var_name; VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
} }
template struct ParameterRecv<float>; template struct ParameterRecv<float>;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +26,10 @@ namespace distributed { ...@@ -26,7 +26,10 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterRecv { struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope); void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
}; };
}; // namespace distributed }; // namespace distributed
......
...@@ -41,42 +41,67 @@ using DDim = framework::DDim; ...@@ -41,42 +41,67 @@ using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS; typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext( inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext(
const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) { const CommContext &rpc_ctx, const framework::Scope &scope,
int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs; EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) { if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1"); PADDLE_ENFORCE_GE(multi_parts, 1,
platform::errors::InvalidArgument(
if (multi_parts == 1) { "multi_parts must == 1 in parameter send, now is: %d",
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { multi_parts));
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i])); for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
} table_pairs.push_back(
} else { std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i]));
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (int x = 0; x < multi_parts; x++) {
auto table =
string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
}
}
} }
} else if (send_var->IsType<framework::LoDTensor>()) {
PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
} else { } else {
PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!"); PADDLE_THROW(platform::errors::InvalidArgument(
"GetMultiFieldCommContext unsupported LoDTensor current!"));
} }
return table_pairs; return table_pairs;
} // namespace distributed } // namespace distributed
void SendByNotifyRPC(const CommContext &rpc_ctx,
const framework::Scope &scope) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto &send_var_name = rpc_ctx.var_name;
std::vector<distributed::VarHandlePtr> rets;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
if (NeedSend(scope, send_var_name)) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope,
send_var_name));
VLOG(4) << "send var " << send_var_name << " by notify RPC done";
}
} else {
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name;
}
for (auto &handle : rets) {
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
template <typename T> template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool sync, const framework::Scope &scope, bool sync,
int multi_parts) { int multi_parts) {
if (rpc_ctx.var_name == STEP_COUNTER) {
SendByNotifyRPC(rpc_ctx, scope);
return;
}
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::LoDTensor>()) { if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_var_names.size(); size_t out_num = rpc_ctx.splited_varnames.size();
if (out_num > 1) { if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>(); auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims(); auto &send_tensor_dims = send_tensor.dims();
...@@ -110,72 +134,49 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -110,72 +134,49 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
// create output var in local scope // create output var in local scope
size_t row_offset = 0; size_t row_offset = 0;
for (size_t i = 0; i < out_num; ++i) { for (size_t i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i]) framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i])
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0]; row_offset += outs_dims[i][0];
} }
} else { } else {
auto &send_tensor = send_var->Get<framework::LoDTensor>(); auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[0]) framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0])
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor); out->ShareDataWith(send_tensor);
} }
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i]; auto &send_var_name = rpc_ctx.splited_varnames[i];
VLOG(4) << "send var name: " << send_var_name; auto &endpoint = rpc_ctx.epmap[i];
auto &endpoint = rpc_ctx.epmap[i]; VLOG(4) << " send var name: " << send_var_name
VLOG(4) << "send var endpoint: " << endpoint; << "endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name); if (NeedSend(*local_scope.get(), send_var_name)) {
if (NeedSend(*local_scope.get(), send_var_name)) { VLOG(3) << "sending " << send_var_name << " to " << endpoint;
VLOG(3) << "sending " << send_var_name << " to " << endpoint; rets.push_back(rpc_client->AsyncSendVar(
rets.push_back(rpc_client->AsyncSendVar( endpoint, cpu_ctx, *local_scope.get(), send_var_name));
endpoint, cpu_ctx, *local_scope.get(), send_var_name)); VLOG(4) << "send var " << send_var_name << " async handle done";
VLOG(4) << "send var " << send_var_name << " async handle done"; } else {
} else { VLOG(3) << "don't send non-initialized variable: "
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.splited_varnames[i];
<< rpc_ctx.splited_var_names[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} }
} }
} else if (send_var->IsType<framework::SelectedRows>()) { } else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>(); auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows(); auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) { if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which " LOG(WARNING)
"may cause an unknown error. Please check the state of " << "WARNING: The variable sent to pserver is empty, which "
"use_double_buffer in pyreader async mode, you need to " "may cause an unknown error. Please check the state of "
"turn it false."; "use_double_buffer in pyreader/dataloader async mode, you need to "
"turn it false.";
} }
std::vector<std::vector<size_t>> outs_rows_idx; std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx; std::vector<std::vector<size_t>> outs_dense_idx;
auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts); auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1);
outs_rows_idx.resize(table_pairs.size()); outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size()); outs_dense_idx.resize(table_pairs.size());
...@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs.push_back(out); outs.push_back(out);
} }
// split rows index into output sparse vars if (!rpc_ctx.is_distributed) {
for (size_t i = 0; i < send_rows.size(); ++i) { auto pserver_num = rpc_ctx.epmap.size();
auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
auto table_idx = send_rows[i] % multi_parts; // split rows index into output sparse vars
auto out_idx = ep_idx * multi_parts + table_idx; for (size_t i = 0; i < send_rows.size(); ++i) {
outs_rows_idx[out_idx].push_back(send_rows[i]); auto ep_idx = send_rows[i] % pserver_num;
outs_dense_idx[out_idx].push_back(i); auto id = send_rows[i] / pserver_num;
} outs_rows_idx[ep_idx].push_back(id);
outs_dense_idx[ep_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
} else {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto out_idx = send_rows[i] % pserver_num;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
for (size_t ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) { for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
for (int part = 0; part < multi_parts; part++) { out_idx++) {
auto out_idx = ctx * multi_parts + part;
auto rows_idx = outs_rows_idx[out_idx]; auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims(); auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size(); dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]); outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear(); outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place()); outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) { if (rows_idx.size() > 0) {
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]); outs[out_idx]->mutable_rows()->push_back(idx);
} }
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place); auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) { for (size_t j = 0; j < rows_idx.size(); j++) {
...@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
src + outs_dense_idx[out_idx][j] * row_numel, src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel); sizeof(T) * row_numel);
} else { } else {
PADDLE_THROW("do not support GPU now"); PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
} }
} }
} }
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(), PADDLE_ENFORCE_EQ(
"rows should has the same size with tensor dim 0"); rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
} }
} }
...@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto need_send = NeedSend(*local_scope.get(), send_var_name); auto need_send = NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << "send var name: " << send_var_name VLOG(4) << "send var name: " << send_var_name
<< "send var endpoint: " << endpoint << " send var endpoint: " << endpoint
<< "need send: " << need_send; << " need send: " << need_send;
if (need_send) { if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint; VLOG(4) << "sending " << send_var_name << " to " << endpoint;
...@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done"; VLOG(4) << "send var " << send_var_name << " async handle done";
} else { } else {
VLOG(4) << "don't send non-initialized variable: " VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i]; << rpc_ctx.splited_varnames[i];
} }
} }
} else { } else {
...@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
if (sync) { if (sync) {
for (auto &handle : rets) { for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle; VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
} }
} }
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +26,7 @@ namespace distributed { ...@@ -26,7 +26,7 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterSend { struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope, void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts); bool sync, int multi_parts);
}; };
......
...@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000; ...@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000;
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" #define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,13 +39,13 @@ namespace distributed { ...@@ -38,13 +39,13 @@ namespace distributed {
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Sync // Sync
...@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname); scope->Rename(varname, run_varname);
} }
if (distributed_mode_ == DistributedMode::kGeo && auto *var = scope->FindVar(run_varname);
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr = // for sparse ids
scope->FindVar(run_varname)->Get<framework::SelectedRows>(); if (var->IsType<framework::SelectedRows>()) {
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, if (distributed_mode_ == DistributedMode::kAsync ||
grad_slr.rows()); distributed_mode_ == DistributedMode::kHalfAsync) {
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->GradInLargeScale(run_varname)) {
auto *large_scale_var = ins->GetByGrad(run_varname);
for (auto name : large_scale_var->CachedVarnames()) {
scope->Var(name);
}
}
}
if (distributed_mode_ == DistributedMode::kGeo) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(
run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(
run_varname, grad_slr.rows());
}
}
} }
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope); scope);
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
...@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name; << " table_name: " << table_name;
...@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name; VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo && if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) { !table_name.empty()) {
VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist ";
std::vector<int64_t> updated_rows; std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows); varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
std::ostringstream sstream; std::ostringstream sstream;
sstream << "["; sstream << "[";
for (auto& row_id : updated_rows) { for (auto &row_id : updated_rows) {
sstream << row_id << ", "; sstream << row_id << ", ";
} }
sstream << "]"; sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " " VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str(); << sstream.str();
} }
auto& origin_tensor =
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>(); scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>(); auto *origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims(); auto &dims = origin_tensor.dims();
*outvar = scope->Var(); *outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>(); auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows); out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]); out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim( auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]}); {static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>( auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place()); out_dims, origin_tensor.place());
auto width = dims[1]; auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) { for (size_t i = 0; i < updated_rows.size(); ++i) {
...@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestGetNoBarrierHandler::Handle(const std::string& varname, bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
...@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname, ...@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) { (*outvar)->GetMutable<framework::LoDTensor>();
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType()); VLOG(1) << "Prefetch "
executor_->RunPreparedContext( << "tablename: " << table_name << " ids:" << varname
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope); << " out: " << out_var_name;
paddle::platform::CPUPlace cpu_place;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->ParamInLargeScale(table_name)) {
auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
} else { } else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op = auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name); BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place); lookup_table_op->Run(*scope, cpu_place);
} }
return true; return true;
} }
bool RequestCheckpointHandler::Handle(const std::string& varname, bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
PADDLE_ENFORCE( VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// TODO(tangwei12): find out why scope will be error. // auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); // paddle::platform::CPUPlace cpu_place;
lt_var->clear(); // checkpoint_op->Run(*scope_, cpu_place);
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true; return true;
} }
bool RequestNotifyHandler::Handle(const std::string& varname, bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname; VLOG(3) << "RequestNotifyHandler: " << varname
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); string::Piece decay_piece(STEP_COUNTER);
string::Piece var_name_piece = string::Piece(varname); string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) { if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update"; VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1, auto *send_var = scope->FindVar(varname);
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>(); auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value = auto *send_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place()); send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
auto counter = decay_counters.at(trainer_id);
counter += send_value[0];
decay_counters.at(trainer_id) = counter;
auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER);
if (global_step_var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find LEARNING_RATE_DECAY_COUNTER "));
}
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
auto global_counter = 0;
for (auto &trainer_counter : decay_counters) {
global_counter += trainer_counter.second;
}
value[0] = global_counter;
if (lr_decay_prepared_ctx_.get() == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find decay block for executor"));
}
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
} }
return true; return true;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp( std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name, const std::string& table_name, const std::string& id_name,
const std::string& out_name) { const std::string& out_name) {
...@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler {
public: public:
explicit RequestCheckpointHandler(int distributed_mode, explicit RequestCheckpointHandler(int distributed_mode)
int checkpoint_notify_id) : RequestHandler(distributed_mode) {}
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
...@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
int checkpoint_notify_id; std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
}; };
class RequestNotifyHandler final : public RequestHandler { class RequestNotifyHandler final : public RequestHandler {
public: public:
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id) explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) { : RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id; this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
} }
virtual ~RequestNotifyHandler() {} virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
...@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler { ...@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
int lr_decay_block_id; int trainers;
std::unordered_map<int, int64_t> decay_counters;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -77,8 +77,8 @@ class RPCClient { ...@@ -77,8 +77,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify( virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify( virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
......
...@@ -34,7 +34,7 @@ namespace framework = paddle::framework; ...@@ -34,7 +34,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed; namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table); USE_NO_KERNEL_OP(lookup_sparse_table_read);
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
...@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { ...@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}}); framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp(); auto op = block->AppendOp();
op->SetType("lookup_sparse_table"); op->SetType("lookup_sparse_table_read");
op->SetInput("W", {"w"}); op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"}); op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out"); auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
...@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) { ...@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) {
platform::CPUPlace place; platform::CPUPlace place;
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared; prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program); g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
...@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) { ...@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) {
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
}
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) { TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...@@ -42,6 +43,7 @@ void RunServer(std::shared_ptr<distributed::RPCServer> service) { ...@@ -42,6 +43,7 @@ void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
...@@ -109,6 +111,19 @@ static int64_t GetTimestamp() { ...@@ -109,6 +111,19 @@ static int64_t GetTimestamp() {
return tp.tv_sec * 1000 + tp.tv_usec / 1000; return tp.tv_sec * 1000 + tp.tv_usec / 1000;
} }
// For sync, sparse variables need recover grad type from LodTensor to
// SelectedRows
void ResetSparseVarsType(framework::Scope *recv_scope) {
auto *ins = distributed::LargeScaleKV::GetInstance();
auto grads = ins->GetAllGrads();
for (auto &grad : grads) {
auto *v = recv_scope->FindVar(grad);
v->Clear();
v->GetMutable<framework::SelectedRows>();
}
}
void ListenAndServOp::RunSyncLoop( void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx, framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
...@@ -179,6 +194,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -179,6 +194,7 @@ void ListenAndServOp::RunSyncLoop(
VLOG(3) << "ResetReceivedVars"; VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
ResetSparseVarsType(recv_scope);
VLOG(3) << "wait all clients to get parameters back"; VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
...@@ -372,12 +388,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -372,12 +388,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestGetHandler(distributed_mode, dc_sgd)); new distributed::RequestGetHandler(distributed_mode, dc_sgd));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(distributed_mode)); new distributed::RequestPrefetchHandler(distributed_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(
distributed_mode, checkpoint_block_id)); new distributed::RequestCheckpointHandler(distributed_mode));
request_get_no_barrier_handler_.reset( request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler()); new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(new distributed::RequestNotifyHandler( request_notify_handler_.reset(
distributed_mode, lr_decay_block_id)); new distributed::RequestNotifyHandler(distributed_mode, fan_in));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num); request_send_handler_.get(), rpc_send_thread_num);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册