未验证 提交 703487c6 编写于 作者: Z zhaocaibei123 提交者: GitHub

memory sparse table (#36909)

上级 41a09113
......@@ -37,7 +37,9 @@ set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPI
set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_accessor SRCS ctr_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
cc_library(table SRCS table.cc DEPS common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost ctr_accessor)
cc_library(table SRCS table.cc DEPS memory_sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
......@@ -221,15 +221,6 @@ class DAdamD2Sum : public DenseOptimizer {
void update(const float* update_values, size_t num, int begin,
int end) override {
auto update_numel = end - begin;
/*
// for debug
std::cout << "before update:\n";
for (int i = 0; i < 3; ++ i) {
std::cout << "param: " << i << " " << *(param+begin+i) <<
"grad: " << *(update_values+begin+i) << "\n";
}*/
std::vector<float> grad, grad2, scale;
grad.resize(update_numel);
grad2.resize(update_numel);
......@@ -240,57 +231,21 @@ class DAdamD2Sum : public DenseOptimizer {
blas.VCOPY(update_numel, update_values + begin, grad.data());
blas.VCOPY(update_numel, update_values + begin, grad2.data());
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "copy grad: " << i << " " << *(grad.data()+begin+i) <<
"copy grad2: " << *(grad2.data()+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "d2sum before: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}*/
// d2sum
blas.SCAL(update_numel, ada_decay_rate[0], ada_d2sum + begin);
ADD<float>(update_numel, ada_d2sum + begin, 1, ada_d2sum + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "d2sum update: " << i << " " << *(ada_d2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "g2sum before: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}*/
// g2sum
blas.SCAL(update_numel, ada_decay_rate[0], ada_g2sum + begin);
blas.VSQUARE(update_numel, grad2.data(), grad2.data());
blas.VADD(update_numel, ada_g2sum + begin, grad2.data(), ada_g2sum + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "g2sum update: " << i << " " << *(ada_g2sum+begin+i) << "\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "mom before: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}*/
// mom
blas.SCAL(update_numel, mom_decay_rate[0], mom_velocity + begin);
blas.SCAL(update_numel, 1 - mom_decay_rate[0], grad.data());
blas.VADD(update_numel, mom_velocity + begin, grad.data(),
mom_velocity + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "mom update: " << i << " " << *(mom_velocity+begin+i) <<
"\n";
}
for (int i = 0; i < 3; ++ i) {
std::cout << "scale before: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
// scale
float* scale_ = scale.data();
blas.VDIV(update_numel, ada_g2sum + begin, ada_d2sum + begin, scale_);
......@@ -298,30 +253,13 @@ class DAdamD2Sum : public DenseOptimizer {
DIV<float>(update_numel, 1 + ada_epsilon[0], scale_, scale_);
SQRT<float>(update_numel, scale_, scale_);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas.SCAL(update_numel, learning_rate[0], scale_);
// TODO(zhaocaibei123): check if there exists elementwise_multiply in blas
// TODO(zhaocaibei123): blas.VMUL
ELE_MUL<float>(update_numel, scale_, mom_velocity + begin, scale_);
/*
for (int i = 0; i < 3; ++ i) {
std::cout << "scale update2: " << i << " " << *(scale.data()+begin+i) <<
"\n";
}*/
blas.VSUB(update_numel, param + begin, scale_, param + begin);
/*
for (int i = 0; i < end-begin; ++ i) {
std::cout << "param update " << i << " " << *(param+begin+i) << "\n";
}*/
}
float* learning_rate;
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include "paddle/fluid/distributed/table/memory_sparse_table.h"
#include "paddle/fluid/framework/io/fs.h"
#include "boost/lexical_cast.hpp"
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace distributed {
// TODO(zhaocaibei123): configure
bool FLAGS_pslib_create_value_when_push = false;
int FLAGS_pslib_table_save_max_retry = 3;
bool FLAGS_pslib_enable_create_feasign_randomly = false;
int32_t MemorySparseTable::initialize() {
shards_task_pool_.resize(task_pool_size_);
for (int i = 0; i < shards_task_pool_.size(); ++i) {
shards_task_pool_[i].reset(new ::ThreadPool(1));
}
initialize_value();
VLOG(0) << "initalize MemorySparseTable succ";
return 0;
}
int32_t MemorySparseTable::initialize_value() {
sparse_table_shard_num_ = static_cast<int>(_config.shard_num());
avg_local_shard_num_ =
SparseTable::sparse_local_shard_num(sparse_table_shard_num_, _shard_num);
real_local_shard_num_ = avg_local_shard_num_;
if (real_local_shard_num_ * (_shard_idx + 1) > sparse_table_shard_num_) {
real_local_shard_num_ =
sparse_table_shard_num_ - real_local_shard_num_ * _shard_idx;
real_local_shard_num_ =
real_local_shard_num_ < 0 ? 0 : real_local_shard_num_;
}
VLOG(1) << "memory sparse table avg_local_shard_num_: "
<< avg_local_shard_num_
<< " real_local_shard_num_: " << real_local_shard_num_;
shard_values_.reserve(real_local_shard_num_);
for (int x = 0; x < real_local_shard_num_; ++x) {
auto shard = std::make_shared<SparseTableShard>();
shard_values_.emplace_back(shard);
}
return 0;
}
int32_t MemorySparseTable::load(const std::string& path,
const std::string& param) {
std::string table_path = table_dir(path);
auto file_list = _afs_client.list(table_path);
std::sort(file_list.begin(), file_list.end());
for (auto file : file_list) {
VLOG(1) << "MemorySparseTable::load() file list: " << file;
}
int load_param = atoi(param.c_str());
auto expect_shard_num = sparse_table_shard_num_;
if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num;
return -1;
}
if (file_list.size() == 0) {
LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
return -1;
}
size_t file_start_idx = _shard_idx * avg_local_shard_num_;
size_t feature_value_size = _value_accesor->size() / sizeof(float);
// TODO(zhaocaibei123): multi-thread
// int thread_num = shard_values_.size() < 15 ? shard_values_.size() : 15;
// omp_set_num_threads(thread_num);
// #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < real_local_shard_num_; ++i) {
FsChannelConfig channel_config;
channel_config.path = file_list[file_start_idx + i];
VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
<< " into local shard " << i;
channel_config.converter = _value_accesor->converter(load_param).converter;
channel_config.deconverter =
_value_accesor->converter(load_param).deconverter;
bool is_read_failed = false;
int retry_num = 0;
int err_no = 0;
do {
is_read_failed = false;
err_no = 0;
std::string line_data;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL;
auto& shard = shard_values_[i];
try {
while (read_channel->read_line(line_data) == 0 &&
line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto* value = shard->Init(key);
value->resize(feature_value_size);
int parse_size =
_value_accesor->parse_from_string(++end, value->data());
value->resize(parse_size);
// for debug
for (int ii = 0; ii < parse_size; ++ii) {
VLOG(2) << "MemorySparseTable::load key: " << key << " value " << ii
<< ": " << value->data()[ii] << " local_shard: " << i;
}
}
read_channel->close();
if (err_no == -1) {
++retry_num;
is_read_failed = true;
LOG(ERROR)
<< "MemorySparseTable load failed after read, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
}
} catch (...) {
++retry_num;
is_read_failed = true;
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pslib_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
} while (is_read_failed);
}
LOG(INFO) << "MemorySparseTable load success, path from "
<< file_list[file_start_idx] << " to "
<< file_list[file_start_idx + real_local_shard_num_ - 1];
return 0;
}
int32_t MemorySparseTable::load_local_fs(const std::string& path,
const std::string& param) {
std::string table_path = table_dir(path);
auto file_list = paddle::framework::localfs_list(table_path);
int load_param = atoi(param.c_str());
auto expect_shard_num = sparse_table_shard_num_;
if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num;
return -1;
}
if (file_list.size() == 0) {
LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
return -1;
}
size_t file_start_idx = _shard_idx * avg_local_shard_num_;
size_t feature_value_size = _value_accesor->size() / sizeof(float);
// int thread_num = shard_values_.size() < 15 ? shard_values_.size() : 15;
// omp_set_num_threads(thread_num);
// #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < real_local_shard_num_; ++i) {
bool is_read_failed = false;
int retry_num = 0;
int err_no = 0;
do {
is_read_failed = false;
err_no = 0;
std::string line_data;
std::ifstream file(file_list[file_start_idx + i]);
char* end = NULL;
auto& shard = shard_values_[i];
try {
while (std::getline(file, line_data) && line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto* value = shard->Init(key);
value->resize(feature_value_size);
int parse_size =
_value_accesor->parse_from_string(++end, value->data());
value->resize(parse_size);
// value->shrink_to_fit();
}
file.close();
if (err_no == -1) {
++retry_num;
is_read_failed = true;
LOG(ERROR)
<< "MemorySparseTable load failed after read, retry it! path:"
<< file_list[file_start_idx + i] << " , retry_num=" << retry_num;
}
} catch (...) {
++retry_num;
is_read_failed = true;
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< file_list[file_start_idx + i]
<< " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pslib_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
} while (is_read_failed);
}
LOG(INFO) << "MemorySparseTable load success, path from "
<< file_list[file_start_idx] << " to "
<< file_list[file_start_idx + real_local_shard_num_ - 1];
return 0;
}
int32_t MemorySparseTable::save(const std::string& dirname,
const std::string& param) {
VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
// int thread_num = shard_values_.size() < 20 ? shard_values_.size() : 20;
std::atomic<uint32_t> feasign_size_all{0};
size_t file_start_idx = avg_local_shard_num_ * _shard_idx;
// TODO(zhaocaibei123): openmp
// omp_set_num_threads(thread_num);
// #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < real_local_shard_num_; ++i) {
FsChannelConfig channel_config;
if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx,
file_start_idx + i);
} else {
channel_config.path =
paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
_shard_idx, file_start_idx + i);
}
channel_config.converter = _value_accesor->converter(save_param).converter;
channel_config.deconverter =
_value_accesor->converter(save_param).deconverter;
bool is_write_failed = false;
int feasign_size = 0;
int retry_num = 0;
int err_no = 0;
auto& shard = shard_values_[i];
do {
err_no = 0;
feasign_size = 0;
is_write_failed = false;
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& table : shard->values_) {
for (auto& value : table) {
if (_value_accesor->save(value.second->data(), save_param)) {
std::string format_value = _value_accesor->parse_to_string(
value.second->data(), value.second->size());
if (0 !=
write_channel->write_line(paddle::string::format_string(
"%lu %s", value.first, format_value.c_str()))) {
++retry_num;
is_write_failed = true;
LOG(ERROR)
<< "MemorySparseTable save prefix failed, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
break;
}
++feasign_size;
}
}
}
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR)
<< "MemorySparseTable save prefix failed after write, retry it! "
<< "path:" << channel_config.path << " , retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num > paddle::distributed::FLAGS_pslib_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
exit(-1);
}
} while (is_write_failed);
feasign_size_all += feasign_size;
for (auto& table : shard->values_) {
for (auto& value : table) {
_value_accesor->update_stat_after_save(value.second->data(),
save_param);
}
}
LOG(INFO) << "MemorySparseTable save prefix success, path: "
<< channel_config.path;
}
// int32 may overflow need to change return value
return 0;
}
int32_t MemorySparseTable::save_local_fs(const std::string& dirname,
const std::string& param,
const std::string& prefix) {
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = table_dir(dirname);
int feasign_cnt = 0;
size_t file_start_idx = avg_local_shard_num_ * _shard_idx;
for (size_t i = 0; i < real_local_shard_num_; ++i) {
feasign_cnt = 0;
auto& shard = shard_values_[i];
std::string file_name = paddle::string::format_string(
"%s/part-%s-%03d-%05d", table_path.c_str(), prefix.c_str(), _shard_idx,
file_start_idx + i);
std::ofstream os;
os.open(file_name);
for (auto& table : shard->values_) {
for (auto& value : table) {
if (_value_accesor->save(value.second->data(), save_param)) {
std::string format_value = _value_accesor->parse_to_string(
value.second->data(), value.second->size());
std::string out_line = paddle::string::format_string(
"%lu %s\n", value.first, format_value.c_str());
// VLOG(2) << out_line.c_str();
os.write(out_line.c_str(), sizeof(char) * out_line.size());
++feasign_cnt;
}
}
}
os.close();
LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
<< "feasign_cnt: " << feasign_cnt;
}
return 0;
}
std::pair<int64_t, int64_t> MemorySparseTable::print_table_stat() {
int64_t feasign_size = 0;
int64_t mf_size = 0;
for (auto& shard : shard_values_) {
for (auto& table : shard->values_) {
feasign_size += table.size();
}
}
return {feasign_size, mf_size};
}
int32_t MemorySparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) {
std::vector<std::future<int>> tasks(real_local_shard_num_);
const size_t value_size = _value_accesor->size() / sizeof(float);
size_t mf_value_size = _value_accesor->mf_size() / sizeof(float);
size_t select_value_size = _value_accesor->select_size() / sizeof(float);
// std::atomic<uint32_t> missed_keys{0};
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
real_local_shard_num_);
size_t num = pull_value.numel_;
for (size_t i = 0; i < num; ++i) {
int shard_id = (pull_value.feasigns_[i] % sparse_table_shard_num_) %
avg_local_shard_num_;
task_keys[shard_id].push_back({pull_value.feasigns_[i], i});
}
for (int shard_id = 0; shard_id < real_local_shard_num_; ++shard_id) {
tasks[shard_id] =
shards_task_pool_[shard_id % shards_task_pool_.size()]->enqueue(
[this, shard_id, &task_keys, value_size, pull_values, mf_value_size,
select_value_size]() -> int {
auto& local_shard = shard_values_[shard_id];
float data_buffer[value_size]; // NOLINT
float* data_buffer_ptr = data_buffer;
auto& keys = task_keys[shard_id];
for (size_t i = 0; i < keys.size(); i++) {
uint64_t key = keys[i].first;
auto itr = local_shard->Find(key);
size_t data_size = value_size - mf_value_size;
if (itr == local_shard->end()) {
// ++missed_keys;
if (FLAGS_pslib_create_value_when_push) {
memset(data_buffer, 0, sizeof(float) * data_size);
} else {
auto* feature_value = local_shard->Init(key);
feature_value->resize(data_size);
float* data_ptr = feature_value->data();
_value_accesor->create(&data_buffer_ptr, 1);
memcpy(data_ptr, data_buffer_ptr,
data_size * sizeof(float));
}
} else {
data_size = itr->second->size();
memcpy(data_buffer_ptr, itr->second->data(),
data_size * sizeof(float));
}
for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
data_buffer[mf_idx] = 0.0;
}
auto offset = keys[i].second;
float* select_data = pull_values + select_value_size * offset;
_value_accesor->select(&select_data,
(const float**)&data_buffer_ptr, 1);
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t MemorySparseTable::pull_sparse_ptr(char** pull_values,
const uint64_t* keys, size_t num) {
return 0;
}
int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
const float* values, size_t num) {
std::vector<std::future<int>> tasks(real_local_shard_num_);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
real_local_shard_num_);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % sparse_table_shard_num_) % avg_local_shard_num_;
task_keys[shard_id].push_back({keys[i], i});
}
const size_t value_col = _value_accesor->size() / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float);
for (size_t shard_id = 0; shard_id < real_local_shard_num_; ++shard_id) {
tasks[shard_id] = shards_task_pool_[shard_id % task_pool_size_]->enqueue(
[this, shard_id, value_col, mf_value_col, update_value_col, values,
&task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = shard_values_[shard_id];
float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second;
const float* update_data =
values + push_data_idx * update_value_col;
auto itr = local_shard->Find(key);
if (itr == local_shard->end()) {
VLOG(0) << "sparse table push_sparse: " << key << "not found!";
if (FLAGS_pslib_enable_create_feasign_randomly &&
!_value_accesor->create_value(1, update_data)) {
continue;
}
auto value_size = value_col - mf_value_col;
auto* feature_value = local_shard->Init(key);
feature_value->resize(value_size);
_value_accesor->create(&data_buffer_ptr, 1);
memcpy(feature_value->data(), data_buffer_ptr,
value_size * sizeof(float));
itr = local_shard->Find(key);
} else {
VLOG(2) << "sparse table debug push_sparse: " << key << " found!";
}
auto* feature_value = itr->second;
float* value_data = feature_value->data();
size_t value_size = feature_value->size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update
_value_accesor->update(&value_data, &update_data, 1);
} else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
_value_accesor->update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->need_extend_mf(data_buffer)) {
feature_value->resize(value_col);
value_data = feature_value->data();
_value_accesor->create(&value_data, 1);
}
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
}
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
const float** values, size_t num) {
_push_sparse(keys, values, num);
return 0;
}
int32_t MemorySparseTable::_push_sparse(const uint64_t* keys,
const float** values, size_t num) {
std::vector<std::future<int>> tasks(real_local_shard_num_);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
real_local_shard_num_);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % sparse_table_shard_num_) % avg_local_shard_num_;
task_keys[shard_id].push_back({keys[i], i});
}
size_t value_col = _value_accesor->size() / sizeof(float);
size_t mf_value_col = _value_accesor->mf_size() / sizeof(float);
size_t update_value_col = _value_accesor->update_size() / sizeof(float);
for (int shard_id = 0; shard_id < real_local_shard_num_; ++shard_id) {
tasks[shard_id] = shards_task_pool_[shard_id % task_pool_size_]->enqueue(
[this, shard_id, value_col, mf_value_col, update_value_col, values,
&task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = shard_values_[shard_id];
float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second;
const float* update_data = values[push_data_idx];
auto itr = local_shard->Find(key);
if (itr == local_shard->end()) {
if (FLAGS_pslib_enable_create_feasign_randomly &&
!_value_accesor->create_value(1, update_data)) {
continue;
}
auto value_size = value_col - mf_value_col;
auto* feature_value = local_shard->Init(key);
feature_value->resize(value_size);
_value_accesor->create(&data_buffer_ptr, 1);
memcpy(feature_value->data(), data_buffer_ptr,
value_size * sizeof(float));
itr = local_shard->Find(key);
}
auto* feature_value = itr->second;
float* value_data = feature_value->data();
size_t value_size = feature_value->size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update
_value_accesor->update(&value_data, &update_data, 1);
} else {
// 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, value_data, value_size * sizeof(float));
_value_accesor->update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->need_extend_mf(data_buffer)) {
feature_value->resize(value_col);
value_data = feature_value->data();
_value_accesor->create(&value_data, 1);
}
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
}
}
return 0;
});
}
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait();
}
return 0;
}
int32_t MemorySparseTable::flush() { return 0; }
int32_t MemorySparseTable::shrink(const std::string& param) {
VLOG(0) << "MemorySparseTable::shrink";
// TODO(zhaocaibei123): implement with multi-thread
for (int shard_id = 0; shard_id < real_local_shard_num_; ++shard_id) {
// shrink
auto& shard = shard_values_[shard_id];
for (auto& table : shard->values_) {
for (auto iter = table.begin(); iter != table.end();) {
if (_value_accesor->shrink(iter->second->data())) {
butil::return_object(iter->second);
iter = table.erase(iter);
VLOG(1) << "shrink erase key: " << iter->first;
} else {
++iter;
}
}
}
}
return 0;
}
void MemorySparseTable::clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <assert.h>
#include <pthread.h>
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/common_table.h"
#include "paddle/fluid/distributed/table/depends/feature_value.h"
#include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX ".shard"
namespace paddle {
namespace distributed {
class MemorySparseTable : public SparseTable {
public:
MemorySparseTable() {}
virtual ~MemorySparseTable() {}
// unused method begin
virtual int32_t pull_dense(float* pull_values, size_t num) { return 0; }
virtual int32_t push_dense_param(const float* values, size_t num) {
return 0;
}
virtual int32_t push_dense(const float* values, size_t num) { return 0; }
// unused method end
virtual int32_t initialize();
virtual int32_t initialize_shard() { return 0; }
virtual int32_t initialize_value();
virtual int32_t load(const std::string& path, const std::string& param);
virtual int32_t save(const std::string& path, const std::string& param);
int32_t load_local_fs(const std::string& path, const std::string& param);
int32_t save_local_fs(const std::string& path, const std::string& param,
const std::string& prefix);
virtual std::pair<int64_t, int64_t> print_table_stat();
virtual int32_t pull_sparse(float* values, const PullSparseValue& pull_value);
virtual int32_t pull_sparse_ptr(char** pull_values, const uint64_t* keys,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float* values,
size_t num);
virtual int32_t push_sparse(const uint64_t* keys, const float** values,
size_t num);
virtual int32_t flush();
virtual int32_t shrink(const std::string& param);
virtual void clear();
protected:
virtual int32_t _push_sparse(const uint64_t* keys, const float** values,
size_t num);
protected:
const int task_pool_size_ = 24;
size_t avg_local_shard_num_;
size_t real_local_shard_num_;
size_t sparse_table_shard_num_;
std::vector<std::shared_ptr<::ThreadPool>> shards_task_pool_;
std::vector<std::shared_ptr<SparseTableShard>> shard_values_;
};
} // namespace distributed
} // namespace paddle
......@@ -24,6 +24,8 @@
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/table/ssd_sparse_table.h"
#endif
#include "paddle/fluid/distributed/table/ctr_accessor.h"
#include "paddle/fluid/distributed/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/table/tensor_accessor.h"
#include "paddle/fluid/distributed/table/tensor_table.h"
......@@ -40,7 +42,13 @@ REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, StdAdaGradSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdamSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseNaiveSGDRule);
REGISTER_PSCORE_CLASS(SparseValueSGDRule, SparseAdaGradSGDRule);
int32_t TableManager::initialize() {
static bool initialized = false;
......@@ -58,6 +66,11 @@ int32_t Table::initialize(const TableParameter &config,
LOG(WARNING) << "Table accessor initialize failed";
return -1;
}
if (_afs_client.initialize(fs_config) != 0) {
LOG(WARNING) << "Table fs_client initialize failed";
// return -1;
}
return initialize();
}
......@@ -67,6 +80,9 @@ int32_t Table::initialize_accessor() {
<< _config.table_id();
return -1;
}
LOG(INFO) << "accessor initializing: table_id: " << _config.table_id()
<< ", accessor_name: " << _config.accessor().accessor_class();
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) {
......
......@@ -20,6 +20,7 @@
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/distributed/common/afs_warpper.h"
#include "paddle/fluid/distributed/table/accessor.h"
#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/graph/graph_node.h"
......@@ -103,10 +104,10 @@ class Table {
virtual int32_t flush() = 0;
virtual int32_t shrink(const std::string &param) = 0;
//指定加载路径
// 指定加载路径
virtual int32_t load(const std::string &path,
const std::string &converter) = 0;
//指定保存路径
// 指定保存路径
virtual int32_t save(const std::string &path,
const std::string &converter) = 0;
......@@ -137,6 +138,7 @@ class Table {
TableParameter _config;
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
AfsClient _afs_client;
};
REGISTER_PSCORE_REGISTERER(Table);
......
......@@ -29,3 +29,6 @@ cc_test(sparse_sgd_rule_test SRCS sparse_sgd_rule_test.cc DEPS ${COMMON_DEPS} bo
set_source_files_properties(ctr_accessor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(ctr_accessor_test SRCS ctr_accessor_test.cc DEPS ${COMMON_DEPS} boost table)
set_source_files_properties(memory_sparse_table_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(memory_sparse_table_test SRCS memory_sparse_table_test.cc DEPS ${COMMON_DEPS} boost table)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ThreadPool.h>
#include <unistd.h>
#include <string>
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/distributed/ps.pb.h"
#include "paddle/fluid/distributed/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/table/table.h"
namespace paddle {
namespace distributed {
TEST(MemorySparseTable, SGD) {
int emb_dim = 8;
int trainers = 2;
TableParameter table_config;
table_config.set_table_class("MemorySparseTable");
table_config.set_shard_num(10);
FsClientParameter fs_config;
Table *table = new MemorySparseTable();
table->set_shard(0, 1);
TableAccessorParameter *accessor_config = table_config.mutable_accessor();
accessor_config->set_accessor_class("CtrCommonAccessor");
accessor_config->set_fea_dim(11);
accessor_config->set_embedx_dim(8);
accessor_config->set_embedx_threshold(5);
accessor_config->mutable_ctr_accessor_param()->set_nonclk_coeff(0.2);
accessor_config->mutable_ctr_accessor_param()->set_click_coeff(1);
accessor_config->mutable_ctr_accessor_param()->set_base_threshold(0.5);
accessor_config->mutable_ctr_accessor_param()->set_delta_threshold(0.2);
accessor_config->mutable_ctr_accessor_param()->set_delta_keep_days(16);
accessor_config->mutable_ctr_accessor_param()->set_show_click_decay_rate(
0.99);
accessor_config->mutable_embed_sgd_param()->set_name("SparseNaiveSGDRule");
auto *naive_param =
accessor_config->mutable_embed_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
accessor_config->mutable_embedx_sgd_param()->set_name("SparseNaiveSGDRule");
naive_param = accessor_config->mutable_embedx_sgd_param()->mutable_naive();
naive_param->set_learning_rate(0.1);
naive_param->set_initial_range(0.3);
naive_param->add_weight_bounds(-10.0);
naive_param->add_weight_bounds(10.0);
auto ret = table->initialize(table_config, fs_config);
ASSERT_EQ(ret, 0);
// pull parameters for create and check
std::vector<uint64_t> init_keys = {0, 1, 2, 3, 4};
std::vector<uint32_t> init_fres = {1, 1, 1, 1, 1};
std::vector<float> init_values;
init_values.resize(init_keys.size() * (emb_dim + 1));
auto value = PullSparseValue(init_keys, init_fres, emb_dim);
table->pull_sparse(init_values.data(), value);
// for check
std::vector<float> total_gradients;
total_gradients.resize(init_keys.size() * (4 + emb_dim));
memset(total_gradients.data(), 0, sizeof(float) * total_gradients.size());
// push gradient
std::vector<std::vector<uint64_t>> trainer_keys;
std::vector<std::vector<float>> trainer_gradient_values;
trainer_keys.resize(trainers);
trainer_gradient_values.resize(trainers);
float start = 0.0;
for (int i = 0; i < trainers; i++) {
start = 0.0;
trainer_keys[i] = init_keys;
for (size_t j = 0; j < trainer_keys[i].size(); j++) {
auto id = trainer_keys[i][j];
for (int k = 0; k < emb_dim + 4; k++) {
trainer_gradient_values[i].push_back(start);
total_gradients[id * (emb_dim + 4) + k] += start;
start += 0.1;
}
}
}
std::shared_ptr<::ThreadPool> pool_ =
std::make_shared<::ThreadPool>(trainers);
std::vector<std::future<void>> task_status;
for (int i = 0; i < trainers; i++) {
auto &push_keys = trainer_keys[i];
auto &push_values = trainer_gradient_values[i];
auto task = [table, &push_keys, &push_values] {
table->push_sparse(push_keys.data(), push_values.data(),
push_keys.size());
};
task_status.push_back(pool_->enqueue(std::move(task)));
}
for (auto &status : task_status) {
status.wait();
}
std::vector<float> pull_values;
pull_values.resize(init_keys.size() * (emb_dim + 1));
table->pull_sparse(pull_values.data(), value);
for (size_t i = 0; i < init_keys.size(); ++i) {
for (size_t j = 0; j < emb_dim + 1; ++j) {
auto update_val = init_values[i * (emb_dim + 1) + j] -
0.1 * total_gradients[3 + i * (emb_dim + 4) + j];
VLOG(3) << total_gradients[i * (emb_dim + 4) + j + 3] << ":"
<< init_values[i * (emb_dim + 1) + j];
VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j];
}
}
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->save_local_fs("./work/table.save", "0", "test");
}
} // namespace distributed
} // namespace paddle
cc_library(fs SRCS fs.cc DEPS string_helper glog boost enforce)
cc_library(shell SRCS shell.cc DEPS string_helper glog timer enforce)
cc_library(fs SRCS fs.cc DEPS string_helper glog boost enforce shell)
cc_test(test_fs SRCS test_fs.cc DEPS fs shell)
if (WITH_CRYPTO)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册