未验证 提交 caa90a65 编写于 作者: T tangwei12 提交者: GitHub

Integrated Trainer of Parameter Server (API add...

Integrated Trainer of Parameter Server (API add `fluid.contrib.layers.sparse_embedding` only) (#22957)

* Integrated Trainer of Parameter Server
上级 af74675b
......@@ -42,45 +42,10 @@ inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
}
}
// get RpcContext and remote send and recv op
// get CommContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
for (auto &node : graphs[0]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames =
BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = BOOST_GET_CONST(std::vector<std::string>,
node->Op()->GetNullableAttr("epmap"));
auto height_section = BOOST_GET_CONST(
std::vector<int64_t>, node->Op()->GetNullableAttr("sections"));
auto trainer_id =
BOOST_GET_CONST(int, node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
BOOST_GET_CONST(bool, node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler = BOOST_GET_CONST(
bool, node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
auto *instance = operators::distributed::Communicator::GetInstance();
auto initialized = instance ? true : false;
PADDLE_ENFORCE_EQ(initialized, true,
......@@ -88,7 +53,7 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
"Communicator is not Initialized, you may use "
"FleetAPI(https://github.com/PaddlePaddle/Fleet/tree/"
"develop/markdown_doc/transpiler)"));
}
#endif
}
......
......@@ -122,7 +122,7 @@ class SelectedRows {
/*
* @brief Get the index of the key from id_to_index_ map.
*/
inline int64_t GetIndexFromId(int64_t key) {
inline int64_t GetIndexFromId(int64_t key) const {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
......
......@@ -79,5 +79,6 @@ void CopyVariable(const Variable &src_var, Variable *dst_var) {
PADDLE_THROW("unknown var type to copy");
}
}
} // namespace framework
} // namespace paddle
......@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/variable.h"
namespace paddle {
namespace framework {
......
......@@ -13,6 +13,7 @@ cc_library(async_sparse_param_update_recorder SRCS async_sparse_param_update_rec
cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_recorder_test.cc DEPS async_sparse_param_update_recorder)
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_library(large_scale_kv SRCS large_scale_kv.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
......@@ -26,7 +27,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor large_scale_kv)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......@@ -50,12 +51,12 @@ else()
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_read_op)
endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_op)
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
......@@ -446,11 +446,12 @@ VarHandlePtr BRPCClient::AsyncSendMessage(const std::string& ep,
}
VarHandlePtr BRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
return AsyncSendVarMessage(ep, "CheckPointNotifyRPC", req, time_out);
}
......
......@@ -102,7 +102,8 @@ class BRPCClient : public RPCClient {
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
bool Wait() override;
......
......@@ -22,67 +22,69 @@ namespace paddle {
namespace operators {
namespace distributed {
struct RpcContext {
RpcContext() = default;
struct CommContext {
CommContext() = default;
RpcContext(const std::string &name, const std::vector<std::string> &names,
CommContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id,
bool merge_add_ = true, bool use_send_handler_ = true)
const std::vector<int64_t> &sections,
const std::vector<std::string> &origin_names, int id,
bool merge_add_ = true, bool is_sparse_ = true,
bool is_distributed_ = false)
: var_name(name),
splited_var_names(names),
splited_varnames(names),
epmap(emap),
height_sections(sections),
origin_varnames(origin_names),
trainer_id(id),
merge_add(merge_add_),
use_send_handler(use_send_handler_) {}
is_sparse(is_sparse_),
is_distributed(is_distributed_) {}
RpcContext(const RpcContext &ctx) {
CommContext(const CommContext &ctx) {
var_name = ctx.var_name;
splited_var_names = ctx.splited_var_names;
splited_varnames = ctx.splited_varnames;
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler;
is_sparse = ctx.is_sparse;
origin_varnames = ctx.origin_varnames;
is_distributed = ctx.is_distributed;
}
std::string var_name;
std::vector<std::string> splited_var_names;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
bool merge_add;
bool use_send_handler;
};
std::string print() const {
std::stringstream ss;
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << "{";
os << "var_name: " << rpc_ctx.var_name << "\n";
ss << "varname: " << var_name << " trainer_id: " << trainer_id << " ";
os << "splited_var_names: [";
for (auto &name : rpc_ctx.splited_var_names) {
os << name << ", ";
for (size_t i = 0; i < splited_varnames.size(); i++) {
ss << "slice varname: " << splited_varnames[i] << " ep: " << epmap[i]
<< " section: " << height_sections[i] << " ";
}
os << "]\n";
os << "epmap: [";
for (auto &ep : rpc_ctx.epmap) {
os << ep << ", ";
ss << "origin varnames: ";
for (size_t i = 0; i < origin_varnames.size(); i++) {
ss << origin_varnames[i] << " ";
}
os << "]\n";
os << "height_sections: [";
for (auto &section : rpc_ctx.height_sections) {
os << section << ", ";
ss << " aggregation->add: " << merge_add << " ";
ss << " is_sparse: " << is_sparse << "\n";
ss << " is_distributed: " << is_distributed << "\n";
return ss.str();
}
os << "]\n";
os << "merge add: " << rpc_ctx.merge_add;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n";
os << "}";
return os;
}
std::string var_name;
std::vector<std::string> splited_varnames;
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
std::vector<std::string> origin_varnames;
int trainer_id;
bool merge_add;
bool is_sparse;
bool is_distributed;
};
} // namespace distributed
} // namespace operators
......
......@@ -409,7 +409,8 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
}
VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
const std::string& dirname,
const std::string& varname,
int64_t time_out) {
const auto ch = GetChannel(ep);
......@@ -422,8 +423,8 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
req.set_varname(varname);
req.set_out_varname(dirname);
platform::RecordRPCEvent record_event(method);
......
......@@ -222,7 +222,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
const std::string& ep, const std::string& dirname,
const std::string& varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
......
......@@ -103,11 +103,13 @@ class RequestSend final : public RequestBase {
void Process() override {
std::string varname = GetReqName();
VLOG(4) << "RequestSend var_name:" << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSend var_name:" << varname << " trainer: " << trainer_id;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
......@@ -332,8 +334,9 @@ class RequestPrefetch final : public RequestBase {
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name;
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
namespace distributed {
std::once_flag LargeScaleKV::init_flag_;
std::shared_ptr<LargeScaleKV> LargeScaleKV::scale_kv_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
此差异已折叠。
......@@ -41,39 +41,55 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static std::vector<std::vector<int64_t>> SplitIds(
const std::vector<int64_t>& ids_vector,
const std::vector<int64_t>& height_section) {
std::set<int64_t> all_ids;
for (auto id : ids_vector) {
all_ids.insert(id);
}
auto abs_sections = ToAbsoluteSection(height_section);
std::vector<std::vector<int64_t>> splited_ids;
splited_ids.resize(height_section.size() + 1);
for (auto& id : all_ids) {
auto section_index = GetSectionIndex(id, abs_sections);
splited_ids[section_index].push_back(id - abs_sections[section_index]);
}
return splited_ids;
}
static void SplitIdsIntoMultipleVarsBySection(
const std::vector<std::string>& in_var_names,
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
framework::Scope* scope) {
PADDLE_ENFORCE_EQ(in_var_names.size(), height_section.size(), "");
const std::vector<int64_t> &in_ids,
const std::vector<std::string> &in_varnames, const int tables,
const int pservers, const bool is_distibuted, framework::Scope *scope,
std::vector<std::vector<int64_t>> *splited_ids,
std::vector<std::vector<int64_t>> *origin_ids) {
PADDLE_ENFORCE_EQ(
in_varnames.size(), tables,
platform::errors::OutOfRange(
"send varnames size: %d not equal table number: %d, internal error",
in_varnames.size(), tables));
PADDLE_ENFORCE_LE(
tables, pservers,
platform::errors::OutOfRange("table number %d not equal or less than "
"pserver number: %d, internal error",
tables, pservers));
auto place = platform::CPUPlace();
for (size_t i = 0; i < in_var_names.size(); ++i) {
auto* id_tensor =
scope->Var(in_var_names[i])->GetMutable<framework::LoDTensor>();
auto& ids = splited_ids[i];
std::set<int64_t> st(in_ids.begin(), in_ids.end());
std::vector<int64_t> all_ids;
all_ids.assign(st.begin(), st.end());
splited_ids->resize(tables);
origin_ids->resize(tables);
if (is_distibuted) {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*splited_ids)[pserver_id].push_back(id);
(*origin_ids)[pserver_id].push_back(id);
}
} else {
for (auto &id : all_ids) {
auto pserver_id = id % pservers;
(*origin_ids)[pserver_id].push_back(id);
id = id / pservers;
(*splited_ids)[pserver_id].push_back(id);
}
}
for (size_t i = 0; i < in_varnames.size(); ++i) {
auto *id_tensor =
scope->Var(in_varnames[i])->GetMutable<framework::LoDTensor>();
auto &ids = (*splited_ids)[i];
if (!ids.empty()) {
auto* id_tensor_data = id_tensor->mutable_data<int64_t>(
auto *id_tensor_data = id_tensor->mutable_data<int64_t>(
framework::make_ddim({static_cast<int64_t>(ids.size()), 1}), place);
memcpy(id_tensor_data, ids.data(), sizeof(int64_t) * ids.size());
}
......@@ -83,12 +99,18 @@ static void SplitIdsIntoMultipleVarsBySection(
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::Scope& scope,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
const std::vector<int64_t> &ids, const TableAndEndpoints &tables,
const framework::ExecutionContext &context, const framework::Scope &scope,
const bool is_distributed,
std::unordered_map<int64_t, std::vector<float>> *recved_vec_map) {
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
int pservers = context.Attr<int>("pserver_num");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &actual_ctx = *pool.Get(context.GetPlace());
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
......@@ -99,19 +121,17 @@ void prefetch_core(
out_var_names.push_back("prefetch_recv@" + tables[i].second);
}
auto splited_ids = SplitIds(ids, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope.get());
std::vector<std::vector<int64_t>> split_ids;
std::vector<std::vector<int64_t>> origin_ids;
SplitIdsIntoMultipleVarsBySection(ids, in_var_names, tables.size(), pservers,
is_distributed, local_scope.get(),
&split_ids, &origin_ids);
// create output var in local scope
for (auto& name : out_var_names) {
for (auto &name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
}
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
......@@ -126,20 +146,18 @@ void prefetch_core(
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), "");
for (size_t o_idx = 0; o_idx < out_var_names.size(); ++o_idx) {
auto &ids_in_this_section = origin_ids[o_idx];
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx])
->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
auto &prefetch_out_var =
local_scope->Var(out_var_names[o_idx])->Get<framework::LoDTensor>();
const auto *out_var_data = prefetch_out_var.data<float>();
auto &dims = prefetch_out_var.dims();
PADDLE_ENFORCE_EQ(dims.size(), 2, "");
PADDLE_ENFORCE_EQ(ids_in_this_section.size(), dims[0]);
......@@ -147,8 +165,7 @@ void prefetch_core(
auto row_numel = dims[1];
for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
auto origin_id = ids_in_this_section[i];
std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs;
......@@ -159,38 +176,35 @@ void prefetch_core(
}
}
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names,
endpoints, height_sections, context, scope);
void prefetch(const std::string &id_name, const std::string &out_name,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
prefetchs({id_name}, {out_name}, persistable_var_name, is_distributed,
table_names, endpoints, context, scope);
}
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
void prefetchs(const std::vector<std::string> &id_var_names,
const std::vector<std::string> &out_var_names,
const std::string &persistable_var_name,
const bool is_distributed,
const std::vector<std::string> &table_names,
const std::vector<std::string> &endpoints,
const framework::ExecutionContext &context,
const framework::Scope &scope) {
auto vec_dim_1 = 0;
framework::Variable* var = scope.FindVar(persistable_var_name);
PADDLE_ENFORCE_EQ(var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"prefetch can only support LodTensor only"));
auto vec_dim_0 = 0;
framework::Variable *var = scope.FindVar(persistable_var_name);
if (var->IsType<SelectedRows>()) {
vec_dim_1 = var->Get<framework::SelectedRows>().value().dims()[1];
} else {
vec_dim_0 = var->Get<framework::LoDTensor>().dims()[0];
vec_dim_1 = var->Get<framework::LoDTensor>().dims()[1];
}
PADDLE_ENFORCE_GT(vec_dim_1, 0,
platform::errors::InvalidArgument(
......@@ -203,37 +217,38 @@ void prefetchs(const std::vector<std::string>& id_var_names,
PADDLE_THROW("multi prefetch only support CPU currently");
}
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables;
for (auto& id_name : id_var_names) {
auto* id_tensor =
scope.FindVar(id_name)->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
id_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0] * id_dims[1]), 1}));
auto* id_data = id_tensor->data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor->numel(); ++i) {
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
ids_group.push_back(ids);
ids_lods.push_back(id_tensor->lod());
for (auto &id_name : id_var_names) {
auto *in_var = scope.FindVar(id_name);
auto &id_tensor = in_var->Get<framework::LoDTensor>();
std::copy_n(id_tensor.data<int64_t>(), id_tensor.numel(),
back_inserter(ids_union));
}
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end());
for (auto &i : ids_union) {
PADDLE_ENFORCE_GE(
i, 0, platform::errors::OutOfRange(
"each element in embedding should be larger or equal 0"));
if (!is_distributed) {
PADDLE_ENFORCE_LT(
i, vec_dim_0,
platform::errors::OutOfRange(
"embedding id must in [0, %d) when is_distributed False",
vec_dim_0));
}
}
for (size_t i = 0; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i]));
}
std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, height_sections, context, scope,
prefetch_core(ids_union, tables, context, scope, is_distributed,
&recved_vec_map);
auto padding_idx = distributed::kNoPadding;
......@@ -242,20 +257,20 @@ void prefetchs(const std::vector<std::string>& id_var_names,
padding_idx = context.Attr<int64_t>("padding_idx");
}
// copy vectors to out vars
for (size_t i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i];
auto* out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto *in_var = scope.FindVar(id_var_names[i]);
auto &id_tensor = in_var->Get<framework::LoDTensor>();
auto ids_size = id_tensor.dims()[0];
const auto *id_data = id_tensor.data<int64_t>();
auto* out_d = out_t->mutable_data<float>(place);
for (size_t idx = 0; idx < ids.size(); idx++) {
const auto& id = ids[idx];
auto *out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->set_lod(id_tensor.lod());
out_t->Resize(framework::make_ddim({ids_size, vec_dim_1}));
auto *out_d = out_t->mutable_data<float>(place);
for (auto idx = 0; idx < static_cast<int>(ids_size); idx++) {
const auto &id = id_data[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else {
......
......@@ -31,7 +31,6 @@ void prefetchs(const std::vector<std::string>& id_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
......@@ -39,7 +38,6 @@ void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <set>
#include <string>
......@@ -40,153 +41,131 @@ using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
void RecvSelectedRows(const CommContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(2) << "ParameterRecv in " << rpc_ctx.var_name;
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &cpu_ctx = *pool.Get(platform::CPUPlace());
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>() ||
recv_var->IsType<framework::SelectedRows>()) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
local_scope->Var(recv_var_name);
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
if (recv_var->IsType<framework::LoDTensor>()) {
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(),
*local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
} else {
// sparse param in pserver_scope is SelectedRows
rets.push_back(rpc_client->AsyncGetVar(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
} else {
PADDLE_THROW("unsupported var type to recv!");
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
// concat recved tensor into one var
if (recv_var->IsType<framework::LoDTensor>()) {
size_t output_offset = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor =
recv_var->GetMutable<framework::LoDTensor>();
auto dev_ctx = paddle::platform::CPUDeviceContext();
int64_t recv_numel = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
int64_t height = 0;
int64_t ids_num = 0;
int64_t width = 0;
std::vector<int64_t> all_ids;
auto pserver_num = rpc_ctx.splited_varnames.size();
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
if (recv_var->IsType<framework::LoDTensor>()) {
auto &in = recv_var->Get<framework::LoDTensor>();
recv_numel += in.numel();
auto in_stride = framework::stride_numel(in.dims());
auto out_stride = framework::stride_numel(recv_tensor->dims());
StridedNumelCopyWithAxis<T>(
dev_ctx, 0, recv_tensor->data<T>() + output_offset, out_stride,
in.data<T>(), in_stride, in_stride[0]);
output_offset += in_stride[0];
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto &recv_slr = recv_var->Get<framework::SelectedRows>();
auto &recv_dims = recv_tensor->dims();
int64_t width = recv_dims[1];
recv_numel += recv_slr.height() * width;
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[1], width);
PADDLE_ENFORCE_EQ(recv_slr.value().dims()[0], recv_slr.rows().size());
VLOG(3) << "recv slr " << recv_var_name << " dims "
<< recv_slr.value().dims();
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : recv_slr.rows()) {
sstream << row_id << ", ";
auto &recv_t = recv_var->Get<framework::SelectedRows>();
height += recv_t.height();
ids_num += recv_t.rows().size();
width = recv_t.value().dims()[1];
std::transform(recv_t.rows().begin(), recv_t.rows().end(),
std::back_inserter(all_ids),
[&](int64_t id) { return id * pserver_num + i; });
}
sstream << "]";
VLOG(3) << "recv_slr size: " << recv_slr.rows().size() << " "
<< sstream.str();
auto *var = scope.FindVar(rpc_ctx.var_name);
auto *t_ = var->GetMutable<framework::SelectedRows>();
T *out_data =
t_->mutable_value()->mutable_data<T>({ids_num, width}, cpu_place);
t_->set_height(height);
t_->set_rows(all_ids);
int64_t cnt = 0;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_varnames[i];
auto *recv_var = local_scope->FindVar(recv_var_name);
auto &recv_t = recv_var->Get<framework::SelectedRows>();
auto rows = recv_t.rows().size();
const T *in_data = recv_t.value().data<T>();
std::copy_n(in_data, rows * width, out_data + cnt);
cnt += rows * width;
}
t_->SyncIndex();
}
template <typename T>
void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto cpu_place = platform::CPUPlace();
auto &cpu_ctx = *pool.Get(cpu_place);
for (size_t i = 0; i < recv_slr.rows().size(); ++i) {
auto row_id = recv_slr.rows()[i] + row_offset;
PADDLE_ENFORCE_LT(row_id, recv_dims[0]);
memcpy(recv_tensor->data<T>() + row_id * width,
recv_slr.value().data<T>() + i * width, sizeof(T) * width);
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
// variable do not spilt
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0];
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
rets[i]->Wait(), 0U,
platform::errors::ExecutionTimeout("internal error in RPCClient"));
}
row_offset += recv_slr.height();
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
return;
} else {
PADDLE_THROW("unsupported recieved var type");
}
PADDLE_ENFORCE(false, platform::errors::Unimplemented(
"ParameterRecv can not recv dense with multi "
"parts now, add it soon."));
}
auto numel = recv_tensor->numel();
PADDLE_ENFORCE_EQ(
recv_numel, numel,
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool barrier) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
PADDLE_ENFORCE_GE(rpc_ctx.origin_varnames.size(), 1,
platform::errors::InvalidArgument(
"The number of receive tensor's elements are not valid. The "
"recevie tensor numel is %d, the actual tensor numel is %d.",
recv_numel, numel));
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto cpu_place = platform::CPUPlace();
auto *slr = recv_var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
int64_t width = 0;
int64_t height = 0;
std::vector<int64_t> new_rows{};
// trans sparse ids from local to global
std::vector<int64_t> abs_sections =
ToAbsoluteSection(rpc_ctx.height_sections);
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
width = var_slr->mutable_value()->dims()[1];
height += var_slr->height();
auto row_offset = abs_sections[i];
VLOG(4) << "Recv split_var " << recv_var_name << " Row size "
<< var_slr_row->size();
for (size_t j = 0; j < var_slr_row->size(); j++) {
new_rows.push_back(row_offset + (*var_slr_row)[j]);
}
}
slr->set_rows(new_rows);
slr->set_height(height);
slr->mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(slr->mutable_rows()->size()), width}),
cpu_place);
auto *slr_data = slr->mutable_value()->data<float>();
size_t row_offset = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
auto var_slr_row_size = var_slr_row->size();
auto *var_slr_data = var_slr->mutable_value()->data<float>();
memcpy(slr_data + row_offset * width, var_slr_data,
sizeof(float) * width * var_slr_row_size);
row_offset += var_slr_row_size;
}
"origin_varnames.size() >= 1 is permitted"));
if (rpc_ctx.is_sparse) {
RecvSelectedRows<T>(rpc_ctx, scope);
} else {
RecvLodTensor<T>(rpc_ctx, scope);
}
VLOG(2) << "ParameterRecv out " << rpc_ctx.var_name;
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
}
template <typename T>
void ParameterRecv<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope) {
this->operator()(rpc_ctx, scope, true);
}
template struct ParameterRecv<float>;
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
......@@ -26,7 +26,10 @@ namespace distributed {
template <typename T>
struct ParameterRecv {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool barrier);
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope);
};
}; // namespace distributed
......
......@@ -41,42 +41,67 @@ using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext(
const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) {
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldCommContext(
const CommContext &rpc_ctx, const framework::Scope &scope,
int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1");
PADDLE_ENFORCE_GE(multi_parts, 1,
platform::errors::InvalidArgument(
"multi_parts must == 1 in parameter send, now is: %d",
multi_parts));
if (multi_parts == 1) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i]));
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (int x = 0; x < multi_parts; x++) {
auto table =
string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
}
}
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_varnames[i]));
}
} else if (send_var->IsType<framework::LoDTensor>()) {
PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
} else {
PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!");
PADDLE_THROW(platform::errors::InvalidArgument(
"GetMultiFieldCommContext unsupported LoDTensor current!"));
}
return table_pairs;
} // namespace distributed
void SendByNotifyRPC(const CommContext &rpc_ctx,
const framework::Scope &scope) {
auto cpu_ctx = paddle::platform::CPUDeviceContext();
auto &send_var_name = rpc_ctx.var_name;
std::vector<distributed::VarHandlePtr> rets;
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
if (NeedSend(scope, send_var_name)) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(endpoint, cpu_ctx, scope,
send_var_name));
VLOG(4) << "send var " << send_var_name << " by notify RPC done";
}
} else {
VLOG(3) << "don't send non-initialized variable: " << rpc_ctx.var_name;
}
for (auto &handle : rets) {
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
void ParameterSend<T>::operator()(const CommContext &rpc_ctx,
const framework::Scope &scope, bool sync,
int multi_parts) {
if (rpc_ctx.var_name == STEP_COUNTER) {
SendByNotifyRPC(rpc_ctx, scope);
return;
}
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -86,11 +111,10 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_var_names.size();
size_t out_num = rpc_ctx.splited_varnames.size();
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
......@@ -110,24 +134,23 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
// create output var in local scope
size_t row_offset = 0;
for (size_t i = 0; i < out_num; ++i) {
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[i])
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[i])
->GetMutable<framework::LoDTensor>();
*out = send_tensor.Slice(row_offset, row_offset + outs_dims[i][0]);
row_offset += outs_dims[i][0];
}
} else {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_var_names[0])
framework::Tensor *out = local_scope->Var(rpc_ctx.splited_varnames[0])
->GetMutable<framework::LoDTensor>();
out->ShareDataWith(send_tensor);
}
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
for (size_t i = 0; i < rpc_ctx.splited_varnames.size(); i++) {
auto &send_var_name = rpc_ctx.splited_varnames[i];
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << " send var name: " << send_var_name
<< "endpoint: " << endpoint;
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
......@@ -135,47 +158,25 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
<< rpc_ctx.splited_varnames[i];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
auto &send_rows = send_slr.rows();
if (send_rows.size() == 0) {
LOG(WARNING) << "WARNING: The variable sent to pserver is empty, which "
LOG(WARNING)
<< "WARNING: The variable sent to pserver is empty, which "
"may cause an unknown error. Please check the state of "
"use_double_buffer in pyreader async mode, you need to "
"use_double_buffer in pyreader/dataloader async mode, you need to "
"turn it false.";
}
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts);
auto table_pairs = GetMultiFieldCommContext(rpc_ctx, scope, 1);
outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size());
......@@ -190,32 +191,77 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
outs.push_back(out);
}
if (!rpc_ctx.is_distributed) {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = send_rows[i] % pserver_num;
auto id = send_rows[i] / pserver_num;
outs_rows_idx[ep_idx].push_back(id);
outs_dense_idx[ep_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
} else {
auto pserver_num = rpc_ctx.epmap.size();
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
auto table_idx = send_rows[i] % multi_parts;
auto out_idx = ep_idx * multi_parts + table_idx;
auto out_idx = send_rows[i] % pserver_num;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) {
for (int part = 0; part < multi_parts; part++) {
auto out_idx = ctx * multi_parts + part;
for (size_t out_idx = 0; out_idx < rpc_ctx.splited_varnames.size();
out_idx++) {
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]);
outs[out_idx]->set_height(rpc_ctx.height_sections[out_idx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]);
outs[out_idx]->mutable_rows()->push_back(idx);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
......@@ -225,12 +271,15 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
PADDLE_THROW(
platform::errors::Unimplemented("do not support GPU now"));
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(),
"rows should has the same size with tensor dim 0");
PADDLE_ENFORCE_EQ(
rows_idx.size(), outs[out_idx]->rows().size(),
platform::errors::InvalidArgument(
"rows should has the same size with tensor dim 0"));
}
}
......@@ -240,8 +289,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
auto need_send = NeedSend(*local_scope.get(), send_var_name);
VLOG(4) << "send var name: " << send_var_name
<< "send var endpoint: " << endpoint
<< "need send: " << need_send;
<< " send var endpoint: " << endpoint
<< " need send: " << need_send;
if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
......@@ -251,7 +300,7 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
<< rpc_ctx.splited_varnames[i];
}
}
} else {
......@@ -262,7 +311,8 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
if (sync) {
for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
PADDLE_ENFORCE_NE(handle->Wait(), 0U, platform::errors::ExecutionTimeout(
"internal error in RPCClient"));
}
}
}
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
namespace paddle {
namespace operators {
......@@ -26,7 +26,7 @@ namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
void operator()(const CommContext &rpc_ctx, const framework::Scope &scope,
bool sync, int multi_parts);
};
......
......@@ -65,6 +65,7 @@ constexpr int64_t kPrefetchTimeout = 60000;
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define STEP_COUNTER "@PS_STEP_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -29,6 +29,7 @@
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
namespace paddle {
namespace operators {
......@@ -38,13 +39,13 @@ namespace distributed {
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
// Sync
......@@ -82,16 +83,34 @@ bool RequestSendHandler::Handle(const std::string& varname,
scope->Rename(varname, run_varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr =
auto *var = scope->FindVar(run_varname);
// for sparse ids
if (var->IsType<framework::SelectedRows>()) {
if (distributed_mode_ == DistributedMode::kAsync ||
distributed_mode_ == DistributedMode::kHalfAsync) {
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->GradInLargeScale(run_varname)) {
auto *large_scale_var = ins->GetByGrad(run_varname);
for (auto name : large_scale_var->CachedVarnames()) {
scope->Var(name);
}
}
}
if (distributed_mode_ == DistributedMode::kGeo) {
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(
run_varname)) {
auto &grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
AsyncSparseParamUpdateRecorder::GetInstance()->Update(
run_varname, grad_slr.rows());
}
}
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
......@@ -104,13 +123,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
......@@ -138,39 +157,38 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
VLOG(3) << "AsyncSparseParamUpdateRecorder " << varname << " exist ";
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
for (auto &row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
......@@ -186,13 +204,13 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
......@@ -212,77 +230,96 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
auto var_desc = program_->Block(0).FindVar(out_var_name);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(
(*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
} else {
(*outvar)->GetMutable<framework::LoDTensor>();
VLOG(1) << "Prefetch "
<< "tablename: " << table_name << " ids:" << varname
<< " out: " << out_var_name;
paddle::platform::CPUPlace cpu_place;
auto *ins = distributed::LargeScaleKV::GetInstance();
if (ins->ParamInLargeScale(table_name)) {
auto lookup_table_op = PullLargeScaleOp(table_name, varname, out_var_name);
lookup_table_op->Run(*scope, cpu_place);
} else {
auto lookup_table_op =
BuildLookupTableOp(table_name, varname, out_var_name);
paddle::platform::CPUPlace cpu_place;
lookup_table_op->Run(*scope, cpu_place);
}
return true;
}
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
// TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
const std::string &out_var_name,
const std::string &table_name) {
VLOG(4) << "receive save var " << varname << " with path " << out_var_name;
auto *ins = distributed::LargeScaleKV::GetInstance();
ins->Get(varname)->Save(out_var_name);
// auto checkpoint_op = BuildCheckpointOp(varname, out_var_name);
// paddle::platform::CPUPlace cpu_place;
// checkpoint_op->Run(*scope_, cpu_place);
return true;
}
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname;
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "RequestNotifyHandler: " << varname
<< ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece decay_piece(STEP_COUNTER);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto *send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
auto *send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
auto counter = decay_counters.at(trainer_id);
counter += send_value[0];
decay_counters.at(trainer_id) = counter;
auto *global_step_var = this->scope()->FindVar(LEARNING_RATE_DECAY_COUNTER);
if (global_step_var == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find LEARNING_RATE_DECAY_COUNTER "));
}
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());
auto global_counter = 0;
for (auto &trainer_counter : decay_counters) {
global_counter += trainer_counter.second;
}
value[0] = global_counter;
if (lr_decay_prepared_ctx_.get() == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"can not find decay block for executor"));
}
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
......
......@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
......@@ -98,6 +99,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
std::unique_ptr<paddle::framework::OperatorBase> PullLargeScaleOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
framework::OpDesc desc;
desc.SetType("lookup_sparse_table_read");
desc.SetInput("Ids", {id_name});
desc.SetOutput("Out", std::vector<std::string>({out_name}));
desc.SetAttr("tablename", {table_name});
desc.SetAttr("init", true);
desc.SetAttr("value_names", std::vector<std::string>({"Param"}));
auto op = paddle::framework::OpRegistry::CreateOp(desc);
return op;
}
std::unique_ptr<paddle::framework::OperatorBase> BuildLookupTableOp(
const std::string& table_name, const std::string& id_name,
const std::string& out_name) {
......@@ -114,11 +130,9 @@ class RequestPrefetchHandler final : public RequestHandler {
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(int distributed_mode,
int checkpoint_notify_id)
: RequestHandler(distributed_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
explicit RequestCheckpointHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
......@@ -126,14 +140,30 @@ class RequestCheckpointHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int checkpoint_notify_id;
std::unique_ptr<paddle::framework::OperatorBase> BuildCheckpointOp(
const std::string& varname, const std::string& file_path) {
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("save");
BuildVar("X", {varname.data()}, op_desc.add_inputs());
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("file_path");
attr->set_type(paddle::framework::proto::AttrType::STRING);
attr->set_s(file_path);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
return op;
}
};
class RequestNotifyHandler final : public RequestHandler {
public:
explicit RequestNotifyHandler(int distributed_mode, int lr_decay_block_id)
explicit RequestNotifyHandler(int distributed_mode, int trainers)
: RequestHandler(distributed_mode) {
this->lr_decay_block_id = lr_decay_block_id;
this->trainers = trainers;
for (int i = 0; i < trainers; i++) {
decay_counters[i] = 0;
}
}
virtual ~RequestNotifyHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
......@@ -142,7 +172,8 @@ class RequestNotifyHandler final : public RequestHandler {
const std::string& table_name = "") override;
private:
int lr_decay_block_id;
int trainers;
std::unordered_map<int, int64_t> decay_counters;
};
} // namespace distributed
......
......@@ -77,8 +77,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncCheckpointNotify(
const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
const std::string& ep, const std::string& dirname,
const std::string& varname, int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
......
......@@ -34,7 +34,7 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table);
USE_NO_KERNEL_OP(lookup_sparse_table_read);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
......@@ -46,10 +46,12 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_sparse_table");
op->SetType("lookup_sparse_table_read");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR);
......@@ -99,16 +101,10 @@ void StartServer(const std::string& rpc_name) {
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto* block = AppendPrefetchBlcok(&program);
std::string in_var_name("ids");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared;
prefetch_var_name_to_prepared[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(&prefetch_var_name_to_prepared);
......@@ -128,49 +124,6 @@ void StartServer(const std::string& rpc_name) {
server_thread.join();
}
TEST(PREFETCH, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestPrefetchHandler(
distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
client->AsyncPrefetchVar(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value->dims()[1]], static_cast<float>(i * 2));
}
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -35,19 +32,31 @@ class CheckpointNotifyOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
int trainer_id = Attr<int>("trainer_id");
std::vector<std::string> epmap =
Attr<std::vector<std::string>>("endpoints");
std::string dirname = Attr<std::string>("dirname");
std::string varname = Attr<std::string>("varname");
auto is_slice = Attr<bool>("is_slice");
VLOG(1) << "is_slice: " << is_slice;
std::vector<std::string> slice_varnames =
Attr<std::vector<std::string>>("slice_varnames");
std::vector<std::string> remote_varnames =
Attr<std::vector<std::string>>("remote_varnames");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
auto save_path =
string::Sprintf("%s/%s/%s", dirname, varname, slice_varnames[i]);
rpc_client->AsyncCheckpointNotify(epmap[i], save_path,
remote_varnames[i]);
VLOG(3) << "checkpoint notify sending with path: " << save_path
<< " and var:" << slice_varnames[i] << " to " << epmap[i];
}
PADDLE_ENFORCE_EQ(
rpc_client->Wait(), true,
......@@ -59,18 +68,22 @@ class CheckpointNotifyOp : public framework::OperatorBase {
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>(
"endpoints",
"(string vector)"
"Parameter Server endpoints in the order");
AddAttr<std::string>("dirname",
"(string) indicate the folder checkpoint will use");
AddAttr<std::string>("varname", "(string) the var need to be saved");
AddAttr<std::vector<std::string>>(
"slice_varnames", "(string vector) the slice vars need to be saved");
AddAttr<std::vector<std::string>>(
"remote_varnames", "(string vector) the slice vars need to be saved");
AddAttr<bool>(
"is_slice",
"is_slice=True means the var has been slice by parameter server");
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
......
......@@ -19,9 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
......@@ -41,6 +42,7 @@ class RecvOp : public framework::OperatorBase {
VLOG(3) << "recv do not run!";
return;
}
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
......@@ -59,10 +61,13 @@ class RecvOp : public framework::OperatorBase {
Attr<std::vector<std::string>>("recv_varnames");
if (recv_varnames.size() > 0) {
auto recv_functor = distributed::ParameterRecv<float>();
auto rpc_ctx = distributed::RpcContext(outs[0], recv_varnames, epmap, {},
trainer_id);
recv_functor(rpc_ctx, scope);
auto *communicator = distributed::Communicator::GetInstance();
if (communicator == nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"need run fleet.init_worker first"));
}
communicator->RecvNoBarrier();
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
......
......@@ -52,8 +52,6 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (remote_prefetch && !epmap.empty()) {
......@@ -62,8 +60,8 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, embedding_name, false,
table_names, epmap, height_sections,
context, context.scope());
table_names, epmap, context,
context.scope());
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册