From 84b0ec970b549c95c627f8a1be38590454c4b7a6 Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Mon, 15 Nov 2021 11:21:29 +0800 Subject: [PATCH] Accessor 20211112 2 (#37181) --- paddle/fluid/distributed/fleet.cc | 182 ++++++++++++-- paddle/fluid/distributed/fleet.h | 9 +- .../fluid/distributed/service/communicator.cc | 35 ++- .../fluid/distributed/service/communicator.h | 26 +- .../distributed/table/common_dense_table.cc | 222 ++++++++++++++++++ .../distributed/table/common_dense_table.h | 36 ++- 6 files changed, 461 insertions(+), 49 deletions(-) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 9e2a0b35224..4a3dfc3e485 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) { std::vector FleetWrapper::GetClientsInfo() { VLOG(3) << "Going to get client info"; - return pserver_ptr_->get_client_info(); - return std::vector(); + auto* communicator = Communicator::GetInstance(); + std::vector res = communicator->GetClientInfo(); + return res; } void FleetWrapper::CreateClient2ClientConnection() { - VLOG(3) << "Going to create client2client connection"; - pserver_ptr_->create_client2client_connection( + VLOG(1) << "Going to create client2client connection"; + auto* communicator = Communicator::GetInstance(); + communicator->_worker_ptr->create_client2client_connection( client2client_request_timeout_ms_, client2client_connect_timeout_ms_, client2client_max_retry_); } @@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync( const std::vector& var_names, std::vector>* push_sparse_status, float scale_datanorm, int batch_size) { - auto* communicator = Communicator::GetInstance(); - PADDLE_ENFORCE_EQ( - communicator->Check(table_id), true, - platform::errors::InvalidArgument( - "can not find table: %s, please check your config", table_id)); - communicator->Send(var_names, scope); + auto place = platform::CPUPlace(); + std::vector regions; + for (auto& t : var_names) { + Variable* var = scope.FindVar(t); + CHECK(var != nullptr) << "var[" << t << "] not found"; + LoDTensor* tensor = var->GetMutable(); + float* g = tensor->mutable_data(place); + paddle::distributed::Region reg(g, tensor->numel()); + regions.emplace_back(std::move(reg)); + VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id " + << table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] " + << g[tensor->numel() - 1]; + } + + auto* communicator = + dynamic_cast(Communicator::GetInstance()); + auto push_status = communicator->_worker_ptr->push_dense( + regions.data(), regions.size(), table_id); + + communicator->PushDensePostProcessing(); } void FleetWrapper::PushSparseVarsAsync( @@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync( return; } -void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) { +void FleetWrapper::PushSparseFromTensorAsync( + const uint64_t table_id, int fea_dim, uint64_t padding_id, + platform::Place place, std::vector* inputs, + const LoDTensor* shows, const LoDTensor* clks, + std::vector* outputs) { + int batch_size = -1; + for (auto* input : *inputs) { + int cur_batch_size = + input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0]; + if (batch_size == -1) { + batch_size = cur_batch_size; + } else { + CHECK(batch_size == cur_batch_size); // NOLINT + } + } + CHECK(batch_size > 0); // NOLINT + + int show_size = + shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0]; + CHECK(show_size == batch_size || show_size == 1); + int clk_size = + clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0]; + CHECK(clk_size == batch_size || clk_size == 1); + + std::vector g; + for (framework::LoDTensor* g_tensor : *outputs) { + float* g_ori = g_tensor->data(); + // no cvm + if (true) { // TODO(zhaocaibei123): add config + // scale_sparse_gradient_with_batch_size_ + Eigen::Map< + Eigen::Matrix> + g_mat(g_ori, g_tensor->numel() / fea_dim, fea_dim); + g_mat.rightCols(fea_dim) *= batch_size; + } + + size_t origin = g.size(); + size_t add = g_tensor->numel(); + g.resize(origin + add); + + memcpy(g.data() + origin, g_tensor->data(), add * sizeof(float)); + } + + std::vector push_keys; + push_keys.reserve(MAX_FEASIGN_NUM / 100); + std::vector> push_values; + push_values.reserve(MAX_FEASIGN_NUM / 100); + size_t output_len = 0; + size_t input_idx = 0; + + VLOG(2) << "fleet.cc::emb_dim: " << fea_dim; + + // TODO(zhaocaibei123): check type of show/clk is int? float? uint64? + // const long int* show_tensor = shows->data(); + // const long int* clk_tensor = clks->data(); + const int64_t* show_tensor = shows->data(); + const int64_t* clk_tensor = clks->data(); + + for (size_t index = 0; index < inputs->size(); ++index) { + const framework::LoDTensor* tensor = inputs->at(index); + const int64_t* ids = tensor->data(); + size_t len = tensor->numel(); + + if (tensor->lod().size() > 0) { + for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) { + for (int j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1]; + ++j, output_len += fea_dim) { + uint64_t real_id = static_cast(ids[j]); + if (real_id == padding_id) { + continue; + } + push_keys.emplace_back(real_id); + push_values.emplace_back(fea_dim + 3); + // slot show clk grad... consistent with CtrCommonPushValue defined in + // ctr_accessor.h + push_values.back()[0] = 2; // TODO(zhaocaibei123): slot + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); + + float* data = push_values.back().data() + 3; + + memcpy(data, g.data() + output_len, sizeof(float) * fea_dim); + + ++input_idx; + } + } + } else { + for (size_t i = 0; i < len; ++i, output_len += fea_dim) { + uint64_t real_id = static_cast(ids[i]); + if (real_id == padding_id) { + continue; + } + push_keys.emplace_back(real_id); + push_values.emplace_back(fea_dim + 3); + // slot show clk grad... consistent with CtrCommonPushValue defined in + // ctr_accessor.h + push_values.back()[0] = 2; // TODO(zhaocaibei123): slot + push_values.back()[1] = + (i >= show_size ? 1 : static_cast(show_tensor[i])); + push_values.back()[2] = + (i >= clk_size ? 0 : static_cast(clk_tensor[i])); + + float* data = push_values.back().data() + 3; + + memcpy(data, g.data() + output_len, sizeof(float) * fea_dim); + + ++input_idx; + } + } + } + VLOG(1) << "output_len: " << output_len << " g.size(): " << g.size(); + CHECK(output_len == g.size()); + + std::vector push_g_vec(input_idx, nullptr); + + for (auto i = 0u; i < push_keys.size(); ++i) { + push_g_vec[i] = push_values.at(i).data(); + } + auto* communicator = Communicator::GetInstance(); - auto ret = communicator->_worker_ptr->load(path, mode); - // auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode)); + PADDLE_ENFORCE_EQ( + communicator->Check(table_id), true, + platform::errors::InvalidArgument( + "can not find table: %s, please check your config", table_id)); + auto status = communicator->_worker_ptr->push_sparse( + table_id, push_keys.data(), (const float**)push_g_vec.data(), + push_keys.size()); +} + +void FleetWrapper::LoadModel(const std::string& path, const int mode) { + auto* communicator = Communicator::GetInstance(); + auto ret = communicator->_worker_ptr->load(path, std::to_string(mode)); ret.wait(); if (ret.get() != 0) { LOG(ERROR) << "load model from path:" << path << " failed"; @@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() { int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler) { - VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; - VLOG(3) << "pserver_ptr_=" << pserver_ptr_; - VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr; - return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type, + VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler"; + auto* communicator = Communicator::GetInstance(); + return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type, handler); } std::future FleetWrapper::SendClientToClientMsg( int msg_type, int to_client_id, const std::string& msg) { - return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type, + auto* communicator = Communicator::GetInstance(); + return communicator->_worker_ptr->send_client2client_msg(msg_type, to_client_id, msg); } diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 1b2bde85de0..6d9ce01535e 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -157,7 +157,12 @@ class FleetWrapper { const std::vector& input_names, std::vector* inputs, // NOLINT std::vector* outputs); // NOLINT - + void PushSparseFromTensorAsync(const uint64_t table_id, int fea_dim, + uint64_t padding_id, platform::Place place, + std::vector* inputs, + const LoDTensor* shows, + const LoDTensor* clicks, + std::vector* outputs); // Push sparse variables to server in Async mode // Param: scope, table_id, fea_keys, sparse_grad_names // Param: push_values, push_sparse_status @@ -200,7 +205,7 @@ class FleetWrapper { void PrintTableStat(const uint64_t table_id); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff - void LoadModel(const std::string& path, const std::string& mode); + void LoadModel(const std::string& path, const int mode); // mode = 0, load all feature // mode = 1, load delta feature, which means load diff void LoadModelOneTable(const uint64_t table_id, const std::string& path, diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc index 00fae6e276e..f51ffbcf811 100644 --- a/paddle/fluid/distributed/service/communicator.cc +++ b/paddle/fluid/distributed/service/communicator.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/distributed/service/communicator.h" + #include #include "gflags/gflags.h" @@ -87,7 +88,7 @@ void Communicator::InitBrpcClient( servers_ = host_sign_list.size(); _ps_env = paddle::distributed::PaddlePSEnvironment(); _ps_env.set_ps_servers(&host_sign_list, servers_); - _worker_ptr = std::shared_ptr( + _worker_ptr = std::unique_ptr( paddle::distributed::PSClientFactory::create(_ps_param)); _worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env, trainer_id_); @@ -95,6 +96,19 @@ void Communicator::InitBrpcClient( return; } +std::vector Communicator::GetClientInfo() { + std::vector res = _ps_env.get_client_info(); + for (auto rr : res) { + VLOG(2) << "Communicator::GetClientInfo " << rr; + } + return res; +} + +int Communicator::SetClients(std::vector &host_sign_list) { + int node = host_sign_list.size(); + return _ps_env.set_ps_clients(host_sign_list.data(), node); +} + void Communicator::RpcRecvDense(const std::vector &varnames, int table_id, Scope *scope) { platform::RecordEvent record_event("Communicator->RpcRecvDense"); @@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector &varnames, LoDTensor *tensor = var->GetMutable(); VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? " << platform::is_gpu_place(tensor->place()); + + float *temp_recv_data = tensor->mutable_data(platform::CPUPlace()); + VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id " + << table_id << " Temp_data[0] " << temp_recv_data[0] + << " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1]; if (platform::is_gpu_place(tensor->place())) { #ifdef PADDLE_WITH_CUDA LoDTensor *temp_tensor = @@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() { MergeVars(var_name, vars[i], send_scope_.get(), 1); } } + if (ctx.is_tensor_table) { SendGlobalStep(ctx, merged_var_num, send_scope_.get()); } else if (ctx.is_sparse) { @@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() { return; } +void AsyncCommunicator::PushDensePostProcessing() { + if (independent_recv_) { + grad_num_.fetch_add(1, std::memory_order_relaxed); + } + return; +} + void AsyncCommunicator::MainThread() { VLOG(3) << "AsyncCommunicator MainThread start and wait"; @@ -627,13 +654,13 @@ void AsyncCommunicator::Start() { } void AsyncCommunicator::Stop() { - VLOG(1) << "Communicator stop"; - _worker_ptr->finalize_worker(); - VLOG(0) << "Communicator finalize_worker done"; + VLOG(1) << "Communicator stop begin"; running_ = false; if (!communicator_) { VLOG(0) << "Communicator is not inited, do nothing"; } else { + _worker_ptr->finalize_worker(); + VLOG(1) << "client finalize_worker done"; if (recv_thread_) { VLOG(1) << "stop recv thread"; recv_thread_->join(); diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h index 01ec3c617d5..8714918dc8e 100644 --- a/paddle/fluid/distributed/service/communicator.h +++ b/paddle/fluid/distributed/service/communicator.h @@ -245,6 +245,11 @@ class Communicator { virtual void InitBrpcClient(const std::string &dist_desc, const std::vector &host_sign_list); + + virtual std::vector GetClientInfo(); + + virtual int SetClients(std::vector &host_sign_list); // NOLINT + // 1. recv dense param virtual void RpcRecvDense(const std::vector &varnames, int table_id, Scope *scope); @@ -271,6 +276,7 @@ class Communicator { virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx); + // note: only for pull dense param first before training virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx); virtual void Start() = 0; @@ -296,6 +302,13 @@ class Communicator { rets.wait(); } + virtual void CreateC2CConnection(int pserver_timeout_ms, + int pserver_connect_timeout_ms, + int max_retry) { + _worker_ptr->create_client2client_connection( + pserver_timeout_ms, pserver_connect_timeout_ms, max_retry); + } + virtual void BarrierTriggerDecrement() {} virtual void BarrierTriggerReset(int init_counter) {} @@ -342,13 +355,13 @@ class Communicator { PSClient *GetPsClient() { return _worker_ptr.get(); } - std::shared_ptr GetPsClientPtr() { - return _worker_ptr; + std::unique_ptr GetPsClientPtr() { + return std::move(_worker_ptr); } RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; } - std::shared_ptr _worker_ptr; // pointer to worker + std::unique_ptr _worker_ptr; // pointer to worker protected: bool running_ = false; @@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator { virtual void BarrierWeakUp() {} + void PushDensePostProcessing(); + protected: std::unordered_map>>> @@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator { Scope *recv_scope) override; void InitParams(const RecvCtxMap &recv_varname_to_ctx) override; - void InitDense(std::vector &varnames, int table_id); + void InitDense(std::vector &varnames, int table_id); // NOLINT void InitSparse(const std::string &var_name, int table_id); void SendDense(const CommContext &send_ctx); void RecvDense(const CommContext &send_ctx); std::vector MergeSparseIds(const std::string &varname); - void SendSparse(const std::string &varname, std::vector &sparse_ids, + void SendSparse(const std::string &varname, + std::vector &sparse_ids, // NOLINT int table_id, int ep_idx); void RecvSparse(const std::string &varname, int table_id, int ep_idx); diff --git a/paddle/fluid/distributed/table/common_dense_table.cc b/paddle/fluid/distributed/table/common_dense_table.cc index 8d8b43b3740..b34b143a3ce 100644 --- a/paddle/fluid/distributed/table/common_dense_table.cc +++ b/paddle/fluid/distributed/table/common_dense_table.cc @@ -19,6 +19,8 @@ namespace paddle { namespace distributed { +int FLAGS_pslib_table_save_max_retry_dense = 3; + void CommonDenseTable::create_initializer(const std::string& attr, const std::string& name) { auto slices = string::split_string(attr, "&"); @@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() { auto common = _config.common(); int size = static_cast(common.params().size()); values_.resize(size); + total_dim_ = 0; for (int x = 0; x < size; ++x) { auto& varname = common.params()[x]; auto& dim = common.dims()[x]; @@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() { param_dim_ = dim; param_idx_ = x; } + auto& initializer = common.initializers()[x]; + total_dim_ += dim; create_initializer(initializer, varname); values_[x].resize(dim); @@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() { } } + fixed_len_params_dim_ = 0; + for (int x = 0; x < size; ++x) { + auto& dim = common.dims()[x]; + if (dim != param_dim_) { + fixed_len_params_dim_ += dim; + } else { + param_col_ids_.push_back(x); + } + } + if (_config.common().name() == "adam_d2sum") { + param_col_ids_.insert(param_col_ids_.begin() + 1, -1); + } + + VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_ + << " fixed_len_params_dim: " << fixed_len_params_dim_; + pull_reservoir_ = ReservoirValue(param_dim_); return 0; } @@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() { } else if (name == "adam") { optimizer_ = std::make_shared(common, &values_); optimizer_->set_global_lr(_global_lr); + } else if (name == "adam_d2sum") { + optimizer_ = std::make_shared(common, &values_); + // optimizer_->set_global_lr(_global_lr); //no use } else if (name == "sum") { optimizer_ = std::make_shared(common, &values_); } else { @@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } + VLOG(2) << "debug CommonDenseTable::_push_dense done"; + return 0; +} + +int32_t CommonDenseTable::load(const std::string& path, + const std::string& param) { + if (param_dim_ <= 0) { + return 0; + } + std::string table_path = table_dir(path); + auto file_list = _afs_client.list(table_path); + std::sort(file_list.begin(), file_list.end()); + for (auto ff : file_list) { + VLOG(1) << "load dense table file list: " << ff; + } + size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1; + // param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1 + size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1; + size_t start_dim_idx = dim_num_per_shard * _shard_idx; + size_t start_file_idx = start_dim_idx / dim_num_per_file; + size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file; + end_file_idx = + end_file_idx < file_list.size() ? end_file_idx : file_list.size() - 1; + VLOG(2) << "load dense table start_file_idx: " << start_file_idx + << " end_file_idx: " << end_file_idx; + + int load_param = atoi(param.c_str()); + FsChannelConfig channel_config; + + channel_config.converter = _value_accesor->converter(load_param).converter; + channel_config.deconverter = + _value_accesor->converter(load_param).deconverter; + bool is_read_failed = false; + int err_no = 0; + int retry_num = 0; + do { + is_read_failed = false; + try { + size_t dim_idx = 0; + float data_buffer[5]; + float* data_buff_ptr = data_buffer; + std::string line_data; + int size = static_cast(values_.size()); + auto common = _config.common(); + + for (int i = start_file_idx; i < end_file_idx + 1; ++i) { + channel_config.path = file_list[i]; + err_no = 0; + auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); + size_t file_start_idx = start_dim_idx - i * dim_num_per_file; + + // not all file contains param and the length of last file containing + // param may not equal to others + size_t file_dim_idx = 0; + for (; file_dim_idx < dim_num_per_file; ++file_dim_idx) { + if (read_channel->read_line(line_data) != 0) { + break; + } + if (dim_idx >= param_dim_) { + break; + } + if (file_dim_idx < file_start_idx) { + continue; + } + auto str_len = + paddle::string::str_to_float(line_data.data(), data_buff_ptr); + CHECK(str_len == param_col_ids_.size()) + << "expect " << param_col_ids_.size() << " float, but got " + << str_len; + for (size_t col_idx = 0; col_idx < str_len; ++col_idx) { + if (param_col_ids_[col_idx] < 0) { + continue; + } + values_[param_col_ids_[col_idx]][dim_idx] = data_buffer[col_idx]; + VLOG(2) << "CommonDenseTable::load param x: " + << param_col_ids_[col_idx] << " y: " << dim_idx + << " value: " << values_[param_col_ids_[col_idx]][dim_idx] + << " line " << file_dim_idx; + } + ++dim_idx; + } + read_channel->close(); + VLOG(1) << "DownpourDenseTable load done " << channel_config.path + << " file_start_idx: " << file_start_idx + << " dim_idx: " << dim_idx; + if (err_no == -1) { + if (retry_num > FLAGS_pslib_table_save_max_retry_dense) { + LOG(ERROR) << "DownpourDenseTable load failed reach max limit!"; + exit(-1); + } + ++retry_num; + --i; + LOG(ERROR) + << "DownpourDenseTable load failed after read , retry it! path:" + << channel_config.path << ", retry_num=" << retry_num; + continue; + } + retry_num = 0; + start_dim_idx += file_dim_idx - file_start_idx; + LOG(INFO) << "DownpourDenseTable load success, path:" + << channel_config.path; + } + } catch (...) { + is_read_failed = true; + LOG(ERROR) << "DownpourDenseTable load failed, retry it! path:" + << channel_config.path; + } + } while (is_read_failed); return 0; } +int32_t CommonDenseTable::save(const std::string& path, + const std::string& param) { + int save_param = atoi(param.c_str()); + uint32_t feasign_size; + VLOG(0) << "CommonDenseTable::save path " << path; + + FsChannelConfig channel_config; + if (_config.compress_in_save()) { + channel_config.path = paddle::string::format_string( + "%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx); + } else { + channel_config.path = paddle::string::format_string( + "%s/part-%03d", table_dir(path).c_str(), _shard_idx); + } + _afs_client.remove(channel_config.path); + channel_config.converter = _value_accesor->converter(save_param).converter; + channel_config.deconverter = + _value_accesor->converter(save_param).deconverter; + + bool is_write_failed = false; + std::vector> result_buffer_param( + param_dim_, std::vector()); + std::vector result_buffer_fixed_len; + result_buffer_fixed_len.reserve(fixed_len_params_dim_); + + auto common = _config.common(); + int size = static_cast(common.params().size()); + std::ostringstream os; + for (int x = 0; x < size; ++x) { + auto& varname = common.params()[x]; + auto& dim = common.dims()[x]; + VLOG(0) << "CommonDenseTable::save dim " << x << " size: " << dim; + for (int y = 0; y < dim; ++y) { + os.clear(); + os.str(""); + os << values_[x][y]; + if (dim == param_dim_) { + result_buffer_param[y].emplace_back(std::move(os.str())); + } else { + result_buffer_fixed_len.emplace_back(std::move(os.str())); + } + } + } + + int retry_num = 0; + int err_no = 0; + do { + err_no = 0; + is_write_failed = false; + feasign_size = 0; + // 40M + auto write_channel = + _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); + for (auto& t : result_buffer_param) { + if (_config.common().name() == "adam_d2sum") { + t.insert(t.begin() + 1, "0"); // avg_w + } + if (0 != + write_channel->write_line(paddle::string::join_strings(t, ' '))) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "DownpourDenseTable save failed, retry it! " + "path:" + << channel_config.path << ", retry_num=" << retry_num; + break; + } + } + + ++feasign_size; + write_channel->close(); + if (err_no == -1) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "DownpourDenseTable save failed after write, retry it! " + << "path:" << channel_config.path + << ", retry_num=" << retry_num; + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + if (retry_num > + paddle::distributed::FLAGS_pslib_table_save_max_retry_dense) { + LOG(ERROR) << "DownpourDenseTable save failed reach max limit!"; + exit(-1); + } + } while (is_write_failed); + LOG(INFO) << "DownpourDenseTable save success, path:" << channel_config.path; + return feasign_size; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/table/common_dense_table.h b/paddle/fluid/distributed/table/common_dense_table.h index 74366f03588..1fa0226decd 100644 --- a/paddle/fluid/distributed/table/common_dense_table.h +++ b/paddle/fluid/distributed/table/common_dense_table.h @@ -32,33 +32,26 @@ class DenseOptimizer; class CommonDenseTable : public DenseTable { public: - explicit CommonDenseTable() {} + CommonDenseTable() {} virtual ~CommonDenseTable() {} - virtual int32_t initialize() override; - virtual int32_t initialize_shard() override { return 0; } + int32_t initialize() override; + int32_t initialize_shard() override { return 0; } virtual void create_initializer(const std::string& attr, const std::string& name); virtual int32_t initialize_value(); virtual int32_t initialize_optimizer(); - virtual int32_t pull_dense(float* pull_values, size_t num) override; - virtual int32_t push_dense_param(const float* values, size_t num) override; - virtual int32_t push_dense(const float* values, size_t num) override; - virtual int32_t pour() override; - virtual int32_t set_global_lr(float* lr) override; + int32_t pull_dense(float* pull_values, size_t num) override; + int32_t push_dense_param(const float* values, size_t num) override; + int32_t push_dense(const float* values, size_t num) override; + int32_t pour() override; + int32_t set_global_lr(float* lr) override; - int32_t load(const std::string& path, const std::string& param) override { - VLOG(0) << "WARNING: dense variables will load on No.0 trainer"; - return 0; - } + int32_t load(const std::string& path, const std::string& param) override; + int32_t save(const std::string& path, const std::string& param) override; - int32_t save(const std::string& path, const std::string& param) override { - VLOG(0) << "WARNING: dense variables will save on No.0 trainer"; - return 0; - } - - virtual int32_t flush() override { return 0; } - virtual int32_t shrink(const std::string& param) override { return 0; } - virtual void clear() override { return; } + int32_t flush() override { return 0; } + int32_t shrink(const std::string& param) override { return 0; } + void clear() override { return; } protected: int32_t _push_dense(const float* values, size_t num); @@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable { ReservoirValue pull_reservoir_; std::unordered_map initializers_; std::unordered_map names_index_; + int total_dim_ = 0; + int fixed_len_params_dim_ = 0; // used for save/load + std::vector param_col_ids_; // used for save/load }; } // namespace distributed -- GitLab