未验证 提交 caa90a65 编写于 作者: T tangwei12 提交者: GitHub

Integrated Trainer of Parameter Server (API add...

Integrated Trainer of Parameter Server (API add `fluid.contrib.layers.sparse_embedding` only) (#22957)

* Integrated Trainer of Parameter Server
上级 af74675b
......@@ -42,53 +42,18 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
}
}
// get RpcContext and remote send and recv op
// get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
platform::errors::InvalidArgument(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
}
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
platform::errors::InvalidArgument(
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
#endif
}
......
......@@ -122,7 +122,7 @@ class SelectedRows {
/*
* @brief Get the index of the key from id_to_index_ map.
*/
inline int64_t GetIndexFromId(int64_t key) {
inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
......
......@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework
} // namespace paddle
......@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
......
......@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
......@@ -26,7 +27,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......@@ -50,12 +51,12 @@ else()
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
......@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
......
......@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient {
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
......
......@@ -22,67 +22,69 @@ namespace paddle {
namespace operators {
namespace distributed {
struct RpcContext {
RpcContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id,
bool merge_add_ = true, bool use_send_handler_ = true)
struct CommContext {
CommContext() = default;
CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name),
splited_var_names(names),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
use_send_handler(use_send_handler_) {}
is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
RpcContext(const RpcContext &ctx) {
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
bool merge_add;
bool use_send_handler;
};
std::string print() const {
std::stringstream ss;
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
os << "splited_var_names: [";
for (auto &name : rpc_ctx.splited_var_names) {
os << name << ", ";
}
os << "]\n";
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
os << "epmap: [";
for (auto &ep : rpc_ctx.epmap) {
os << ep << ", ";
}
os << "]\n";
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
ss << " aggregation->add: " << merge_add << " ";
ss << " is_sparse: " << is_sparse << "\n";
ss << " is_distributed: " << is_distributed << "\n";
os << "height_sections: [";
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
return ss.str();
}
os << "]\n";
os << "merge add: " << rpc_ctx.merge_add;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n";
os << "}";
return os;
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed
} // namespace operators
......
......@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
}
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
const auto ch = GetChannel(ep);
......@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);
......
......@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
......
......@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase {
void Process() override {
std::string varname = GetReqName();
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
......@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase {
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector,
const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids;
for (auto id : ids_vector) {
all_ids.insert(id);
}
auto abs_sections = ToAbsoluteSection(height_section);
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names,
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
const std::vector<int64_t> &in_ids,
const std::vector<std::string> &in_varnames, const int tables,
const int pservers, const bool is_distibuted, framework::Scope *scope,
std::vector<std::vector<int64_t>> *splited_ids,
std::vector<std::vector<int64_t>> *origin_ids) {
PADDLE_ENFORCE_EQ(
in_varnames.size(), tables,
platform::errors::OutOfRange(
"send varnames size: %d not equal table number: %d, internal error",
in_varnames.size(), tables));
PADDLE_ENFORCE_LE(
tables, pservers,
platform::errors::OutOfRange("table number %d not equal or less than "
"pserver number: %d, internal error",
tables, pservers));
auto place = platform::CPUPlace();
for (size_t i = 0; i < in_var_names.size(); ++i) {
auto* id_tensor =
scope->Var(in_var_names[i])->GetMutable<framework::LoDTensor>();
auto& ids = splited_ids[i];
std::set<int64_t> st(in_ids.begin(), in_ids.end());
std::vector<int64_t> all_ids;
all_ids.assign(st.begin(), st.end());
splited_ids->resize(tables);
origin_ids->resize(tables);
if (is_distibuted) {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*splited_ids)[pserver_id].push_back(id);
(*origin_ids)[pserver_id].push_back(id);
}
} else {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*origin_ids)[pserver_id].push_back(id);
id = id / pservers;
(*splited_ids)[pserver_id].push_back(id);
}
}
for (size_t i = 0; i < in_varnames.size(); ++i) {
auto *id_tensor =
scope->Var(in_varnames[i])->GetMutable<framework::LoDTensor>();
auto &ids = (*splited_ids)[i];
if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>(
auto *id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
}
......@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection(
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::Scope& scope,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
const std::vector<int64_t> &ids, const TableAndEndpoints &tables,
const framework::ExecutionContext &context, const framework::Scope &scope,
const bool is_distributed,
std::unordered_map<int64_t, std::vector<float>> *recved_vec_map) {
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(context.GetPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
......@@ -99,19 +121,17 @@ void prefetch_core(
out_var_names.push_back("prefetch_recv@" + tables[i].second);
}
auto splited_ids = SplitIds(ids, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope.get());
std::vector<std::vector<int64_t>> split_ids;
std::vector<std::vector<int64_t>> origin_ids;
SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers,
is_distributed, local_scope.get(),
&split_ids, &origin_ids);
// create output var in local scope
for (auto& name : out_var_names) {
for (auto &name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
}
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
......@@ -126,20 +146,18 @@ void prefetch_core(
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), "");
for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) {
auto &ids_in_this_section = origin_ids[o_idx];
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx])
->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
auto &prefetch_out_var =
local_scope->Var(out_var_names[o_idx])->Get<framework::LoDTensor>();
const auto *out_var_data = prefetch_out_var.data<float>();
auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
......@@ -147,8 +165,7 @@ void prefetch_core(
auto row_numel = dims[1];
for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
auto origin_id = ids_in_this_section[i];
std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs;
......@@ -159,38 +176,35 @@ void prefetch_core(
}
}
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names,
endpoints, height_sections, context, scope);
void prefetch(const std::string &id_name, const std::string &out_name,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed,
table_names, endpoints, context, scope);
}
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
void prefetchs(const std::vector<std::string> &id_var_names,
const std::vector<std::string> &out_var_names,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
auto vec_dim_1 = 0;
framework::Variable* var = scope.FindVar(persistable_var_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"prefetch can only support LodTensor only"));
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
auto vec_dim_0 = 0;
framework::Variable *var = scope.FindVar(persistable_var_name);
if (var->IsType<SelectedRows>()) {
vec_dim_1 = var->Get<framework::SelectedRows>().value().dims()[1];
} else {
vec_dim_0 = var->Get<framework::LoDTensor>().dims()[0];
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
}
PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument(
......@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names,
PADDLE_THROW("multi prefetch only support CPU currently");
}
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables;
for (auto& id_name : id_var_names) {
auto* id_tensor =
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
id_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
ids_group.push_back(ids);
ids_lods.push_back(id_tensor->lod());
for (auto &id_name : id_var_names) {
auto *in_var = scope.FindVar(id_name);
auto &id_tensor = in_var->Get<framework::LoDTensor>();
std::copy_n(id_tensor.data<int64_t>(), id_tensor.numel(),
back_inserter(ids_union));
}
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end());
for (auto &i : ids_union) {
PADDLE_ENFORCE_GE(
i, 0, platform::errors::OutOfRange(
"each element in embedding should be larger or equal 0"));
if (!is_distributed) {
PADDLE_ENFORCE_LT(
i, vec_dim_0,
platform::errors::OutOfRange(
"embedding id must in [0, %d) when is_distributed False",
vec_dim_0));
}
}
for (size_t i = 0; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i]));
}
std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, height_sections, context, scope,
prefetch_core(ids_union, tables, context, scope, is_distributed,
&recved_vec_map);
auto padding_idx = distributed::kNoPadding;
......@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
padding_idx = context.Attr<int64_t>("padding_idx");
}
// copy vectors to out vars
for (size_t i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i];
auto* out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto* out_d = out_t->mutable_data<float>(place);
auto *in_var = scope.FindVar(id_var_names[i]);
auto &id_tensor = in_var->Get<framework::LoDTensor>();
auto ids_size = id_tensor.dims()[0];
const auto *id_data = id_tensor.data<int64_t>();
for (size_t idx = 0; idx < ids.size(); idx++) {
const auto& id = ids[idx];
auto *out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->set_lod(id_tensor.lod());
out_t->Resize(framework::make_ddim({ids_size, vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = id_data[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else {
......
......@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
......@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <set>
#include <string>
......@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(2) << "ParameterRecv in " << rpc_ctx.var_name;
void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
int64_t height = 0;
int64_t ids_num = 0;
int64_t width = 0;
std::vector<int64_t> all_ids;
auto pserver_num = rpc_ctx.splited_varnames.size();
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
height += recv_t.height();
ids_num += recv_t.rows().size();
width = recv_t.value().dims()[1];
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
}
auto *var = scope.FindVar(rpc_ctx.var_name);
auto *t_ = var->GetMutable<framework::SelectedRows>();
T *out_data =
t_->mutable_value()->mutable_data<T>({ids_num, width}, cpu_place);
t_->set_height(height);
t_->set_rows(all_ids);
int64_t cnt = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto rows = recv_t.rows().size();
const T *in_data = recv_t.value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
}
t_->SyncIndex();
}
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>() ||
recv_var->IsType<framework::SelectedRows>()) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
if (recv_var->IsType<framework::LoDTensor>()) {
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(),
recv_var_name, recv_var_name));
} else {
// sparse param in pserver_scope is SelectedRows
rets.push_back(rpc_client->AsyncGetVar(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
}
std::vector<distributed::VarHandlePtr> rets;
// variable do not spilt
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
return;
} else {
PADDLE_THROW("unsupported var type to recv!");
PADDLE_ENFORCE(false, platform::errors::Unimplemented(
"ParameterRecv can not recv dense with multi "
"parts now, add it soon."));
}
}
// concat recved tensor into one var
if (recv_var->IsType<framework::LoDTensor>()) {
size_t output_offset = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
int64_t recv_numel = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *recv_var = local_scope->FindVar(recv_var_name);
if (recv_var->IsType<framework::LoDTensor>()) {
auto &in = recv_var->Get<framework::LoDTensor>();
recv_numel += in.numel();
auto in_stride = framework::stride_numel(in.dims());
auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in.data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
<< sstream.str();
}
for (size_t i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
}
row_offset += recv_slr.height();
} else {
PADDLE_THROW("unsupported recieved var type");
}
}
auto numel = recv_tensor->numel();
PADDLE_ENFORCE_EQ(
recv_numel, numel,
platform::errors::InvalidArgument(
"The number of receive tensor's elements are not valid. The "
"recevie tensor numel is %d, the actual tensor numel is %d.",
recv_numel, numel));
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto cpu_place = platform::CPUPlace();
auto *slr = recv_var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
int64_t width = 0;
int64_t height = 0;
std::vector<int64_t> new_rows{};
// trans sparse ids from local to global
std::vector<int64_t> abs_sections =
ToAbsoluteSection(rpc_ctx.height_sections);
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
width = var_slr->mutable_value()->dims()[1];
height += var_slr->height();
auto row_offset = abs_sections[i];
VLOG(4) << "Recv split_var " << recv_var_name << " Row size "
<< var_slr_row->size();
for (size_t j = 0; j < var_slr_row->size(); j++) {
new_rows.push_back(row_offset + (*var_slr_row)[j]);
}
}
slr->set_rows(new_rows);
slr->set_height(height);
slr->mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(slr->mutable_rows()->size()), width}),
cpu_place);
auto *slr_data = slr->mutable_value()->data<float>();
size_t row_offset = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
auto var_slr_row_size = var_slr_row->size();
auto *var_slr_data = var_slr->mutable_value()->data<float>();
memcpy(slr_data + row_offset * width, var_slr_data,
sizeof(float) * width * var_slr_row_size);
row_offset += var_slr_row_size;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool barrier) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
platform::errors::InvalidArgument(
"origin_varnames.size() >= 1 is permitted"));
if (rpc_ctx.is_sparse) {
RecvSelectedRows<T>(rpc_ctx, scope);
} else {
RecvLodTensor<T>(rpc_ctx, scope);
}
VLOG(2) << "ParameterRecv out " << rpc_ctx.var_name;
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
}
template struct ParameterRecv<float>;
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
......@@ -26,7 +26,10 @@ namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed
......
......@@ -41,42 +41,67 @@ using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext(
const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) {
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext(
const CommContext &rpc_ctx, const framework::Scope &scope,
int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1");
if (multi_parts == 1) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i]));
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (int x = 0; x < multi_parts; x++) {
auto table =
string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
}
}
PADDLE_ENFORCE_GE(multi_parts, 1,
platform::errors::InvalidArgument(
"multi_parts must == 1 in parameter send, now is: %d",
multi_parts));
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i]));
}
} else if (send_var->IsType<framework::LoDTensor>()) {
PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
} else {
PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!");
PADDLE_THROW(platform::errors::InvalidArgument(
"GetMultiFieldCommContext unsupported LoDTensor current!"));
}
return table_pairs;
} // namespace distributed
void SendByNotifyRPC(const CommContext &rpc_ctx,
const framework::Scope &scope) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto &send_var_name = rpc_ctx.var_name;
std::vector<distributed::VarHandlePtr> rets;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
if (NeedSend(scope, send_var_name)) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope,
send_var_name));
VLOG(4) << "send var " << send_var_name << " by notify RPC done";
}
} else {
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name;
}
for (auto &handle : rets) {
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool sync,
int multi_parts) {
if (rpc_ctx.var_name == STEP_COUNTER) {
SendByNotifyRPC(rpc_ctx, scope);
return;
}
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_var_names.size();
size_t out_num = rpc_ctx.splited_varnames.size();
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
......@@ -110,72 +134,49 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
// create output var in local scope
size_t row_offset = 0;
for (size_t i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i])
->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0];
}
} else {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[0])
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0])
->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor);
}
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &send_var_name = rpc_ctx.splited_varnames[i];
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << " send var name: " << send_var_name
<< "endpoint: " << endpoint;
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_varnames[i];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader async mode, you need to "
"turn it false.";
LOG(WARNING)
<< "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader/dataloader async mode, you need to "
"turn it false.";
}
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts);
auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1);
outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size());
......@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs.push_back(out);
}
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
auto table_idx = send_rows[i] % multi_parts;
auto out_idx = ep_idx * multi_parts + table_idx;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
if (!rpc_ctx.is_distributed) {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = send_rows[i] % pserver_num;
auto id = send_rows[i] / pserver_num;
outs_rows_idx[ep_idx].push_back(id);
outs_dense_idx[ep_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
} else {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto out_idx = send_rows[i] % pserver_num;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
auto place = platform::CPUPlace();
for (size_t ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) {
for (int part = 0; part < multi_parts; part++) {
auto out_idx = ctx * multi_parts + part;
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]);
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]);
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
......@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(),
"rows should has the same size with tensor dim 0");
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
}
......@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto need_send = NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << "send var name: " << send_var_name
<< "send var endpoint: " << endpoint
<< "need send: " << need_send;
<< " send var endpoint: " << endpoint
<< " need send: " << need_send;
if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
......@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
<< rpc_ctx.splited_varnames[i];
}
}
} else {
......@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
if (sync) {
for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
}
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
......@@ -26,7 +26,7 @@ namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts);
};
......
......@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000;
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
......@@ -38,13 +39,13 @@ namespace distributed {
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
......@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
auto *var = scope->FindVar(run_varname);
// for sparse ids
if (var->IsType<framework::SelectedRows>()) {
if (distributed_mode_ == DistributedMode::kAsync ||
distributed_mode_ == DistributedMode::kHalfAsync) {
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->GradInLargeScale(run_varname)) {
auto *large_scale_var = ins->GetByGrad(run_varname);
for (auto name : large_scale_var->CachedVarnames()) {
scope->Var(name);
}
}
}
if (distributed_mode_ == DistributedMode::kGeo) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(
run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(
run_varname, grad_slr.rows());
}
}
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
......@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
......@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist ";
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
for (auto &row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
......@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
......@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
(*outvar)->GetMutable<framework::LoDTensor>();
VLOG(1) << "Prefetch "
<< "tablename: " << table_name << " ids:" << varname
<< " out: " << out_var_name;
paddle::platform::CPUPlace cpu_place;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->ParamInLargeScale(table_name)) {
auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
} else {
(*outvar)->GetMutable<framework::LoDTensor>();
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
// TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
return true;
}
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname;
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestNotifyHandler: " << varname
<< ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece decay_piece(STEP_COUNTER);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto *send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
auto *send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
auto counter = decay_counters.at(trainer_id);
counter += send_value[0];
decay_counters.at(trainer_id) = counter;
auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER);
if (global_step_var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find LEARNING_RATE_DECAY_COUNTER "));
}
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
auto global_counter = 0;
for (auto &trainer_counter : decay_counters) {
global_counter += trainer_counter.second;
}
value[0] = global_counter;
if (lr_decay_prepared_ctx_.get() == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find decay block for executor"));
}
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
......@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
......@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(int distributed_mode,
int checkpoint_notify_id)
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
explicit RequestCheckpointHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
......@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;
std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id)
explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id;
this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
......@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int lr_decay_block_id;
int trainers;
std::unordered_map<int, int64_t> decay_counters;
};
} // namespace distributed
......
......@@ -77,8 +77,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
......
......@@ -34,7 +34,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table);
USE_NO_KERNEL_OP(lookup_sparse_table_read);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
......@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_sparse_table");
op->SetType("lookup_sparse_table_read");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR);
......@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) {
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
......@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) {
server_thread.join();
}
TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
}
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
......@@ -42,6 +43,7 @@ void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
......@@ -109,6 +111,19 @@ static int64_t GetTimestamp() {
return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}
// For sync, sparse variables need recover grad type from LodTensor to
// SelectedRows
void ResetSparseVarsType(framework::Scope *recv_scope) {
auto *ins = distributed::LargeScaleKV::GetInstance();
auto grads = ins->GetAllGrads();
for (auto &grad : grads) {
auto *v = recv_scope->FindVar(grad);
v->Clear();
v->GetMutable<framework::SelectedRows>();
}
}
void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
......@@ -179,6 +194,7 @@ void ListenAndServOp::RunSyncLoop(
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
ResetSparseVarsType(recv_scope);
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
......@@ -372,12 +388,12 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestGetHandler(distributed_mode, dc_sgd));
request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(distributed_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
distributed_mode, checkpoint_block_id));
request_checkpoint_handler_.reset(
new distributed::RequestCheckpointHandler(distributed_mode));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(new distributed::RequestNotifyHandler(
distributed_mode, lr_decay_block_id));
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(distributed_mode, fan_in));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num);
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册