未验证 提交 84b0ec97 编写于 作者: Z zhaocaibei123 提交者: GitHub

Accessor 20211112 2 (#37181)

上级 12339fa0
...@@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { ...@@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
std::vector<uint64_t> FleetWrapper::GetClientsInfo() { std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
VLOG(3) << "Going to get client info"; VLOG(3) << "Going to get client info";
return pserver_ptr_->get_client_info(); auto* communicator = Communicator::GetInstance();
return std::vector<uint64_t>(); std::vector<uint64_t> res = communicator->GetClientInfo();
return res;
} }
void FleetWrapper::CreateClient2ClientConnection() { void FleetWrapper::CreateClient2ClientConnection() {
VLOG(3) << "Going to create client2client connection"; VLOG(1) << "Going to create client2client connection";
pserver_ptr_->create_client2client_connection( auto* communicator = Communicator::GetInstance();
communicator->_worker_ptr->create_client2client_connection(
client2client_request_timeout_ms_, client2client_connect_timeout_ms_, client2client_request_timeout_ms_, client2client_connect_timeout_ms_,
client2client_max_retry_); client2client_max_retry_);
} }
...@@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync( ...@@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync(
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status, float scale_datanorm, std::vector<std::future<int32_t>>* push_sparse_status, float scale_datanorm,
int batch_size) { int batch_size) {
auto* communicator = Communicator::GetInstance(); auto place = platform::CPUPlace();
PADDLE_ENFORCE_EQ( std::vector<paddle::distributed::Region> regions;
communicator->Check(table_id), true, for (auto& t : var_names) {
platform::errors::InvalidArgument( Variable* var = scope.FindVar(t);
"can not find table: %s, please check your config", table_id)); CHECK(var != nullptr) << "var[" << t << "] not found";
communicator->Send(var_names, scope); LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
<< table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
<< g[tensor->numel() - 1];
}
auto* communicator =
dynamic_cast<AsyncCommunicator*>(Communicator::GetInstance());
auto push_status = communicator->_worker_ptr->push_dense(
regions.data(), regions.size(), table_id);
communicator->PushDensePostProcessing();
} }
void FleetWrapper::PushSparseVarsAsync( void FleetWrapper::PushSparseVarsAsync(
...@@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( ...@@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return; return;
} }
void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) { void FleetWrapper::PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, std::vector<const LoDTensor*>* inputs,
const LoDTensor* shows, const LoDTensor* clks,
std::vector<LoDTensor*>* outputs) {
int batch_size = -1;
for (auto* input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
CHECK(batch_size == cur_batch_size); // NOLINT
}
}
CHECK(batch_size > 0); // NOLINT
int show_size =
shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
CHECK(show_size == batch_size || show_size == 1);
int clk_size =
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1);
std::vector<float> g;
for (framework::LoDTensor* g_tensor : *outputs) {
float* g_ori = g_tensor->data<float>();
// no cvm
if (true) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g_ori, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim) *= batch_size;
}
size_t origin = g.size();
size_t add = g_tensor->numel();
g.resize(origin + add);
memcpy(g.data() + origin, g_tensor->data<float>(), add * sizeof(float));
}
std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values;
push_values.reserve(MAX_FEASIGN_NUM / 100);
size_t output_len = 0;
size_t input_idx = 0;
VLOG(2) << "fleet.cc::emb_dim: " << fea_dim;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
const int64_t* show_tensor = shows->data<int64_t>();
const int64_t* clk_tensor = clks->data<int64_t>();
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
for (int j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
++j, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[j]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
} else {
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
}
VLOG(1) << "output_len: " << output_len << " g.size(): " << g.size();
CHECK(output_len == g.size());
std::vector<float*> push_g_vec(input_idx, nullptr);
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
}
auto* communicator = Communicator::GetInstance(); auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, mode); PADDLE_ENFORCE_EQ(
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = communicator->_worker_ptr->push_sparse(
table_id, push_keys.data(), (const float**)push_g_vec.data(),
push_keys.size());
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, std::to_string(mode));
ret.wait(); ret.wait();
if (ret.get() != 0) { if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed"; LOG(ERROR) << "load model from path:" << path << " failed";
...@@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() { ...@@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() {
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) { MsgHandlerFunc handler) {
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pserver_ptr_=" << pserver_ptr_; auto* communicator = Communicator::GetInstance();
VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type,
return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler); handler);
} }
std::future<int32_t> FleetWrapper::SendClientToClientMsg( std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) { int msg_type, int to_client_id, const std::string& msg) {
return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, auto* communicator = Communicator::GetInstance();
return communicator->_worker_ptr->send_client2client_msg(msg_type,
to_client_id, msg); to_client_id, msg);
} }
......
...@@ -157,7 +157,12 @@ class FleetWrapper { ...@@ -157,7 +157,12 @@ class FleetWrapper {
const std::vector<std::string>& input_names, const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT std::vector<const LoDTensor*>* outputs); // NOLINT
void PushSparseFromTensorAsync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs,
const LoDTensor* shows,
const LoDTensor* clicks,
std::vector<LoDTensor*>* outputs);
// Push sparse variables to server in Async mode // Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names // Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status // Param<Out>: push_values, push_sparse_status
...@@ -200,7 +205,7 @@ class FleetWrapper { ...@@ -200,7 +205,7 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id); void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, load delta feature, which means load diff // mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const std::string& mode); void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature // mode = 0, load all feature
// mode = 1, load delta feature, which means load diff // mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path, void LoadModelOneTable(const uint64_t table_id, const std::string& path,
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "gflags/gflags.h" #include "gflags/gflags.h"
...@@ -87,7 +88,7 @@ void Communicator::InitBrpcClient( ...@@ -87,7 +88,7 @@ void Communicator::InitBrpcClient(
servers_ = host_sign_list.size(); servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_); _ps_env.set_ps_servers(&host_sign_list, servers_);
_worker_ptr = std::shared_ptr<paddle::distributed::PSClient>( _worker_ptr = std::unique_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(_ps_param)); paddle::distributed::PSClientFactory::create(_ps_param));
_worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env, _worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env,
trainer_id_); trainer_id_);
...@@ -95,6 +96,19 @@ void Communicator::InitBrpcClient( ...@@ -95,6 +96,19 @@ void Communicator::InitBrpcClient(
return; return;
} }
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
return res;
}
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope) { int table_id, Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense"); platform::RecordEvent record_event("Communicator->RpcRecvDense");
...@@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames, ...@@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
LoDTensor *tensor = var->GetMutable<LoDTensor>(); LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place()); << platform::is_gpu_place(tensor->place());
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) { if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor = LoDTensor *temp_tensor =
...@@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() { ...@@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1); MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
} }
} }
if (ctx.is_tensor_table) { if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get()); SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) { } else if (ctx.is_sparse) {
...@@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() { ...@@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() {
return; return;
} }
void AsyncCommunicator::PushDensePostProcessing() {
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
return;
}
void AsyncCommunicator::MainThread() { void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait"; VLOG(3) << "AsyncCommunicator MainThread start and wait";
...@@ -627,13 +654,13 @@ void AsyncCommunicator::Start() { ...@@ -627,13 +654,13 @@ void AsyncCommunicator::Start() {
} }
void AsyncCommunicator::Stop() { void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop"; VLOG(1) << "Communicator stop begin";
_worker_ptr->finalize_worker();
VLOG(0) << "Communicator finalize_worker done";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
} else { } else {
_worker_ptr->finalize_worker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) { if (recv_thread_) {
VLOG(1) << "stop recv thread"; VLOG(1) << "stop recv thread";
recv_thread_->join(); recv_thread_->join();
......
...@@ -245,6 +245,11 @@ class Communicator { ...@@ -245,6 +245,11 @@ class Communicator {
virtual void InitBrpcClient(const std::string &dist_desc, virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list); const std::vector<std::string> &host_sign_list);
virtual std::vector<uint64_t> GetClientInfo();
virtual int SetClients(std::vector<uint64_t> &host_sign_list); // NOLINT
// 1. recv dense param // 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames, virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope); int table_id, Scope *scope);
...@@ -271,6 +276,7 @@ class Communicator { ...@@ -271,6 +276,7 @@ class Communicator {
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
// note: only for pull dense param first before training
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx); virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0; virtual void Start() = 0;
...@@ -296,6 +302,13 @@ class Communicator { ...@@ -296,6 +302,13 @@ class Communicator {
rets.wait(); rets.wait();
} }
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->create_client2client_connection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
virtual void BarrierTriggerDecrement() {} virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {} virtual void BarrierTriggerReset(int init_counter) {}
...@@ -342,13 +355,13 @@ class Communicator { ...@@ -342,13 +355,13 @@ class Communicator {
PSClient *GetPsClient() { return _worker_ptr.get(); } PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() { std::unique_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return _worker_ptr; return std::move(_worker_ptr);
} }
RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; } RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker std::unique_ptr<PSClient> _worker_ptr; // pointer to worker
protected: protected:
bool running_ = false; bool running_ = false;
...@@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator { ...@@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator {
virtual void BarrierWeakUp() {} virtual void BarrierWeakUp() {}
void PushDensePostProcessing();
protected: protected:
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>> std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
...@@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator { ...@@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator {
Scope *recv_scope) override; Scope *recv_scope) override;
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override; void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id); void InitDense(std::vector<std::string> &varnames, int table_id); // NOLINT
void InitSparse(const std::string &var_name, int table_id); void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx); void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx); void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname); std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname, std::vector<int64_t> &sparse_ids, void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT
int table_id, int ep_idx); int table_id, int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx); void RecvSparse(const std::string &varname, int table_id, int ep_idx);
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr, void CommonDenseTable::create_initializer(const std::string& attr,
const std::string& name) { const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&"); auto slices = string::split_string<std::string>(attr, "&");
...@@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() {
auto common = _config.common(); auto common = _config.common();
int size = static_cast<int>(common.params().size()); int size = static_cast<int>(common.params().size());
values_.resize(size); values_.resize(size);
total_dim_ = 0;
for (int x = 0; x < size; ++x) { for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x]; auto& varname = common.params()[x];
auto& dim = common.dims()[x]; auto& dim = common.dims()[x];
...@@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() {
param_dim_ = dim; param_dim_ = dim;
param_idx_ = x; param_idx_ = x;
} }
auto& initializer = common.initializers()[x]; auto& initializer = common.initializers()[x];
total_dim_ += dim;
create_initializer(initializer, varname); create_initializer(initializer, varname);
values_[x].resize(dim); values_[x].resize(dim);
...@@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() { ...@@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() {
} }
} }
fixed_len_params_dim_ = 0;
for (int x = 0; x < size; ++x) {
auto& dim = common.dims()[x];
if (dim != param_dim_) {
fixed_len_params_dim_ += dim;
} else {
param_col_ids_.push_back(x);
}
}
if (_config.common().name() == "adam_d2sum") {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
}
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_); pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0; return 0;
} }
...@@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() { ...@@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() {
} else if (name == "adam") { } else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_); optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr); optimizer_->set_global_lr(_global_lr);
} else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use
} else if (name == "sum") { } else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_); optimizer_ = std::make_shared<DSUM>(common, &values_);
} else { } else {
...@@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { ...@@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
VLOG(2) << "debug CommonDenseTable::_push_dense done";
return 0;
}
int32_t CommonDenseTable::load(const std::string& path,
const std::string& param) {
if (param_dim_ <= 0) {
return 0;
}
std::string table_path = table_dir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) {
VLOG(1) << "load dense table file list: " << ff;
}
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
end_file_idx =
end_file_idx < file_list.size() ? end_file_idx : file_list.size() - 1;
VLOG(2) << "load dense table start_file_idx: " << start_file_idx
<< " end_file_idx: " << end_file_idx;
int load_param = atoi(param.c_str());
FsChannelConfig channel_config;
channel_config.converter = _value_accesor->converter(load_param).converter;
channel_config.deconverter =
_value_accesor->converter(load_param).deconverter;
bool is_read_failed = false;
int err_no = 0;
int retry_num = 0;
do {
is_read_failed = false;
try {
size_t dim_idx = 0;
float data_buffer[5];
float* data_buff_ptr = data_buffer;
std::string line_data;
int size = static_cast<int>(values_.size());
auto common = _config.common();
for (int i = start_file_idx; i < end_file_idx + 1; ++i) {
channel_config.path = file_list[i];
err_no = 0;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
size_t file_start_idx = start_dim_idx - i * dim_num_per_file;
// not all file contains param and the length of last file containing
// param may not equal to others
size_t file_dim_idx = 0;
for (; file_dim_idx < dim_num_per_file; ++file_dim_idx) {
if (read_channel->read_line(line_data) != 0) {
break;
}
if (dim_idx >= param_dim_) {
break;
}
if (file_dim_idx < file_start_idx) {
continue;
}
auto str_len =
paddle::string::str_to_float(line_data.data(), data_buff_ptr);
CHECK(str_len == param_col_ids_.size())
<< "expect " << param_col_ids_.size() << " float, but got "
<< str_len;
for (size_t col_idx = 0; col_idx < str_len; ++col_idx) {
if (param_col_ids_[col_idx] < 0) {
continue;
}
values_[param_col_ids_[col_idx]][dim_idx] = data_buffer[col_idx];
VLOG(2) << "CommonDenseTable::load param x: "
<< param_col_ids_[col_idx] << " y: " << dim_idx
<< " value: " << values_[param_col_ids_[col_idx]][dim_idx]
<< " line " << file_dim_idx;
}
++dim_idx;
}
read_channel->close();
VLOG(1) << "DownpourDenseTable load done " << channel_config.path
<< " file_start_idx: " << file_start_idx
<< " dim_idx: " << dim_idx;
if (err_no == -1) {
if (retry_num > FLAGS_pslib_table_save_max_retry_dense) {
LOG(ERROR) << "DownpourDenseTable load failed reach max limit!";
exit(-1);
}
++retry_num;
--i;
LOG(ERROR)
<< "DownpourDenseTable load failed after read , retry it! path:"
<< channel_config.path << ", retry_num=" << retry_num;
continue;
}
retry_num = 0;
start_dim_idx += file_dim_idx - file_start_idx;
LOG(INFO) << "DownpourDenseTable load success, path:"
<< channel_config.path;
}
} catch (...) {
is_read_failed = true;
LOG(ERROR) << "DownpourDenseTable load failed, retry it! path:"
<< channel_config.path;
}
} while (is_read_failed);
return 0; return 0;
} }
int32_t CommonDenseTable::save(const std::string& path,
const std::string& param) {
int save_param = atoi(param.c_str());
uint32_t feasign_size;
VLOG(0) << "CommonDenseTable::save path " << path;
FsChannelConfig channel_config;
if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx);
} else {
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx);
}
_afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->converter(save_param).converter;
channel_config.deconverter =
_value_accesor->converter(save_param).deconverter;
bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param(
param_dim_, std::vector<std::string>());
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
std::ostringstream os;
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
VLOG(0) << "CommonDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear();
os.str("");
os << values_[x][y];
if (dim == param_dim_) {
result_buffer_param[y].emplace_back(std::move(os.str()));
} else {
result_buffer_fixed_len.emplace_back(std::move(os.str()));
}
}
}
int retry_num = 0;
int err_no = 0;
do {
err_no = 0;
is_write_failed = false;
feasign_size = 0;
// 40M
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
"path:"
<< channel_config.path << ", retry_num=" << retry_num;
break;
}
}
++feasign_size;
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed after write, retry it! "
<< "path:" << channel_config.path
<< ", retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num >
paddle::distributed::FLAGS_pslib_table_save_max_retry_dense) {
LOG(ERROR) << "DownpourDenseTable save failed reach max limit!";
exit(-1);
}
} while (is_write_failed);
LOG(INFO) << "DownpourDenseTable save success, path:" << channel_config.path;
return feasign_size;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -32,33 +32,26 @@ class DenseOptimizer; ...@@ -32,33 +32,26 @@ class DenseOptimizer;
class CommonDenseTable : public DenseTable { class CommonDenseTable : public DenseTable {
public: public:
explicit CommonDenseTable() {} CommonDenseTable() {}
virtual ~CommonDenseTable() {} virtual ~CommonDenseTable() {}
virtual int32_t initialize() override; int32_t initialize() override;
virtual int32_t initialize_shard() override { return 0; } int32_t initialize_shard() override { return 0; }
virtual void create_initializer(const std::string& attr, virtual void create_initializer(const std::string& attr,
const std::string& name); const std::string& name);
virtual int32_t initialize_value(); virtual int32_t initialize_value();
virtual int32_t initialize_optimizer(); virtual int32_t initialize_optimizer();
virtual int32_t pull_dense(float* pull_values, size_t num) override; int32_t pull_dense(float* pull_values, size_t num) override;
virtual int32_t push_dense_param(const float* values, size_t num) override; int32_t push_dense_param(const float* values, size_t num) override;
virtual int32_t push_dense(const float* values, size_t num) override; int32_t push_dense(const float* values, size_t num) override;
virtual int32_t pour() override; int32_t pour() override;
virtual int32_t set_global_lr(float* lr) override; int32_t set_global_lr(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override { int32_t load(const std::string& path, const std::string& param) override;
VLOG(0) << "WARNING: dense variables will load on No.0 trainer"; int32_t save(const std::string& path, const std::string& param) override;
return 0;
}
int32_t save(const std::string& path, const std::string& param) override { int32_t flush() override { return 0; }
VLOG(0) << "WARNING: dense variables will save on No.0 trainer"; int32_t shrink(const std::string& param) override { return 0; }
return 0; void clear() override { return; }
}
virtual int32_t flush() override { return 0; }
virtual int32_t shrink(const std::string& param) override { return 0; }
virtual void clear() override { return; }
protected: protected:
int32_t _push_dense(const float* values, size_t num); int32_t _push_dense(const float* values, size_t num);
...@@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable { ...@@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable {
ReservoirValue<float> pull_reservoir_; ReservoirValue<float> pull_reservoir_;
std::unordered_map<std::string, Initializer*> initializers_; std::unordered_map<std::string, Initializer*> initializers_;
std::unordered_map<std::string, int> names_index_; std::unordered_map<std::string, int> names_index_;
int total_dim_ = 0;
int fixed_len_params_dim_ = 0; // used for save/load
std::vector<int> param_col_ids_; // used for save/load
}; };
} // namespace distributed } // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册