未验证 提交 84b0ec97 编写于 作者: Z zhaocaibei123 提交者: GitHub

Accessor 20211112 2 (#37181)

上级 12339fa0
......@@ -135,13 +135,15 @@ uint64_t FleetWrapper::RunServer(const std::string& ip, uint32_t port) {
std::vector<uint64_t> FleetWrapper::GetClientsInfo() {
VLOG(3) << "Going to get client info";
return pserver_ptr_->get_client_info();
return std::vector<uint64_t>();
auto* communicator = Communicator::GetInstance();
std::vector<uint64_t> res = communicator->GetClientInfo();
return res;
}
void FleetWrapper::CreateClient2ClientConnection() {
VLOG(3) << "Going to create client2client connection";
pserver_ptr_->create_client2client_connection(
VLOG(1) << "Going to create client2client connection";
auto* communicator = Communicator::GetInstance();
communicator->_worker_ptr->create_client2client_connection(
client2client_request_timeout_ms_, client2client_connect_timeout_ms_,
client2client_max_retry_);
}
......@@ -370,12 +372,26 @@ void FleetWrapper::PushDenseVarsAsync(
const std::vector<std::string>& var_names,
std::vector<std::future<int32_t>>* push_sparse_status, float scale_datanorm,
int batch_size) {
auto* communicator = Communicator::GetInstance();
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
communicator->Send(var_names, scope);
auto place = platform::CPUPlace();
std::vector<paddle::distributed::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::distributed::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg));
VLOG(3) << "FleetWrapper::PushDenseVarsAsync Var " << t << " talbe_id "
<< table_id << " Temp_data[0] " << g[0] << " Temp_data[-1] "
<< g[tensor->numel() - 1];
}
auto* communicator =
dynamic_cast<AsyncCommunicator*>(Communicator::GetInstance());
auto push_status = communicator->_worker_ptr->push_dense(
regions.data(), regions.size(), table_id);
communicator->PushDensePostProcessing();
}
void FleetWrapper::PushSparseVarsAsync(
......@@ -417,10 +433,140 @@ void FleetWrapper::PushSparseFromTensorWithLabelAsync(
return;
}
void FleetWrapper::LoadModel(const std::string& path, const std::string& mode) {
void FleetWrapper::PushSparseFromTensorAsync(
const uint64_t table_id, int fea_dim, uint64_t padding_id,
platform::Place place, std::vector<const LoDTensor*>* inputs,
const LoDTensor* shows, const LoDTensor* clks,
std::vector<LoDTensor*>* outputs) {
int batch_size = -1;
for (auto* input : *inputs) {
int cur_batch_size =
input->lod().size() ? input->lod()[0].size() - 1 : input->dims()[0];
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
CHECK(batch_size == cur_batch_size); // NOLINT
}
}
CHECK(batch_size > 0); // NOLINT
int show_size =
shows->lod().size() ? shows->lod()[0].size() - 1 : shows->dims()[0];
CHECK(show_size == batch_size || show_size == 1);
int clk_size =
clks->lod().size() ? clks->lod()[0].size() - 1 : clks->dims()[0];
CHECK(clk_size == batch_size || clk_size == 1);
std::vector<float> g;
for (framework::LoDTensor* g_tensor : *outputs) {
float* g_ori = g_tensor->data<float>();
// no cvm
if (true) { // TODO(zhaocaibei123): add config
// scale_sparse_gradient_with_batch_size_
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g_ori, g_tensor->numel() / fea_dim, fea_dim);
g_mat.rightCols(fea_dim) *= batch_size;
}
size_t origin = g.size();
size_t add = g_tensor->numel();
g.resize(origin + add);
memcpy(g.data() + origin, g_tensor->data<float>(), add * sizeof(float));
}
std::vector<uint64_t> push_keys;
push_keys.reserve(MAX_FEASIGN_NUM / 100);
std::vector<std::vector<float>> push_values;
push_values.reserve(MAX_FEASIGN_NUM / 100);
size_t output_len = 0;
size_t input_idx = 0;
VLOG(2) << "fleet.cc::emb_dim: " << fea_dim;
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
const int64_t* show_tensor = shows->data<int64_t>();
const int64_t* clk_tensor = clks->data<int64_t>();
for (size_t index = 0; index < inputs->size(); ++index) {
const framework::LoDTensor* tensor = inputs->at(index);
const int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) {
for (int j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
++j, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[j]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
} else {
for (size_t i = 0; i < len; ++i, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) {
continue;
}
push_keys.emplace_back(real_id);
push_values.emplace_back(fea_dim + 3);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values.back()[0] = 2; // TODO(zhaocaibei123): slot
push_values.back()[1] =
(i >= show_size ? 1 : static_cast<float>(show_tensor[i]));
push_values.back()[2] =
(i >= clk_size ? 0 : static_cast<float>(clk_tensor[i]));
float* data = push_values.back().data() + 3;
memcpy(data, g.data() + output_len, sizeof(float) * fea_dim);
++input_idx;
}
}
}
VLOG(1) << "output_len: " << output_len << " g.size(): " << g.size();
CHECK(output_len == g.size());
std::vector<float*> push_g_vec(input_idx, nullptr);
for (auto i = 0u; i < push_keys.size(); ++i) {
push_g_vec[i] = push_values.at(i).data();
}
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, mode);
// auto ret = pserver_ptr_->_worker_ptr->load(path, std::to_string(mode));
PADDLE_ENFORCE_EQ(
communicator->Check(table_id), true,
platform::errors::InvalidArgument(
"can not find table: %s, please check your config", table_id));
auto status = communicator->_worker_ptr->push_sparse(
table_id, push_keys.data(), (const float**)push_g_vec.data(),
push_keys.size());
}
void FleetWrapper::LoadModel(const std::string& path, const int mode) {
auto* communicator = Communicator::GetInstance();
auto ret = communicator->_worker_ptr->load(path, std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model from path:" << path << " failed";
......@@ -562,16 +708,16 @@ void FleetWrapper::ClientFlush() {
int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
MsgHandlerFunc handler) {
VLOG(3) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
VLOG(3) << "pserver_ptr_=" << pserver_ptr_;
VLOG(3) << "_worker_ptr=" << pserver_ptr_->_worker_ptr;
return pserver_ptr_->_worker_ptr->registe_client2client_msg_handler(msg_type,
VLOG(1) << "calling FleetWrapper::RegisterClientToClientMsgHandler";
auto* communicator = Communicator::GetInstance();
return communicator->_worker_ptr->registe_client2client_msg_handler(msg_type,
handler);
}
std::future<int32_t> FleetWrapper::SendClientToClientMsg(
int msg_type, int to_client_id, const std::string& msg) {
return pserver_ptr_->_worker_ptr->send_client2client_msg(msg_type,
auto* communicator = Communicator::GetInstance();
return communicator->_worker_ptr->send_client2client_msg(msg_type,
to_client_id, msg);
}
......
......@@ -157,7 +157,12 @@ class FleetWrapper {
const std::vector<std::string>& input_names,
std::vector<const LoDTensor*>* inputs, // NOLINT
std::vector<const LoDTensor*>* outputs); // NOLINT
void PushSparseFromTensorAsync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
std::vector<const LoDTensor*>* inputs,
const LoDTensor* shows,
const LoDTensor* clicks,
std::vector<LoDTensor*>* outputs);
// Push sparse variables to server in Async mode
// Param<In>: scope, table_id, fea_keys, sparse_grad_names
// Param<Out>: push_values, push_sparse_status
......@@ -200,7 +205,7 @@ class FleetWrapper {
void PrintTableStat(const uint64_t table_id);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModel(const std::string& path, const std::string& mode);
void LoadModel(const std::string& path, const int mode);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
void LoadModelOneTable(const uint64_t table_id, const std::string& path,
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/distributed/service/communicator.h"
#include <google/protobuf/text_format.h>
#include "gflags/gflags.h"
......@@ -87,7 +88,7 @@ void Communicator::InitBrpcClient(
servers_ = host_sign_list.size();
_ps_env = paddle::distributed::PaddlePSEnvironment();
_ps_env.set_ps_servers(&host_sign_list, servers_);
_worker_ptr = std::shared_ptr<paddle::distributed::PSClient>(
_worker_ptr = std::unique_ptr<paddle::distributed::PSClient>(
paddle::distributed::PSClientFactory::create(_ps_param));
_worker_ptr->configure(_ps_param, _dense_pull_regions, _ps_env,
trainer_id_);
......@@ -95,6 +96,19 @@ void Communicator::InitBrpcClient(
return;
}
std::vector<uint64_t> Communicator::GetClientInfo() {
std::vector<uint64_t> res = _ps_env.get_client_info();
for (auto rr : res) {
VLOG(2) << "Communicator::GetClientInfo " << rr;
}
return res;
}
int Communicator::SetClients(std::vector<uint64_t> &host_sign_list) {
int node = host_sign_list.size();
return _ps_env.set_ps_clients(host_sign_list.data(), node);
}
void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope) {
platform::RecordEvent record_event("Communicator->RpcRecvDense");
......@@ -130,6 +144,11 @@ void Communicator::RpcRecvDense(const std::vector<std::string> &varnames,
LoDTensor *tensor = var->GetMutable<LoDTensor>();
VLOG(1) << "AsyncCommunicator::RecvNoBarrier Var " << t << " On gpu? "
<< platform::is_gpu_place(tensor->place());
float *temp_recv_data = tensor->mutable_data<float>(platform::CPUPlace());
VLOG(1) << "AsyncCommunicator::RpcRecvDense Var " << t << " table_id "
<< table_id << " Temp_data[0] " << temp_recv_data[0]
<< " Temp_data[-1] " << temp_recv_data[tensor->numel() - 1];
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
LoDTensor *temp_tensor =
......@@ -519,6 +538,7 @@ void AsyncCommunicator::SendByCommunicator() {
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
}
}
if (ctx.is_tensor_table) {
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
} else if (ctx.is_sparse) {
......@@ -547,6 +567,13 @@ void AsyncCommunicator::SendByCommunicator() {
return;
}
void AsyncCommunicator::PushDensePostProcessing() {
if (independent_recv_) {
grad_num_.fetch_add(1, std::memory_order_relaxed);
}
return;
}
void AsyncCommunicator::MainThread() {
VLOG(3) << "AsyncCommunicator MainThread start and wait";
......@@ -627,13 +654,13 @@ void AsyncCommunicator::Start() {
}
void AsyncCommunicator::Stop() {
VLOG(1) << "Communicator stop";
_worker_ptr->finalize_worker();
VLOG(0) << "Communicator finalize_worker done";
VLOG(1) << "Communicator stop begin";
running_ = false;
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
} else {
_worker_ptr->finalize_worker();
VLOG(1) << "client finalize_worker done";
if (recv_thread_) {
VLOG(1) << "stop recv thread";
recv_thread_->join();
......
......@@ -245,6 +245,11 @@ class Communicator {
virtual void InitBrpcClient(const std::string &dist_desc,
const std::vector<std::string> &host_sign_list);
virtual std::vector<uint64_t> GetClientInfo();
virtual int SetClients(std::vector<uint64_t> &host_sign_list); // NOLINT
// 1. recv dense param
virtual void RpcRecvDense(const std::vector<std::string> &varnames,
int table_id, Scope *scope);
......@@ -271,6 +276,7 @@ class Communicator {
virtual void InitParams(const RecvCtxMap &recv_varname_to_ctx);
// note: only for pull dense param first before training
virtual void PullDense(const RecvCtxMap &recv_varname_to_ctx);
virtual void Start() = 0;
......@@ -296,6 +302,13 @@ class Communicator {
rets.wait();
}
virtual void CreateC2CConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
int max_retry) {
_worker_ptr->create_client2client_connection(
pserver_timeout_ms, pserver_connect_timeout_ms, max_retry);
}
virtual void BarrierTriggerDecrement() {}
virtual void BarrierTriggerReset(int init_counter) {}
......@@ -342,13 +355,13 @@ class Communicator {
PSClient *GetPsClient() { return _worker_ptr.get(); }
std::shared_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return _worker_ptr;
std::unique_ptr<paddle::distributed::PSClient> GetPsClientPtr() {
return std::move(_worker_ptr);
}
RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
std::shared_ptr<PSClient> _worker_ptr; // pointer to worker
std::unique_ptr<PSClient> _worker_ptr; // pointer to worker
protected:
bool running_ = false;
......@@ -434,6 +447,8 @@ class AsyncCommunicator : public Communicator {
virtual void BarrierWeakUp() {}
void PushDensePostProcessing();
protected:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
......@@ -542,14 +557,15 @@ class GeoCommunicator : public AsyncCommunicator {
Scope *recv_scope) override;
void InitParams(const RecvCtxMap &recv_varname_to_ctx) override;
void InitDense(std::vector<std::string> &varnames, int table_id);
void InitDense(std::vector<std::string> &varnames, int table_id); // NOLINT
void InitSparse(const std::string &var_name, int table_id);
void SendDense(const CommContext &send_ctx);
void RecvDense(const CommContext &send_ctx);
std::vector<int64_t> MergeSparseIds(const std::string &varname);
void SendSparse(const std::string &varname, std::vector<int64_t> &sparse_ids,
void SendSparse(const std::string &varname,
std::vector<int64_t> &sparse_ids, // NOLINT
int table_id, int ep_idx);
void RecvSparse(const std::string &varname, int table_id, int ep_idx);
......
......@@ -19,6 +19,8 @@
namespace paddle {
namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3;
void CommonDenseTable::create_initializer(const std::string& attr,
const std::string& name) {
auto slices = string::split_string<std::string>(attr, "&");
......@@ -56,6 +58,7 @@ int32_t CommonDenseTable::initialize_value() {
auto common = _config.common();
int size = static_cast<int>(common.params().size());
values_.resize(size);
total_dim_ = 0;
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
......@@ -63,7 +66,9 @@ int32_t CommonDenseTable::initialize_value() {
param_dim_ = dim;
param_idx_ = x;
}
auto& initializer = common.initializers()[x];
total_dim_ += dim;
create_initializer(initializer, varname);
values_[x].resize(dim);
......@@ -74,6 +79,22 @@ int32_t CommonDenseTable::initialize_value() {
}
}
fixed_len_params_dim_ = 0;
for (int x = 0; x < size; ++x) {
auto& dim = common.dims()[x];
if (dim != param_dim_) {
fixed_len_params_dim_ += dim;
} else {
param_col_ids_.push_back(x);
}
}
if (_config.common().name() == "adam_d2sum") {
param_col_ids_.insert(param_col_ids_.begin() + 1, -1);
}
VLOG(1) << "CommonDenseTable::initialize_value total dim: " << total_dim_
<< " fixed_len_params_dim: " << fixed_len_params_dim_;
pull_reservoir_ = ReservoirValue<float>(param_dim_);
return 0;
}
......@@ -89,6 +110,9 @@ int32_t CommonDenseTable::initialize_optimizer() {
} else if (name == "adam") {
optimizer_ = std::make_shared<DAdam>(common, &values_);
optimizer_->set_global_lr(_global_lr);
} else if (name == "adam_d2sum") {
optimizer_ = std::make_shared<DAdamD2Sum>(common, &values_);
// optimizer_->set_global_lr(_global_lr); //no use
} else if (name == "sum") {
optimizer_ = std::make_shared<DSUM>(common, &values_);
} else {
......@@ -162,8 +186,206 @@ int32_t CommonDenseTable::_push_dense(const float* values, size_t num) {
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
VLOG(2) << "debug CommonDenseTable::_push_dense done";
return 0;
}
int32_t CommonDenseTable::load(const std::string& path,
const std::string& param) {
if (param_dim_ <= 0) {
return 0;
}
std::string table_path = table_dir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto ff : file_list) {
VLOG(1) << "load dense table file list: " << ff;
}
size_t dim_num_per_file = _config.accessor().fea_dim() / file_list.size() + 1;
// param_dim_ in last node != _config.accesor().fea_dim() / _shard_num + 1
size_t dim_num_per_shard = _value_accesor->fea_dim() / _shard_num + 1;
size_t start_dim_idx = dim_num_per_shard * _shard_idx;
size_t start_file_idx = start_dim_idx / dim_num_per_file;
size_t end_file_idx = (start_dim_idx + param_dim_) / dim_num_per_file;
end_file_idx =
end_file_idx < file_list.size() ? end_file_idx : file_list.size() - 1;
VLOG(2) << "load dense table start_file_idx: " << start_file_idx
<< " end_file_idx: " << end_file_idx;
int load_param = atoi(param.c_str());
FsChannelConfig channel_config;
channel_config.converter = _value_accesor->converter(load_param).converter;
channel_config.deconverter =
_value_accesor->converter(load_param).deconverter;
bool is_read_failed = false;
int err_no = 0;
int retry_num = 0;
do {
is_read_failed = false;
try {
size_t dim_idx = 0;
float data_buffer[5];
float* data_buff_ptr = data_buffer;
std::string line_data;
int size = static_cast<int>(values_.size());
auto common = _config.common();
for (int i = start_file_idx; i < end_file_idx + 1; ++i) {
channel_config.path = file_list[i];
err_no = 0;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
size_t file_start_idx = start_dim_idx - i * dim_num_per_file;
// not all file contains param and the length of last file containing
// param may not equal to others
size_t file_dim_idx = 0;
for (; file_dim_idx < dim_num_per_file; ++file_dim_idx) {
if (read_channel->read_line(line_data) != 0) {
break;
}
if (dim_idx >= param_dim_) {
break;
}
if (file_dim_idx < file_start_idx) {
continue;
}
auto str_len =
paddle::string::str_to_float(line_data.data(), data_buff_ptr);
CHECK(str_len == param_col_ids_.size())
<< "expect " << param_col_ids_.size() << " float, but got "
<< str_len;
for (size_t col_idx = 0; col_idx < str_len; ++col_idx) {
if (param_col_ids_[col_idx] < 0) {
continue;
}
values_[param_col_ids_[col_idx]][dim_idx] = data_buffer[col_idx];
VLOG(2) << "CommonDenseTable::load param x: "
<< param_col_ids_[col_idx] << " y: " << dim_idx
<< " value: " << values_[param_col_ids_[col_idx]][dim_idx]
<< " line " << file_dim_idx;
}
++dim_idx;
}
read_channel->close();
VLOG(1) << "DownpourDenseTable load done " << channel_config.path
<< " file_start_idx: " << file_start_idx
<< " dim_idx: " << dim_idx;
if (err_no == -1) {
if (retry_num > FLAGS_pslib_table_save_max_retry_dense) {
LOG(ERROR) << "DownpourDenseTable load failed reach max limit!";
exit(-1);
}
++retry_num;
--i;
LOG(ERROR)
<< "DownpourDenseTable load failed after read , retry it! path:"
<< channel_config.path << ", retry_num=" << retry_num;
continue;
}
retry_num = 0;
start_dim_idx += file_dim_idx - file_start_idx;
LOG(INFO) << "DownpourDenseTable load success, path:"
<< channel_config.path;
}
} catch (...) {
is_read_failed = true;
LOG(ERROR) << "DownpourDenseTable load failed, retry it! path:"
<< channel_config.path;
}
} while (is_read_failed);
return 0;
}
int32_t CommonDenseTable::save(const std::string& path,
const std::string& param) {
int save_param = atoi(param.c_str());
uint32_t feasign_size;
VLOG(0) << "CommonDenseTable::save path " << path;
FsChannelConfig channel_config;
if (_config.compress_in_save()) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d.gz", table_dir(path).c_str(), _shard_idx);
} else {
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_dir(path).c_str(), _shard_idx);
}
_afs_client.remove(channel_config.path);
channel_config.converter = _value_accesor->converter(save_param).converter;
channel_config.deconverter =
_value_accesor->converter(save_param).deconverter;
bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param(
param_dim_, std::vector<std::string>());
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
std::ostringstream os;
for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x];
auto& dim = common.dims()[x];
VLOG(0) << "CommonDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear();
os.str("");
os << values_[x][y];
if (dim == param_dim_) {
result_buffer_param[y].emplace_back(std::move(os.str()));
} else {
result_buffer_fixed_len.emplace_back(std::move(os.str()));
}
}
}
int retry_num = 0;
int err_no = 0;
do {
err_no = 0;
is_write_failed = false;
feasign_size = 0;
// 40M
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
"path:"
<< channel_config.path << ", retry_num=" << retry_num;
break;
}
}
++feasign_size;
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed after write, retry it! "
<< "path:" << channel_config.path
<< ", retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num >
paddle::distributed::FLAGS_pslib_table_save_max_retry_dense) {
LOG(ERROR) << "DownpourDenseTable save failed reach max limit!";
exit(-1);
}
} while (is_write_failed);
LOG(INFO) << "DownpourDenseTable save success, path:" << channel_config.path;
return feasign_size;
}
} // namespace distributed
} // namespace paddle
......@@ -32,33 +32,26 @@ class DenseOptimizer;
class CommonDenseTable : public DenseTable {
public:
explicit CommonDenseTable() {}
CommonDenseTable() {}
virtual ~CommonDenseTable() {}
virtual int32_t initialize() override;
virtual int32_t initialize_shard() override { return 0; }
int32_t initialize() override;
int32_t initialize_shard() override { return 0; }
virtual void create_initializer(const std::string& attr,
const std::string& name);
virtual int32_t initialize_value();
virtual int32_t initialize_optimizer();
virtual int32_t pull_dense(float* pull_values, size_t num) override;
virtual int32_t push_dense_param(const float* values, size_t num) override;
virtual int32_t push_dense(const float* values, size_t num) override;
virtual int32_t pour() override;
virtual int32_t set_global_lr(float* lr) override;
int32_t pull_dense(float* pull_values, size_t num) override;
int32_t push_dense_param(const float* values, size_t num) override;
int32_t push_dense(const float* values, size_t num) override;
int32_t pour() override;
int32_t set_global_lr(float* lr) override;
int32_t load(const std::string& path, const std::string& param) override {
VLOG(0) << "WARNING: dense variables will load on No.0 trainer";
return 0;
}
int32_t load(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override;
int32_t save(const std::string& path, const std::string& param) override {
VLOG(0) << "WARNING: dense variables will save on No.0 trainer";
return 0;
}
virtual int32_t flush() override { return 0; }
virtual int32_t shrink(const std::string& param) override { return 0; }
virtual void clear() override { return; }
int32_t flush() override { return 0; }
int32_t shrink(const std::string& param) override { return 0; }
void clear() override { return; }
protected:
int32_t _push_dense(const float* values, size_t num);
......@@ -74,6 +67,9 @@ class CommonDenseTable : public DenseTable {
ReservoirValue<float> pull_reservoir_;
std::unordered_map<std::string, Initializer*> initializers_;
std::unordered_map<std::string, int> names_index_;
int total_dim_ = 0;
int fixed_len_params_dim_ = 0; // used for save/load
std::vector<int> param_col_ids_; // used for save/load
};
} // namespace distributed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册