未验证 提交 416922e2 编写于 作者: T tangwei12 提交者: GitHub

Distributed training cherry-pick for Release 1.5 (#19486)

* fix bug in Class MultiSlotDataGenerator's function _gen_str, test=develop (#18222)
* fix some bug when merge sparse embedding parameters, test=develop (#18223)
* fix communicator with pyreader (#18350)
* delete AllocatorFacade destructor  (#18606)
* fix distribute transpiler GRPC error code 4, RPC Deadline (#18984)
* merge pr #18441
上级 0edeb838
......@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
}
#endif
}
......
......@@ -133,13 +133,6 @@ class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", 1);
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
......
......@@ -248,6 +248,8 @@ class ExecutionContext {
return op_.Attr<T>(name);
}
bool HasAttr(const std::string& name) const { return op_.HasAttr(name); }
bool HasInput(const std::string& name) const;
bool HasOutput(const std::string& name) const;
......
......@@ -295,7 +295,9 @@ class AllocatorFacadePrivate {
// Pimpl. Make interface clean.
AllocatorFacade::AllocatorFacade() : m_(new AllocatorFacadePrivate()) {}
AllocatorFacade::~AllocatorFacade() { delete m_; }
// delete m_ may cause core dump when the destructor of python in conflict with
// cpp.
AllocatorFacade::~AllocatorFacade() {}
AllocatorFacade& AllocatorFacade::Instance() {
static AllocatorFacade instance;
......
......@@ -73,14 +73,26 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
VLOG(0) << "communicator_max_merge_var_num: "
<< FLAGS_communicator_max_merge_var_num;
VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
if (send_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be send, will not start send_thread";
} else {
send_scope_.reset(new Scope());
for (auto &iter : send_varname_to_ctx_) {
send_varname_to_queue_[iter.first] =
std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
FLAGS_communicator_send_queue_size);
}
send_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
if (recv_varname_to_ctx.size() == 0) {
VLOG(0) << "nothing need to be received, will not start recv_thread";
} else {
recv_threadpool_.reset(
new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}
Communicator::~Communicator() {
......@@ -157,18 +169,28 @@ void Communicator::SendThread() {
task_f.wait();
}
auto after_run_send_graph = GetCurrentUS();
auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
if (send_graph_use_time > 100) {
VLOG(1) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
}
if (!FLAGS_communicator_independent_recv_thread) {
RecvAll();
}
VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
RecvNonIndependent();
}
VLOG(0) << "communicator stopped, send thread exit";
}
void Communicator::RecvNonIndependent() {
if (!FLAGS_communicator_independent_recv_thread) {
return;
}
auto grad_num = grad_num_.load();
if (grad_num > 0) {
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
if (!running_) return;
......
......@@ -167,12 +167,15 @@ class Communicator {
void Start();
void Stop();
bool IsRunning() { return running_; }
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
private:
// recv all parameter
void RecvAll();
void RecvNonIndependent();
void SendThread();
void RecvThread();
......
......@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
......@@ -78,45 +80,64 @@ static void SplitIdsIntoMultipleVarsBySection(
}
}
static void MergeMultipleVarsIntoOneBySection(
const std::string& id_name, const std::vector<int64_t>& ids_vector,
const std::string& out_name, const std::vector<std::string>& out_var_names,
const std::vector<int64_t>& height_section,
const std::vector<std::vector<int64_t>>& splited_ids,
const framework::ExecutionContext& context, framework::Scope* scope,
platform::DeviceContext* actual_ctx) {
PADDLE_ENFORCE_EQ(out_var_names.size(), height_section.size(), "");
typedef std::vector<std::pair<std::string, std::string>> TableAndEndpoints;
auto cpu_place = platform::CPUPlace();
void prefetch_core(
const std::vector<int64_t>& ids, const TableAndEndpoints& tables,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context, const framework::Scope& scope,
std::unordered_map<int64_t, std::vector<float>>* recved_vec_map) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
auto abs_sections = ToAbsoluteSection(height_section);
std::unordered_map<int64_t, std::vector<size_t>> id_to_offset;
for (size_t i = 0; i < ids_vector.size(); ++i) {
id_to_offset[ids_vector[i]].push_back(i);
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < tables.size(); ++i) {
in_var_names.push_back("prefetch_send@" + tables[i].second);
out_var_names.push_back("prefetch_recv@" + tables[i].second);
}
auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
auto* out_tensor =
scope->FindVar(out_name)->GetMutable<framework::LoDTensor>();
auto splited_ids = SplitIds(ids, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope.get());
// create output var in local scope
for (auto& name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
}
PADDLE_ENFORCE_GT(
out_tensor->numel(), 0,
"When calling this method, the LoDTensor's numel must larger than zero. "
"Please check LoDTensor::Resize has been called first.");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << tables[i].second
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
tables[i].second, actual_ctx, *local_scope.get(), in_var_names[i],
out_var_names[i], tables[i].first));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
}
bool is_on_cpu_place = true;
if (!platform::is_cpu_place(id_tensor.place())) {
is_on_cpu_place = false;
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
PADDLE_ENFORCE_EQ(out_var_names.size(), height_sections.size(), "");
auto abs_sections = ToAbsoluteSection(height_sections);
for (size_t section_idx = 0; section_idx < out_var_names.size();
++section_idx) {
auto& ids_in_this_section = splited_ids[section_idx];
if (!ids_in_this_section.empty()) {
auto& prefetch_out_var =
scope->Var(out_var_names[section_idx])->Get<framework::LoDTensor>();
auto& prefetch_out_var = local_scope->Var(out_var_names[section_idx])
->Get<framework::LoDTensor>();
const auto* out_var_data = prefetch_out_var.data<float>();
auto& dims = prefetch_out_var.dims();
......@@ -128,26 +149,9 @@ static void MergeMultipleVarsIntoOneBySection(
for (int64_t i = 0; i < dims[0]; ++i) {
auto id = ids_in_this_section[i];
auto origin_id = id + abs_sections[section_idx];
auto& offsets = id_to_offset[origin_id];
for (auto& offset : offsets) {
// should support GPU tensor
if (is_on_cpu_place) {
memory::Copy(cpu_place, out_tensor_data + offset * row_numel,
cpu_place, out_var_data + i * row_numel,
sizeof(float) * row_numel);
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto stream =
static_cast<platform::CUDADeviceContext*>(actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(id_tensor.place()),
out_tensor_data + offset * row_numel, cpu_place,
out_var_data + i * row_numel,
sizeof(float) * row_numel, stream);
#endif
}
}
std::vector<float> vecs(row_numel);
std::copy_n(out_var_data + i * row_numel, row_numel, vecs.begin());
(*recved_vec_map)[origin_id] = vecs;
}
} else {
VLOG(3) << "ids in this section is empty";
......@@ -156,84 +160,107 @@ static void MergeMultipleVarsIntoOneBySection(
}
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& cpu_ctx = *pool.Get(platform::CPUPlace());
auto& actual_ctx = *pool.Get(context.GetPlace());
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
context.Attr<int>("trainer_id"));
prefetchs({id_name}, {out_name}, persistable_var_name, backfill, table_names,
endpoints, height_sections, context, scope);
}
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (size_t i = 0; i < epmap.size(); ++i) {
in_var_names.push_back(id_name + "@" + epmap[i]);
out_var_names.push_back(out_name + "@" + epmap[i]);
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope) {
PADDLE_ENFORCE_GT(id_var_names.size(), 0, "");
PADDLE_ENFORCE_EQ(id_var_names.size(), out_var_names.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), endpoints.size(), "");
PADDLE_ENFORCE_EQ(table_names.size(), height_sections.size(), "");
auto* reconstruct_var =
scope.FindVar(persistable_var_name)->GetMutable<framework::LoDTensor>();
const auto vec_dim_1 = reconstruct_var->dims()[1];
const auto place =
scope.FindVar(id_var_names[0])->Get<framework::LoDTensor>().place();
if (!platform::is_cpu_place(place)) {
PADDLE_THROW("multi prefetch only support CPU currently");
}
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
std::vector<int64_t> ids_vector;
if (platform::is_cpu_place(id_tensor.place())) {
std::vector<std::vector<int64_t>> ids_group;
std::vector<int64_t> ids_union;
std::vector<framework::LoD> ids_lods;
TableAndEndpoints tables;
for (auto& id_name : id_var_names) {
auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* id_data = id_tensor.data<int64_t>();
std::vector<int64_t> ids;
for (int64_t i = 0; i < id_tensor.numel(); ++i) {
ids_vector.push_back(id_data[i]);
}
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
auto cpu_place = platform::CPUPlace();
framework::LoDTensor cpu_tensor;
auto* cpu_tensor_data =
cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(cpu_place, cpu_tensor_data,
boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
stream);
for (int64_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]);
ids.push_back(id_data[i]);
ids_union.push_back(id_data[i]);
}
#endif
ids_group.push_back(ids);
ids_lods.push_back(id_tensor.lod());
}
auto splited_ids = SplitIds(ids_vector, height_sections);
SplitIdsIntoMultipleVarsBySection(in_var_names, height_sections, splited_ids,
local_scope.get());
std::unordered_set<int64_t> s(ids_union.begin(), ids_union.end());
ids_union.assign(s.begin(), s.end());
// create output var in local scope
for (auto& name : out_var_names) {
local_scope->Var(name)->GetMutable<framework::LoDTensor>();
for (int i; i < table_names.size(); i++) {
tables.push_back(std::make_pair(table_names[i], endpoints[i]));
}
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < in_var_names.size(); i++) {
if (NeedSend(*local_scope.get(), in_var_names[i])) {
VLOG(3) << "sending " << in_var_names[i] << " to " << epmap[i]
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], cpu_ctx, *local_scope.get(), in_var_names[i],
out_var_names[i], table_names[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << out_var_names[i];
}
std::unordered_map<int64_t, std::vector<float>> recved_vec_map;
prefetch_core(ids_union, tables, height_sections, context, scope,
&recved_vec_map);
auto padding_idx = distributed::kNoPadding;
if (context.HasAttr("padding_idx")) {
padding_idx = context.Attr<int64_t>("padding_idx");
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
// copy vectors to out vars
for (int i = 0; i < out_var_names.size(); i++) {
auto& ids = ids_group[i];
auto* out_t =
scope.FindVar(out_var_names[i])->GetMutable<framework::LoDTensor>();
out_t->Resize(
framework::make_ddim({static_cast<int64_t>(ids.size()), vec_dim_1}));
out_t->set_lod(ids_lods[i]);
auto* out_d = out_t->mutable_data<float>(place);
for (int idx = 0; idx < ids.size(); idx++) {
const auto& id = ids[idx];
if (padding_idx != distributed::kNoPadding && id == padding_idx) {
memset(out_d + idx * vec_dim_1, 0, sizeof(float) * vec_dim_1);
} else {
std::copy_n(recved_vec_map[id].begin(), vec_dim_1,
out_d + idx * vec_dim_1);
}
}
}
MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
out_var_names, height_sections, splited_ids,
context, local_scope.get(), &actual_ctx);
if (backfill) {
VLOG(3) << "backfill persistable var's id with vecs";
auto* reconstruct_d = reconstruct_var->data<float>();
for (auto& id : ids_union) {
std::copy(recved_vec_map[id].begin(), recved_vec_map[id].end(),
reconstruct_d + id * vec_dim_1);
}
}
}
}; // namespace distributed
......
......@@ -15,6 +15,7 @@
#pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/operator.h"
......@@ -23,61 +24,25 @@ namespace paddle {
namespace operators {
namespace distributed {
constexpr int64_t kNoPadding = -1;
void prefetchs(const std::vector<std::string>& id_var_names,
const std::vector<std::string>& out_var_names,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& persistable_var_name, const bool backfill,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<std::string>& endpoints,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope);
template <typename T>
void prefetch_with_reconstruct(const std::string& id_name,
const std::string& out_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context,
const framework::Scope& scope,
framework::LoDTensor* original) {
prefetch(id_name, out_name, table_names, epmap, height_sections, context,
scope);
auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
auto* original_value = original->data<T>();
auto* out_value = out.data<T>();
size_t original_width = original->numel() / original->dims()[0];
bool is_on_cpu_place = true;
if (!platform::is_cpu_place(ids.place())) {
is_on_cpu_place = false;
}
if (is_on_cpu_place) {
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row =
original_value + original_width * ids.data<int64_t>()[i];
std::memcpy(original_row, out_rows, original_width * sizeof(T));
}
} else {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("paddle is not compiled with CUDA!");
#else
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& actual_ctx = *pool.Get(context.GetPlace());
for (int64_t i = 0; i < ids.numel(); i++) {
const T* out_rows = out_value + original_width * i;
T* original_row =
original_value + original_width * ids.data<int64_t>()[i];
auto stream =
static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
platform::CPUPlace(), out_rows, original_width * sizeof(T),
stream);
}
#endif
}
}
}; // namespace distributed
}; // namespace operators
}; // namespace paddle
......@@ -116,42 +116,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (auto i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
} else {
*outvar = scope_->FindVar(varname);
}
*outvar = scope_->FindVar(varname);
}
}
return true;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
class DistributedLookupTableOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
"Output(Outs) of LookupTableOp should not be null.");
auto ids_dims = ctx->GetInputsDim("Ids");
auto table_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported.");
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2.");
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
"The last dimension of the 'Ids' tensor must be 1.");
}
auto lookup_tables =
ctx->Attrs().Get<std::vector<std::string>>("table_names");
auto height_sections =
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
lookup_tables.size() == endpoints.size() &&
lookup_tables.size() != 0,
"Attrs lookup_tables/height_sections/endpoints must have "
"save size and can not be 0.");
auto outputs_dims = std::vector<framework::DDim>();
for (auto &ids_dim : ids_dims) {
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
}
ctx->SetOutputsDim("Outputs", outputs_dims);
ctx->ShareLoD("Ids", /*->*/ "Outputs");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class DistributedLookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto ids_vars = context.MultiInputVar("Ids");
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
auto id_names = context.Inputs("Ids");
auto embedding_name = context.Inputs("W").front();
auto out_names = context.Outputs("Outputs");
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
operators::distributed::prefetchs(
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
height_sections, context, context.scope());
}
};
class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Ids",
"(LoDTensor) Ids's type should be LoDTensor"
"THe ids to be looked up in W.")
.AsDuplicable();
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddOutput("Outputs",
"(LoDTensor) The lookup results, which have the same type as W.")
.AsDuplicable();
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, such as emb_block0, emb_block1)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({""});
AddAttr<std::vector<int64_t>>("height_sections",
"Height for each output SelectedRows.")
.SetDefault(std::vector<int64_t>({}));
AddAttr<std::vector<std::string>>(
"endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({"127.0.0.1:6164"});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(distributed::kNoPadding);
AddComment(R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
then concatenated into a sparse tensor.
The type of Ids(Input) is SelectedRows, the rows of Ids contains
the ids to be looked up in W;
if the Id is not in the sparse table, this operator will return a
random value and set the value into the table for the next looking up.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp,
ops::DistributedLookupTableOpMaker);
REGISTER_OP_CPU_KERNEL(distributed_lookup_table,
ops::DistributedLookupTableKernel<float>);
......@@ -40,13 +40,15 @@ class FetchBarrierOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
rets.push_back(rpc_client->AsyncSendFetchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -44,7 +44,7 @@ class RecvOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
......@@ -64,8 +64,8 @@ class RecvOp : public framework::OperatorBase {
trainer_id);
recv_functor(rpc_ctx, scope);
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
......@@ -73,13 +73,7 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
......@@ -87,9 +81,11 @@ class RecvOp : public framework::OperatorBase {
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
}
}
}
......@@ -112,10 +108,6 @@ This operator can get variables from server side.
"variables for mapping")
.SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
......
......@@ -44,13 +44,16 @@ class SendBarrierOp : public framework::OperatorBase {
VLOG(3) << "SendBarrierOp sync";
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
std::vector<distributed::VarHandlePtr> rets;
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
rets.push_back(rpc_client->AsyncSendBatchBarrier(ep));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
}
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
}
};
......
......@@ -41,7 +41,6 @@ class SendOp : public framework::OperatorBase {
auto ins = Inputs("X");
auto epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_mode");
auto trainer_id = Attr<int>("trainer_id");
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
......@@ -75,12 +74,10 @@ class SendOp : public framework::OperatorBase {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
for (size_t i = 0; i < rets.size(); i++) {
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
}
}
}
......@@ -98,10 +95,6 @@ Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
......
......@@ -97,10 +97,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_DISTRIBUTE
// w_Out is set to used by prefetch, never change it in other cases
auto* w_out = ctx.Output<framework::LoDTensor>("W_Out");
operators::distributed::prefetch_with_reconstruct<T>(
"Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections,
ctx, local_scope, w_out);
auto weight = ctx.Outputs("W_Out").front();
operators::distributed::prefetch("Ids@Prefetch", "W@Prefetch", weight,
true, table_names, epmap,
height_sections, ctx, local_scope);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
......
......@@ -82,46 +82,27 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context,
context.scope());
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<T, 128, 8, 8, false><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<T, 128, 8, 8, true><<<
grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
if (padding_idx == -1)
LookupTable<
T, 128, 8, 8,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTable<
T, 128, 8, 8,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
};
......
......@@ -46,6 +46,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *table_var = context.InputVar("W");
auto id_name = context.Inputs("Ids").front();
auto embedding_name = context.Inputs("W").front();
auto out_name = context.Outputs("Out").front();
// for remote prefetch
......@@ -57,12 +58,12 @@ class LookupTableKernel : public framework::OpKernel<T> {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
// parameter server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context,
context.scope());
operators::distributed::prefetch(id_name, out_name, embedding_name, false,
table_names, epmap, height_sections,
context, context.scope());
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
......
......@@ -195,9 +195,10 @@ class NCEKernel : public framework::OpKernel<T> {
w_tensor->Resize(framework::make_ddim(w_dims));
#ifdef PADDLE_WITH_DISTRIBUTE
auto weight = context.Inputs("Weight").front();
operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch",
table_names, epmap, height_sections,
context, local_scope);
weight, false, table_names, epmap,
height_sections, context, local_scope);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
......
......@@ -102,16 +102,19 @@ class SaveOpKernel : public framework::OpKernel<T> {
void SaveSelectedRows(const framework::ExecutionContext &ctx,
const platform::Place &place,
const framework::Variable *var) const {
framework::Variable *out_put_var = ctx.OutputVar(LOOKUP_TABLE_PATH);
auto file_path = ctx.Attr<std::string>("file_path");
auto overwrite = ctx.Attr<bool>("overwrite");
std::string filename = file_path;
VLOG(4) << "SaveSelectedRows output file_path: " << file_path;
framework::Variable *out_put_var = ctx.scope().FindVar(LOOKUP_TABLE_PATH);
if (out_put_var != nullptr) {
auto *lt_var = out_put_var->GetMutable<std::string>();
filename = *lt_var;
if (lt_var->length() > 0) {
VLOG(4) << "SaveSelectedRows output var name: " << *lt_var;
filename = *lt_var;
}
}
if (FileExists(filename) && !overwrite) {
......
......@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) {
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)
.def("start", &Communicator::Start);
.def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning);
}
} // namespace pybind
......
......@@ -86,3 +86,21 @@ class Communicator(object):
comm.stop()
"""
self.communicator_.stop()
def is_running(self):
"""
Get communicator is running or stop.
Returns:
bool
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.is_running()
"""
self.communicator_.is_running()
......@@ -363,7 +363,17 @@ def load_persistables_for_inference(dirname, executor, program,
})
sums.append(param_var)
global_block.append_op(
type='sum', inputs={"X": sums}, outputs={'Out': emb_var}, attrs={})
type='merge_sparse_lookup_table',
inputs={"X": sums},
outputs={'Out': emb_var},
attrs={})
global_block.append_op(
type='save',
inputs={"X": [emb_var]},
outputs={},
attrs={
'file_path': os.path.join(lookup_table_dirname, emb_var.name)
})
global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program)
......
......@@ -312,7 +312,7 @@ class MultiSlotDataGenerator(DataGenerator):
)
if name != self._proto_info[index][0]:
raise ValueError(
"the field name of two given line are not match: require<%s>, get<%d>."
"the field name of two given line are not match: require<%s>, get<%s>."
% (self._proto_info[index][0], name))
if output:
output += " "
......
......@@ -158,7 +158,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
finalize the current MPI instance.
"""
pass
self.MPI.Finalize()
def _get_ips(self):
"""
......@@ -316,6 +316,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER":
self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator
......@@ -25,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class DistributedTranspiler(Fleet):
......@@ -51,9 +53,20 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
# if MPISymetricRoleMaker is defined
# we suppose a user wants to submit job on mpi cluster
if isinstance(self._role_maker, MPISymetricRoleMaker):
# check whether server has been initialized
from paddle.fluid.transpiler.details.checkport import wait_server_ready
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program)
self._communicator.start()
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
def init_server(self, model_dir=None):
"""
......@@ -104,10 +117,14 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if not self._transpile_config.sync_mode:
if not self._transpile_config.sync_mode and self._communicator.is_running(
):
self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize()
def distributed_optimizer(self, optimizer, strategy=None):
"""
Optimizer for distributed training.
......@@ -193,13 +210,23 @@ class DistributedTranspiler(Fleet):
self._transpile_config = config
self._transpiler = OriginTranspiler(config)
print("server endpoints")
print(fleet.server_endpoints(to_string=True))
print("worker index: %d" % fleet.worker_index())
print("worker num: %d" % fleet.worker_num())
if self.is_worker():
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode)
self.main_program = self._transpiler.get_trainer_program()
if isinstance(self._role_maker, MPISymetricRoleMaker):
config.wait_port = False
self.main_program = self._transpiler.get_trainer_program(
wait_port=config.wait_port)
self.startup_program = default_startup_program()
else:
self._transpiler.transpile(
......
......@@ -88,13 +88,21 @@ def is_persistable(var):
def _clone_var_in_block_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
else:
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
def save_vars(executor,
......
......@@ -15,12 +15,51 @@
from __future__ import print_function
import unittest
import time
import paddle.fluid as fluid
from paddle.fluid.communicator import Communicator
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
class TestCommunicator(unittest.TestCase):
def net(self):
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def test_communicator_init_and_start(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.WORKER,
worker_num=2,
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet.init(role)
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(0.01)
strategy = DistributeTranspilerConfig()
strategy.sync_mode = True
strategy.wait_port = False
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
comm = Communicator(fleet.main_program)
comm.start()
time.sleep(10)
comm.stop()
class TestCommunicator2(unittest.TestCase):
def test_communicator_init_and_start(self):
prog = fluid.Program()
comm = Communicator(prog)
......
......@@ -18,6 +18,18 @@ import unittest
from test_dist_base import TestDistBase
def skip_ci(func):
on_ci = bool(int(os.environ.get("SKIP_UNSTABLE_CI", '0')))
def __func__(*args, **kwargs):
if on_ci:
return
return func(*args, **kwargs)
return __func__
@skip_ci
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -27,6 +39,7 @@ class TestDistCTR2x2(TestDistBase):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
@skip_ci
class TestDistCTRWithL2Decay2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -37,7 +50,7 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
self.check_with_place(
"dist_ctr.py",
delta=1e-7,
check_error_log=False,
check_error_log=True,
need_envs=need_envs)
......
......@@ -19,6 +19,18 @@ import unittest
from test_dist_fleet_base import TestFleetBase
def skip_ci(func):
on_ci = bool(int(os.environ.get("SKIP_UNSTABLE_CI", '0')))
def __func__(*args, **kwargs):
if on_ci:
return
return func(*args, **kwargs)
return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
......
......@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
class TestDistW2V2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4)
......@@ -29,6 +30,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
self._enforce_place = "CPU"
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=1e-4)
......@@ -37,6 +39,7 @@ class TestDistW2V2x2WithMemOpt(TestDistBase):
class TestDistW2V2x2Async(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._enforce_place = "CPU"
def test_dist_train(self):
self.check_with_place("dist_word2vec.py", delta=100)
......
......@@ -185,8 +185,6 @@ class TestListenAndServOp(unittest.TestCase):
port1 = self._get_pserver_port(p1.pid)
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self._run_lookup_table_op_one_pserver(place, port0)
......
......@@ -314,14 +314,49 @@ class DistributeTranspiler(object):
sparse_update_ops.append(op)
return sparse_update_ops
def _update_remote_sparse_update_op(self, param_varname, height_sections,
endpint_map, table_names):
def _update_remote_sparse_update_op(self, program, param_varname,
height_sections, endpoints,
table_names):
ops = []
op_type = ""
for op in self.sparse_update_ops:
if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map)
op._set_attr('table_names', table_names)
op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id)
if param_varname in op.input_arg_names and op_type == "":
op_type = op.type
ops.append(op)
elif param_varname in op.input_arg_names and op_type == op.type:
ops.append(op)
if op_type == "lookup_table":
all_ops = program.global_block().ops
op_idxs = [all_ops.index(op) for op in ops]
inputs = [
program.global_block().vars[op.input("Ids")[0]] for op in ops
]
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
]
for idx in op_idxs[::-1]:
program.global_block()._remove_op(idx)
program.global_block()._insert_op(
index=op_idxs[0],
type="distributed_lookup_table",
inputs={"Ids": inputs,
'W': w},
outputs={"Outputs": outputs},
attrs={
"table_names": table_names,
"height_sections": height_sections,
"endpoints": endpoints,
"padding_idx": padding_idx,
"trainer_id": self.trainer_id
})
def _is_input_of_remote_sparse_update_op(self, param_name):
for op in self.sparse_update_ops:
......@@ -456,17 +491,12 @@ class DistributeTranspiler(object):
splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[
grad_varname]
if self._is_input_of_remote_sparse_update_op(
sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = [splited_vars[0].shape[0]]
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
if not self.config.runtime_split_send_recv:
self._insert_split_op(program, orig_var, index,
splited_vars)
......@@ -475,6 +505,13 @@ class DistributeTranspiler(object):
AssertionError("Can not insert the send op by original "
"variable name :", splited_grad_varname)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[grad_varname]
if self._is_input_of_remote_sparse_update_op(sparse_param_name):
self.sparse_param_to_height_sections[sparse_param_name] = [
splited_var.shape[0] for splited_var in splited_vars
]
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
......@@ -507,8 +544,7 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode,
]
})
for _, var in enumerate(splited_vars):
send_vars.append(var)
......@@ -528,7 +564,6 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out},
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
......@@ -574,7 +609,6 @@ class DistributeTranspiler(object):
recv_op_role_var_name = splited_trainer_grad[0].name
if param_varname in self.sparse_param_to_height_sections:
for table_name in table_names:
distributed_var = self.vars_overview.get_distributed_var_by_slice(
table_name)
......@@ -583,7 +617,7 @@ class DistributeTranspiler(object):
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names)
program, param_varname, height_sections, eps, table_names)
else:
recv_varnames = []
if self.config.runtime_split_send_recv:
......@@ -602,8 +636,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
[param_varname, recv_op_role_var_name]
})
if self.sync_mode:
......@@ -1481,7 +1514,6 @@ class DistributeTranspiler(object):
if self.sync_mode else []
},
attrs={
"sync_mode": not self.sync_mode,
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册