未验证 提交 caa90a65 编写于 作者: T tangwei12 提交者: GitHub

Integrated Trainer of Parameter Server (API add...

Integrated Trainer of Parameter Server (API add `fluid.contrib.layers.sparse_embedding` only) (#22957)

* Integrated Trainer of Parameter Server
上级 af74675b
...@@ -42,45 +42,10 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope, ...@@ -42,45 +42,10 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
} }
} }
// get RpcContext and remote send and recv op // get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) {
auto *instance = operators::distributed::Communicator::GetInstance(); auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false; auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true, PADDLE_ENFORCE_EQ(initialized, true,
...@@ -88,7 +53,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -88,7 +53,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
"Communicator is not Initialized, you may use " "Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/" "FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)")); "develop/markdown_doc/transpiler)"));
}
#endif #endif
} }
......
...@@ -122,7 +122,7 @@ class SelectedRows { ...@@ -122,7 +122,7 @@ class SelectedRows {
/* /*
* @brief Get the index of the key from id_to_index_ map. * @brief Get the index of the key from id_to_index_ map.
*/ */
inline int64_t GetIndexFromId(int64_t key) { inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key); auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) { if (iter == id_to_index_.end()) {
return -1; return -1;
......
...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) { ...@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy"); PADDLE_THROW("unknown var type to copy");
} }
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec ...@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder) cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
...@@ -26,7 +27,7 @@ if(WITH_GRPC) ...@@ -26,7 +27,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc collective_client.cc collective_server.cc
${GRPC_SRCS} ${GRPC_SRCS}
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
...@@ -50,12 +51,12 @@ else() ...@@ -50,12 +51,12 @@ else()
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op) DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif() endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_op) DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
...@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep, ...@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
} }
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dirname,
const std::string& varname,
int64_t time_out) { int64_t time_out) {
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(varname);
req.set_out_varname(dir); req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out); return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
} }
......
...@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient { ...@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient {
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override; bool Wait() override;
......
...@@ -22,67 +22,69 @@ namespace paddle { ...@@ -22,67 +22,69 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
struct RpcContext { struct CommContext {
RpcContext() = default; CommContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names, CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap, const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id, const std::vector<int64_t> &sections,
bool merge_add_ = true, bool use_send_handler_ = true) const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name), : var_name(name),
splited_var_names(names), splited_varnames(names),
epmap(emap), epmap(emap),
height_sections(sections), height_sections(sections),
origin_varnames(origin_names),
trainer_id(id), trainer_id(id),
merge_add(merge_add_), merge_add(merge_add_),
use_send_handler(use_send_handler_) {} is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
RpcContext(const RpcContext &ctx) { CommContext(const CommContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names; splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap; epmap = ctx.epmap;
height_sections = ctx.height_sections; height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id; trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add; merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler; is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
} }
std::string var_name; std::string print() const {
std::vector<std::string> splited_var_names; std::stringstream ss;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
bool merge_add;
bool use_send_handler;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
os << "splited_var_names: ["; for (size_t i = 0; i < splited_varnames.size(); i++) {
for (auto &name : rpc_ctx.splited_var_names) { ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
os << name << ", "; << " section: " << height_sections[i] << " ";
} }
os << "]\n";
os << "epmap: ["; ss << "origin varnames: ";
for (auto &ep : rpc_ctx.epmap) { for (size_t i = 0; i < origin_varnames.size(); i++) {
os << ep << ", "; ss << origin_varnames[i] << " ";
} }
os << "]\n";
os << "height_sections: ["; ss << " aggregation->add: " << merge_add << " ";
for (auto &section : rpc_ctx.height_sections) { ss << " is_sparse: " << is_sparse << "\n";
os << section << ", "; ss << " is_distributed: " << is_distributed << "\n";
return ss.str();
} }
os << "]\n";
os << "merge add: " << rpc_ctx.merge_add; std::string var_name;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n"; std::vector<std::string> splited_varnames;
os << "}"; std::vector<std::string> epmap;
return os; std::vector<int64_t> height_sections;
} std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
......
...@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
} }
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir, const std::string& dirname,
const std::string& varname,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
...@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(varname);
req.set_out_varname(dir); req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
......
...@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient { ...@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify( VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify( VarHandlePtr AsyncDistributeNotify(
......
...@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase { ...@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase {
void Process() override { void Process() override {
std::string varname = GetReqName(); std::string varname = GetReqName();
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
...@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase { ...@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase {
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName(); std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name); auto invar = scope->FindVar(in_var_name);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
此差异已折叠。
...@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor; ...@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector,
const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids;
for (auto id : ids_vector) {
all_ids.insert(id);
}
auto abs_sections = ToAbsoluteSection(height_section);
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
static void SplitIdsIntoMultipleVarsBySection( static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names, const std::vector<int64_t> &in_ids,
const std::vector<int64_t>& height_section, const std::vector<std::string> &in_varnames, const int tables,
const std::vector<std::vector<int64_t>>& splited_ids, const int pservers, const bool is_distibuted, framework::Scope *scope,
framework::Scope* scope) { std::vector<std::vector<int64_t>> *splited_ids,
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), ""); std::vector<std::vector<int64_t>> *origin_ids) {
PADDLE_ENFORCE_EQ(
in_varnames.size(), tables,
platform::errors::OutOfRange(
"send varnames size: %d not equal table number: %d, internal error",
in_varnames.size(), tables));
PADDLE_ENFORCE_LE(
tables, pservers,
platform::errors::OutOfRange("table number %d not equal or less than "
"pserver number: %d, internal error",
tables, pservers));
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
for (size_t i = 0; i < in_var_names.size(); ++i) { std::set<int64_t> st(in_ids.begin(), in_ids.end());
auto* id_tensor = std::vector<int64_t> all_ids;
scope->Var(in_var_names[i])->GetMutable<framework::LoDTensor>(); all_ids.assign(st.begin(), st.end());
auto& ids = splited_ids[i];
splited_ids->resize(tables);
origin_ids->resize(tables);
if (is_distibuted) {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*splited_ids)[pserver_id].push_back(id);
(*origin_ids)[pserver_id].push_back(id);
}
} else {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*origin_ids)[pserver_id].push_back(id);
id = id / pservers;
(*splited_ids)[pserver_id].push_back(id);
}
}
for (size_t i = 0; i < in_varnames.size(); ++i) {
auto *id_tensor =
scope->Var(in_varnames[i])->GetMutable<framework::LoDTensor>();
auto &ids = (*splited_ids)[i];
if (!ids.empty()) { if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>( auto *id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place); framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size()); memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
} }
...@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection( ...@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection(
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints; typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
void prefetch_core( void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables, const std::vector<int64_t> &ids, const TableAndEndpoints &tables,
const std::vector<int64_t>& height_sections, const framework::ExecutionContext &context, const framework::Scope &scope,
const framework::ExecutionContext& context, const framework::Scope& scope, const bool is_distributed,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) { std::unordered_map<int64_t, std::vector<float>> *recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); distributed::RPCClient *rpc_client =
auto& actual_ctx = *pool.Get(context.GetPlace()); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(context.GetPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
...@@ -99,19 +121,17 @@ void prefetch_core( ...@@ -99,19 +121,17 @@ void prefetch_core(
out_var_names.push_back("prefetch_recv@" + tables[i].second); out_var_names.push_back("prefetch_recv@" + tables[i].second);
} }
auto splited_ids = SplitIds(ids, height_sections); std::vector<std::vector<int64_t>> split_ids;
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, std::vector<std::vector<int64_t>> origin_ids;
local_scope.get()); SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers,
is_distributed, local_scope.get(),
&split_ids, &origin_ids);
// create output var in local scope // create output var in local scope
for (auto& name : out_var_names) { for (auto &name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>(); local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) { for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) { if (NeedSend(*local_scope.get(), in_var_names[i])) {
...@@ -126,20 +146,18 @@ void prefetch_core( ...@@ -126,20 +146,18 @@ void prefetch_core(
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
} }
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), ""); for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) {
auto &ids_in_this_section = origin_ids[o_idx];
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) { if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx]) auto &prefetch_out_var =
->Get<framework::LoDTensor>(); local_scope->Var(out_var_names[o_idx])->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>(); const auto *out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims(); auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, ""); PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]); PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
...@@ -147,8 +165,7 @@ void prefetch_core( ...@@ -147,8 +165,7 @@ void prefetch_core(
auto row_numel = dims[1]; auto row_numel = dims[1];
for (int64_t i = 0; i < dims[0]; ++i) { for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i]; auto origin_id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
std::vector<float> vecs(row_numel); std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin()); std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs; (*recved_vec_map)[origin_id] = vecs;
...@@ -159,38 +176,35 @@ void prefetch_core( ...@@ -159,38 +176,35 @@ void prefetch_core(
} }
} }
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string &id_name, const std::string &out_name,
const std::string& persistable_var_name, const bool backfill, const std::string &persistable_var_name,
const std::vector<std::string>& table_names, const bool is_distributed,
const std::vector<std::string>& endpoints, const std::vector<std::string> &table_names,
const std::vector<int64_t>& height_sections, const std::vector<std::string> &endpoints,
const framework::ExecutionContext& context, const framework::ExecutionContext &context,
const framework::Scope& scope) { const framework::Scope &scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names, prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed,
endpoints, height_sections, context, scope); table_names, endpoints, context, scope);
} }
void prefetchs(const std::vector<std::string>& id_var_names, void prefetchs(const std::vector<std::string> &id_var_names,
const std::vector<std::string>& out_var_names, const std::vector<std::string> &out_var_names,
const std::string& persistable_var_name, const bool backfill, const std::string &persistable_var_name,
const std::vector<std::string>& table_names, const bool is_distributed,
const std::vector<std::string>& endpoints, const std::vector<std::string> &table_names,
const std::vector<int64_t>& height_sections, const std::vector<std::string> &endpoints,
const framework::ExecutionContext& context, const framework::ExecutionContext &context,
const framework::Scope& scope) { const framework::Scope &scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
auto vec_dim_1 = 0; auto vec_dim_1 = 0;
framework::Variable* var = scope.FindVar(persistable_var_name); auto vec_dim_0 = 0;
framework::Variable *var = scope.FindVar(persistable_var_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"prefetch can only support LodTensor only"));
if (var->IsType<SelectedRows>()) {
vec_dim_1 = var->Get<framework::SelectedRows>().value().dims()[1];
} else {
vec_dim_0 = var->Get<framework::LoDTensor>().dims()[0];
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1]; vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
}
PADDLE_ENFORCE_GT(vec_dim_1, 0, PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names,
PADDLE_THROW("multi prefetch only support CPU currently"); PADDLE_THROW("multi prefetch only support CPU currently");
} }
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union; std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables; TableAndEndpoints tables;
for (auto& id_name : id_var_names) { for (auto &id_name : id_var_names) {
auto* id_tensor = auto *in_var = scope.FindVar(id_name);
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>(); auto &id_tensor = in_var->Get<framework::LoDTensor>();
auto id_dims = id_tensor->dims(); std::copy_n(id_tensor.data<int64_t>(), id_tensor.numel(),
id_tensor->Resize(framework::make_ddim( back_inserter(ids_union));
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
ids_group.push_back(ids);
ids_lods.push_back(id_tensor->lod());
} }
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end()); std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end()); ids_union.assign(s.begin(), s.end());
for (auto &i : ids_union) {
PADDLE_ENFORCE_GE(
i, 0, platform::errors::OutOfRange(
"each element in embedding should be larger or equal 0"));
if (!is_distributed) {
PADDLE_ENFORCE_LT(
i, vec_dim_0,
platform::errors::OutOfRange(
"embedding id must in [0, %d) when is_distributed False",
vec_dim_0));
}
}
for (size_t i = 0; i < table_names.size(); i++) { for (size_t i = 0; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i])); tables.push_back(std::make_pair(table_names[i], endpoints[i]));
} }
std::unordered_map<int64_t, std::vector<float>> recved_vec_map; std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, height_sections, context, scope, prefetch_core(ids_union, tables, context, scope, is_distributed,
&recved_vec_map); &recved_vec_map);
auto padding_idx = distributed::kNoPadding; auto padding_idx = distributed::kNoPadding;
...@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
padding_idx = context.Attr<int64_t>("padding_idx"); padding_idx = context.Attr<int64_t>("padding_idx");
} }
// copy vectors to out vars
for (size_t i = 0; i < out_var_names.size(); i++) { for (size_t i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i]; auto *in_var = scope.FindVar(id_var_names[i]);
auto* out_t = auto &id_tensor = in_var->Get<framework::LoDTensor>();
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>(); auto ids_size = id_tensor.dims()[0];
out_t->Resize( const auto *id_data = id_tensor.data<int64_t>();
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto* out_d = out_t->mutable_data<float>(place); auto *out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
for (size_t idx = 0; idx < ids.size(); idx++) { out_t->set_lod(id_tensor.lod());
const auto& id = ids[idx]; out_t->Resize(framework::make_ddim({ids_size, vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = id_data[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) { if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1); memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else { } else {
......
...@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names, ...@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
const std::string& persistable_var_name, const bool backfill, const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
...@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill, const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
...@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows; ...@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T> template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx, void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(2) << "ParameterRecv in " << rpc_ctx.var_name;
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace()); auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>() ||
recv_var->IsType<framework::SelectedRows>()) {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i]; auto &recv_var_name = rpc_ctx.splited_varnames[i];
local_scope->Var(recv_var_name); local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i]; VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
if (recv_var->IsType<framework::LoDTensor>()) {
// sparse param in recv_scope is LoDTensor // sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx, rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name)); recv_var_name, recv_var_name));
} else {
// sparse param in pserver_scope is SelectedRows
rets.push_back(rpc_client->AsyncGetVar(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
} "internal error in RPCClient"));
} else {
PADDLE_THROW("unsupported var type to recv!");
} }
// concat recved tensor into one var int64_t height = 0;
if (recv_var->IsType<framework::LoDTensor>()) { int64_t ids_num = 0;
size_t output_offset = 0; int64_t width = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor = std::vector<int64_t> all_ids;
recv_var->GetMutable<framework::LoDTensor>(); auto pserver_num = rpc_ctx.splited_varnames.size();
auto dev_ctx = paddle::platform::CPUDeviceContext();
int64_t recv_numel = 0; for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
for (auto &recv_var_name : rpc_ctx.splited_var_names) { auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name); auto *recv_var = local_scope->FindVar(recv_var_name);
if (recv_var->IsType<framework::LoDTensor>()) { auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto &in = recv_var->Get<framework::LoDTensor>();
recv_numel += in.numel(); height += recv_t.height();
auto in_stride = framework::stride_numel(in.dims()); ids_num += recv_t.rows().size();
auto out_stride = framework::stride_numel(recv_tensor->dims()); width = recv_t.value().dims()[1];
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride, std::transform(recv_t.rows().begin(), recv_t.rows().end(),
in.data<T>(), in_stride, in_stride[0]); std::back_inserter(all_ids),
output_offset += in_stride[0]; [&](int64_t id) { return id * pserver_num + i; });
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
} }
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " " auto *var = scope.FindVar(rpc_ctx.var_name);
<< sstream.str(); auto *t_ = var->GetMutable<framework::SelectedRows>();
T *out_data =
t_->mutable_value()->mutable_data<T>({ids_num, width}, cpu_place);
t_->set_height(height);
t_->set_rows(all_ids);
int64_t cnt = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto rows = recv_t.rows().size();
const T *in_data = recv_t.value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
} }
t_->SyncIndex();
}
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
for (size_t i = 0; i < recv_slr.rows().size(); ++i) { distributed::RPCClient *rpc_client =
auto row_id = recv_slr.rows()[i] + row_offset; distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width, std::vector<distributed::VarHandlePtr> rets;
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
// variable do not spilt
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
} }
row_offset += recv_slr.height();
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
return;
} else { } else {
PADDLE_THROW("unsupported recieved var type"); PADDLE_ENFORCE(false, platform::errors::Unimplemented(
} "ParameterRecv can not recv dense with multi "
"parts now, add it soon."));
} }
auto numel = recv_tensor->numel(); }
PADDLE_ENFORCE_EQ(
recv_numel, numel, template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool barrier) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The number of receive tensor's elements are not valid. The " "origin_varnames.size() >= 1 is permitted"));
"recevie tensor numel is %d, the actual tensor numel is %d.",
recv_numel, numel)); if (rpc_ctx.is_sparse) {
} else if (recv_var->IsType<framework::SelectedRows>()) { RecvSelectedRows<T>(rpc_ctx, scope);
auto cpu_place = platform::CPUPlace(); } else {
auto *slr = recv_var->GetMutable<framework::SelectedRows>(); RecvLodTensor<T>(rpc_ctx, scope);
slr->mutable_rows()->clear();
slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
int64_t width = 0;
int64_t height = 0;
std::vector<int64_t> new_rows{};
// trans sparse ids from local to global
std::vector<int64_t> abs_sections =
ToAbsoluteSection(rpc_ctx.height_sections);
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
width = var_slr->mutable_value()->dims()[1];
height += var_slr->height();
auto row_offset = abs_sections[i];
VLOG(4) << "Recv split_var " << recv_var_name << " Row size "
<< var_slr_row->size();
for (size_t j = 0; j < var_slr_row->size(); j++) {
new_rows.push_back(row_offset + (*var_slr_row)[j]);
}
}
slr->set_rows(new_rows);
slr->set_height(height);
slr->mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(slr->mutable_rows()->size()), width}),
cpu_place);
auto *slr_data = slr->mutable_value()->data<float>();
size_t row_offset = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
auto var_slr_row_size = var_slr_row->size();
auto *var_slr_data = var_slr->mutable_value()->data<float>();
memcpy(slr_data + row_offset * width, var_slr_data,
sizeof(float) * width * var_slr_row_size);
row_offset += var_slr_row_size;
}
} }
VLOG(2) << "ParameterRecv out " << rpc_ctx.var_name; VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
} }
template struct ParameterRecv<float>; template struct ParameterRecv<float>;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +26,10 @@ namespace distributed { ...@@ -26,7 +26,10 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterRecv { struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope); void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
}; };
}; // namespace distributed }; // namespace distributed
......
...@@ -41,42 +41,67 @@ using DDim = framework::DDim; ...@@ -41,42 +41,67 @@ using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS; typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext( inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext(
const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) { const CommContext &rpc_ctx, const framework::Scope &scope,
int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs; EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) { if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1"); PADDLE_ENFORCE_GE(multi_parts, 1,
platform::errors::InvalidArgument(
"multi_parts must == 1 in parameter send, now is: %d",
multi_parts));
if (multi_parts == 1) { for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
table_pairs.push_back( table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i])); std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i]));
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (int x = 0; x < multi_parts; x++) {
auto table =
string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
}
}
} }
} else if (send_var->IsType<framework::LoDTensor>()) {
PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
} else { } else {
PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!"); PADDLE_THROW(platform::errors::InvalidArgument(
"GetMultiFieldCommContext unsupported LoDTensor current!"));
} }
return table_pairs; return table_pairs;
} // namespace distributed } // namespace distributed
void SendByNotifyRPC(const CommContext &rpc_ctx,
const framework::Scope &scope) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto &send_var_name = rpc_ctx.var_name;
std::vector<distributed::VarHandlePtr> rets;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
if (NeedSend(scope, send_var_name)) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope,
send_var_name));
VLOG(4) << "send var " << send_var_name << " by notify RPC done";
}
} else {
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name;
}
for (auto &handle : rets) {
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
template <typename T> template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool sync, const framework::Scope &scope, bool sync,
int multi_parts) { int multi_parts) {
if (rpc_ctx.var_name == STEP_COUNTER) {
SendByNotifyRPC(rpc_ctx, scope);
return;
}
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
...@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name); auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::LoDTensor>()) { if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_var_names.size(); size_t out_num = rpc_ctx.splited_varnames.size();
if (out_num > 1) { if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>(); auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims(); auto &send_tensor_dims = send_tensor.dims();
...@@ -110,24 +134,23 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -110,24 +134,23 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
// create output var in local scope // create output var in local scope
size_t row_offset = 0; size_t row_offset = 0;
for (size_t i = 0; i < out_num; ++i) { for (size_t i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i]) framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i])
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]); *out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0]; row_offset += outs_dims[i][0];
} }
} else { } else {
auto &send_tensor = send_var->Get<framework::LoDTensor>(); auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[0]) framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0])
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor); out->ShareDataWith(send_tensor);
} }
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i]; auto &send_var_name = rpc_ctx.splited_varnames[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i]; auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint; VLOG(4) << " send var name: " << send_var_name
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name); << "endpoint: " << endpoint;
if (NeedSend(*local_scope.get(), send_var_name)) { if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint; VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar( rets.push_back(rpc_client->AsyncSendVar(
...@@ -135,47 +158,25 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -135,47 +158,25 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done"; VLOG(4) << "send var " << send_var_name << " async handle done";
} else { } else {
VLOG(3) << "don't send non-initialized variable: " VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i]; << rpc_ctx.splited_varnames[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} }
} }
} else if (send_var->IsType<framework::SelectedRows>()) { } else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>(); auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows(); auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) { if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which " LOG(WARNING)
<< "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of " "may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader async mode, you need to " "use_double_buffer in pyreader/dataloader async mode, you need to "
"turn it false."; "turn it false.";
} }
std::vector<std::vector<size_t>> outs_rows_idx; std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx; std::vector<std::vector<size_t>> outs_dense_idx;
auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts); auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1);
outs_rows_idx.resize(table_pairs.size()); outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size()); outs_dense_idx.resize(table_pairs.size());
...@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs.push_back(out); outs.push_back(out);
} }
if (!rpc_ctx.is_distributed) {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = send_rows[i] % pserver_num;
auto id = send_rows[i] / pserver_num;
outs_rows_idx[ep_idx].push_back(id);
outs_dense_idx[ep_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
} else {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars // split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) { for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = GetSectionIndex(send_rows[i], abs_sections); auto out_idx = send_rows[i] % pserver_num;
auto table_idx = send_rows[i] % multi_parts;
auto out_idx = ep_idx * multi_parts + table_idx;
outs_rows_idx[out_idx].push_back(send_rows[i]); outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i); outs_dense_idx[out_idx].push_back(i);
} }
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
for (size_t ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) { for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
for (int part = 0; part < multi_parts; part++) { out_idx++) {
auto out_idx = ctx * multi_parts + part;
auto rows_idx = outs_rows_idx[out_idx]; auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims(); auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size(); dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]); outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear(); outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place()); outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) { if (rows_idx.size() > 0) {
for (auto idx : rows_idx) { for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]); outs[out_idx]->mutable_rows()->push_back(idx);
} }
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place); auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) { for (size_t j = 0; j < rows_idx.size(); j++) {
...@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
src + outs_dense_idx[out_idx][j] * row_numel, src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel); sizeof(T) * row_numel);
} else { } else {
PADDLE_THROW("do not support GPU now"); PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
} }
} }
} }
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(), PADDLE_ENFORCE_EQ(
"rows should has the same size with tensor dim 0"); rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
} }
} }
...@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto need_send = NeedSend(*local_scope.get(), send_var_name); auto need_send = NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << "send var name: " << send_var_name VLOG(4) << "send var name: " << send_var_name
<< "send var endpoint: " << endpoint << " send var endpoint: " << endpoint
<< "need send: " << need_send; << " need send: " << need_send;
if (need_send) { if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint; VLOG(4) << "sending " << send_var_name << " to " << endpoint;
...@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done"; VLOG(4) << "send var " << send_var_name << " async handle done";
} else { } else {
VLOG(4) << "don't send non-initialized variable: " VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i]; << rpc_ctx.splited_varnames[i];
} }
} }
} else { } else {
...@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
if (sync) { if (sync) {
for (auto &handle : rets) { for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle; VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
} }
} }
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h" #include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +26,7 @@ namespace distributed { ...@@ -26,7 +26,7 @@ namespace distributed {
template <typename T> template <typename T>
struct ParameterSend { struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope, void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts); bool sync, int multi_parts);
}; };
......
...@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000; ...@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000;
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" #define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -38,13 +39,13 @@ namespace distributed { ...@@ -38,13 +39,13 @@ namespace distributed {
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Sync // Sync
...@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname); scope->Rename(varname, run_varname);
} }
if (distributed_mode_ == DistributedMode::kGeo && auto *var = scope->FindVar(run_varname);
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr = // for sparse ids
if (var->IsType<framework::SelectedRows>()) {
if (distributed_mode_ == DistributedMode::kAsync ||
distributed_mode_ == DistributedMode::kHalfAsync) {
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->GradInLargeScale(run_varname)) {
auto *large_scale_var = ins->GetByGrad(run_varname);
for (auto name : large_scale_var->CachedVarnames()) {
scope->Var(name);
}
}
}
if (distributed_mode_ == DistributedMode::kGeo) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(
run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>(); scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname, AsyncSparseParamUpdateRecorder::GetInstance()->Update(
grad_slr.rows()); run_varname, grad_slr.rows());
}
}
} }
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope); scope);
return true; return true;
} else { // sync } else { // sync
rpc_server_->WaitCond(kRequestSend); rpc_server_->WaitCond(kRequestSend);
...@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name; << " table_name: " << table_name;
...@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name; VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo && if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) { !table_name.empty()) {
VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist ";
std::vector<int64_t> updated_rows; std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows); varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) { if (VLOG_IS_ON(3)) {
std::ostringstream sstream; std::ostringstream sstream;
sstream << "["; sstream << "[";
for (auto& row_id : updated_rows) { for (auto &row_id : updated_rows) {
sstream << row_id << ", "; sstream << row_id << ", ";
} }
sstream << "]"; sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " " VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str(); << sstream.str();
} }
auto& origin_tensor =
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>(); scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>(); auto *origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims(); auto &dims = origin_tensor.dims();
*outvar = scope->Var(); *outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>(); auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows); out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]); out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim( auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]}); {static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>( auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place()); out_dims, origin_tensor.place());
auto width = dims[1]; auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) { for (size_t i = 0; i < updated_rows.size(); ++i) {
...@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestGetNoBarrierHandler::Handle(const std::string& varname, bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
...@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname, ...@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
(*outvar)->GetMutable<framework::LoDTensor>(); (*outvar)->GetMutable<framework::LoDTensor>();
VLOG(1) << "Prefetch "
<< "tablename: " << table_name << " ids:" << varname
<< " out: " << out_var_name;
paddle::platform::CPUPlace cpu_place;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->ParamInLargeScale(table_name)) {
auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
} else {
auto lookup_table_op = auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name); BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place); lookup_table_op->Run(*scope, cpu_place);
} }
return true; return true;
} }
bool RequestCheckpointHandler::Handle(const std::string& varname, bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
PADDLE_ENFORCE( VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// TODO(tangwei12): find out why scope will be error. // auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); // paddle::platform::CPUPlace cpu_place;
lt_var->clear(); // checkpoint_op->Run(*scope_, cpu_place);
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true; return true;
} }
bool RequestNotifyHandler::Handle(const std::string& varname, bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname; VLOG(3) << "RequestNotifyHandler: " << varname
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); string::Piece decay_piece(STEP_COUNTER);
string::Piece var_name_piece = string::Piece(varname); string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) { if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update"; VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1, auto *send_var = scope->FindVar(varname);
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>(); auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value = auto *send_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place()); send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
auto counter = decay_counters.at(trainer_id);
counter += send_value[0];
decay_counters.at(trainer_id) = counter;
auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER);
if (global_step_var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find LEARNING_RATE_DECAY_COUNTER "));
}
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
auto global_counter = 0;
for (auto &trainer_counter : decay_counters) {
global_counter += trainer_counter.second;
}
value[0] = global_counter;
if (lr_decay_prepared_ctx_.get() == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find decay block for executor"));
}
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
} }
return true; return true;
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp( std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name, const std::string& table_name, const std::string& id_name,
const std::string& out_name) { const std::string& out_name) {
...@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler { class RequestCheckpointHandler final : public RequestHandler {
public: public:
explicit RequestCheckpointHandler(int distributed_mode, explicit RequestCheckpointHandler(int distributed_mode)
int checkpoint_notify_id) : RequestHandler(distributed_mode) {}
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
...@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
int checkpoint_notify_id; std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
}; };
class RequestNotifyHandler final : public RequestHandler { class RequestNotifyHandler final : public RequestHandler {
public: public:
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id) explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) { : RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id; this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
} }
virtual ~RequestNotifyHandler() {} virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
...@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler { ...@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler {
const std::string& table_name = "") override; const std::string& table_name = "") override;
private: private:
int lr_decay_block_id; int trainers;
std::unordered_map<int, int64_t> decay_counters;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -77,8 +77,8 @@ class RPCClient { ...@@ -77,8 +77,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify( virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir, const std::string& ep, const std::string& dirname,
int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify( virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx, const std::string& ep, const platform::DeviceContext& ctx,
......
...@@ -34,7 +34,7 @@ namespace framework = paddle::framework; ...@@ -34,7 +34,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed; namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table); USE_NO_KERNEL_OP(lookup_sparse_table_read);
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
...@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { ...@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}}); framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp(); auto op = block->AppendOp();
op->SetType("lookup_sparse_table"); op->SetType("lookup_sparse_table_read");
op->SetInput("W", {"w"}); op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"}); op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out"); auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
...@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) { ...@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) {
platform::CPUPlace place; platform::CPUPlace place;
framework::Executor exe(place); framework::Executor exe(place);
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared; prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program); g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared); g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
...@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) { ...@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) {
server_thread.join(); server_thread.join();
} }
TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
}
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) { TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -35,19 +32,31 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -35,19 +32,31 @@ class CheckpointNotifyOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap =
std::string dir = Attr<std::string>("dir"); Attr<std::vector<std::string>>("endpoints");
std::string lookup_table_name = Attr<std::string>("lookup_table"); std::string dirname = Attr<std::string>("dirname");
int trainer_id = Attr<int>("trainer_id"); std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice");
VLOG(1) << "is_slice: " << is_slice;
std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames");
std::vector<std::string> remote_varnames =
Attr<std::vector<std::string>>("remote_varnames");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir = auto save_path =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name rpc_client->AsyncCheckpointNotify(epmap[i], save_path,
<< " and dir:" << dir << " to " << epmap[i]; remote_varnames[i]);
VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i];
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true, rpc_client->Wait(), true,
...@@ -59,18 +68,22 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -59,18 +68,22 @@ class CheckpointNotifyOp : public framework::OperatorBase {
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>(
"(string vector, default 127.0.0.1:6164)" "endpoints",
"Parameter Server endpoints in the order") "(string vector)"
.SetDefault({"127.0.0.1:6164"}); "Parameter Server endpoints in the order");
AddAttr<std::string>( AddAttr<std::string>("dirname",
"dir", "(string, default '') indicate the folder checkpoint will use"); "(string) indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table", AddAttr<std::string>("varname", "(string) the var need to be saved");
"(string, default '') the lookup table name"); AddAttr<std::vector<std::string>>(
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); "slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>(
"is_slice",
"is_slice=True means the var has been slice by parameter server");
AddComment(R"DOC( AddComment(R"DOC(
CheckpointNotify operator CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server. the parameter server.
)DOC"); )DOC");
......
...@@ -19,9 +19,10 @@ limitations under the License. */ ...@@ -19,9 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h" #include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -41,6 +42,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -41,6 +42,7 @@ class RecvOp : public framework::OperatorBase {
VLOG(3) << "recv do not run!"; VLOG(3) << "recv do not run!";
return; return;
} }
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames = std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames"); Attr<std::vector<std::string>>("varnames");
...@@ -59,10 +61,13 @@ class RecvOp : public framework::OperatorBase { ...@@ -59,10 +61,13 @@ class RecvOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("recv_varnames"); Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) { if (recv_varnames.size() > 0) {
auto recv_functor = distributed::ParameterRecv<float>(); auto *communicator = distributed::Communicator::GetInstance();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {},
trainer_id); if (communicator == nullptr) {
recv_functor(rpc_ctx, scope); PADDLE_THROW(platform::errors::InvalidArgument(
"need run fleet.init_worker first"));
}
communicator->RecvNoBarrier();
} else { } else {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) { if (with_barrier) {
......
...@@ -52,8 +52,6 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -52,8 +52,6 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
// for remote prefetch // for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch"); auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (remote_prefetch && !epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
...@@ -62,8 +60,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> { ...@@ -62,8 +60,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, embedding_name, false, operators::distributed::prefetch(id_name, out_name, embedding_name, false,
table_names, epmap, height_sections, table_names, epmap, context,
context, context.scope()); context.scope());
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册