未验证 提交 c49f35cf 编写于 作者: Z zhangchunle 提交者: GitHub

[part1] fix sign-compare warning (#43276)

* fix sign-compare warning

* fix sign-compare 2
上级 caa57498
...@@ -197,7 +197,7 @@ int32_t BrpcPsClient::Initialize() { ...@@ -197,7 +197,7 @@ int32_t BrpcPsClient::Initialize() {
// 异步push 请求队列初始化 // 异步push 请求队列初始化
const auto &worker_param = _config.worker_param().downpour_worker_param(); const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
auto type = worker_param.downpour_table_param(i).type(); auto type = worker_param.downpour_table_param(i).type();
auto table_id = worker_param.downpour_table_param(i).table_id(); auto table_id = worker_param.downpour_table_param(i).table_id();
if (type == PS_DENSE_TABLE) { if (type == PS_DENSE_TABLE) {
...@@ -662,7 +662,7 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id, ...@@ -662,7 +662,7 @@ std::future<int32_t> BrpcPsClient::PushSparseParam(size_t table_id,
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size); memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size; push_data_ptr += value_size;
} }
...@@ -882,7 +882,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient( ...@@ -882,7 +882,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradient(
memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t)); memcpy(push_data_ptr, kvs.data(), kv_size * sizeof(uint64_t));
push_data_ptr += kv_size * sizeof(uint64_t); push_data_ptr += kv_size * sizeof(uint64_t);
for (int i = 0; i < kv_size; ++i) { for (size_t i = 0; i < kv_size; ++i) {
memcpy(push_data_ptr, value_ptr[i], value_size); memcpy(push_data_ptr, value_ptr[i], value_size);
push_data_ptr += value_size; push_data_ptr += value_size;
} }
...@@ -1237,7 +1237,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial( ...@@ -1237,7 +1237,7 @@ std::future<int32_t> BrpcPsClient::PushSparseRawGradientPartial(
char *push_data_ptr = const_cast<char *>(push_data->data()); char *push_data_ptr = const_cast<char *>(push_data->data());
memcpy(push_data_ptr, keys, num * sizeof(uint64_t)); memcpy(push_data_ptr, keys, num * sizeof(uint64_t));
push_data_ptr += num * sizeof(uint64_t); push_data_ptr += num * sizeof(uint64_t);
for (int i = 0; i < num; ++i) { for (uint32_t i = 0; i < num; ++i) {
memcpy(push_data_ptr, update_values[i], value_size); memcpy(push_data_ptr, update_values[i], value_size);
push_data_ptr += value_size; push_data_ptr += value_size;
} }
...@@ -1257,7 +1257,7 @@ int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id, ...@@ -1257,7 +1257,7 @@ int32_t BrpcPsClient::RecvAndSaveTable(const uint64_t table_id,
int64_t var_shape = 0; int64_t var_shape = 0;
std::string table_class; std::string table_class;
const auto &worker_param = _config.worker_param().downpour_worker_param(); const auto &worker_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) { for (int i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) { if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name(); var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).common().table_num(); var_num = worker_param.downpour_table_param(i).common().table_num();
...@@ -1481,13 +1481,13 @@ void BrpcPsClient::PushSparseTaskConsume() { ...@@ -1481,13 +1481,13 @@ void BrpcPsClient::PushSparseTaskConsume() {
closure->add_timer(rpc_timer); closure->add_timer(rpc_timer);
std::vector<std::future<int>> merge_status(request_call_num); std::vector<std::future<int>> merge_status(request_call_num);
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind( async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::PushSparseAsyncShardPush, this, task_list, &BrpcPsClient::PushSparseAsyncShardPush, this, task_list,
request_kv_num, table_id, shard_idx, closure, accessor)); request_kv_num, table_id, shard_idx, closure, accessor));
} }
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait(); merge_status[shard_idx].wait();
} }
merge_status.clear(); merge_status.clear();
...@@ -1497,13 +1497,13 @@ void BrpcPsClient::PushSparseTaskConsume() { ...@@ -1497,13 +1497,13 @@ void BrpcPsClient::PushSparseTaskConsume() {
auto queue_size = task_queue->Size(); auto queue_size = task_queue->Size();
} else { // 未达到阈值 只做多路归并 } else { // 未达到阈值 只做多路归并
std::vector<std::future<int>> merge_status(request_call_num); std::vector<std::future<int>> merge_status(request_call_num);
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx] = merge_status[shard_idx] =
async_push_sparse_shard_threads.enqueue(std::bind( async_push_sparse_shard_threads.enqueue(std::bind(
&BrpcPsClient::PushSparseAsyncShardMerge, this, task_list, &BrpcPsClient::PushSparseAsyncShardMerge, this, task_list,
request_kv_num, table_id, shard_idx, accessor)); request_kv_num, table_id, shard_idx, accessor));
} }
for (int shard_idx = 0; shard_idx < request_call_num; ++shard_idx) { for (size_t shard_idx = 0; shard_idx < request_call_num; ++shard_idx) {
merge_status[shard_idx].wait(); merge_status[shard_idx].wait();
} }
...@@ -1529,7 +1529,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data, ...@@ -1529,7 +1529,7 @@ void sparse_local_merge(ValueAccessor *accessor, float *merge_data,
size_t col_num = accessor->GetAccessorInfo().update_dim; size_t col_num = accessor->GetAccessorInfo().update_dim;
float *merge_data_shell[col_num]; float *merge_data_shell[col_num];
const float *another_data_shell[col_num]; const float *another_data_shell[col_num];
for (int i = 0; i < col_num; ++i) { for (size_t i = 0; i < col_num; ++i) {
merge_data_shell[i] = merge_data + i; merge_data_shell[i] = merge_data + i;
another_data_shell[i] = another_data + i; another_data_shell[i] = another_data + i;
} }
...@@ -1546,12 +1546,12 @@ int BrpcPsClient::PushSparseAsyncShardMerge( ...@@ -1546,12 +1546,12 @@ int BrpcPsClient::PushSparseAsyncShardMerge(
thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list; thread_local std::vector<std::pair<uint64_t, const float *>> sorted_kv_list;
sorted_kv_list.clear(); sorted_kv_list.clear();
for (int i = 1; i < task_list.size(); ++i) { for (size_t i = 1; i < task_list.size(); ++i) {
size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num; size_t kv_num = task_list[i]->data()->shared_data[shard_idx].kv_num;
auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list; auto &key_list = task_list[i]->data()->shared_data[shard_idx].key_list;
auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list; auto &value_list = task_list[i]->data()->shared_data[shard_idx].value_list;
for (int j = 0; j < kv_num; ++j) { for (size_t j = 0; j < kv_num; ++j) {
if (value_list[j].size() < value_size) { if (value_list[j].size() < value_size) {
LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str() LOG(WARNING) << "value_list[" << j << "]: " << value_list[j].c_str()
<< "is invalid."; << "is invalid.";
...@@ -1654,7 +1654,7 @@ int BrpcPsClient::PushSparseAsyncShardPush( ...@@ -1654,7 +1654,7 @@ int BrpcPsClient::PushSparseAsyncShardPush(
memcpy(push_data_ptr, merged_key_list.data(), memcpy(push_data_ptr, merged_key_list.data(),
merged_kv_count * sizeof(uint64_t)); merged_kv_count * sizeof(uint64_t));
push_data_ptr += merged_kv_count * sizeof(uint64_t); push_data_ptr += merged_kv_count * sizeof(uint64_t);
for (int i = 0; i < merged_kv_count; ++i) { for (size_t i = 0; i < merged_kv_count; ++i) {
const char *task_data_ptr = merged_value_list[i].data(); const char *task_data_ptr = merged_value_list[i].data();
memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT memcpy(push_data_ptr, (float *)(task_data_ptr), // NOLINT
...@@ -1778,7 +1778,7 @@ void BrpcPsClient::PushDenseTaskConsume() { ...@@ -1778,7 +1778,7 @@ void BrpcPsClient::PushDenseTaskConsume() {
}); });
++merge_count; ++merge_count;
} }
for (int i = 0; i < merge_count; ++i) { for (uint32_t i = 0; i < merge_count; ++i) {
merge_status[i].wait(); merge_status[i].wait();
} }
......
...@@ -60,7 +60,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -60,7 +60,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
std::vector<std::vector<std::string>> &res) { std::vector<std::vector<std::string>> &res) {
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
...@@ -70,7 +70,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -70,7 +70,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index]; int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]); node_id_buckets[request_idx].push_back(node_ids[query_idx]);
...@@ -83,7 +83,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -83,7 +83,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0; size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) { ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) != if (closure->check_response(request_idx, PS_GRAPH_GET_NODE_FEAT) !=
0) { 0) {
...@@ -122,7 +122,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat( ...@@ -122,7 +122,7 @@ std::future<int32_t> GraphBrpcClient::get_node_feat(
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT); closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_table_id(table_id);
...@@ -271,7 +271,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node( ...@@ -271,7 +271,7 @@ std::future<int32_t> GraphBrpcClient::remove_graph_node(
request_call_num, [&, request_call_num](void *done) { request_call_num, [&, request_call_num](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0; size_t fail_num = 0;
for (size_t request_idx = 0; request_idx < request_call_num; for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) { ++request_idx) {
if (closure->check_response(request_idx, if (closure->check_response(request_idx,
...@@ -378,7 +378,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -378,7 +378,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
res.clear(); res.clear();
res_weight.clear(); res_weight.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
...@@ -393,7 +393,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -393,7 +393,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index]; int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]); node_id_buckets[request_idx].push_back(node_ids[query_idx]);
...@@ -454,7 +454,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors( ...@@ -454,7 +454,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample_neighbors(
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_table_id(table_id);
...@@ -492,7 +492,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes( ...@@ -492,7 +492,7 @@ std::future<int32_t> GraphBrpcClient::random_sample_nodes(
size_t bytes_size = io_buffer_itr.bytes_left(); size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size]; char *buffer = new char[bytes_size];
auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0; size_t index = 0;
while (index < bytes_size) { while (index < bytes_size) {
ids.push_back(*(int64_t *)(buffer + index)); ids.push_back(*(int64_t *)(buffer + index));
index += GraphNode::id_size; index += GraphNode::id_size;
...@@ -534,7 +534,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list( ...@@ -534,7 +534,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
size_t bytes_size = io_buffer_itr.bytes_left(); size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size]; char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size); io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0; size_t index = 0;
while (index < bytes_size) { while (index < bytes_size) {
FeatureNode node; FeatureNode node;
node.recover_from_buffer(buffer + index); node.recover_from_buffer(buffer + index);
...@@ -570,7 +570,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -570,7 +570,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
const std::vector<std::vector<std::string>> &features) { const std::vector<std::vector<std::string>> &features) {
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
...@@ -582,7 +582,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -582,7 +582,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
std::vector<std::vector<std::vector<std::string>>> features_idx_buckets( std::vector<std::vector<std::vector<std::string>>> features_idx_buckets(
request_call_num); request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]); int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index]; int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]); node_id_buckets[request_idx].push_back(node_ids[query_idx]);
...@@ -590,7 +590,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -590,7 +590,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
if (features_idx_buckets[request_idx].size() == 0) { if (features_idx_buckets[request_idx].size() == 0) {
features_idx_buckets[request_idx].resize(feature_names.size()); features_idx_buckets[request_idx].resize(feature_names.size());
} }
for (int feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
features_idx_buckets[request_idx][feat_idx].push_back( features_idx_buckets[request_idx][feat_idx].push_back(
features[feat_idx][query_idx]); features[feat_idx][query_idx]);
} }
...@@ -602,7 +602,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -602,7 +602,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = (DownpourBrpcClosure *)done;
size_t fail_num = 0; size_t fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; for (size_t request_idx = 0; request_idx < request_call_num;
++request_idx) { ++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) != if (closure->check_response(request_idx, PS_GRAPH_SET_NODE_FEAT) !=
0) { 0) {
...@@ -619,7 +619,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat( ...@@ -619,7 +619,7 @@ std::future<int32_t> GraphBrpcClient::set_node_feat(
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx) { for (size_t request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT); closure->request(request_idx)->set_cmd_id(PS_GRAPH_SET_NODE_FEAT);
closure->request(request_idx)->set_table_id(table_id); closure->request(request_idx)->set_table_id(table_id);
......
...@@ -516,7 +516,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -516,7 +516,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int64_t> local_id; std::vector<int64_t> local_id;
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = GetRank(); size_t rank = GetRank();
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
...@@ -538,7 +538,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -538,7 +538,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<size_t> seq; std::vector<size_t> seq;
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num); std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_num; ++query_idx) { for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index]; int request_idx = server2request[server_index];
...@@ -614,7 +614,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -614,7 +614,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->add_promise(promise); closure->add_promise(promise);
std::future<int> fut = promise->get_future(); std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < remote_call_num; ++request_idx) { for (size_t request_idx = 0; request_idx < remote_call_num; ++request_idx) {
int server_index = request2server[request_idx]; int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS); closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBORS);
closure->request(request_idx)->set_table_id(request.table_id()); closure->request(request_idx)->set_table_id(request.table_id());
......
...@@ -196,7 +196,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) { ...@@ -196,7 +196,7 @@ bool CtrCommonAccessor::NeedExtendMF(float* value) {
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool CtrCommonAccessor::HasMF(size_t size) { bool CtrCommonAccessor::HasMF(int size) {
return size > common_feature_value.EmbedxG2SumIndex(); return size > common_feature_value.EmbedxG2SumIndex();
} }
...@@ -227,11 +227,11 @@ int32_t CtrCommonAccessor::Merge(float** update_values, ...@@ -227,11 +227,11 @@ int32_t CtrCommonAccessor::Merge(float** update_values,
const float** other_update_values, const float** other_update_values,
size_t num) { size_t num) {
auto embedx_dim = _config.embedx_dim(); auto embedx_dim = _config.embedx_dim();
size_t total_dim = CtrCommonPushValue::Dim(embedx_dim); int total_dim = CtrCommonPushValue::Dim(embedx_dim);
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) { for (int i = 0; i < total_dim; ++i) {
if (i != CtrCommonPushValue::SlotIndex()) { if (i != CtrCommonPushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
......
...@@ -143,7 +143,7 @@ class CtrCommonAccessor : public ValueAccessor { ...@@ -143,7 +143,7 @@ class CtrCommonAccessor : public ValueAccessor {
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
// virtual bool save_ssd(float* value); // virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value); virtual bool NeedExtendMF(float* value);
virtual bool HasMF(size_t size); virtual bool HasMF(int size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
......
...@@ -139,7 +139,7 @@ bool CtrDoubleAccessor::Save(float* value, int param) { ...@@ -139,7 +139,7 @@ bool CtrDoubleAccessor::Save(float* value, int param) {
} }
default: default:
return true; return true;
}; }
} }
void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
...@@ -166,7 +166,7 @@ void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) { ...@@ -166,7 +166,7 @@ void CtrDoubleAccessor::UpdateStatAfterSave(float* value, int param) {
return; return;
default: default:
return; return;
}; }
} }
int32_t CtrDoubleAccessor::Create(float** values, size_t num) { int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
...@@ -175,7 +175,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) { ...@@ -175,7 +175,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
float* value = values[value_item]; float* value = values[value_item];
value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0; value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0; value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*(double*)(value + CtrDoubleFeatureValue::ShowIndex()) = 0; *reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0; *(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1; value[CtrDoubleFeatureValue::SlotIndex()] = -1;
_embed_sgd_rule->InitValue( _embed_sgd_rule->InitValue(
...@@ -233,7 +233,7 @@ int32_t CtrDoubleAccessor::Merge(float** update_values, ...@@ -233,7 +233,7 @@ int32_t CtrDoubleAccessor::Merge(float** update_values,
for (auto i = 3u; i < total_dim; ++i) { for (auto i = 3u; i < total_dim; ++i) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
}*/ }*/
for (auto i = 0u; i < total_dim; ++i) { for (size_t i = 0; i < total_dim; ++i) {
if (i != CtrDoublePushValue::SlotIndex()) { if (i != CtrDoublePushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
...@@ -320,7 +320,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) { ...@@ -320,7 +320,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
auto score = ShowClickScore(show, click); auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() && param_size > 9) { if (score >= _config.embedx_threshold() && param_size > 9) {
os << " " << v[9]; os << " " << v[9];
for (auto i = 0; i < _config.embedx_dim(); ++i) { for (size_t i = 0; i < _config.embedx_dim(); ++i) {
os << " " << v[10 + i]; os << " " << v[10 + i];
} }
} }
......
...@@ -198,7 +198,7 @@ bool CtrDymfAccessor::NeedExtendMF(float* value) { ...@@ -198,7 +198,7 @@ bool CtrDymfAccessor::NeedExtendMF(float* value) {
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool CtrDymfAccessor::HasMF(size_t size) { bool CtrDymfAccessor::HasMF(int size) {
return size > common_feature_value.EmbedxG2SumIndex(); return size > common_feature_value.EmbedxG2SumIndex();
} }
......
...@@ -158,7 +158,7 @@ class CtrDymfAccessor : public ValueAccessor { ...@@ -158,7 +158,7 @@ class CtrDymfAccessor : public ValueAccessor {
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
// virtual bool save_ssd(float* value); // virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value); virtual bool NeedExtendMF(float* value);
virtual bool HasMF(size_t size); virtual bool HasMF(int size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
......
...@@ -41,7 +41,7 @@ void MemoryDenseTable::CreateInitializer(const std::string& attr, ...@@ -41,7 +41,7 @@ void MemoryDenseTable::CreateInitializer(const std::string& attr,
int32_t MemoryDenseTable::Initialize() { int32_t MemoryDenseTable::Initialize() {
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
} }
...@@ -74,14 +74,14 @@ int32_t MemoryDenseTable::InitializeValue() { ...@@ -74,14 +74,14 @@ int32_t MemoryDenseTable::InitializeValue() {
values_[x].resize(dim); values_[x].resize(dim);
names_index_[varname] = x; names_index_[varname] = x;
for (int y = 0; y < dim; ++y) { for (size_t y = 0; y < dim; ++y) {
values_[x][y] = initializers_[varname]->GetValue(); values_[x][y] = initializers_[varname]->GetValue();
} }
} }
fixed_len_params_dim_ = 0; fixed_len_params_dim_ = 0;
for (int x = 0; x < size; ++x) { for (int x = 0; x < size; ++x) {
auto& dim = common.dims()[x]; int dim = common.dims()[x];
if (dim != param_dim_) { if (dim != param_dim_) {
fixed_len_params_dim_ += dim; fixed_len_params_dim_ += dim;
} else { } else {
...@@ -245,14 +245,14 @@ int32_t MemoryDenseTable::Load(const std::string& path, ...@@ -245,14 +245,14 @@ int32_t MemoryDenseTable::Load(const std::string& path,
do { do {
is_read_failed = false; is_read_failed = false;
try { try {
size_t dim_idx = 0; int dim_idx = 0;
float data_buffer[5]; float data_buffer[5];
float* data_buff_ptr = data_buffer; float* data_buff_ptr = data_buffer;
std::string line_data; std::string line_data;
int size = static_cast<int>(values_.size()); int size = static_cast<int>(values_.size());
auto common = _config.common(); auto common = _config.common();
for (int i = start_file_idx; i < end_file_idx + 1; ++i) { for (size_t i = start_file_idx; i < end_file_idx + 1; ++i) {
channel_config.path = file_list[i]; channel_config.path = file_list[i];
err_no = 0; err_no = 0;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
...@@ -271,12 +271,12 @@ int32_t MemoryDenseTable::Load(const std::string& path, ...@@ -271,12 +271,12 @@ int32_t MemoryDenseTable::Load(const std::string& path,
if (file_dim_idx < file_start_idx) { if (file_dim_idx < file_start_idx) {
continue; continue;
} }
auto str_len = size_t str_len =
paddle::string::str_to_float(line_data.data(), data_buff_ptr); paddle::string::str_to_float(line_data.data(), data_buff_ptr);
CHECK(str_len == param_col_ids_.size()) CHECK(str_len == param_col_ids_.size())
<< "expect " << param_col_ids_.size() << " float, but got " << "expect " << param_col_ids_.size() << " float, but got "
<< str_len; << str_len;
for (size_t col_idx = 0; col_idx < str_len; ++col_idx) { for (int col_idx = 0; col_idx < str_len; ++col_idx) {
if (param_col_ids_[col_idx] < 0) { if (param_col_ids_[col_idx] < 0) {
continue; continue;
} }
...@@ -355,7 +355,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -355,7 +355,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
std::ostringstream os; std::ostringstream os;
for (int x = 0; x < size; ++x) { for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x]; auto& varname = common.params()[x];
auto& dim = common.dims()[x]; int dim = common.dims()[x];
VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim; VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) { for (int y = 0; y < dim; ++y) {
os.clear(); os.clear();
......
...@@ -49,7 +49,7 @@ int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, ...@@ -49,7 +49,7 @@ int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(shard_num); offset_bucket.resize(shard_num);
for (int x = 0; x < num; ++x) { for (size_t x = 0; x < num; ++x) {
auto y = keys[x] % shard_num; auto y = keys[x] % shard_num;
offset_bucket[y].push_back(x); offset_bucket[y].push_back(x);
if (x < 10) { if (x < 10) {
...@@ -66,7 +66,7 @@ int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys, ...@@ -66,7 +66,7 @@ int32_t MemorySparseGeoTable::PushSparseParam(const uint64_t* keys,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
auto& offsets = offset_bucket[shard_id]; auto& offsets = offset_bucket[shard_id];
for (int i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i]; auto offset = offsets[i];
auto id = keys[offset]; auto id = keys[offset];
auto& feature_value = local_shard[id]; auto& feature_value = local_shard[id];
...@@ -132,7 +132,7 @@ int32_t MemorySparseGeoTable::Initialize() { ...@@ -132,7 +132,7 @@ int32_t MemorySparseGeoTable::Initialize() {
_dim = _config.common().dims()[0]; _dim = _config.common().dims()[0];
_shards_task_pool.resize(_task_pool_size); _shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
} }
...@@ -200,14 +200,14 @@ int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys, ...@@ -200,14 +200,14 @@ int32_t MemorySparseGeoTable::_PushSparse(const uint64_t* keys,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
for (size_t shard_id = 0; shard_id < shard_num; ++shard_id) { for (int shard_id = 0; shard_id < shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id]->enqueue(
[this, shard_id, values, &task_keys]() -> int { [this, shard_id, values, &task_keys]() -> int {
auto& keys = task_keys[shard_id]; auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = values + push_data_idx * _dim; const float* update_data = values + push_data_idx * _dim;
......
...@@ -37,7 +37,7 @@ namespace distributed { ...@@ -37,7 +37,7 @@ namespace distributed {
int32_t MemorySparseTable::Initialize() { int32_t MemorySparseTable::Initialize() {
_shards_task_pool.resize(_task_pool_size); _shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
} }
auto& profiler = CostProfiler::instance(); auto& profiler = CostProfiler::instance();
...@@ -79,7 +79,7 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -79,7 +79,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
} }
int load_param = atoi(param.c_str()); int load_param = atoi(param.c_str());
auto expect_shard_num = _sparse_table_shard_num; size_t expect_shard_num = _sparse_table_shard_num;
if (file_list.size() != expect_shard_num) { if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size() LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num; << " not equal to expect_shard_num:" << expect_shard_num;
...@@ -98,7 +98,7 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -98,7 +98,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
channel_config.path = file_list[file_start_idx + i]; channel_config.path = file_list[file_start_idx + i];
VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path VLOG(1) << "MemorySparseTable::load begin load " << channel_config.path
...@@ -164,7 +164,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, ...@@ -164,7 +164,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
auto file_list = paddle::framework::localfs_list(table_path); auto file_list = paddle::framework::localfs_list(table_path);
int load_param = atoi(param.c_str()); int load_param = atoi(param.c_str());
auto expect_shard_num = _sparse_table_shard_num; size_t expect_shard_num = _sparse_table_shard_num;
if (file_list.size() != expect_shard_num) { if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size() LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num; << " not equal to expect_shard_num:" << expect_shard_num;
...@@ -183,7 +183,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, ...@@ -183,7 +183,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
bool is_read_failed = false; bool is_read_failed = false;
int retry_num = 0; int retry_num = 0;
int err_no = 0; int err_no = 0;
...@@ -244,7 +244,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname, ...@@ -244,7 +244,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) { if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
...@@ -326,7 +326,7 @@ int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, ...@@ -326,7 +326,7 @@ int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
feasign_cnt = 0; feasign_cnt = 0;
auto& shard = _local_shards[i]; auto& shard = _local_shards[i];
std::string file_name = paddle::string::format_string( std::string file_name = paddle::string::format_string(
...@@ -354,7 +354,7 @@ int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, ...@@ -354,7 +354,7 @@ int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
int64_t MemorySparseTable::LocalSize() { int64_t MemorySparseTable::LocalSize() {
int64_t local_size = 0; int64_t local_size = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
local_size += _local_shards[i].size(); local_size += _local_shards[i].size();
} }
return local_size; return local_size;
...@@ -364,7 +364,7 @@ int64_t MemorySparseTable::LocalMFSize() { ...@@ -364,7 +364,7 @@ int64_t MemorySparseTable::LocalMFSize() {
std::vector<int64_t> size_arr(_real_local_shard_num, 0); std::vector<int64_t> size_arr(_real_local_shard_num, 0);
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
int64_t ret_size = 0; int64_t ret_size = 0;
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, &size_arr]() -> int { [this, shard_id, &size_arr]() -> int {
...@@ -378,7 +378,7 @@ int64_t MemorySparseTable::LocalMFSize() { ...@@ -378,7 +378,7 @@ int64_t MemorySparseTable::LocalMFSize() {
return 0; return 0;
}); });
} }
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait(); tasks[i].wait();
} }
for (auto x : size_arr) { for (auto x : size_arr) {
...@@ -469,7 +469,7 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -469,7 +469,7 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
memcpy(data_buffer_ptr, itr.value().data(), memcpy(data_buffer_ptr, itr.value().data(),
data_size * sizeof(float)); data_size * sizeof(float));
} }
for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) { for (size_t mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
data_buffer[mf_idx] = 0.0; data_buffer[mf_idx] = 0.0;
} }
auto offset = keys[i].second; auto offset = keys[i].second;
...@@ -503,7 +503,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -503,7 +503,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
// std::atomic<uint32_t> missed_keys{0}; // std::atomic<uint32_t> missed_keys{0};
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, &task_keys, pull_values, value_size, [this, shard_id, &task_keys, pull_values, value_size,
...@@ -512,7 +512,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -512,7 +512,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_size]; float data_buffer[value_size];
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
size_t data_size = value_size - mf_value_size; size_t data_size = value_size - mf_value_size;
...@@ -558,7 +558,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -558,7 +558,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
size_t update_value_col = size_t update_value_col =
_value_accesor->GetAccessorInfo().update_size / sizeof(float); _value_accesor->GetAccessorInfo().update_size / sizeof(float);
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue(
[this, shard_id, value_col, mf_value_col, update_value_col, values, [this, shard_id, value_col, mf_value_col, update_value_col, values,
&task_keys]() -> int { &task_keys]() -> int {
...@@ -566,7 +566,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -566,7 +566,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, const float* values,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col]; // NOLINT float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = const float* update_data =
...@@ -639,7 +639,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -639,7 +639,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col]; // NOLINT float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = values[push_data_idx]; const float* update_data = values[push_data_idx];
......
...@@ -171,7 +171,7 @@ bool SparseAccessor::NeedExtendMF(float* value) { ...@@ -171,7 +171,7 @@ bool SparseAccessor::NeedExtendMF(float* value) {
return score >= _config.embedx_threshold(); return score >= _config.embedx_threshold();
} }
bool SparseAccessor::HasMF(size_t size) { bool SparseAccessor::HasMF(int size) {
return size > sparse_feature_value.EmbedxG2SumIndex(); return size > sparse_feature_value.EmbedxG2SumIndex();
} }
...@@ -201,7 +201,7 @@ int32_t SparseAccessor::Merge(float** update_values, ...@@ -201,7 +201,7 @@ int32_t SparseAccessor::Merge(float** update_values,
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* update_value = update_values[value_item]; float* update_value = update_values[value_item];
const float* other_update_value = other_update_values[value_item]; const float* other_update_value = other_update_values[value_item];
for (auto i = 0u; i < total_dim; ++i) { for (size_t i = 0; i < total_dim; ++i) {
if (i != SparsePushValue::SlotIndex()) { if (i != SparsePushValue::SlotIndex()) {
update_value[i] += other_update_value[i]; update_value[i] += other_update_value[i];
} }
......
...@@ -130,7 +130,7 @@ class SparseAccessor : public ValueAccessor { ...@@ -130,7 +130,7 @@ class SparseAccessor : public ValueAccessor {
// 判断该value是否保存到ssd // 判断该value是否保存到ssd
// virtual bool save_ssd(float* value); // virtual bool save_ssd(float* value);
virtual bool NeedExtendMF(float* value); virtual bool NeedExtendMF(float* value);
virtual bool HasMF(size_t size); virtual bool HasMF(int size);
// 判断该value是否在save阶段dump, // 判断该value是否在save阶段dump,
// param作为参数用于标识save阶段,如downpour的xbox与batch_model // param作为参数用于标识save阶段,如downpour的xbox与batch_model
// param = 0, save all feature // param = 0, save all feature
......
...@@ -90,7 +90,7 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, ...@@ -90,7 +90,7 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd,
float& g2sum = sgd[G2SumIndex()]; float& g2sum = sgd[G2SumIndex()];
double add_g2sum = 0; double add_g2sum = 0;
for (int i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
double scaled_grad = grad[i] / scale; double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad * w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
...@@ -103,7 +103,7 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, ...@@ -103,7 +103,7 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w, float* sgd,
void SparseAdaGradSGDRule::InitValueWork(float* value, float* sgd, void SparseAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
BoundValue(value[i]); BoundValue(value[i]);
...@@ -141,7 +141,7 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, ...@@ -141,7 +141,7 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
float scale) { float scale) {
for (int i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
float& g2sum = sgd[G2SumIndex() + i]; float& g2sum = sgd[G2SumIndex() + i];
double scaled_grad = grad[i] / scale; double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad * w[i] -= learning_rate_ * scaled_grad *
...@@ -153,7 +153,7 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, ...@@ -153,7 +153,7 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
void StdAdaGradSGDRule::InitValueWork(float* value, float* sgd, void StdAdaGradSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
BoundValue(value[i]); BoundValue(value[i]);
...@@ -204,7 +204,7 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, ...@@ -204,7 +204,7 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
// lr not change in one update // lr not change in one update
lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (int i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
// Calculation // Calculation
gsum[i] = _beta1_decay_rate * gsum[i] + (1 - _beta1_decay_rate) * g[i]; gsum[i] = _beta1_decay_rate * gsum[i] + (1 - _beta1_decay_rate) * g[i];
g2sum[i] = g2sum[i] =
...@@ -219,7 +219,7 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad, ...@@ -219,7 +219,7 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, float* sgd, const float* grad,
void SparseAdamSGDRule::InitValueWork(float* value, float* sgd, void SparseAdamSGDRule::InitValueWork(float* value, float* sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
BoundValue(value[i]); BoundValue(value[i]);
...@@ -233,7 +233,7 @@ void SparseAdamSGDRule::InitValueWork(float* value, float* sgd, ...@@ -233,7 +233,7 @@ void SparseAdamSGDRule::InitValueWork(float* value, float* sgd,
} }
} }
// init rule gsum and g2sum // init rule gsum and g2sum
for (int i = GSumIndex(); i < Beta1PowIndex(); i++) { for (size_t i = GSumIndex(); i < Beta1PowIndex(); i++) {
sgd[i] = 0.0; sgd[i] = 0.0;
} }
// init beta1_pow and beta2_pow // init beta1_pow and beta2_pow
......
...@@ -58,7 +58,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, ...@@ -58,7 +58,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys,
} }
std::atomic<uint32_t> missed_keys{0}; std::atomic<uint32_t> missed_keys{0};
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, &task_keys, value_size, mf_value_size, [this, shard_id, &task_keys, value_size, mf_value_size,
...@@ -67,7 +67,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, ...@@ -67,7 +67,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_size]; float data_buffer[value_size];
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
size_t data_size = value_size - mf_value_size; size_t data_size = value_size - mf_value_size;
...@@ -105,7 +105,8 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, ...@@ -105,7 +105,8 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys,
memcpy(data_buffer_ptr, itr.value().data(), memcpy(data_buffer_ptr, itr.value().data(),
data_size * sizeof(float)); data_size * sizeof(float));
} }
for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) { for (size_t mf_idx = data_size; mf_idx < value_size;
++mf_idx) {
data_buffer[mf_idx] = 0.0; data_buffer[mf_idx] = 0.0;
} }
int pull_data_idx = keys[i].second; int pull_data_idx = keys[i].second;
...@@ -117,7 +118,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, ...@@ -117,7 +118,7 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys,
return 0; return 0;
}); });
} }
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait(); tasks[i].wait();
} }
if (FLAGS_pserver_print_missed_key_num_every_push) { if (FLAGS_pserver_print_missed_key_num_every_push) {
...@@ -145,7 +146,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -145,7 +146,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values,
int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
task_keys[shard_id].push_back({keys[i], i}); task_keys[shard_id].push_back({keys[i], i});
} }
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] = tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, value_col, mf_value_col, update_value_col, [this, shard_id, value_col, mf_value_col, update_value_col,
...@@ -154,7 +155,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -154,7 +155,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values,
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col]; float data_buffer[value_col];
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = const float* update_data =
...@@ -196,7 +197,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values, ...@@ -196,7 +197,7 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values,
return 0; return 0;
}); });
} }
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait(); tasks[i].wait();
} }
} }
...@@ -228,7 +229,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) { ...@@ -228,7 +229,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) {
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
uint64_t mem_count = 0; uint64_t mem_count = 0;
uint64_t ssd_count = 0; uint64_t ssd_count = 0;
...@@ -264,7 +265,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) { ...@@ -264,7 +265,7 @@ int32_t SSDSparseTable::Shrink(const std::string& param) {
int32_t SSDSparseTable::UpdateTable() { int32_t SSDSparseTable::UpdateTable() {
// TODO implement with multi-thread // TODO implement with multi-thread
int count = 0; int count = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
auto& shard = _local_shards[i]; auto& shard = _local_shards[i];
// from mem to ssd // from mem to ssd
for (auto it = shard.begin(); it != shard.end();) { for (auto it = shard.begin(); it != shard.end();) {
...@@ -285,7 +286,7 @@ int32_t SSDSparseTable::UpdateTable() { ...@@ -285,7 +286,7 @@ int32_t SSDSparseTable::UpdateTable() {
int64_t SSDSparseTable::LocalSize() { int64_t SSDSparseTable::LocalSize() {
int64_t local_size = 0; int64_t local_size = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
local_size += _local_shards[i].size(); local_size += _local_shards[i].size();
} }
// TODO rocksdb size // TODO rocksdb size
...@@ -328,7 +329,7 @@ int32_t SSDSparseTable::Save(const std::string& path, ...@@ -328,7 +329,7 @@ int32_t SSDSparseTable::Save(const std::string& path,
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) { if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
channel_config.path = paddle::string::format_string( channel_config.path = paddle::string::format_string(
...@@ -484,14 +485,14 @@ int64_t SSDSparseTable::CacheShuffle( ...@@ -484,14 +485,14 @@ int64_t SSDSparseTable::CacheShuffle(
int feasign_size = 0; int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>> std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels; tmp_channels;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back( tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>()); paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
} }
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer = paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[i]; writers[i];
// std::shared_ptr<paddle::framework::ChannelObject<std::pair<uint64_t, // std::shared_ptr<paddle::framework::ChannelObject<std::pair<uint64_t,
...@@ -520,7 +521,7 @@ int64_t SSDSparseTable::CacheShuffle( ...@@ -520,7 +521,7 @@ int64_t SSDSparseTable::CacheShuffle(
<< " and start sparse cache data shuffle real local shard num: " << " and start sparse cache data shuffle real local shard num: "
<< _real_local_shard_num; << _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas; std::vector<std::pair<uint64_t, std::string>> local_datas;
for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) { for (int idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer = paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[idx_shard]; writers[idx_shard];
auto channel = writer.channel(); auto channel = writer.channel();
...@@ -543,8 +544,8 @@ int64_t SSDSparseTable::CacheShuffle( ...@@ -543,8 +544,8 @@ int64_t SSDSparseTable::CacheShuffle(
send_index[i] = i; send_index[i] = i;
} }
std::random_shuffle(send_index.begin(), send_index.end()); std::random_shuffle(send_index.begin(), send_index.end());
for (auto index = 0u; index < shuffle_node_num; ++index) { for (int index = 0; index < shuffle_node_num; ++index) {
int i = send_index[index]; size_t i = send_index[index];
if (i == _shard_idx) { if (i == _shard_idx) {
continue; continue;
} }
...@@ -624,7 +625,7 @@ int32_t SSDSparseTable::Load(const std::string& path, ...@@ -624,7 +625,7 @@ int32_t SSDSparseTable::Load(const std::string& path,
} }
//加载path目录下数据[start_idx, end_idx) //加载path目录下数据[start_idx, end_idx)
int32_t SSDSparseTable::Load(size_t start_idx, size_t end_idx, int32_t SSDSparseTable::Load(size_t start_idx, int end_idx,
const std::vector<std::string>& file_list, const std::vector<std::string>& file_list,
const std::string& param) { const std::string& param) {
if (start_idx >= file_list.size()) { if (start_idx >= file_list.size()) {
...@@ -688,7 +689,7 @@ int32_t SSDSparseTable::Load(size_t start_idx, size_t end_idx, ...@@ -688,7 +689,7 @@ int32_t SSDSparseTable::Load(size_t start_idx, size_t end_idx,
continue; continue;
} }
} }
int value_size = size_t value_size =
_value_accesor->ParseFromString(++end, data_buffer_ptr); _value_accesor->ParseFromString(++end, data_buffer_ptr);
// ssd or mem // ssd or mem
if (_value_accesor->SaveSSD(data_buffer_ptr)) { if (_value_accesor->SaveSSD(data_buffer_ptr)) {
......
...@@ -55,7 +55,7 @@ class SSDSparseTable : public MemorySparseTable { ...@@ -55,7 +55,7 @@ class SSDSparseTable : public MemorySparseTable {
int32_t Flush() override { return 0; } int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) override; virtual int32_t Shrink(const std::string& param) override;
virtual void Clear() override { virtual void Clear() override {
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
_local_shards[i].clear(); _local_shards[i].clear();
} }
} }
...@@ -79,7 +79,7 @@ class SSDSparseTable : public MemorySparseTable { ...@@ -79,7 +79,7 @@ class SSDSparseTable : public MemorySparseTable {
virtual int32_t Load(const std::string& path, virtual int32_t Load(const std::string& path,
const std::string& param) override; const std::string& param) override;
//加载path目录下数据[start_idx, end_idx) //加载path目录下数据[start_idx, end_idx)
virtual int32_t Load(size_t start_idx, size_t end_idx, virtual int32_t Load(size_t start_idx, int end_idx,
const std::vector<std::string>& file_list, const std::vector<std::string>& file_list,
const std::string& param); const std::string& param);
int64_t LocalSize(); int64_t LocalSize();
......
...@@ -536,8 +536,8 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -536,8 +536,8 @@ void FleetWrapper::PushSparseFromTensorAsync(
output_len = 0; output_len = 0;
if (tensor->lod().size() > 0) { if (tensor->lod().size() > 0) {
for (size_t i = 0; i < tensor->lod()[0].size() - 1; ++i) { for (int i = 0; i < tensor->lod()[0].size() - 1; ++i) {
for (int j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1]; for (size_t j = tensor->lod()[0][i]; j < tensor->lod()[0][i + 1];
++j, output_len += fea_dim) { ++j, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[j]); uint64_t real_id = static_cast<uint64_t>(ids[j]);
if (real_id == padding_id) { if (real_id == padding_id) {
...@@ -566,7 +566,7 @@ void FleetWrapper::PushSparseFromTensorAsync( ...@@ -566,7 +566,7 @@ void FleetWrapper::PushSparseFromTensorAsync(
} }
} }
} else { } else {
for (size_t i = 0; i < len; ++i, output_len += fea_dim) { for (int i = 0; i < len; ++i, output_len += fea_dim) {
uint64_t real_id = static_cast<uint64_t>(ids[i]); uint64_t real_id = static_cast<uint64_t>(ids[i]);
if (real_id == padding_id) { if (real_id == padding_id) {
continue; continue;
......
...@@ -222,7 +222,7 @@ void RunBrpcPushDense() { ...@@ -222,7 +222,7 @@ void RunBrpcPushDense() {
worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0); worker_ptr_->PullDense(temp_region.data(), temp_region.size(), 0);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (int64_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(temp[idx], 1.0); EXPECT_FLOAT_EQ(temp[idx], 1.0);
} }
...@@ -236,7 +236,7 @@ void RunBrpcPushDense() { ...@@ -236,7 +236,7 @@ void RunBrpcPushDense() {
pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0); pull_status = worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_status.wait(); pull_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (int64_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(w[idx], float(idx)); EXPECT_FLOAT_EQ(w[idx], float(idx));
} }
...@@ -265,7 +265,7 @@ void RunBrpcPushDense() { ...@@ -265,7 +265,7 @@ void RunBrpcPushDense() {
worker_ptr_->PullDense(regions.data(), regions.size(), 0); worker_ptr_->PullDense(regions.data(), regions.size(), 0);
pull_update_status.wait(); pull_update_status.wait();
for (size_t idx = 0; idx < tensor->numel(); ++idx) { for (int64_t idx = 0; idx < tensor->numel(); ++idx) {
EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0); EXPECT_FLOAT_EQ(w[idx], float(idx) - 1.0);
} }
......
...@@ -89,25 +89,25 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { ...@@ -89,25 +89,25 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) {
rule.InitValue(w, w + 10, true); rule.InitValue(w, w + 10, true);
for (auto i = 0u; i < kEmbSize; ++i) { for (int i = 0; i < kEmbSize; ++i) {
ASSERT_FLOAT_EQ(w[i], 0); ASSERT_FLOAT_EQ(w[i], 0);
} }
ASSERT_FLOAT_EQ(w[kEmbSize], 0); ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check init_value for random // check init_value for random
rule.InitValue(w, w + 10, false); rule.InitValue(w, w + 10, false);
for (auto i = 0u; i < kEmbSize; ++i) { for (int i = 0; i < kEmbSize; ++i) {
ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound()); ASSERT_TRUE(w[i] >= rule.MinBound() && w[i] <= rule.MaxBound());
} }
ASSERT_FLOAT_EQ(w[kEmbSize], 0); ASSERT_FLOAT_EQ(w[kEmbSize], 0);
// check update_value for one field // check update_value for one field
for (auto i = 0u; i < kEmbSize; ++i) { for (int i = 0; i < kEmbSize; ++i) {
w[i] = 0; w[i] = 0;
} }
w[kEmbSize] = 0; w[kEmbSize] = 0;
float grad[kEmbSize]; float grad[kEmbSize];
for (auto i = 0u; i < kEmbSize; ++i) { for (int i = 0; i < kEmbSize; ++i) {
grad[i] = (i + 1) * 1.0; grad[i] = (i + 1) * 1.0;
} }
...@@ -185,7 +185,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { ...@@ -185,7 +185,7 @@ TEST(downpour_sparse_adam_test, test_init_and_update) {
rule.UpdateValue(value, value + embed_dim, grad); rule.UpdateValue(value, value + embed_dim, grad);
for (auto i = 0u; i < value_dim; ++i) { // check update for (int i = 0; i < value_dim; ++i) { // check update
ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i; ASSERT_FLOAT_EQ(value[i], label[i]) << "i is " << i;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册