未验证 提交 416922e2 编写于 作者: T tangwei12 提交者: GitHub

Distributed training cherry-pick for Release 1.5 (#19486)

* fix bug in Class MultiSlotDataGenerator's function _gen_str, test=develop (#18222)
* fix some bug when merge sparse embedding parameters, test=develop (#18223)
* fix communicator with pyreader (#18350)
* delete AllocatorFacade destructor  (#18606)
* fix distribute transpiler GRPC error code 4, RPC Deadline (#18984)
* merge pr #18441
上级 0edeb838
...@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope); if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::GetInstance()->Start(); operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
} }
#endif #endif
} }
......
...@@ -133,13 +133,6 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { ...@@ -133,13 +133,6 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
VLOG(1) << "set recv op do_not_run to true"; VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", 1); node->Op()->SetAttr("do_not_run", 1);
node->Op()->Flush(); node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
} }
return false; return false;
} }
......
...@@ -248,6 +248,8 @@ class ExecutionContext { ...@@ -248,6 +248,8 @@ class ExecutionContext {
return op_.Attr<T>(name); return op_.Attr<T>(name);
} }
bool HasAttr(const std::string& name) const { return op_.HasAttr(name); }
bool HasInput(const std::string& name) const; bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const; bool HasOutput(const std::string& name) const;
......
...@@ -295,7 +295,9 @@ class AllocatorFacadePrivate { ...@@ -295,7 +295,9 @@ class AllocatorFacadePrivate {
// Pimpl. Make interface clean. // Pimpl. Make interface clean.
AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {} AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {}
AllocatorFacade::~AllocatorFacade() { delete m_; } // delete m_ may cause core dump when the destructor of python in conflict with
// cpp.
AllocatorFacade::~AllocatorFacade() {}
AllocatorFacade& AllocatorFacade::Instance() { AllocatorFacade& AllocatorFacade::Instance() {
static AllocatorFacade instance; static AllocatorFacade instance;
......
...@@ -73,14 +73,26 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, ...@@ -73,14 +73,26 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_max_merge_var_num: " VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num; << FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc; VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) { if (send_varname_to_ctx.size() == 0) {
send_varname_to_queue_[iter.first] = VLOG(0) << "nothing need to be send, will not start send_thread";
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>( } else {
FLAGS_communicator_send_queue_size); send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
}
send_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
if (recv_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be received, will not start recv_thread";
} else {
recv_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
} }
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
} }
Communicator::~Communicator() { Communicator::~Communicator() {
...@@ -157,18 +169,28 @@ void Communicator::SendThread() { ...@@ -157,18 +169,28 @@ void Communicator::SendThread() {
task_f.wait(); task_f.wait();
} }
auto after_run_send_graph = GetCurrentUS(); auto after_run_send_graph = GetCurrentUS();
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
if (send_graph_use_time > 100) { VLOG(3) << "run send graph use time "
VLOG(1) << "run send graph use time " << after_run_send_graph - before_run_send_graph;
<< after_run_send_graph - before_run_send_graph; RecvNonIndependent();
}
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll();
}
} }
VLOG(0) << "communicator stopped, send thread exit"; VLOG(0) << "communicator stopped, send thread exit";
} }
void Communicator::RecvNonIndependent() {
if (!FLAGS_communicator_independent_recv_thread) {
return;
}
auto grad_num = grad_num_.load();
if (grad_num > 0) {
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
void Communicator::RecvAll() { void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph"; VLOG(3) << "parallel run recv graph";
if (!running_) return; if (!running_) return;
......
...@@ -167,12 +167,15 @@ class Communicator { ...@@ -167,12 +167,15 @@ class Communicator {
void Start(); void Start();
void Stop(); void Stop();
bool IsRunning() { return running_; }
// send grad // send grad
void Send(const std::string& var_name, const framework::Scope& scope); void Send(const std::string& var_name, const framework::Scope& scope);
private: private:
// recv all parameter // recv all parameter
void RecvAll(); void RecvAll();
void RecvNonIndependent();
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();
......
...@@ -12,10 +12,12 @@ ...@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h" #include "paddle/fluid/operators/distributed/parameter_prefetch.h"
...@@ -78,45 +80,64 @@ static void SplitIdsIntoMultipleVarsBySection( ...@@ -78,45 +80,64 @@ static void SplitIdsIntoMultipleVarsBySection(
} }
} }
static void MergeMultipleVarsIntoOneBySection( typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
const std::string& id_name, const std::vector<int64_t>& ids_vector,
const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
auto cpu_place = platform::CPUPlace(); void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::Scope& scope,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
auto abs_sections = ToAbsoluteSection(height_section); std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::unordered_map<int64_t, std::vector<size_t>> id_to_offset;
for (size_t i = 0; i < ids_vector.size(); ++i) { std::vector<std::string> in_var_names;
id_to_offset[ids_vector[i]].push_back(i); std::vector<std::string> out_var_names;
for (size_t i = 0; i < tables.size(); ++i) {
in_var_names.push_back("prefetch_send@" + tables[i].second);
out_var_names.push_back("prefetch_recv@" + tables[i].second);
} }
auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>(); auto splited_ids = SplitIds(ids, height_sections);
auto* out_tensor = SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
scope->FindVar(out_name)->GetMutable<framework::LoDTensor>(); local_scope.get());
// create output var in local scope
for (auto& name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
}
PADDLE_ENFORCE_GT( distributed::RPCClient* rpc_client =
out_tensor->numel(), 0, distributed::RPCClient::GetInstance<RPCCLIENT_T>(
"When calling this method, the LoDTensor's numel must larger than zero. " context.Attr<int>("trainer_id"));
"Please check LoDTensor::Resize has been called first.");
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place()); std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << tables[i].second
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
tables[i].second, actual_ctx, *local_scope.get(), in_var_names[i],
out_var_names[i], tables[i].first));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
}
bool is_on_cpu_place = true; for (size_t i = 0; i < rets.size(); i++) {
if (!platform::is_cpu_place(id_tensor.place())) { PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
is_on_cpu_place = false;
} }
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), "");
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size(); for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) { ++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx]; auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) { if (!ids_in_this_section.empty()) {
auto& prefetch_out_var = auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx])
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>(); ->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>(); const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims(); auto& dims = prefetch_out_var.dims();
...@@ -128,26 +149,9 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -128,26 +149,9 @@ static void MergeMultipleVarsIntoOneBySection(
for (int64_t i = 0; i < dims[0]; ++i) { for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i]; auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx]; auto origin_id = id + abs_sections[section_idx];
auto& offsets = id_to_offset[origin_id]; std::vector<float> vecs(row_numel);
for (auto& offset : offsets) { std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
// should support GPU tensor (*recved_vec_map)[origin_id] = vecs;
if (is_on_cpu_place) {
memory::Copy(cpu_place, out_tensor_data + offset * row_numel,
cpu_place, out_var_data + i * row_numel,
sizeof(float) * row_numel);
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(id_tensor.place()),
out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * row_numel,
sizeof(float) * row_numel, stream);
#endif
}
}
} }
} else { } else {
VLOG(3) << "ids in this section is empty"; VLOG(3) << "ids in this section is empty";
...@@ -156,84 +160,107 @@ static void MergeMultipleVarsIntoOneBySection( ...@@ -156,84 +160,107 @@ static void MergeMultipleVarsIntoOneBySection(
} }
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope) { const framework::Scope& scope) {
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope(); prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names,
endpoints, height_sections, context, scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); }
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
auto& actual_ctx = *pool.Get(context.GetPlace());
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
std::vector<std::string> in_var_names; void prefetchs(const std::vector<std::string>& id_var_names,
std::vector<std::string> out_var_names; const std::vector<std::string>& out_var_names,
for (size_t i = 0; i < epmap.size(); ++i) { const std::string& persistable_var_name, const bool backfill,
in_var_names.push_back(id_name + "@" + epmap[i]); const std::vector<std::string>& table_names,
out_var_names.push_back(out_name + "@" + epmap[i]); const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
auto* reconstruct_var =
scope.FindVar(persistable_var_name)->GetMutable<framework::LoDTensor>();
const auto vec_dim_1 = reconstruct_var->dims()[1];
const auto place =
scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("multi prefetch only support CPU currently");
} }
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>(); std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_vector; std::vector<int64_t> ids_union;
if (platform::is_cpu_place(id_tensor.place())) { std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables;
for (auto& id_name : id_var_names) {
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* id_data = id_tensor.data<int64_t>(); auto* id_data = id_tensor.data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) { for (int64_t i = 0; i < id_tensor.numel(); ++i) {
ids_vector.push_back(id_data[i]); ids.push_back(id_data[i]);
} ids_union.push_back(id_data[i]);
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto cpu_place = platform::CPUPlace();
framework::LoDTensor cpu_tensor;
auto* cpu_tensor_data =
cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(cpu_place, cpu_tensor_data,
boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
stream);
for (int64_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]);
} }
#endif ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod());
} }
auto splited_ids = SplitIds(ids_vector, height_sections); std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids, ids_union.assign(s.begin(), s.end());
local_scope.get());
// create output var in local scope for (int i; i < table_names.size(); i++) {
for (auto& name : out_var_names) { tables.push_back(std::make_pair(table_names[i], endpoints[i]));
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
} }
std::vector<distributed::VarHandlePtr> rets; std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
for (size_t i = 0; i < in_var_names.size(); i++) { prefetch_core(ids_union, tables, height_sections, context, scope,
if (NeedSend(*local_scope.get(), in_var_names[i])) { &recved_vec_map);
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back"; auto padding_idx = distributed::kNoPadding;
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, *local_scope.get(), in_var_names[i], if (context.HasAttr("padding_idx")) {
out_var_names[i], table_names[i])); padding_idx = context.Attr<int64_t>("padding_idx");
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
} }
for (size_t i = 0; i < rets.size(); i++) { // copy vectors to out vars
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); for (int i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i];
auto* out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto* out_d = out_t->mutable_data<float>(place);
for (int idx = 0; idx < ids.size(); idx++) {
const auto& id = ids[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else {
std::copy_n(recved_vec_map[id].begin(), vec_dim_1,
out_d + idx * vec_dim_1);
}
}
} }
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name, if (backfill) {
out_var_names, height_sections, splited_ids, VLOG(3) << "backfill persistable var's id with vecs";
context, local_scope.get(), &actual_ctx);
auto* reconstruct_d = reconstruct_var->data<float>();
for (auto& id : ids_union) {
std::copy(recved_vec_map[id].begin(), recved_vec_map[id].end(),
reconstruct_d + id * vec_dim_1);
}
}
} }
}; // namespace distributed }; // namespace distributed
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -23,61 +24,25 @@ namespace paddle { ...@@ -23,61 +24,25 @@ namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
constexpr int64_t kNoPadding = -1;
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::ExecutionContext& context,
const framework::Scope& scope); const framework::Scope& scope);
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope,
framework::LoDTensor* original) {
prefetch(id_name, out_name, table_names, epmap, height_sections, context,
scope);
auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* original_value = original->data<T>();
auto* out_value = out.data<T>();
size_t original_width = original->numel() / original->dims()[0];
bool is_on_cpu_place = true;
if (!platform::is_cpu_place(ids.place())) {
is_on_cpu_place = false;
}
if (is_on_cpu_place) {
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row =
original_value + original_width * ids.data<int64_t>()[i];
std::memcpy(original_row, out_rows, original_width * sizeof(T));
}
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row =
original_value + original_width * ids.data<int64_t>()[i];
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
platform::CPUPlace(), out_rows, original_width * sizeof(T),
stream);
}
#endif
}
}
}; // namespace distributed }; // namespace distributed
}; // namespace operators }; // namespace operators
}; // namespace paddle }; // namespace paddle
...@@ -116,42 +116,7 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -116,42 +116,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name; VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
} }
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && *outvar = scope_->FindVar(varname);
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (auto i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
} else {
*outvar = scope_->FindVar(varname);
}
} }
} }
return true; return true;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class DistributedLookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
"Output(Outs) of LookupTableOp should not be null.");
auto ids_dims = ctx->GetInputsDim("Ids");
auto table_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported.");
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2.");
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
}
auto lookup_tables =
ctx->Attrs().Get<std::vector<std::string>>("table_names");
auto height_sections =
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
lookup_tables.size() == endpoints.size() &&
lookup_tables.size() != 0,
"Attrs lookup_tables/height_sections/endpoints must have "
"save size and can not be 0.");
auto outputs_dims = std::vector<framework::DDim>();
for (auto &ids_dim : ids_dims) {
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
}
ctx->SetOutputsDim("Outputs", outputs_dims);
ctx->ShareLoD("Ids", /*->*/ "Outputs");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
auto id_names = context.Inputs("Ids");
auto embedding_name = context.Inputs("W").front();
auto out_names = context.Outputs("Outputs");
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
operators::distributed::prefetchs(
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
height_sections, context, context.scope());
}
};
class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.")
.AsDuplicable();
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddOutput("Outputs",
"(LoDTensor) The lookup results, which have the same type as W.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, such as emb_block0, emb_block1)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({""});
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>(
"endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(distributed::kNoPadding);
AddComment(R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp,
ops::DistributedLookupTableOpMaker);
REGISTER_OP_CPU_KERNEL(distributed_lookup_table,
ops::DistributedLookupTableKernel<float>);
...@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>( distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id")); Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep; VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep); rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames = std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames"); Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out"); auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier"); bool with_barrier = Attr<bool>("with_barrier");
...@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase {
trainer_id); trainer_id);
recv_functor(rpc_ctx, scope); recv_functor(rpc_ctx, scope);
} else { } else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) { if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
...@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase {
rets.push_back( rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i])); rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
} }
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else { } else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i]; std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with " VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
...@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase { ...@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope, rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i])); varname, outs[i]));
} }
for (size_t i = 0; i < rets.size(); i++) { }
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); for (size_t i = 0; i < rets.size(); i++) {
} VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
} }
} }
} }
...@@ -112,10 +108,6 @@ This operator can get variables from server side. ...@@ -112,10 +108,6 @@ This operator can get variables from server side.
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier", AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use " "(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately") "AsyncGetVarNoBarrier get variable from pserver immediately")
......
...@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync"; VLOG(3) << "SendBarrierOp sync";
// need to wait before sending send_barrier message std::vector<distributed::VarHandlePtr> rets;
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
} }
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
} }
}; };
......
...@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase { ...@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X"); auto ins = Inputs("X");
auto epmap = Attr<std::vector<std::string>>("epmap"); auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id"); auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames"); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
...@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase { ...@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i]; VLOG(3) << "don't send no-initialied variable: " << ins[i];
} }
} }
if (sync_send) { for (size_t i = 0; i < rets.size(); i++) {
for (size_t i = 0; i < rets.size(); i++) { VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i]; PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient"); VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
} }
} }
} }
...@@ -98,10 +95,6 @@ Send operator ...@@ -98,10 +95,6 @@ Send operator
This operator will send variables to listen_and_serve op at the parameter server. This operator will send variables to listen_and_serve op at the parameter server.
)DOC"); )DOC");
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -97,10 +97,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -97,10 +97,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// w_Out is set to used by prefetch, never change it in other cases // w_Out is set to used by prefetch, never change it in other cases
auto* w_out = ctx.Output<framework::LoDTensor>("W_Out"); auto weight = ctx.Outputs("W_Out").front();
operators::distributed::prefetch_with_reconstruct<T>( operators::distributed::prefetch("Ids@Prefetch", "W@Prefetch", weight,
"Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections, true, table_names, epmap,
ctx, local_scope, w_out); height_sections, ctx, local_scope);
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
......
...@@ -82,46 +82,27 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -82,46 +82,27 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto id_name = context.Inputs("Ids").front(); auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front(); auto out_name = context.Outputs("Out").front();
// for remote prefetch size_t N = table_t->dims()[0];
auto epmap = context.Attr<std::vector<std::string>>("epmap"); size_t D = table_t->dims()[1];
auto height_sections = size_t K = ids_t->numel();
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
if (!epmap.empty()) { auto *output = output_t->mutable_data<T>(context.GetPlace());
// if epmap is not empty, then the parameter will be fetched from remote
// parameter dim3 threads(128, 8);
// server dim3 grids(8, 1);
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap, if (padding_idx == -1)
height_sections, context, LookupTable<
context.scope()); T, 128, 8, 8,
#else false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
PADDLE_THROW( output, table, ids, N, K, D, padding_idx);
"paddle is not compiled with distribute support, can not do " else
"parameter prefetch!"); LookupTable<
#endif T, 128, 8, 8,
} else { true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
size_t N = table_t->dims()[0]; output, table, ids, N, K, D, padding_idx);
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<T, 128, 8, 8, false><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<T, 128, 8, 8, true><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
} }
}; };
......
...@@ -46,6 +46,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -46,6 +46,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *table_var = context.InputVar("W"); auto *table_var = context.InputVar("W");
auto id_name = context.Inputs("Ids").front(); auto id_name = context.Inputs("Ids").front();
auto embedding_name = context.Inputs("W").front();
auto out_name = context.Outputs("Out").front(); auto out_name = context.Outputs("Out").front();
// for remote prefetch // for remote prefetch
...@@ -57,12 +58,12 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -57,12 +58,12 @@ class LookupTableKernel : public framework::OpKernel<T> {
if (remote_prefetch && !epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter server
// server
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap, operators::distributed::prefetch(id_name, out_name, embedding_name, false,
height_sections, context, table_names, epmap, height_sections,
context.scope()); context, context.scope());
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
......
...@@ -195,9 +195,10 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -195,9 +195,10 @@ class NCEKernel : public framework::OpKernel<T> {
w_tensor->Resize(framework::make_ddim(w_dims)); w_tensor->Resize(framework::make_ddim(w_dims));
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
auto weight = context.Inputs("Weight").front();
operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch", operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch",
table_names, epmap, height_sections, weight, false, table_names, epmap,
context, local_scope); height_sections, context, local_scope);
#else #else
PADDLE_THROW( PADDLE_THROW(
"paddle is not compiled with distribute support, can not do " "paddle is not compiled with distribute support, can not do "
......
...@@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel<T> { ...@@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
void SaveSelectedRows(const framework::ExecutionContext &ctx, void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place, const platform::Place &place,
const framework::Variable *var) const { const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
auto file_path = ctx.Attr<std::string>("file_path"); auto file_path = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite"); auto overwrite = ctx.Attr<bool>("overwrite");
std::string filename = file_path; std::string filename = file_path;
VLOG(4) << "SaveSelectedRows output file_path: " << file_path;
framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH);
if (out_put_var != nullptr) { if (out_put_var != nullptr) {
auto *lt_var = out_put_var->GetMutable<std::string>(); auto *lt_var = out_put_var->GetMutable<std::string>();
filename = *lt_var; if (lt_var->length() > 0) {
VLOG(4) << "SaveSelectedRows output var name: " << *lt_var;
filename = *lt_var;
}
} }
if (FileExists(filename) && !overwrite) { if (FileExists(filename) && !overwrite) {
......
...@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) { ...@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) {
return Communicator::GetInstantcePtr(); return Communicator::GetInstantcePtr();
})) }))
.def("stop", &Communicator::Stop) .def("stop", &Communicator::Stop)
.def("start", &Communicator::Start); .def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning);
} }
} // namespace pybind } // namespace pybind
......
...@@ -86,3 +86,21 @@ class Communicator(object): ...@@ -86,3 +86,21 @@ class Communicator(object):
comm.stop() comm.stop()
""" """
self.communicator_.stop() self.communicator_.stop()
def is_running(self):
"""
Get communicator is running or stop.
Returns:
bool
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.is_running()
"""
self.communicator_.is_running()
...@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program, ...@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program,
}) })
sums.append(param_var) sums.append(param_var)
global_block.append_op( global_block.append_op(
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={}) type='merge_sparse_lookup_table',
inputs={"X": sums},
outputs={'Out': emb_var},
attrs={})
global_block.append_op(
type='save',
inputs={"X": [emb_var]},
outputs={},
attrs={
'file_path': os.path.join(lookup_table_dirname, emb_var.name)
})
global_block.append_op(type='delete_var', inputs={'X': sums}) global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program) executor.run(convert_program)
......
...@@ -312,7 +312,7 @@ class MultiSlotDataGenerator(DataGenerator): ...@@ -312,7 +312,7 @@ class MultiSlotDataGenerator(DataGenerator):
) )
if name != self._proto_info[index][0]: if name != self._proto_info[index][0]:
raise ValueError( raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>." "the field name of two given line are not match: require<%s>, get<%s>."
% (self._proto_info[index][0], name)) % (self._proto_info[index][0], name))
if output: if output:
output += " " output += " "
......
...@@ -158,7 +158,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -158,7 +158,7 @@ class MPIRoleMaker(RoleMakerBase):
""" """
finalize the current MPI instance. finalize the current MPI instance.
""" """
pass self.MPI.Finalize()
def _get_ips(self): def _get_ips(self):
""" """
...@@ -316,6 +316,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -316,6 +316,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints) print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",") self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER": if self.role.upper() == "PSERVER":
self._current_id = self.endpoints.index(self.current_endpoint) self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER self._role = Role.SERVER
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import warnings
import paddle.fluid.io as io import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator from paddle.fluid.communicator import Communicator
...@@ -25,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo ...@@ -25,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class DistributedTranspiler(Fleet): class DistributedTranspiler(Fleet):
...@@ -51,9 +53,20 @@ class DistributedTranspiler(Fleet): ...@@ -51,9 +53,20 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
# if MPISymetricRoleMaker is defined
# we suppose a user wants to submit job on mpi cluster
if isinstance(self._role_maker, MPISymetricRoleMaker):
# check whether server has been initialized
from paddle.fluid.transpiler.details.checkport import wait_server_ready
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode: if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program) self._communicator = Communicator(self.main_program)
self._communicator.start()
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
def init_server(self, model_dir=None): def init_server(self, model_dir=None):
""" """
...@@ -104,10 +117,14 @@ class DistributedTranspiler(Fleet): ...@@ -104,10 +117,14 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
if not self._transpile_config.sync_mode: if not self._transpile_config.sync_mode and self._communicator.is_running(
):
self._communicator.stop() self._communicator.stop()
self._executor.close() self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize()
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
Optimizer for distributed training. Optimizer for distributed training.
...@@ -193,13 +210,23 @@ class DistributedTranspiler(Fleet): ...@@ -193,13 +210,23 @@ class DistributedTranspiler(Fleet):
self._transpile_config = config self._transpile_config = config
self._transpiler = OriginTranspiler(config) self._transpiler = OriginTranspiler(config)
print("server endpoints")
print(fleet.server_endpoints(to_string=True))
print("worker index: %d" % fleet.worker_index())
print("worker num: %d" % fleet.worker_num())
if self.is_worker(): if self.is_worker():
self._transpiler.transpile( self._transpiler.transpile(
trainer_id=fleet.worker_index(), trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True), pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(), trainers=fleet.worker_num(),
sync_mode=config.sync_mode) sync_mode=config.sync_mode)
self.main_program = self._transpiler.get_trainer_program()
if isinstance(self._role_maker, MPISymetricRoleMaker):
config.wait_port = False
self.main_program = self._transpiler.get_trainer_program(
wait_port=config.wait_port)
self.startup_program = default_startup_program() self.startup_program = default_startup_program()
else: else:
self._transpiler.transpile( self._transpiler.transpile(
......
...@@ -88,13 +88,21 @@ def is_persistable(var): ...@@ -88,13 +88,21 @@ def is_persistable(var):
def _clone_var_in_block_(block, var): def _clone_var_in_block_(block, var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
return block.create_var( if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
name=var.name, return block.create_var(
shape=var.shape, name=var.name,
dtype=var.dtype, shape=var.shape,
type=var.type, dtype=var.dtype,
lod_level=var.lod_level, type=var.type,
persistable=True) lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def save_vars(executor, def save_vars(executor,
......
...@@ -15,12 +15,51 @@ ...@@ -15,12 +15,51 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import time
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator from paddle.fluid.communicator import Communicator
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
class TestCommunicator(unittest.TestCase): class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def test_communicator_init_and_start(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(0.01)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = True
strategy.wait_port = False
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
comm = Communicator(fleet.main_program)
comm.start()
time.sleep(10)
comm.stop()
class TestCommunicator2(unittest.TestCase):
def test_communicator_init_and_start(self): def test_communicator_init_and_start(self):
prog = fluid.Program() prog = fluid.Program()
comm = Communicator(prog) comm = Communicator(prog)
......
...@@ -18,6 +18,18 @@ import unittest ...@@ -18,6 +18,18 @@ import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
def skip_ci(func):
on_ci = bool(int(os.environ.get("SKIP_UNSTABLE_CI", '0')))
def __func__(*args, **kwargs):
if on_ci:
return
return func(*args, **kwargs)
return __func__
@skip_ci
class TestDistCTR2x2(TestDistBase): class TestDistCTR2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -27,6 +39,7 @@ class TestDistCTR2x2(TestDistBase): ...@@ -27,6 +39,7 @@ class TestDistCTR2x2(TestDistBase):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
@skip_ci
class TestDistCTRWithL2Decay2x2(TestDistBase): class TestDistCTRWithL2Decay2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -37,7 +50,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase): ...@@ -37,7 +50,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
self.check_with_place( self.check_with_place(
"dist_ctr.py", "dist_ctr.py",
delta=1e-7, delta=1e-7,
check_error_log=False, check_error_log=True,
need_envs=need_envs) need_envs=need_envs)
......
...@@ -19,6 +19,18 @@ import unittest ...@@ -19,6 +19,18 @@ import unittest
from test_dist_fleet_base import TestFleetBase from test_dist_fleet_base import TestFleetBase
def skip_ci(func):
on_ci = bool(int(os.environ.get("SKIP_UNSTABLE_CI", '0')))
def __func__(*args, **kwargs):
if on_ci:
return
return func(*args, **kwargs)
return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase): class TestDistMnist2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
......
...@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase ...@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
class TestDistW2V2x2(TestDistBase): class TestDistW2V2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4) self.check_with_place("dist_word2vec.py", delta=1e-4)
...@@ -29,6 +30,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase): ...@@ -29,6 +30,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._mem_opt = True self._mem_opt = True
self._enforce_place = "CPU"
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4) self.check_with_place("dist_word2vec.py", delta=1e-4)
...@@ -37,6 +39,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase): ...@@ -37,6 +39,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase):
class TestDistW2V2x2Async(TestDistBase): class TestDistW2V2x2Async(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU"
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=100) self.check_with_place("dist_word2vec.py", delta=100)
......
...@@ -185,8 +185,6 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -185,8 +185,6 @@ class TestListenAndServOp(unittest.TestCase):
port1 = self._get_pserver_port(p1.pid) port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places: for place in places:
self._run_lookup_table_op_one_pserver(place, port0) self._run_lookup_table_op_one_pserver(place, port0)
......
...@@ -314,14 +314,49 @@ class DistributeTranspiler(object): ...@@ -314,14 +314,49 @@ class DistributeTranspiler(object):
sparse_update_ops.append(op) sparse_update_ops.append(op)
return sparse_update_ops return sparse_update_ops
def _update_remote_sparse_update_op(self, param_varname, height_sections, def _update_remote_sparse_update_op(self, program, param_varname,
endpint_map, table_names): height_sections, endpoints,
table_names):
ops = []
op_type = ""
for op in self.sparse_update_ops: for op in self.sparse_update_ops:
if param_varname in op.input_arg_names: if param_varname in op.input_arg_names and op_type == "":
op._set_attr('epmap', endpint_map) op_type = op.type
op._set_attr('table_names', table_names) ops.append(op)
op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id) elif param_varname in op.input_arg_names and op_type == op.type:
ops.append(op)
if op_type == "lookup_table":
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
program.global_block()._insert_op(
index=op_idxs[0],
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
def _is_input_of_remote_sparse_update_op(self, param_name): def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops: for op in self.sparse_update_ops:
...@@ -456,17 +491,12 @@ class DistributeTranspiler(object): ...@@ -456,17 +491,12 @@ class DistributeTranspiler(object):
splited_grad_varname = splited_vars[0].name splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg( index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True) program.global_block(), splited_grad_varname, reverse=True)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[
grad_varname]
if self._is_input_of_remote_sparse_update_op(
sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = [splited_vars[0].shape[0]]
elif len(splited_vars) > 1: elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname] orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg( index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True) program.global_block(), splited_grad_varname, reverse=True)
if not self.config.runtime_split_send_recv: if not self.config.runtime_split_send_recv:
self._insert_split_op(program, orig_var, index, self._insert_split_op(program, orig_var, index,
splited_vars) splited_vars)
...@@ -475,6 +505,13 @@ class DistributeTranspiler(object): ...@@ -475,6 +505,13 @@ class DistributeTranspiler(object):
AssertionError("Can not insert the send op by original " AssertionError("Can not insert the send op by original "
"variable name :", splited_grad_varname) "variable name :", splited_grad_varname)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[grad_varname]
if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[sparse_param_name] = [
splited_var.shape[0] for splited_var in splited_vars
]
dummy_output = program.global_block().create_var( dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
...@@ -507,8 +544,7 @@ class DistributeTranspiler(object): ...@@ -507,8 +544,7 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname], self.grad_name_to_param_name[grad_varname],
splited_grad_varname splited_grad_varname
], ]
"sync_mode": not self.sync_mode,
}) })
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
send_vars.append(var) send_vars.append(var)
...@@ -528,7 +564,6 @@ class DistributeTranspiler(object): ...@@ -528,7 +564,6 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out}, outputs={"Out": send_barrier_out},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -574,7 +609,6 @@ class DistributeTranspiler(object): ...@@ -574,7 +609,6 @@ class DistributeTranspiler(object):
recv_op_role_var_name = splited_trainer_grad[0].name recv_op_role_var_name = splited_trainer_grad[0].name
if param_varname in self.sparse_param_to_height_sections: if param_varname in self.sparse_param_to_height_sections:
for table_name in table_names: for table_name in table_names:
distributed_var = self.vars_overview.get_distributed_var_by_slice( distributed_var = self.vars_overview.get_distributed_var_by_slice(
table_name) table_name)
...@@ -583,7 +617,7 @@ class DistributeTranspiler(object): ...@@ -583,7 +617,7 @@ class DistributeTranspiler(object):
height_sections = self.sparse_param_to_height_sections[ height_sections = self.sparse_param_to_height_sections[
param_varname] param_varname]
self._update_remote_sparse_update_op( self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names) program, param_varname, height_sections, eps, table_names)
else: else:
recv_varnames = [] recv_varnames = []
if self.config.runtime_split_send_recv: if self.config.runtime_split_send_recv:
...@@ -602,8 +636,7 @@ class DistributeTranspiler(object): ...@@ -602,8 +636,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name], [param_varname, recv_op_role_var_name]
"sync_mode": not self.sync_mode
}) })
if self.sync_mode: if self.sync_mode:
...@@ -1481,7 +1514,6 @@ class DistributeTranspiler(object): ...@@ -1481,7 +1514,6 @@ class DistributeTranspiler(object):
if self.sync_mode else [] if self.sync_mode else []
}, },
attrs={ attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册