未验证 提交 2467c137 编写于 作者: C Chengmo 提交者: GitHub

Add GEO-SGD distribute training algorithm (#20018) (#20133)

* refector geo sgd & communicator
上级 43f11b5e
......@@ -14,16 +14,17 @@ limitations under the License. */
#pragma once
#include <ThreadPool.h>
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <ThreadPool.h>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/distributed/rpc_common.h"
......@@ -170,6 +171,11 @@ class Communicator {
virtual void Send(const std::string& var_name,
const framework::Scope& scope) = 0;
virtual void Send(const std::vector<std::string>& sparse_var_names,
const std::vector<std::string>& sparse_var_tables,
const framework::Scope& scope) = 0;
virtual void Recv() = 0;
virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx,
......@@ -179,6 +185,13 @@ class Communicator {
virtual void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) = 0;
// for geo-sgd
virtual void InitImpl(
const paddle::framework::ProgramDesc& program, Scope* param_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) = 0;
static Communicator* GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
......@@ -194,6 +207,26 @@ class Communicator {
return communicator_.get();
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
return communicator_.get();
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) {
std::call_once(init_flag_, &Communicator::InitWithTranspilerInfo<T>,
program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums));
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx,
......@@ -206,14 +239,6 @@ class Communicator {
}
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
return communicator_.get();
}
template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) {
......@@ -223,12 +248,28 @@ class Communicator {
}
}
template <typename T>
static void InitWithTranspilerInfo(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(program, training_scope, std::ref(vars_info),
std::ref(trainers), std::ref(geo_need_push_nums));
}
}
protected:
bool running_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
};
using SparseIdsMap =
std::unordered_map<std::string, std::unordered_set<int64_t>>;
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() {}
......@@ -251,6 +292,16 @@ class AsyncCommunicator : public Communicator {
void SendThread();
void RecvThread();
void Send(const std::vector<std::string>& sparse_var_names,
const std::vector<std::string>& sparse_var_tables,
const framework::Scope& scope) override;
void InitImpl(
const paddle::framework::ProgramDesc& program, Scope* param_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) override;
private:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
......@@ -266,6 +317,93 @@ class AsyncCommunicator : public Communicator {
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
};
class GeoSgdCommunicator : public Communicator {
public:
GeoSgdCommunicator() {}
~GeoSgdCommunicator();
void InitImpl(
const paddle::framework::ProgramDesc& program, Scope* training_scope,
std::map<std::string, std::map<std::string, std::vector<std::string>>>&
vars_info,
const int& trainers, const int& geo_need_push_nums) override;
void Start() override;
void Stop() override;
void Send(const std::string& var_name,
const framework::Scope& scope) override;
void Send(const std::vector<std::string>& sparse_var_names,
const std::vector<std::string>& sparse_var_tables,
const framework::Scope& scope) override;
void Recv() override;
void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) override;
void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override;
private:
void SendThread();
void RecvAll();
std::unordered_set<int64_t> SparseIdsMerge(
const std::vector<SparseIdsMap>& ids_send_vec,
const std::string& var_name);
void SendUpdateDenseVars(const std::string& var_name);
void SendUpdateSparseVars(const std::string& var_name,
const std::unordered_set<int64_t>& ids_table);
void RecvUpdateVars(const std::string& var_name);
void GeoSgdDenseParamInit(framework::Scope* scope_x,
framework::Scope* scope_y,
const std::string var_name);
void GeoSgdSparseParamInit(framework::Scope* scope_x,
framework::Scope* scope_y,
const std::string var_name);
const std::string VarToDeltaVar(const std::string var_name) {
std::string delta_name = var_name;
const std::string send_name = delta_name.append(".delta");
return send_name;
}
const std::string DeltaVarToVar(const std::string var_name) {
std::string origin_name = var_name;
origin_name.erase(origin_name.find(".delta"), 6);
const std::string param_name = origin_name;
return param_name;
}
private:
int trainer_nums_ = 1;
int geo_need_push_nums_ = 100;
bool is_geo_sgd_ = false;
Scope* training_scope_;
std::shared_ptr<Scope> delta_scope_; // parameter local delta: recv - old
std::shared_ptr<Scope>
old_scope_; // parameter local, storage the param after last recv
std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope
RpcCtxMap send_varname_to_ctx_;
RpcCtxMap recv_varname_to_ctx_;
std::atomic_uint have_push_{0};
std::unordered_map<std::string, bool>
var_list_; // if var is sparse, using selected rows, bool=true
std::shared_ptr<BlockingQueue<std::shared_ptr<SparseIdsMap>>>
need_push_queue_;
std::vector<SparseIdsMap> ids_send_vec_;
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::unique_ptr<std::thread> send_thread_{nullptr};
};
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -42,7 +42,7 @@ using DDim = framework::DDim;
template <typename T>
void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope) {
VLOG(3) << "ParameterRecv in " << rpc_ctx.var_name;
VLOG(2) << "ParameterRecv in " << rpc_ctx.var_name;
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -54,15 +54,24 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
auto *recv_var = scope.FindVar(rpc_ctx.var_name);
// recv all vars to local scope
if (recv_var->IsType<framework::LoDTensor>()) {
if (recv_var->IsType<framework::LoDTensor>() ||
recv_var->IsType<framework::SelectedRows>()) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
local_scope->Var(recv_var_name);
VLOG(3) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(), recv_var_name,
recv_var_name));
VLOG(4) << "recv " << recv_var_name << " from " << rpc_ctx.epmap[i];
if (recv_var->IsType<framework::LoDTensor>()) {
// sparse param in recv_scope is LoDTensor
rets.push_back(rpc_client->AsyncGetVar(rpc_ctx.epmap[i], cpu_ctx,
*local_scope.get(),
recv_var_name, recv_var_name));
} else {
// sparse param in pserver_scope is SelectedRows
rets.push_back(rpc_client->AsyncGetVar(
rpc_ctx.epmap[i], cpu_ctx, *local_scope.get(), recv_var_name,
recv_var_name, recv_var_name));
}
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
......@@ -72,7 +81,7 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
}
// concat recved tensor into one var
{
if (recv_var->IsType<framework::LoDTensor>()) {
size_t output_offset = 0;
size_t row_offset = 0;
framework::Tensor *recv_tensor =
......@@ -126,9 +135,56 @@ void ParameterRecv<T>::operator()(const RpcContext &rpc_ctx,
LOG(FATAL) << "recv_numel: " << recv_numel << " acture numel: " << numel;
}
PADDLE_ENFORCE_EQ(recv_numel, numel);
} else if (recv_var->IsType<framework::SelectedRows>()) {
auto cpu_place = platform::CPUPlace();
auto *slr = recv_var->GetMutable<framework::SelectedRows>();
slr->mutable_rows()->clear();
slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
int64_t width = 0;
int64_t height = 0;
std::vector<int64_t> new_rows{};
// trans sparse ids from local to global
std::vector<int64_t> abs_sections =
ToAbsoluteSection(rpc_ctx.height_sections);
for (int i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &recv_var_name = rpc_ctx.splited_var_names[i];
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
width = var_slr->mutable_value()->dims()[1];
height += var_slr->height();
auto row_offset = abs_sections[i];
VLOG(4) << "Recv split_var " << recv_var_name << " Row size "
<< var_slr_row->size();
for (size_t j = 0; j < var_slr_row->size(); j++) {
new_rows.push_back(row_offset + (*var_slr_row)[j]);
}
}
slr->set_rows(new_rows);
slr->set_height(height);
slr->mutable_value()->mutable_data<float>(
framework::make_ddim(
{static_cast<int64_t>(slr->mutable_rows()->size()), width}),
cpu_place);
auto *slr_data = slr->mutable_value()->data<float>();
size_t row_offset = 0;
for (auto &recv_var_name : rpc_ctx.splited_var_names) {
auto *var = local_scope->FindVar(recv_var_name);
auto *var_slr = var->GetMutable<framework::SelectedRows>();
auto *var_slr_row = var_slr->mutable_rows();
auto var_slr_row_size = var_slr_row->size();
auto *var_slr_data = var_slr->mutable_value()->data<float>();
memcpy(slr_data + row_offset * width, var_slr_data,
sizeof(float) * width * var_slr_row_size);
row_offset += var_slr_row_size;
}
}
VLOG(3) << "ParameterRecv out " << rpc_ctx.var_name;
VLOG(2) << "ParameterRecv out " << rpc_ctx.var_name;
}
template struct ParameterRecv<float>;
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
......@@ -28,6 +28,7 @@
#include "paddle/fluid/operators/distributed/rpc_client.h"
#include "paddle/fluid/operators/distributed/variable_response.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
......@@ -38,9 +39,44 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
typedef std::vector<std::pair<std::string, std::string>> EP_SPLIT_TABLE_PAIRS;
inline EP_SPLIT_TABLE_PAIRS GetMultiFieldRpcContext(
const RpcContext &rpc_ctx, const framework::Scope &scope, int multi_parts) {
EP_SPLIT_TABLE_PAIRS table_pairs;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
if (send_var->IsType<framework::SelectedRows>()) {
PADDLE_ENFORCE_GT(multi_parts, 0, "multi_parts must >=1");
if (multi_parts == 1) {
for (int i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
table_pairs.push_back(
std::make_pair(rpc_ctx.epmap[i], rpc_ctx.splited_var_names[i]));
}
} else {
for (int i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (int x = 0; x < multi_parts; x++) {
auto table =
string::Sprintf("%s@%d@PIECE", rpc_ctx.splited_var_names[i], x);
table_pairs.push_back(std::make_pair(rpc_ctx.epmap[i], table));
}
}
}
} else if (send_var->IsType<framework::LoDTensor>()) {
PADDLE_THROW("GetMultiFieldRpcContext can not support LoDTensor current!");
} else {
PADDLE_THROW("GetMultiFieldRpcContext unsupported var type!");
}
return table_pairs;
} // namespace distributed
template <typename T>
void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
const framework::Scope &scope, bool sync) {
const framework::Scope &scope, bool sync,
int multi_parts) {
std::unique_ptr<framework::Scope> local_scope = scope.NewTmpScope();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
......@@ -49,9 +85,12 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(rpc_ctx.trainer_id);
std::vector<distributed::VarHandlePtr> rets;
auto *send_var = scope.FindVar(rpc_ctx.var_name);
size_t out_num = rpc_ctx.splited_var_names.size();
if (send_var->IsType<framework::LoDTensor>()) {
size_t out_num = rpc_ctx.splited_var_names.size();
if (out_num > 1) {
auto &send_tensor = send_var->Get<framework::LoDTensor>();
auto &send_tensor_dims = send_tensor.dims();
......@@ -77,6 +116,24 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
row_offset += outs_dims[i][0];
}
}
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
......@@ -85,84 +142,94 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
std::vector<std::vector<size_t>> outs_rows_idx;
std::vector<std::vector<size_t>> outs_dense_idx;
outs_rows_idx.resize(out_num);
outs_dense_idx.resize(out_num);
auto table_pairs = GetMultiFieldRpcContext(rpc_ctx, scope, multi_parts);
outs_rows_idx.resize(table_pairs.size());
outs_dense_idx.resize(table_pairs.size());
auto row_numel = send_slr.value().numel() / send_slr.value().dims()[0];
auto *src = send_slr.value().data<T>();
// create output var in local scope
std::vector<framework::SelectedRows *> outs;
for (auto &name : rpc_ctx.splited_var_names) {
auto *out = local_scope->Var(name)->GetMutable<framework::SelectedRows>();
for (auto &table : table_pairs) {
auto *out =
local_scope->Var(table.second)->GetMutable<framework::SelectedRows>();
outs.push_back(out);
}
// split rows index into output sparse vars
for (size_t i = 0; i < send_rows.size(); ++i) {
size_t out_idx = GetSectionIndex(send_rows[i], abs_sections);
auto ep_idx = GetSectionIndex(send_rows[i], abs_sections);
auto table_idx = send_rows[i] % multi_parts;
auto out_idx = ep_idx * multi_parts + table_idx;
outs_rows_idx[out_idx].push_back(send_rows[i]);
outs_dense_idx[out_idx].push_back(i);
}
auto place = platform::CPUPlace();
for (size_t i = 0; i < outs_rows_idx.size(); ++i) {
auto rows_idx = outs_rows_idx[i];
outs[i]->set_height(rpc_ctx.height_sections[i]);
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[i]->mutable_rows()->clear();
outs[i]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[i]->mutable_rows()->push_back(idx - abs_sections[i]);
}
auto dst = outs[i]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(
platform::CPUPlace(), dst + j * row_numel, platform::CPUPlace(),
src + outs_dense_idx[i][j] * row_numel, sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
/*
#ifdef PADDLE_WITH_CUDA
auto stream = ctx.cuda_device_context().stream();
memory::Copy(platform::CUDAPlace(), dst + j * row_numel,
platform::CUDAPlace(),
src + outs_dense_idx[i][j] * row_numel,
sizeof(T) * row_numel, stream);
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
*/
for (int ctx = 0; ctx < rpc_ctx.splited_var_names.size(); ctx++) {
for (int part = 0; part < multi_parts; part++) {
auto out_idx = ctx * multi_parts + part;
auto rows_idx = outs_rows_idx[out_idx];
auto dims = send_slr.GetCompleteDims();
dims[0] = rows_idx.size();
outs[out_idx]->set_height(rpc_ctx.height_sections[ctx]);
outs[out_idx]->mutable_rows()->clear();
outs[out_idx]->mutable_value()->mutable_data<T>(dims, send_slr.place());
if (rows_idx.size() > 0) {
for (auto idx : rows_idx) {
outs[out_idx]->mutable_rows()->push_back(idx - abs_sections[ctx]);
}
auto dst = outs[out_idx]->mutable_value()->mutable_data<T>(place);
for (size_t j = 0; j < rows_idx.size(); j++) {
if (platform::is_cpu_place(place)) {
memory::Copy(platform::CPUPlace(), dst + j * row_numel,
platform::CPUPlace(),
src + outs_dense_idx[out_idx][j] * row_numel,
sizeof(T) * row_numel);
} else {
PADDLE_THROW("do not support GPU now");
}
}
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[out_idx]->rows().size(),
"rows should has the same size with tensor dim 0");
}
PADDLE_ENFORCE_EQ(rows_idx.size(), outs[i]->rows().size(),
"rows should has the same size with tensor dim 0");
}
} else {
PADDLE_THROW("unsupported var type to send!");
}
for (size_t i = 0; i < table_pairs.size(); i++) {
auto &send_var_name = table_pairs[i].second;
auto &endpoint = table_pairs[i].first;
auto need_send = NeedSend(*local_scope.get(), send_var_name);
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
auto &endpoint = rpc_ctx.epmap[i];
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name
<< "send var endpoint: " << endpoint
<< "need send: " << need_send;
if (need_send) {
VLOG(4) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(4) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} else {
PADDLE_THROW("unsupported var type to send!");
}
VLOG(4) << "Prepare to send var " << rpc_ctx.var_name;
if (sync) {
for (auto &handle : rets) {
VLOG(4) << "Wait send var to pserver handle: " << handle;
PADDLE_ENFORCE(handle->Wait(), "internal error in RPCClient");
}
}
......
......@@ -27,7 +27,7 @@ namespace distributed {
template <typename T>
struct ParameterSend {
void operator()(const RpcContext &rpc_ctx, const framework::Scope &scope,
bool sync);
bool sync, int multi_parts);
};
}; // namespace distributed
......
......@@ -26,6 +26,7 @@
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/split.h"
namespace paddle {
namespace operators {
......@@ -60,13 +61,26 @@ bool RequestSendHandler::Handle(const std::string& varname,
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(varname)) {
std::string run_varname = varname;
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@');
PADDLE_ENFORCE_EQ(varname_splits.size(), 3);
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto& grad_slr =
scope->FindVar(varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(varname,
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
......@@ -116,7 +130,46 @@ bool RequestGetHandler::Handle(const std::string& varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname);
VLOG(1) << "Table name empty? " << table_name.empty();
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
std::vector<int64_t> updated_rows;
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (auto i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
} else {
*outvar = scope_->FindVar(varname);
}
}
}
return true;
......
......@@ -47,8 +47,12 @@ class SendOp : public framework::OperatorBase {
auto height_sections = Attr<std::vector<int64_t>>("sections");
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
distributed::Communicator::GetInstance()->Send(ins[0], scope);
if (ins.size() > 1) {
distributed::Communicator::GetInstance()->Send(ins, send_varnames,
scope);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#include "paddle/fluid/pybind/communicator_py.h"
#include <Python.h>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "pybind11/pybind11.h"
......@@ -27,6 +29,7 @@ namespace py = pybind11;
using paddle::framework::ProgramDesc;
using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::AsyncCommunicator;
using paddle::operators::distributed::GeoSgdCommunicator;
using paddle::framework::Scope;
namespace paddle {
......@@ -37,9 +40,20 @@ void BindCommunicator(py::module* m) {
py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope) {
VLOG(0) << "using communicator";
Communicator::InitInstance<AsyncCommunicator>(program, param_scope);
return Communicator::GetInstantcePtr();
}))
.def(py::init([](
const ProgramDesc& program, Scope* training_scope,
std::map<std::string,
std::map<std::string, std::vector<std::string>>>& vars_info,
int& trainers, int& geo_need_push_nums) {
VLOG(0) << "using geo sgd communicator";
Communicator::InitInstance<GeoSgdCommunicator>(
program, training_scope, vars_info, trainers, geo_need_push_nums);
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)
.def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning);
......
......@@ -16,7 +16,11 @@ limitations under the License. */
#include <Python.h>
#include "pybind11/chrono.h"
#include "pybind11/complex.h"
#include "pybind11/functional.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace paddle {
namespace pybind {
......
......@@ -195,6 +195,7 @@ def __bootstrap__():
read_env_flags.append('communicator_min_send_grad_num_before_recv')
read_env_flags.append('communicator_thread_pool_size')
read_env_flags.append('communicator_max_merge_var_num')
read_env_flags.append('communicator_merge_sparse_bucket')
read_env_flags.append('communicator_fake_rpc')
read_env_flags.append('communicator_send_wait_times')
read_env_flags.append('communicator_merge_sparse_grad')
......
......@@ -13,6 +13,10 @@
# limitations under the License.
from .executor import global_scope
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
"""
from . import core
from .framework import Program
......@@ -20,7 +24,11 @@ __all__ = ['Communicator']
class Communicator(object):
def __init__(self, program):
def __init__(self,
program,
vars_info=None,
trainers=None,
geo_sgd_need_push_nums=None):
"""
Communicator is used for async distribute training in distribute_transpiler mode.
It's a wrapper of a cpp class Communicator and should be used inside fleet API.
......@@ -47,7 +55,15 @@ class Communicator(object):
for op in program.block(0).ops:
if op.type == "recv":
op._set_attr('do_not_run', True)
self.communicator_ = core.DistCommunicator(program.desc, global_scope())
# Todo: Add check
if vars_info and trainers and geo_sgd_need_push_nums:
# for geo sgd
self.communicator_ = core.DistCommunicator(
program.desc,
global_scope(), vars_info, trainers, geo_sgd_need_push_nums)
else:
self.communicator_ = core.DistCommunicator(program.desc,
global_scope())
def start(self):
"""
......
......@@ -13,7 +13,9 @@
# limitations under the License.
import os
import warnings
"""
Convert the fluid program to distributed data-parallelism programs.
"""
import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator
from paddle.fluid.framework import default_main_program
......@@ -24,6 +26,7 @@ from paddle.fluid.executor import Executor
from paddle.fluid.parallel_executor import ParallelExecutor
from paddle.fluid.optimizer import Optimizer
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspiler as OriginTranspiler
from paddle.fluid.transpiler.geo_sgd_transpiler import GeoSgdTranspiler
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
......@@ -64,7 +67,13 @@ class DistributedTranspiler(Fleet):
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program)
if self._transpile_config.geo_sgd_mode:
self._communicator = Communicator(
self.main_program, self.vars_info,
fleet.worker_num(),
self._transpile_config.geo_sgd_need_push_nums)
else:
self._communicator = Communicator(self.main_program)
if not self._communicator.is_running():
self._communicator.start()
......@@ -124,7 +133,6 @@ class DistributedTranspiler(Fleet):
):
self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize()
......@@ -239,7 +247,10 @@ class DistributedTranspiler(Fleet):
self._origin_program = default_main_program().clone(for_test=False)
self._transpile_config = config
self._transpiler = OriginTranspiler(config)
if config.geo_sgd_mode:
self._transpiler = GeoSgdTranspiler(config)
else:
self._transpiler = OriginTranspiler(config)
if self.is_worker():
self._transpiler.transpile(
......@@ -254,6 +265,9 @@ class DistributedTranspiler(Fleet):
self.main_program = self._transpiler.get_trainer_program(
wait_port=config.wait_port)
self.startup_program = default_startup_program()
if self._transpile_config.geo_sgd_mode:
self.vars_info = self._transpiler._get_vars_info()
self.startup_program = self._transpiler.trainer_startup_program
else:
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
......@@ -262,7 +276,8 @@ class DistributedTranspiler(Fleet):
sync_mode=config.sync_mode,
current_endpoint=self.server_endpoints()[self.server_index()])
self.main_program, self.startup_program = \
self._transpiler.get_pserver_programs(self.server_endpoints()[self.server_index()])
self._transpiler.get_pserver_programs(
self.server_endpoints()[self.server_index()])
fleet = DistributedTranspiler()
......
......@@ -24,6 +24,7 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_nce_remote_table_op)
LIST(REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op)
LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr)
LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_geo)
endif(NOT WITH_DISTRIBUTE)
......
......@@ -13,7 +13,9 @@
# limitations under the License.
from __future__ import print_function
"""
high level unit test for distribute fleet.
"""
import argparse
import os
import pickle
......@@ -29,6 +31,7 @@ from contextlib import closing
import six
import unittest
import numpy as np
import tempfile
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
......@@ -40,6 +43,12 @@ LEARNING_RATE = 0.01
class FleetDistRunnerBase(object):
"""
run_pserver,run_trainer : after init role, using transpiler split program
net : implment by child class, the network of model
do training : exe run program
"""
def run_pserver(self, args):
if args.role.upper() != "PSERVER":
raise ValueError("args role must be PSERVER")
......@@ -54,6 +63,8 @@ class FleetDistRunnerBase(object):
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
strategy.geo_sgd_mode = args.geo_sgd_mode
strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums
avg_cost = self.net()
......@@ -78,14 +89,14 @@ class FleetDistRunnerBase(object):
strategy = DistributeTranspilerConfig()
strategy.sync_mode = args.sync_mode
strategy.geo_sgd_mode = args.geo_sgd_mode
strategy.geo_sgd_need_push_nums = args.geo_sgd_need_push_nums
avg_cost = self.net()
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
self.do_training(fleet)
out = self.do_training(fleet)
def net(self, batch_size=4, lr=0.01):
......@@ -98,6 +109,11 @@ class FleetDistRunnerBase(object):
class TestFleetBase(unittest.TestCase):
"""
start_pserver,start_trainer : add start cmd to test
run_cluster : using multi process to test distribute program
"""
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
......@@ -109,6 +125,8 @@ class TestFleetBase(unittest.TestCase):
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._geo_sgd = False
self._geo_sgd_need_push_nums = 5
self._setup_config()
def _find_free_port(self):
......@@ -127,8 +145,8 @@ class TestFleetBase(unittest.TestCase):
def _start_pserver(self, cmd, required_envs):
ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)
ps0_pipe = open("/tmp/ps0_err.log", "wb+")
ps1_pipe = open("/tmp/ps1_err.log", "wb+")
ps0_pipe = open(tempfile.gettempdir() + "/ps0_err.log", "wb+")
ps1_pipe = open(tempfile.gettempdir() + "/ps1_err.log", "wb+")
ps0_proc = subprocess.Popen(
ps0_cmd.strip().split(" "),
......@@ -140,14 +158,13 @@ class TestFleetBase(unittest.TestCase):
stdout=subprocess.PIPE,
stderr=ps1_pipe,
env=required_envs)
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
def _start_trainer(self, cmd, required_envs):
tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)
tr0_pipe = open("/tmp/tr0_err.log", "wb+")
tr1_pipe = open("/tmp/tr1_err.log", "wb+")
tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+")
tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
......@@ -164,18 +181,29 @@ class TestFleetBase(unittest.TestCase):
def _run_cluster(self, model, envs):
env = {'CPU_NUM': '1'}
python_path = self._python_interp
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
python_path += " -m coverage run --branch -p"
env.update(envs)
tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3}".format(
self._python_interp, model, self._ps_endpoints, self._trainers)
python_path, model, self._ps_endpoints, self._trainers)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3}".format(
self._python_interp, model, self._ps_endpoints, self._trainers)
python_path, model, self._ps_endpoints, self._trainers)
if self._sync_mode:
tr_cmd += " --sync_mode"
ps_cmd += " --sync_mode"
if self._geo_sgd:
tr_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format(
self._geo_sgd, self._geo_sgd_need_push_nums)
ps_cmd += " --geo_sgd_mode {0} --geo_sgd_need_push_nums {1}".format(
self._geo_sgd, self._geo_sgd_need_push_nums)
# Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env)
......@@ -259,7 +287,10 @@ def runtime_main(test_class):
parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--sync_mode', action='store_true')
parser.add_argument(
'--geo_sgd_mode', type=bool, required=False, default=False)
parser.add_argument(
'--geo_sgd_need_push_nums', type=int, required=False, default=2)
args = parser.parse_args()
model = test_class()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from test_dist_fleet_base import TestFleetBase
from dist_simnet_bow import train_network
def skip_ci(func):
on_ci = bool(int(os.environ.get("SKIP_UNSTABLE_CI", '0')))
def __func__(*args, **kwargs):
if on_ci:
return
return func(*args, **kwargs)
return __func__
class TestDistGeoCtr_2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
self._geo_sgd = True
self._geo_sgd_need_push_nums = 5
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestGeoSgdTranspiler(unittest.TestCase):
def test_pserver(self):
role = role_maker.UserDefinedRoleMaker(
current_id=0,
role=role_maker.Role.SERVER,
worker_num=2,
server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
fleet.init(role)
batch_size = 128
is_sparse = True
is_distribute = False
strategy = DistributeTranspilerConfig()
strategy.sync_mode = False
strategy.geo_sgd_mode = True
strategy.geo_sgd_need_push_nums = 5
avg_cost, _, _ = train_network(batch_size, is_distribute, is_sparse)
optimizer = fluid.optimizer.SGD(0.1)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
pserver_startup_program = fleet.startup_program
pserver_mian_program = fleet.main_program
if __name__ == "__main__":
unittest.main()
......@@ -180,6 +180,10 @@ class DistributeTranspilerConfig(object):
_runtime_split_send_recv = False
_sync_mode = True
# Geo-sgd algorithm
geo_sgd_mode = False
geo_sgd_need_push_nums = 100
nccl_comm_num = 1
#The picture here illustrates the principle:
#https://github.com/PaddlePaddle/Paddle/pull/17263#discussion_r285411396
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
"""
Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. create delta variable in global scope which used to send
3. add send op to send sparse ids to communicator
Steps to transpile pserver:
1. create new program for parameter server.
2. create params variables that assigned to current server instance.
3. create a sub-block in the server side program
4. append sum ops that should run on current server instance.
5. add listen_and_serv op
"""
import sys
import collections
import six
import numpy as np
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, Block, Parameter
from .details import wait_server_ready, VarsDistributed
from .details import delete_ops
from ..distribute_lookup_table import find_distributed_lookup_table
from .distribute_transpiler import DistributeTranspiler, DistributeTranspilerConfig, slice_variable, same_or_split_var
RPC_OP_ROLE_ATTR_NAME = op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
RPC_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.RPC
class GeoSgdTranspiler(DistributeTranspiler):
def __init__(self, config=None):
if config is not None:
self.config = config
else:
self.config = DistributeTranspilerConfig()
if self.config.split_method is None:
self.config.split_method = RoundRobin
assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher)
def transpile(self,
trainer_id,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
sync_mode=False,
startup_program=None,
current_endpoint="127.0.0.1:6174"):
if program is None:
program = default_main_program()
if startup_program is None:
startup_program = default_startup_program()
self.origin_program = program
self.startup_program = startup_program
self.origin_startup_program = self.startup_program.clone()
self.trainer_num = trainers
# geo-sgd only supply async-mode
self.sync_mode = False
self.trainer_id = trainer_id
pserver_endpoints = pservers.split(",")
self.pserver_endpoints = pserver_endpoints
self.vars_overview = VarsDistributed()
self.optimize_ops, self.params_grads = self._get_optimize_pass()
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
self.param_name_to_grad_name = dict()
self.grad_name_to_param_name = dict()
for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name
self.grad_name_to_param_name[grad_var.name] = param_var.name
# distribute lookup table
self.table_name = find_distributed_lookup_table(self.origin_program)
self.has_distributed_lookup_table = self.table_name != None
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._ps_endpoint = current_endpoint
self.origin_program._is_chief = self.trainer_id == 0
# program info send to geo-sgd communicator
self.vars_info = collections.OrderedDict()
self.split_to_origin_mapping = collections.OrderedDict()
self.delta_vars_list = []
self.sparse_var_list = []
self.sparse_var_splited_list = []
# split and create vars, then put splited vars in dicts for later use.
# step 1. split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars()
# step 3. create send recv var (param after optimize)
send_vars = []
ps_dispatcher.reset()
param_var_mapping_items = list(six.iteritems(self.param_var_mapping))
# send_vars is the parameter which splited by communicator and send to pserver,not the origin parameter
for _, splited_vars in param_var_mapping_items:
for _, var in enumerate(splited_vars):
send_vars.append(var)
recv_vars = send_vars
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
for i, ep in enumerate(eplist):
self.param_opt_ep_mapping[ep]["params"].append(recv_vars[i])
distributed_var = self.vars_overview.get_distributed_var_by_slice(
recv_vars[i].name)
distributed_var.endpoint = ep
origin_name = self.split_to_origin_mapping[recv_vars[i].name]
self.vars_info[origin_name]["epmap"].append(ep)
self.origin_program._parameters_on_pservers = self.vars_overview
# send sparse id to communicator
self.sparse_var = []
self.sparse_tables = []
for op in self.origin_program.global_block().ops:
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if sparse_var_name in self.sparse_var_list:
input_var = program.global_block().var(input_var_name)
self.sparse_var.append(input_var)
self.sparse_tables.append(sparse_var_name)
# batch training loop end flag
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": self.sparse_var},
outputs={"Out": dummy_output},
attrs={"send_varnames": self.sparse_tables})
# add param_init flag in trainer startup program
self.trainer_startup_program = self._get_trainer_startup_program(
recv_vars=recv_vars, eplist=eplist)
for delta_var in self.delta_vars_list:
self.trainer_startup_program.global_block().create_var(
name=delta_var.name,
persistable=delta_var.persistable,
dtype=delta_var.dtype,
type=delta_var.type,
shape=delta_var.shape)
dummy_output = self.trainer_startup_program.global_block().create_var(
name=framework.generate_control_dev_var_name())
param_init = self.trainer_startup_program.global_block().create_var(
name="param_init")
self.trainer_startup_program.global_block().append_op(
type="send",
inputs={"X": [param_init]},
outputs={"Out": dummy_output},
attrs={"send_varnames": [param_init.name]})
def _get_vars_info(self):
return self.vars_info
def get_trainer_program(self, wait_port=True):
# if wait_port:
# wait_server_ready(self.pserver_endpoints)
return self.origin_program
def get_pserver_programs(self, endpoint):
pserver_prog = self.get_pserver_program(endpoint)
self.param_grad_ep_mapping = self.param_opt_ep_mapping
pserver_startup = self.get_startup_program(
endpoint, pserver_program=pserver_prog)
return pserver_prog, pserver_startup
def get_pserver_program(self, endpoint):
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
pserver_program._copy_dist_param_info_from(self.origin_program)
# step2: Create vars to receive vars at parameter servers.
recv_inputs = []
for v in self.param_opt_ep_mapping[endpoint]["params"]:
self._clone_var(pserver_program.global_block(), v)
optimize_block = []
param_to_block_id = []
sparse_grad_to_param = []
# append op to the current block
pre_block_idx = pserver_program.num_blocks - 1
for var in self.param_opt_ep_mapping[endpoint]["params"]:
per_opt_block = pserver_program._create_block(pre_block_idx)
optimize_block.append(per_opt_block)
var_name = var.name
pserver_block = per_opt_block.program.global_block()
param = pserver_block.vars[var_name]
delta_var_name = "%s.delta" % (param.name)
if var.name in self.sparse_var_splited_list:
delta_type = core.VarDesc.VarType.SELECTED_ROWS
sparse_grad_to_param.append(":".join(
[delta_var_name, param.name]))
else:
delta_type = param.type
delta_var = pserver_block.create_var(
name=delta_var_name,
persistable=False,
type=delta_type,
dtype=param.dtype,
shape=param.shape)
per_opt_block.append_op(
type="sum",
inputs={"X": [param, delta_var]},
outputs={"Out": param})
param_to_block_id.append(delta_var_name + ":" + str(
per_opt_block.idx))
attrs = {
"optimize_blocks": optimize_block,
"endpoint": endpoint,
"Fanin": self.trainer_num,
"sync_mode": self.sync_mode,
"grad_to_block_id": param_to_block_id,
"sparse_grad_to_param": sparse_grad_to_param
}
# step5 append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={'X': recv_inputs},
outputs={},
attrs=attrs)
pserver_program._sync_with_cpp()
# save pserver program to generate pserver side startup relatively.
self.pserver_program = pserver_program
return pserver_program
def _init_splited_vars(self):
param_list = []
grad_list = []
param_grad_set = set()
# step 1. create param_list
for p, g in self.params_grads:
if type(p) == Parameter and p.trainable == False:
continue
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
self.sparse_var_list.append(p.name)
# step 2. Slice vars into numbers of piece with block_size
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints),
self.config.min_block_size)
# step 3. Create splited param from split blocks
# origin_param_name -> [splited_param_vars]
# Todo: update _create_vars_from_blocklist
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
# step 4. Create mapping of endpoint -> split var to create pserver side program
self.param_opt_ep_mapping = collections.OrderedDict()
[
self.param_opt_ep_mapping.update({
ep: {
"params": [],
}
}) for ep in self.pserver_endpoints
]
# step 5. Create delta var of Geo-Sgd & record vars infomation
for origin_name, splited_vars in self.param_var_mapping.items():
origin_var = self.origin_program.global_block().var(origin_name)
self.vars_info[origin_name] = collections.OrderedDict()
self.vars_info[origin_name]["var_names"] = []
vars_section = self._get_splited_var_sections(splited_vars)
self.vars_info[origin_name]["sections"] = [
str(i) for i in vars_section
]
self.vars_info[origin_name]["epmap"] = []
self.vars_info[origin_name]["is_sparse"] = []
# todo: add var shape(may be no need,because recv scope have)
if origin_name in self.sparse_var_list:
delta_type = core.VarDesc.VarType.SELECTED_ROWS
self.vars_info[origin_name]["is_sparse"].append("True")
else:
delta_type = origin_var.type
self.vars_info[origin_name]["is_sparse"].append("False")
delta_var = self.origin_program.global_block().create_var(
name=".".join([origin_name, "delta"]),
persistable=False,
dtype=origin_var.dtype,
type=delta_type,
shape=origin_var.shape)
self.delta_vars_list.append(delta_var)
for splited_var in splited_vars:
is_slice, block_id, offset = self._get_slice_var_info(
splited_var)
self.vars_overview.add_distributed_var(
origin_var=origin_var,
slice_var=splited_var,
block_id=block_id,
offset=offset,
is_slice=is_slice,
vtype="Param")
self.split_to_origin_mapping[splited_var.name] = origin_name
if origin_name in self.sparse_var_list:
self.sparse_var_splited_list.append(splited_var.name)
self.vars_info[origin_name]["var_names"].append(
splited_var.name)
if len(splited_vars) != 1:
self.origin_program.global_block().create_var(
name=".".join([splited_var.name, "delta"]),
persistable=False,
dtype=splited_var.dtype,
type=delta_type,
shape=splited_var.shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册