未验证 提交 be273ea9 编写于 作者: L Li-fAngyU 提交者: GitHub

fix build warning: [Wsign-compare] on linux (#46644)

上级 ddf317ed
...@@ -78,7 +78,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -78,7 +78,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
paddle::framework::GpuPsFeaInfo x; paddle::framework::GpuPsFeaInfo x;
std::vector<uint64_t> feature_ids; std::vector<uint64_t> feature_ids;
for (size_t j = 0; j < bags[i].size(); j++) { for (size_t j = 0; j < bags[i].size(); j++) {
// TODO use FEATURE_TABLE instead // TODO(danleifeng): use FEATURE_TABLE instead
Node *v = find_node(1, bags[i][j]); Node *v = find_node(1, bags[i][j]);
node_id = bags[i][j]; node_id = bags[i][j];
if (v == NULL) { if (v == NULL) {
...@@ -109,7 +109,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -109,7 +109,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
})); }));
} }
} }
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
paddle::framework::GpuPsCommGraphFea res; paddle::framework::GpuPsCommGraphFea res;
uint64_t tot_len = 0; uint64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
...@@ -120,7 +120,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( ...@@ -120,7 +120,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num); res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num);
unsigned int offset = 0, ind = 0; unsigned int offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_id_array[i].size(); j++) { for (size_t j = 0; j < node_id_array[i].size(); j++) {
res.node_list[ind] = node_id_array[i][j]; res.node_list[ind] = node_id_array[i][j];
res.fea_info_list[ind] = node_fea_info_array[i][j]; res.fea_info_list[ind] = node_fea_info_array[i][j];
res.fea_info_list[ind++].feature_offset += offset; res.fea_info_list[ind++].feature_offset += offset;
...@@ -177,7 +177,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( ...@@ -177,7 +177,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
})); }));
} }
} }
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
int64_t tot_len = 0; int64_t tot_len = 0;
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
...@@ -188,7 +188,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph( ...@@ -188,7 +188,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
res.init_on_cpu(tot_len, ids.size()); res.init_on_cpu(tot_len, ids.size());
int64_t offset = 0, ind = 0; int64_t offset = 0, ind = 0;
for (int i = 0; i < task_pool_size_; i++) { for (int i = 0; i < task_pool_size_; i++) {
for (int j = 0; j < (int)node_array[i].size(); j++) { for (size_t j = 0; j < node_array[i].size(); j++) {
res.node_list[ind] = node_array[i][j]; res.node_list[ind] = node_array[i][j];
res.node_info_list[ind] = info_array[i][j]; res.node_info_list[ind] = info_array[i][j];
res.node_info_list[ind++].neighbor_offset += offset; res.node_info_list[ind++].neighbor_offset += offset;
...@@ -213,7 +213,7 @@ int32_t GraphTable::add_node_to_ssd( ...@@ -213,7 +213,7 @@ int32_t GraphTable::add_node_to_ssd(
ch, ch,
sizeof(int) * 2 + sizeof(uint64_t), sizeof(int) * 2 + sizeof(uint64_t),
str) == 0) { str) == 0) {
uint64_t *stored_data = ((uint64_t *)str.c_str()); uint64_t *stored_data = ((uint64_t *)str.c_str()); // NOLINT
int n = str.size() / sizeof(uint64_t); int n = str.size() / sizeof(uint64_t);
char *new_data = new char[n * sizeof(uint64_t) + len]; char *new_data = new char[n * sizeof(uint64_t) + len];
memcpy(new_data, stored_data, n * sizeof(uint64_t)); memcpy(new_data, stored_data, n * sizeof(uint64_t));
...@@ -221,14 +221,14 @@ int32_t GraphTable::add_node_to_ssd( ...@@ -221,14 +221,14 @@ int32_t GraphTable::add_node_to_ssd(
_db->put(src_id % shard_num % task_pool_size_, _db->put(src_id % shard_num % task_pool_size_,
ch, ch,
sizeof(int) * 2 + sizeof(uint64_t), sizeof(int) * 2 + sizeof(uint64_t),
(char *)new_data, (char *)new_data, // NOLINT
n * sizeof(uint64_t) + len); n * sizeof(uint64_t) + len);
delete[] new_data; delete[] new_data;
} else { } else {
_db->put(src_id % shard_num % task_pool_size_, _db->put(src_id % shard_num % task_pool_size_,
ch, ch,
sizeof(int) * 2 + sizeof(uint64_t), sizeof(int) * 2 + sizeof(uint64_t),
(char *)data, (char *)data, // NOLINT
len); len);
} }
} }
...@@ -254,7 +254,7 @@ char *GraphTable::random_sample_neighbor_from_ssd( ...@@ -254,7 +254,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
ch, ch,
sizeof(int) * 2 + sizeof(uint64_t), sizeof(int) * 2 + sizeof(uint64_t),
str) == 0) { str) == 0) {
uint64_t *data = ((uint64_t *)str.c_str()); uint64_t *data = ((uint64_t *)str.c_str()); // NOLINT
int n = str.size() / sizeof(uint64_t); int n = str.size() / sizeof(uint64_t);
std::unordered_map<int, int> m; std::unordered_map<int, int> m;
// std::vector<uint64_t> res; // std::vector<uint64_t> res;
...@@ -281,7 +281,7 @@ char *GraphTable::random_sample_neighbor_from_ssd( ...@@ -281,7 +281,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
// res.push_back(data[pos]); // res.push_back(data[pos]);
} }
for (int i = 0; i < actual_size; i += 8) { for (int i = 0; i < actual_size; i += 8) {
VLOG(2) << "sampled an neighbor " << *(uint64_t *)&buff[i]; VLOG(2) << "sampled an neighbor " << *(uint64_t *)&buff[i]; // NOLINT
} }
return buff; return buff;
} }
...@@ -310,8 +310,8 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx, ...@@ -310,8 +310,8 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
std::string str; std::string str;
if (_db->get(i, ch, sizeof(int) * 2 + sizeof(uint64_t), str) == 0) { if (_db->get(i, ch, sizeof(int) * 2 + sizeof(uint64_t), str) == 0) {
count[i] += (int64_t)str.size(); count[i] += (int64_t)str.size();
for (size_t j = 0; j < (int)str.size(); j += sizeof(uint64_t)) { for (size_t j = 0; j < str.size(); j += sizeof(uint64_t)) {
uint64_t id = *(uint64_t *)(str.c_str() + j); uint64_t id = *(uint64_t *)(str.c_str() + j); // NOLINT
add_comm_edge(idx, v, id); add_comm_edge(idx, v, id);
} }
} }
...@@ -321,7 +321,7 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx, ...@@ -321,7 +321,7 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
} }
} }
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
int64_t tot = 0; int64_t tot = 0;
for (auto x : count) tot += x; for (auto x : count) tot += x;
return tot; return tot;
...@@ -354,9 +354,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) { ...@@ -354,9 +354,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
iters.push_back(_db->get_iterator(i)); iters.push_back(_db->get_iterator(i));
iters[i]->SeekToFirst(); iters[i]->SeekToFirst();
} }
int next = 0; size_t next = 0;
while (iters.size()) { while (iters.size()) {
if (next >= (int)iters.size()) { if (next >= iters.size()) {
next = 0; next = 0;
} }
if (!iters[next]->Valid()) { if (!iters[next]->Valid()) {
...@@ -364,15 +364,16 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) { ...@@ -364,15 +364,16 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
continue; continue;
} }
std::string key = iters[next]->key().ToString(); std::string key = iters[next]->key().ToString();
int type_idx = *(int *)key.c_str(); int type_idx = *(int *)key.c_str(); // NOLINT
int temp_idx = *(int *)(key.c_str() + sizeof(int)); int temp_idx = *(int *)(key.c_str() + sizeof(int)); // NOLINT
if (type_idx != 0 || temp_idx != idx) { if (type_idx != 0 || temp_idx != idx) {
iters[next]->Next(); iters[next]->Next();
next++; next++;
continue; continue;
} }
std::string value = iters[next]->value().ToString(); std::string value = iters[next]->value().ToString();
std::uint64_t i_key = *(uint64_t *)(key.c_str() + sizeof(int) * 2); std::uint64_t i_key =
*(uint64_t *)(key.c_str() + sizeof(int) * 2); // NOLINT
for (int i = 0; i < part_len; i++) { for (int i = 0; i < part_len; i++) {
if (memory_remaining[i] < (int64_t)value.size()) { if (memory_remaining[i] < (int64_t)value.size()) {
score[i] = -100000.0; score[i] = -100000.0;
...@@ -380,8 +381,8 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) { ...@@ -380,8 +381,8 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
score[i] = 0; score[i] = 0;
} }
} }
for (size_t j = 0; j < (int)value.size(); j += sizeof(uint64_t)) { for (size_t j = 0; j < value.size(); j += sizeof(uint64_t)) {
uint64_t v = *((uint64_t *)(value.c_str() + j)); uint64_t v = *((uint64_t *)(value.c_str() + j)); // NOLINT
int index = -1; int index = -1;
if (id_map.find(v) != id_map.end()) { if (id_map.find(v) != id_map.end()) {
index = id_map[v]; index = id_map[v];
...@@ -398,9 +399,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) { ...@@ -398,9 +399,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
int index = 0; int index = 0;
for (int i = 0; i < part_len; i++) { for (int i = 0; i < part_len; i++) {
base = gb_size_by_discount - memory_remaining[i] + value.size(); base = gb_size_by_discount - memory_remaining[i] + value.size();
if (has_weight) if (has_weight) {
weight_base = weight_cost[i] + w * weight_param; weight_base = weight_cost[i] + w * weight_param;
else { } else {
weight_base = 0; weight_base = 0;
} }
score[i] -= a * y * std::pow(1.0 * base, y - 1) + weight_base; score[i] -= a * y * std::pow(1.0 * base, y - 1) + weight_base;
...@@ -434,7 +435,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) { ...@@ -434,7 +435,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
int part_len = partitions[idx].size(); int part_len = partitions[idx].size();
if (part_len == 0) return; if (part_len == 0) return;
if (file_path == "") file_path = "."; if (file_path == "") file_path = ".";
if (file_path[(int)file_path.size() - 1] != '/') { if (file_path[file_path.size() - 1] != '/') {
file_path += "/"; file_path += "/";
} }
std::vector<std::future<int>> tasks; std::vector<std::future<int>> tasks;
...@@ -459,7 +460,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) { ...@@ -459,7 +460,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
})); }));
} }
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
} }
void GraphTable::clear_graph(int idx) { void GraphTable::clear_graph(int idx) {
for (auto p : edge_shards[idx]) { for (auto p : edge_shards[idx]) {
...@@ -472,7 +473,7 @@ void GraphTable::clear_graph(int idx) { ...@@ -472,7 +473,7 @@ void GraphTable::clear_graph(int idx) {
} }
} }
int32_t GraphTable::load_next_partition(int idx) { int32_t GraphTable::load_next_partition(int idx) {
if (next_partition >= (int)partitions[idx].size()) { if (next_partition >= static_cast<int>(partitions[idx].size())) {
VLOG(0) << "partition iteration is done"; VLOG(0) << "partition iteration is done";
return -1; return -1;
} }
...@@ -518,8 +519,8 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path, ...@@ -518,8 +519,8 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path,
add_node_to_ssd(0, add_node_to_ssd(0,
idx, idx,
src_id, src_id,
(char *)dist_data.data(), (char *)dist_data.data(), // NOLINT
(int)(dist_data.size() * sizeof(uint64_t))); static_cast<int>(dist_data.size() * sizeof(uint64_t)));
} }
} }
VLOG(0) << "total memory cost = " << total_memory_cost << " bytes"; VLOG(0) << "total memory cost = " << total_memory_cost << " bytes";
...@@ -537,14 +538,14 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) { ...@@ -537,14 +538,14 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) {
std::vector<Node *> &v = shards[i]->get_bucket(); std::vector<Node *> &v = shards[i]->get_bucket();
for (size_t j = 0; j < v.size(); j++) { for (size_t j = 0; j < v.size(); j++) {
std::vector<uint64_t> s; std::vector<uint64_t> s;
for (size_t k = 0; k < (int)v[j]->get_neighbor_size(); k++) { for (size_t k = 0; k < v[j]->get_neighbor_size(); k++) {
s.push_back(v[j]->get_neighbor_id(k)); s.push_back(v[j]->get_neighbor_id(k));
} }
cost += v[j]->get_neighbor_size() * sizeof(uint64_t); cost += v[j]->get_neighbor_size() * sizeof(uint64_t);
add_node_to_ssd(0, add_node_to_ssd(0,
idx, idx,
v[j]->get_id(), v[j]->get_id(),
(char *)s.data(), (char *)s.data(), // NOLINT
s.size() * sizeof(uint64_t)); s.size() * sizeof(uint64_t));
} }
return cost; return cost;
...@@ -901,7 +902,8 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table, ...@@ -901,7 +902,8 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std::vector<Node *> GraphShard::get_batch(int start, int end, int step) { std::vector<Node *> GraphShard::get_batch(int start, int end, int step) {
if (start < 0) start = 0; if (start < 0) start = 0;
std::vector<Node *> res; std::vector<Node *> res;
for (int pos = start; pos < std::min(end, (int)bucket.size()); pos += step) { for (int pos = start; pos < std::min(end, (int)bucket.size()); // NOLINT
pos += step) {
res.push_back(bucket[pos]); res.push_back(bucket[pos]);
} }
return res; return res;
...@@ -990,7 +992,7 @@ void GraphShard::delete_node(uint64_t id) { ...@@ -990,7 +992,7 @@ void GraphShard::delete_node(uint64_t id) {
if (iter == node_location.end()) return; if (iter == node_location.end()) return;
int pos = iter->second; int pos = iter->second;
delete bucket[pos]; delete bucket[pos];
if (pos != (int)bucket.size() - 1) { if (pos != static_cast<int>(bucket.size()) - 1) {
bucket[pos] = bucket.back(); bucket[pos] = bucket.back();
node_location[bucket.back()->get_id()] = pos; node_location[bucket.back()->get_id()] = pos;
} }
...@@ -1002,7 +1004,7 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) { ...@@ -1002,7 +1004,7 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
node_location[id] = bucket.size(); node_location[id] = bucket.size();
bucket.push_back(new GraphNode(id)); bucket.push_back(new GraphNode(id));
} }
return (GraphNode *)bucket[node_location[id]]; return (GraphNode *)bucket[node_location[id]]; // NOLINT
} }
GraphNode *GraphShard::add_graph_node(Node *node) { GraphNode *GraphShard::add_graph_node(Node *node) {
...@@ -1011,17 +1013,17 @@ GraphNode *GraphShard::add_graph_node(Node *node) { ...@@ -1011,17 +1013,17 @@ GraphNode *GraphShard::add_graph_node(Node *node) {
node_location[id] = bucket.size(); node_location[id] = bucket.size();
bucket.push_back(node); bucket.push_back(node);
} }
return (GraphNode *)bucket[node_location[id]]; return (GraphNode *)bucket[node_location[id]]; // NOLINT
} }
FeatureNode *GraphShard::add_feature_node(uint64_t id, bool is_overlap) { FeatureNode *GraphShard::add_feature_node(uint64_t id, bool is_overlap) {
if (node_location.find(id) == node_location.end()) { if (node_location.find(id) == node_location.end()) {
node_location[id] = bucket.size(); node_location[id] = bucket.size();
bucket.push_back(new FeatureNode(id)); bucket.push_back(new FeatureNode(id));
return (FeatureNode *)bucket[node_location[id]]; return (FeatureNode *)bucket[node_location[id]]; // NOLINT
} }
if (is_overlap) { if (is_overlap) {
return (FeatureNode *)bucket[node_location[id]]; return (FeatureNode *)bucket[node_location[id]]; // NOLINT
} }
return NULL; return NULL;
...@@ -1037,14 +1039,14 @@ Node *GraphShard::find_node(uint64_t id) { ...@@ -1037,14 +1039,14 @@ Node *GraphShard::find_node(uint64_t id) {
} }
GraphTable::~GraphTable() { GraphTable::~GraphTable() {
for (int i = 0; i < (int)edge_shards.size(); i++) { for (size_t i = 0; i < edge_shards.size(); i++) {
for (auto p : edge_shards[i]) { for (auto p : edge_shards[i]) {
delete p; delete p;
} }
edge_shards[i].clear(); edge_shards[i].clear();
} }
for (int i = 0; i < (int)feature_shards.size(); i++) { for (size_t i = 0; i < feature_shards.size(); i++) {
for (auto p : feature_shards[i]) { for (auto p : feature_shards[i]) {
delete p; delete p;
} }
...@@ -1070,7 +1072,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string &param) { ...@@ -1070,7 +1072,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string &param) {
std::string GraphTable::get_inverse_etype(std::string &etype) { std::string GraphTable::get_inverse_etype(std::string &etype) {
auto etype_split = paddle::string::split_string<std::string>(etype, "2"); auto etype_split = paddle::string::split_string<std::string>(etype, "2");
std::string res; std::string res;
if ((int)etype_split.size() == 3) { if (etype_split.size() == 3) {
res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0]; res = etype_split[2] + "2" + etype_split[1] + "2" + etype_split[0];
} else { } else {
res = etype_split[1] + "2" + etype_split[0]; res = etype_split[1] + "2" + etype_split[0];
...@@ -1099,7 +1101,8 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype, ...@@ -1099,7 +1101,8 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
std::string etype_path = epath + "/" + etypes[i]; std::string etype_path = epath + "/" + etypes[i];
auto etype_path_list = paddle::framework::localfs_list(etype_path); auto etype_path_list = paddle::framework::localfs_list(etype_path);
std::string etype_path_str; std::string etype_path_str;
if (part_num > 0 && part_num < (int)etype_path_list.size()) { if (part_num > 0 &&
part_num < (int)etype_path_list.size()) { // NOLINT
std::vector<std::string> sub_etype_path_list( std::vector<std::string> sub_etype_path_list(
etype_path_list.begin(), etype_path_list.begin() + part_num); etype_path_list.begin(), etype_path_list.begin() + part_num);
etype_path_str = etype_path_str =
...@@ -1116,7 +1119,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype, ...@@ -1116,7 +1119,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
} else { } else {
auto npath_list = paddle::framework::localfs_list(npath); auto npath_list = paddle::framework::localfs_list(npath);
std::string npath_str; std::string npath_str;
if (part_num > 0 && part_num < (int)npath_list.size()) { if (part_num > 0 && part_num < (int)npath_list.size()) { // NOLINT
std::vector<std::string> sub_npath_list( std::vector<std::string> sub_npath_list(
npath_list.begin(), npath_list.begin() + part_num); npath_list.begin(), npath_list.begin() + part_num);
npath_str = paddle::string::join_strings(sub_npath_list, delim); npath_str = paddle::string::join_strings(sub_npath_list, delim);
...@@ -1140,7 +1143,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype, ...@@ -1140,7 +1143,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
return 0; return 0;
})); }));
} }
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get(); for (size_t i = 0; i < tasks.size(); i++) tasks[i].get();
return 0; return 0;
} }
...@@ -1154,13 +1157,14 @@ int32_t GraphTable::get_nodes_ids_by_ranges( ...@@ -1154,13 +1157,14 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res.clear(); res.clear();
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks; std::vector<std::future<size_t>> tasks;
for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); i++) { for (size_t i = 0; i < shards.size() && index < (int)ranges.size(); // NOLINT
i++) {
end = total_size + shards[i]->get_size(); end = total_size + shards[i]->get_size();
start = total_size; start = total_size;
while (start < end && index < (int)ranges.size()) { while (start < end && index < static_cast<int>(ranges.size())) {
if (ranges[index].second <= start) if (ranges[index].second <= start) {
index++; index++;
else if (ranges[index].first >= end) { } else if (ranges[index].first >= end) {
break; break;
} else { } else {
int first = std::max(ranges[index].first, start); int first = std::max(ranges[index].first, start);
...@@ -1178,7 +1182,8 @@ int32_t GraphTable::get_nodes_ids_by_ranges( ...@@ -1178,7 +1182,8 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res.reserve(res.size() + num); res.reserve(res.size() + num);
for (auto &id : keys) { for (auto &id : keys) {
res.push_back(id); res.push_back(id);
std::swap(res[rand() % res.size()], res[(int)res.size() - 1]); std::swap(res[rand() % res.size()],
res[(int)res.size() - 1]); // NOLINT
} }
mutex.unlock(); mutex.unlock();
...@@ -1291,7 +1296,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file( ...@@ -1291,7 +1296,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
return {local_count, local_valid_count}; return {local_count, local_valid_count};
} }
// TODO opt load all node_types in once reading // // TODO(danleifeng): opt load all node_types in once reading
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";"); auto paths = paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0; uint64_t count = 0;
...@@ -1308,7 +1313,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) { ...@@ -1308,7 +1313,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
return parse_node_file(paths[i]); return parse_node_file(paths[i]);
})); }));
} }
for (int i = 0; i < (int)tasks.size(); i++) { for (size_t i = 0; i < tasks.size(); i++) {
auto res = tasks[i].get(); auto res = tasks[i].get();
count += res.first; count += res.first;
valid_count += res.second; valid_count += res.second;
...@@ -1434,13 +1439,13 @@ int32_t GraphTable::load_edges(const std::string &path, ...@@ -1434,13 +1439,13 @@ int32_t GraphTable::load_edges(const std::string &path,
VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]"; VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]";
if (FLAGS_graph_load_in_parallel) { if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks; std::vector<std::future<std::pair<uint64_t, uint64_t>>> tasks;
for (int i = 0; i < paths.size(); i++) { for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue( tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> std::pair<uint64_t, uint64_t> { [&, i, idx, this]() -> std::pair<uint64_t, uint64_t> {
return parse_edge_file(paths[i], idx, reverse_edge); return parse_edge_file(paths[i], idx, reverse_edge);
})); }));
} }
for (int j = 0; j < (int)tasks.size(); j++) { for (size_t j = 0; j < tasks.size(); j++) {
auto res = tasks[j].get(); auto res = tasks[j].get();
count += res.first; count += res.first;
valid_count += res.second; valid_count += res.second;
...@@ -1543,7 +1548,7 @@ int32_t GraphTable::random_sample_nodes(int type_id, ...@@ -1543,7 +1548,7 @@ int32_t GraphTable::random_sample_nodes(int type_id,
int &actual_size) { int &actual_size) {
int total_size = 0; int total_size = 0;
auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; auto &shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
for (int i = 0; i < (int)shards.size(); i++) { for (size_t i = 0; i < shards.size(); i++) {
total_size += shards[i]->get_size(); total_size += shards[i]->get_size();
} }
if (sample_size > total_size) sample_size = total_size; if (sample_size > total_size) sample_size = total_size;
...@@ -1554,9 +1559,11 @@ int32_t GraphTable::random_sample_nodes(int type_id, ...@@ -1554,9 +1559,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
int remain = sample_size, last_pos = -1, num; int remain = sample_size, last_pos = -1, num;
std::set<int> separator_set; std::set<int> separator_set;
for (int i = 0; i < range_num - 1; i++) { for (int i = 0; i < range_num - 1; i++) {
while (separator_set.find(num = rand() % (sample_size - 1)) != unsigned int seed = time(0);
separator_set.end()) while (separator_set.find(num = rand_r(&seed) % (sample_size - 1)) !=
; separator_set.end()) {
continue;
}
separator_set.insert(num); separator_set.insert(num);
} }
for (auto p : separator_set) { for (auto p : separator_set) {
...@@ -1567,8 +1574,11 @@ int32_t GraphTable::random_sample_nodes(int type_id, ...@@ -1567,8 +1574,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
remain = total_size - sample_size + range_num; remain = total_size - sample_size + range_num;
separator_set.clear(); separator_set.clear();
for (int i = 0; i < range_num; i++) { for (int i = 0; i < range_num; i++) {
while (separator_set.find(num = rand() % remain) != separator_set.end()) unsigned int seed = time(0);
; while (separator_set.find(num = rand_r(&seed) % remain) !=
separator_set.end()) {
continue;
}
separator_set.insert(num); separator_set.insert(num);
} }
int used = 0, index = 0; int used = 0, index = 0;
...@@ -1580,12 +1590,13 @@ int32_t GraphTable::random_sample_nodes(int type_id, ...@@ -1580,12 +1590,13 @@ int32_t GraphTable::random_sample_nodes(int type_id,
used += ranges_len[index++]; used += ranges_len[index++];
} }
std::vector<std::pair<int, int>> first_half, second_half; std::vector<std::pair<int, int>> first_half, second_half;
int start_index = rand() % total_size; unsigned int seed = time(0);
int start_index = rand_r(&seed) % total_size;
for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) { for (size_t i = 0; i < ranges_len.size() && i < ranges_pos.size(); i++) {
if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size) if (ranges_pos[i] + ranges_len[i] - 1 + start_index < total_size) {
first_half.push_back({ranges_pos[i] + start_index, first_half.push_back({ranges_pos[i] + start_index,
ranges_pos[i] + ranges_len[i] + start_index}); ranges_pos[i] + ranges_len[i] + start_index});
else if (ranges_pos[i] + start_index >= total_size) { } else if (ranges_pos[i] + start_index >= total_size) {
second_half.push_back( second_half.push_back(
{ranges_pos[i] + start_index - total_size, {ranges_pos[i] + start_index - total_size,
ranges_pos[i] + ranges_len[i] + start_index - total_size}); ranges_pos[i] + ranges_len[i] + start_index - total_size});
...@@ -1623,7 +1634,7 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -1623,7 +1634,7 @@ int32_t GraphTable::random_sample_neighbors(
id_list[index].emplace_back(idx, node_ids[idy], sample_size, need_weight); id_list[index].emplace_back(idx, node_ids[idy], sample_size, need_weight);
} }
for (int i = 0; i < (int)seq_id.size(); i++) { for (size_t i = 0; i < seq_id.size(); i++) {
if (seq_id[i].size() == 0) continue; if (seq_id[i].size() == 0) continue;
tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int { tasks.push_back(_shards_task_pool[i]->enqueue([&, i, this]() -> int {
uint64_t node_id; uint64_t node_id;
...@@ -1633,12 +1644,12 @@ int32_t GraphTable::random_sample_neighbors( ...@@ -1633,12 +1644,12 @@ int32_t GraphTable::random_sample_neighbors(
response = response =
scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r); scaled_lru->query(i, id_list[i].data(), id_list[i].size(), r);
} }
int index = 0; size_t index = 0;
std::vector<SampleResult> sample_res; std::vector<SampleResult> sample_res;
std::vector<SampleKey> sample_keys; std::vector<SampleKey> sample_keys;
auto &rng = _shards_task_rng_pool[i]; auto &rng = _shards_task_rng_pool[i];
for (size_t k = 0; k < id_list[i].size(); k++) { for (size_t k = 0; k < id_list[i].size(); k++) {
if (index < (int)r.size() && if (index < r.size() &&
r[index].first.node_key == id_list[i][k].node_key) { r[index].first.node_key == id_list[i][k].node_key) {
int idy = seq_id[i][k]; int idy = seq_id[i][k];
actual_sizes[idy] = r[index].second.actual_size; actual_sizes[idy] = r[index].second.actual_size;
...@@ -1722,7 +1733,7 @@ int32_t GraphTable::get_node_feat(int idx, ...@@ -1722,7 +1733,7 @@ int32_t GraphTable::get_node_feat(int idx,
if (node == nullptr) { if (node == nullptr) {
return 0; return 0;
} }
for (int feat_idx = 0; feat_idx < (int)feature_names.size(); for (size_t feat_idx = 0; feat_idx < feature_names.size();
++feat_idx) { ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx]; const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) { if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
...@@ -1755,7 +1766,7 @@ int32_t GraphTable::set_node_feat( ...@@ -1755,7 +1766,7 @@ int32_t GraphTable::set_node_feat(
size_t index = node_id % this->shard_num - this->shard_start; size_t index = node_id % this->shard_num - this->shard_start;
auto node = feature_shards[idx][index]->add_feature_node(node_id); auto node = feature_shards[idx][index]->add_feature_node(node_id);
node->set_feature_size(this->feat_name[idx].size()); node->set_feature_size(this->feat_name[idx].size());
for (int feat_idx = 0; feat_idx < (int)feature_names.size(); for (size_t feat_idx = 0; feat_idx < feature_names.size();
++feat_idx) { ++feat_idx) {
const std::string &feature_name = feature_names[feat_idx]; const std::string &feature_name = feature_names[feat_idx];
if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) { if (feat_id_map[idx].find(feature_name) != feat_id_map[idx].end()) {
...@@ -1893,8 +1904,8 @@ int GraphTable::get_all_id(int type_id, ...@@ -1893,8 +1904,8 @@ int GraphTable::get_all_id(int type_id,
MergeShardVector shard_merge(output, slice_num); MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards; auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks; std::vector<std::future<size_t>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) { for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) { for (size_t j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue( tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j, slice_num, &shard_merge]() -> size_t { [&search_shards, idx, j, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys; std::vector<std::vector<uint64_t>> shard_keys;
...@@ -1917,8 +1928,8 @@ int GraphTable::get_all_neighbor_id( ...@@ -1917,8 +1928,8 @@ int GraphTable::get_all_neighbor_id(
MergeShardVector shard_merge(output, slice_num); MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards : feature_shards; auto &search_shards = type_id == 0 ? edge_shards : feature_shards;
std::vector<std::future<size_t>> tasks; std::vector<std::future<size_t>> tasks;
for (int idx = 0; idx < search_shards.size(); idx++) { for (size_t idx = 0; idx < search_shards.size(); idx++) {
for (int j = 0; j < search_shards[idx].size(); j++) { for (size_t j = 0; j < search_shards[idx].size(); j++) {
tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue( tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue(
[&search_shards, idx, j, slice_num, &shard_merge]() -> size_t { [&search_shards, idx, j, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys; std::vector<std::vector<uint64_t>> shard_keys;
...@@ -1970,7 +1981,7 @@ int GraphTable::get_all_neighbor_id( ...@@ -1970,7 +1981,7 @@ int GraphTable::get_all_neighbor_id(
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks; std::vector<std::future<size_t>> tasks;
VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]"; VLOG(3) << "begin task, task_pool_size_[" << task_pool_size_ << "]";
for (int i = 0; i < search_shards.size(); i++) { for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i, slice_num, &shard_merge]() -> size_t { [&search_shards, i, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys; std::vector<std::vector<uint64_t>> shard_keys;
...@@ -1996,7 +2007,7 @@ int GraphTable::get_all_feature_ids( ...@@ -1996,7 +2007,7 @@ int GraphTable::get_all_feature_ids(
MergeShardVector shard_merge(output, slice_num); MergeShardVector shard_merge(output, slice_num);
auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx];
std::vector<std::future<size_t>> tasks; std::vector<std::future<size_t>> tasks;
for (int i = 0; i < search_shards.size(); i++) { for (size_t i = 0; i < search_shards.size(); i++) {
tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue(
[&search_shards, i, slice_num, &shard_merge]() -> size_t { [&search_shards, i, slice_num, &shard_merge]() -> size_t {
std::vector<std::vector<uint64_t>> shard_keys; std::vector<std::vector<uint64_t>> shard_keys;
...@@ -2139,7 +2150,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) { ...@@ -2139,7 +2150,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
if (use_cache) { if (use_cache) {
cache_size_limit = graph.cache_size_limit(); cache_size_limit = graph.cache_size_limit();
cache_ttl = graph.cache_ttl(); cache_ttl = graph.cache_ttl();
make_neighbor_sample_cache((size_t)cache_size_limit, (size_t)cache_ttl); make_neighbor_sample_cache((size_t)cache_size_limit, // NOLINT
(size_t)cache_ttl); // NOLINT
} }
_shards_task_pool.resize(task_pool_size_); _shards_task_pool.resize(task_pool_size_);
for (size_t i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
...@@ -2205,14 +2217,14 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) { ...@@ -2205,14 +2217,14 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS #ifdef PADDLE_WITH_HETERPS
partitions.resize(id_to_edge.size()); partitions.resize(id_to_edge.size());
#endif #endif
for (int k = 0; k < (int)edge_shards.size(); k++) { for (size_t k = 0; k < edge_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) { for (size_t i = 0; i < shard_num_per_server; i++) {
edge_shards[k].push_back(new GraphShard()); edge_shards[k].push_back(new GraphShard());
} }
} }
node_weight[1].resize(id_to_feature.size()); node_weight[1].resize(id_to_feature.size());
feature_shards.resize(id_to_feature.size()); feature_shards.resize(id_to_feature.size());
for (int k = 0; k < (int)feature_shards.size(); k++) { for (size_t k = 0; k < feature_shards.size(); k++) {
for (size_t i = 0; i < shard_num_per_server; i++) { for (size_t i = 0; i < shard_num_per_server; i++) {
feature_shards[k].push_back(new GraphShard()); feature_shards[k].push_back(new GraphShard());
} }
......
...@@ -21,8 +21,8 @@ namespace distributed { ...@@ -21,8 +21,8 @@ namespace distributed {
int FLAGS_pslib_table_save_max_retry_dense = 3; int FLAGS_pslib_table_save_max_retry_dense = 3;
void MemoryDenseTable::CreateInitializer(const std::string& attr, void MemoryDenseTable::CreateInitializer(const std::string &attr,
const std::string& name) { const std::string &name) {
auto slices = string::split_string<std::string>(attr, "&"); auto slices = string::split_string<std::string>(attr, "&");
if (slices[0] == "gaussian_random") { if (slices[0] == "gaussian_random") {
...@@ -60,14 +60,14 @@ int32_t MemoryDenseTable::InitializeValue() { ...@@ -60,14 +60,14 @@ int32_t MemoryDenseTable::InitializeValue() {
values_.resize(size); values_.resize(size);
total_dim_ = 0; total_dim_ = 0;
for (int x = 0; x < size; ++x) { for (int x = 0; x < size; ++x) {
auto& varname = common.params()[x]; auto &varname = common.params()[x];
auto& dim = common.dims()[x]; auto &dim = common.dims()[x];
if (varname == "Param") { if (varname == "Param") {
param_dim_ = dim; param_dim_ = dim;
param_idx_ = x; param_idx_ = x;
} }
auto& initializer = common.initializers()[x]; auto &initializer = common.initializers()[x];
total_dim_ += dim; total_dim_ += dim;
CreateInitializer(initializer, varname); CreateInitializer(initializer, varname);
...@@ -81,7 +81,7 @@ int32_t MemoryDenseTable::InitializeValue() { ...@@ -81,7 +81,7 @@ int32_t MemoryDenseTable::InitializeValue() {
fixed_len_params_dim_ = 0; fixed_len_params_dim_ = 0;
for (int x = 0; x < size; ++x) { for (int x = 0; x < size; ++x) {
auto& dim = common.dims()[x]; auto &dim = common.dims()[x];
if (static_cast<int>(dim) != param_dim_) { if (static_cast<int>(dim) != param_dim_) {
fixed_len_params_dim_ += dim; fixed_len_params_dim_ += dim;
} else { } else {
...@@ -124,19 +124,19 @@ int32_t MemoryDenseTable::InitializeOptimizer() { ...@@ -124,19 +124,19 @@ int32_t MemoryDenseTable::InitializeOptimizer() {
return 0; return 0;
} }
int32_t MemoryDenseTable::SetGlobalLR(float* lr) { int32_t MemoryDenseTable::SetGlobalLR(float *lr) {
_global_lr = lr; _global_lr = lr;
optimizer_->SetGlobalLR(_global_lr); optimizer_->SetGlobalLR(_global_lr);
return 0; return 0;
} }
int32_t MemoryDenseTable::Pull(TableContext& context) { int32_t MemoryDenseTable::Pull(TableContext &context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
float* pull_values = context.pull_context.values; float *pull_values = context.pull_context.values;
return PullDense(pull_values, context.num); return PullDense(pull_values, context.num);
} }
int32_t MemoryDenseTable::Push(TableContext& context) { int32_t MemoryDenseTable::Push(TableContext &context) {
CHECK(context.value_type == Dense); CHECK(context.value_type == Dense);
if (context.push_context.values != nullptr) { if (context.push_context.values != nullptr) {
if (!context.push_context.is_param) { if (!context.push_context.is_param) {
...@@ -148,13 +148,13 @@ int32_t MemoryDenseTable::Push(TableContext& context) { ...@@ -148,13 +148,13 @@ int32_t MemoryDenseTable::Push(TableContext& context) {
return 0; return 0;
} }
int32_t MemoryDenseTable::PullDense(float* pull_values, size_t num) { int32_t MemoryDenseTable::PullDense(float *pull_values, size_t num) {
std::copy( std::copy(
values_[param_idx_].begin(), values_[param_idx_].end(), pull_values); values_[param_idx_].begin(), values_[param_idx_].end(), pull_values);
return 0; return 0;
} }
int32_t MemoryDenseTable::PushDenseParam(const float* values, size_t num) { int32_t MemoryDenseTable::PushDenseParam(const float *values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, num,
param_dim_, param_dim_,
...@@ -171,7 +171,7 @@ int32_t MemoryDenseTable::Pour() { ...@@ -171,7 +171,7 @@ int32_t MemoryDenseTable::Pour() {
return 0; return 0;
} }
int32_t MemoryDenseTable::PushDense(const float* values, size_t num) { int32_t MemoryDenseTable::PushDense(const float *values, size_t num) {
if (sync) { if (sync) {
std::future<int> task = std::future<int> task =
_shards_task_pool[0]->enqueue([this, &values]() -> int { _shards_task_pool[0]->enqueue([this, &values]() -> int {
...@@ -185,7 +185,7 @@ int32_t MemoryDenseTable::PushDense(const float* values, size_t num) { ...@@ -185,7 +185,7 @@ int32_t MemoryDenseTable::PushDense(const float* values, size_t num) {
return 0; return 0;
} }
int32_t MemoryDenseTable::_PushDense(const float* values, size_t num) { int32_t MemoryDenseTable::_PushDense(const float *values, size_t num) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
num, num,
param_dim_, param_dim_,
...@@ -212,8 +212,8 @@ int32_t MemoryDenseTable::_PushDense(const float* values, size_t num) { ...@@ -212,8 +212,8 @@ int32_t MemoryDenseTable::_PushDense(const float* values, size_t num) {
return 0; return 0;
} }
int32_t MemoryDenseTable::Load(const std::string& path, int32_t MemoryDenseTable::Load(const std::string &path,
const std::string& param) { const std::string &param) {
if (param_dim_ <= 0) { if (param_dim_ <= 0) {
return 0; return 0;
} }
...@@ -249,7 +249,7 @@ int32_t MemoryDenseTable::Load(const std::string& path, ...@@ -249,7 +249,7 @@ int32_t MemoryDenseTable::Load(const std::string& path,
try { try {
int dim_idx = 0; int dim_idx = 0;
float data_buffer[5]; float data_buffer[5];
float* data_buff_ptr = data_buffer; float *data_buff_ptr = data_buffer;
std::string line_data; std::string line_data;
auto common = _config.common(); auto common = _config.common();
...@@ -319,8 +319,8 @@ int32_t MemoryDenseTable::Load(const std::string& path, ...@@ -319,8 +319,8 @@ int32_t MemoryDenseTable::Load(const std::string& path,
return 0; return 0;
} }
int32_t MemoryDenseTable::Save(const std::string& path, int32_t MemoryDenseTable::Save(const std::string &path,
const std::string& param) { const std::string &param) {
int save_param = atoi(param.c_str()); int save_param = atoi(param.c_str());
uint32_t feasign_size; uint32_t feasign_size;
VLOG(0) << "MemoryDenseTable::save path " << path; VLOG(0) << "MemoryDenseTable::save path " << path;
...@@ -353,7 +353,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -353,7 +353,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
os.clear(); os.clear();
os.str(""); os.str("");
os << values_[param_col_ids_[0]][y] << " 0"; os << values_[param_col_ids_[0]][y] << " 0";
for (int x = 2; x < param_col_ids_.size(); ++x) { for (size_t x = 2; x < param_col_ids_.size(); ++x) {
os << " "; os << " ";
os << values_[param_col_ids_[x]][y]; os << values_[param_col_ids_[x]][y];
} }
...@@ -365,7 +365,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -365,7 +365,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
os.clear(); os.clear();
os.str(""); os.str("");
os << values_[param_col_ids_[0]][y]; os << values_[param_col_ids_[0]][y];
for (int x = 1; x < param_col_ids_.size(); ++x) { for (size_t x = 1; x < param_col_ids_.size(); ++x) {
os << " "; os << " ";
os << values_[param_col_ids_[x]][y]; os << values_[param_col_ids_[x]][y];
} }
...@@ -383,7 +383,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -383,7 +383,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) { for (auto &t : result_buffer_param) {
if (0 != write_channel->write_line(t)) { if (0 != write_channel->write_line(t)) {
++retry_num; ++retry_num;
is_write_failed = true; is_write_failed = true;
......
...@@ -41,12 +41,12 @@ namespace paddle { ...@@ -41,12 +41,12 @@ namespace paddle {
namespace distributed { namespace distributed {
int32_t MemorySparseTable::Initialize() { int32_t MemorySparseTable::Initialize() {
auto& profiler = CostProfiler::instance(); auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_sparse_update_all"); profiler.register_profiler("pserver_sparse_update_all");
profiler.register_profiler("pserver_sparse_select_all"); profiler.register_profiler("pserver_sparse_select_all");
InitializeValue(); InitializeValue();
_shards_task_pool.resize(_task_pool_size); _shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) { for (size_t i = 0; i < _shards_task_pool.size(); ++i) {
_shards_task_pool[i].reset(new ::ThreadPool(1)); _shards_task_pool[i].reset(new ::ThreadPool(1));
} }
VLOG(0) << "initalize MemorySparseTable succ"; VLOG(0) << "initalize MemorySparseTable succ";
...@@ -102,8 +102,8 @@ int32_t MemorySparseTable::InitializeValue() { ...@@ -102,8 +102,8 @@ int32_t MemorySparseTable::InitializeValue() {
return 0; return 0;
} }
int32_t MemorySparseTable::Load(const std::string& path, int32_t MemorySparseTable::Load(const std::string &path,
const std::string& param) { const std::string &param) {
std::string table_path = TableDir(path); std::string table_path = TableDir(path);
auto file_list = _afs_client.list(table_path); auto file_list = _afs_client.list(table_path);
...@@ -157,13 +157,13 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -157,13 +157,13 @@ int32_t MemorySparseTable::Load(const std::string& path,
err_no = 0; err_no = 0;
std::string line_data; std::string line_data;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL; char *end = NULL;
auto& shard = _local_shards[i]; auto &shard = _local_shards[i];
try { try {
while (read_channel->read_line(line_data) == 0 && while (read_channel->read_line(line_data) == 0 &&
line_data.size() > 1) { line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10); uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto& value = shard[key]; auto &value = shard[key];
value.resize(feature_value_size); value.resize(feature_value_size);
int parse_size = _value_accesor->ParseFromString(++end, value.data()); int parse_size = _value_accesor->ParseFromString(++end, value.data());
value.resize(parse_size); value.resize(parse_size);
...@@ -200,7 +200,7 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -200,7 +200,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
return 0; return 0;
} }
int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, int32_t MemorySparseTable::LoadPatch(const std::vector<std::string> &file_list,
int load_param) { int load_param) {
if (!_config.enable_revert()) { if (!_config.enable_revert()) {
LOG(INFO) << "MemorySparseTable should be enabled revert."; LOG(INFO) << "MemorySparseTable should be enabled revert.";
...@@ -213,7 +213,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, ...@@ -213,7 +213,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
int o_start_idx = _shard_idx * _avg_local_shard_num; int o_start_idx = _shard_idx * _avg_local_shard_num;
int o_end_idx = o_start_idx + _real_local_shard_num; int o_end_idx = o_start_idx + _real_local_shard_num;
if (start_idx >= file_list.size()) { if (start_idx >= static_cast<int>(file_list.size())) {
return 0; return 0;
} }
size_t feature_value_size = size_t feature_value_size =
...@@ -224,7 +224,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, ...@@ -224,7 +224,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = start_idx; i < end_idx; ++i) { for (int i = start_idx; i < end_idx; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
channel_config.path = file_list[i]; channel_config.path = file_list[i];
channel_config.converter = _value_accesor->Converter(load_param).converter; channel_config.converter = _value_accesor->Converter(load_param).converter;
...@@ -239,11 +239,11 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, ...@@ -239,11 +239,11 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
err_no = 0; err_no = 0;
std::string line_data; std::string line_data;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL; char *end = NULL;
int m_local_shard_id = i % _m_avg_local_shard_num; int m_local_shard_id = i % _m_avg_local_shard_num;
std::unordered_set<size_t> global_shard_idx; std::unordered_set<size_t> global_shard_idx;
std::string global_shard_idx_str; std::string global_shard_idx_str;
for (size_t j = o_start_idx; j < o_end_idx; ++j) { for (int j = o_start_idx; j < o_end_idx; ++j) {
if ((j % _avg_local_shard_num) % _m_real_local_shard_num == if ((j % _avg_local_shard_num) % _m_real_local_shard_num ==
m_local_shard_id) { m_local_shard_id) {
global_shard_idx.insert(j); global_shard_idx.insert(j);
...@@ -267,9 +267,9 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, ...@@ -267,9 +267,9 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
continue; continue;
} }
size_t local_shard_idx = *index_iter % _avg_local_shard_num; size_t local_shard_idx = *index_iter % _avg_local_shard_num;
auto& shard = _local_shards[local_shard_idx]; auto &shard = _local_shards[local_shard_idx];
auto& value = shard[key]; auto &value = shard[key];
value.resize(feature_value_size); value.resize(feature_value_size);
int parse_size = _value_accesor->ParseFromString(++end, value.data()); int parse_size = _value_accesor->ParseFromString(++end, value.data());
value.resize(parse_size); value.resize(parse_size);
...@@ -300,7 +300,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list, ...@@ -300,7 +300,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
} }
void MemorySparseTable::Revert() { void MemorySparseTable::Revert() {
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
_local_shards_new[i].clear(); _local_shards_new[i].clear();
} }
} }
...@@ -309,8 +309,8 @@ void MemorySparseTable::CheckSavePrePatchDone() { ...@@ -309,8 +309,8 @@ void MemorySparseTable::CheckSavePrePatchDone() {
_save_patch_model_thread.join(); _save_patch_model_thread.join();
} }
int32_t MemorySparseTable::Save(const std::string& dirname, int32_t MemorySparseTable::Save(const std::string &dirname,
const std::string& param) { const std::string &param) {
if (_real_local_shard_num == 0) { if (_real_local_shard_num == 0) {
_local_show_threshold = -1; _local_show_threshold = -1;
return 0; return 0;
...@@ -368,7 +368,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname, ...@@ -368,7 +368,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
int feasign_size = 0; int feasign_size = 0;
int retry_num = 0; int retry_num = 0;
int err_no = 0; int err_no = 0;
auto& shard = _local_shards[i]; auto &shard = _local_shards[i];
do { do {
err_no = 0; err_no = 0;
feasign_size = 0; feasign_size = 0;
...@@ -426,7 +426,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname, ...@@ -426,7 +426,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
return 0; return 0;
} }
int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) { int32_t MemorySparseTable::SavePatch(const std::string &path, int save_param) {
if (!_config.enable_revert()) { if (!_config.enable_revert()) {
LOG(INFO) << "MemorySparseTable should be enabled revert."; LOG(INFO) << "MemorySparseTable should be enabled revert.";
return 0; return 0;
...@@ -441,7 +441,7 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) { ...@@ -441,7 +441,7 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _m_real_local_shard_num; ++i) { for (int i = 0; i < _m_real_local_shard_num; ++i) {
FsChannelConfig channel_config; FsChannelConfig channel_config;
channel_config.path = paddle::string::format_string("%s/part-%03d-%05d", channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
table_path.c_str(), table_path.c_str(),
...@@ -463,9 +463,9 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) { ...@@ -463,9 +463,9 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (size_t j = 0; j < _real_local_shard_num; ++j) { for (int j = 0; j < _real_local_shard_num; ++j) {
if (j % _m_real_local_shard_num == i) { if (j % _m_real_local_shard_num == i) {
auto& shard = _local_shards_patch_model[j]; auto &shard = _local_shards_patch_model[j];
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->Save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString( std::string format_value = _value_accesor->ParseToString(
...@@ -515,14 +515,14 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) { ...@@ -515,14 +515,14 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
} }
int64_t MemorySparseTable::CacheShuffle( int64_t MemorySparseTable::CacheShuffle(
const std::string& path, const std::string &path,
const std::string& param, const std::string &param,
double cache_threshold, double cache_threshold,
std::function<std::future<int32_t>( std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func, int msg_type, int to_pserver_id, std::string &msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>& paddle::framework::Channel<std::pair<uint64_t, std::string>>
shuffled_channel, &shuffled_channel,
const std::vector<Table*>& table_ptrs) { const std::vector<Table *> &table_ptrs) {
LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold; LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold;
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
if (!_config.enable_sparse_table_cache() || cache_threshold < 0) { if (!_config.enable_sparse_table_cache() || cache_threshold < 0) {
...@@ -546,22 +546,22 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -546,22 +546,22 @@ int64_t MemorySparseTable::CacheShuffle(
int feasign_size = 0; int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>> std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels; tmp_channels;
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back( tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>()); paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
} }
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer = paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer =
writers[i]; writers[i];
writer.Reset(tmp_channels[i].get()); writer.Reset(tmp_channels[i].get());
for (size_t idx = 0; idx < table_ptrs.size(); idx++) { for (size_t idx = 0; idx < table_ptrs.size(); idx++) {
Table* table_ptr = table_ptrs[idx]; Table *table_ptr = table_ptrs[idx];
auto value_accesor = table_ptr->ValueAccesor(); auto value_accesor = table_ptr->ValueAccesor();
shard_type* shard_ptr = static_cast<shard_type*>(table_ptr->GetShard(i)); shard_type *shard_ptr = static_cast<shard_type *>(table_ptr->GetShard(i));
for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) { for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) {
if (value_accesor->SaveCache( if (value_accesor->SaveCache(
...@@ -581,14 +581,14 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -581,14 +581,14 @@ int64_t MemorySparseTable::CacheShuffle(
// size: " << feasign_size << " and start sparse cache data shuffle real local // size: " << feasign_size << " and start sparse cache data shuffle real local
// shard num: " << _real_local_shard_num; // shard num: " << _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas; std::vector<std::pair<uint64_t, std::string>> local_datas;
for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) { for (int idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer = paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>> &writer =
writers[idx_shard]; writers[idx_shard];
auto channel = writer.channel(); auto channel = writer.channel();
std::vector<std::pair<uint64_t, std::string>>& data = datas[idx_shard]; std::vector<std::pair<uint64_t, std::string>> &data = datas[idx_shard];
std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num); std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num);
while (channel->Read(data)) { while (channel->Read(data)) {
for (auto& t : data) { for (auto &t : data) {
auto pserver_id = auto pserver_id =
paddle::distributed::local_random_engine()() % shuffle_node_num; paddle::distributed::local_random_engine()() % shuffle_node_num;
if (pserver_id != _shard_idx) { if (pserver_id != _shard_idx) {
...@@ -604,9 +604,9 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -604,9 +604,9 @@ int64_t MemorySparseTable::CacheShuffle(
send_index[i] = i; send_index[i] = i;
} }
std::random_shuffle(send_index.begin(), send_index.end()); std::random_shuffle(send_index.begin(), send_index.end());
for (auto index = 0u; index < shuffle_node_num; ++index) { for (int index = 0; index < shuffle_node_num; ++index) {
int i = send_index[index]; int i = send_index[index];
if (i == _shard_idx) { if (i == static_cast<int>(_shard_idx)) {
continue; continue;
} }
if (ars[i].Length() == 0) { if (ars[i].Length() == 0) {
...@@ -617,7 +617,7 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -617,7 +617,7 @@ int64_t MemorySparseTable::CacheShuffle(
total_status.push_back(std::move(ret)); total_status.push_back(std::move(ret));
send_data_size[i] += ars[i].Length(); send_data_size[i] += ars[i].Length();
} }
for (auto& t : total_status) { for (auto &t : total_status) {
t.wait(); t.wait();
} }
ars.clear(); ars.clear();
...@@ -630,10 +630,10 @@ int64_t MemorySparseTable::CacheShuffle( ...@@ -630,10 +630,10 @@ int64_t MemorySparseTable::CacheShuffle(
} }
int32_t MemorySparseTable::SaveCache( int32_t MemorySparseTable::SaveCache(
const std::string& path, const std::string &path,
const std::string& param, const std::string &param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>& paddle::framework::Channel<std::pair<uint64_t, std::string>>
shuffled_channel) { &shuffled_channel) {
if (_shard_idx >= _config.sparse_table_cache_file_num()) { if (_shard_idx >= _config.sparse_table_cache_file_num()) {
return 0; return 0;
} }
...@@ -656,7 +656,7 @@ int32_t MemorySparseTable::SaveCache( ...@@ -656,7 +656,7 @@ int32_t MemorySparseTable::SaveCache(
bool is_write_failed = false; bool is_write_failed = false;
shuffled_channel->Close(); shuffled_channel->Close();
while (shuffled_channel->Read(data)) { while (shuffled_channel->Read(data)) {
for (auto& t : data) { for (auto &t : data) {
++feasign_size; ++feasign_size;
if (0 != write_channel->write_line(paddle::string::format_string( if (0 != write_channel->write_line(paddle::string::format_string(
"%lu %s", t.first, t.second.c_str()))) { "%lu %s", t.first, t.second.c_str()))) {
...@@ -695,7 +695,7 @@ int64_t MemorySparseTable::LocalMFSize() { ...@@ -695,7 +695,7 @@ int64_t MemorySparseTable::LocalMFSize() {
tasks[shard_id] = tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, &size_arr]() -> int { [this, shard_id, &size_arr]() -> int {
auto& local_shard = _local_shards[shard_id]; auto &local_shard = _local_shards[shard_id];
for (auto it = local_shard.begin(); it != local_shard.end(); for (auto it = local_shard.begin(); it != local_shard.end();
++it) { ++it) {
if (_value_accesor->HasMF(it.value().size())) { if (_value_accesor->HasMF(it.value().size())) {
...@@ -720,20 +720,20 @@ std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() { ...@@ -720,20 +720,20 @@ std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
return {feasign_size, mf_size}; return {feasign_size, mf_size};
} }
int32_t MemorySparseTable::Pull(TableContext& context) { int32_t MemorySparseTable::Pull(TableContext &context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
if (context.use_ptr) { if (context.use_ptr) {
char** pull_values = context.pull_context.ptr_values; char **pull_values = context.pull_context.ptr_values;
const uint64_t* keys = context.pull_context.keys; const uint64_t *keys = context.pull_context.keys;
return PullSparsePtr(pull_values, keys, context.num); return PullSparsePtr(pull_values, keys, context.num);
} else { } else {
float* pull_values = context.pull_context.values; float *pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value; const PullSparseValue &pull_value = context.pull_context.pull_value;
return PullSparse(pull_values, pull_value); return PullSparse(pull_values, pull_value);
} }
} }
int32_t MemorySparseTable::Push(TableContext& context) { int32_t MemorySparseTable::Push(TableContext &context) {
CHECK(context.value_type == Sparse); CHECK(context.value_type == Sparse);
if (!context.use_ptr) { if (!context.use_ptr) {
return PushSparse( return PushSparse(
...@@ -745,8 +745,8 @@ int32_t MemorySparseTable::Push(TableContext& context) { ...@@ -745,8 +745,8 @@ int32_t MemorySparseTable::Push(TableContext& context) {
} }
} }
int32_t MemorySparseTable::PullSparse(float* pull_values, int32_t MemorySparseTable::PullSparse(float *pull_values,
const PullSparseValue& pull_value) { const PullSparseValue &pull_value) {
CostTimer timer("pserver_sparse_select_all"); CostTimer timer("pserver_sparse_select_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
...@@ -776,11 +776,11 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -776,11 +776,11 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
pull_values, pull_values,
mf_value_size, mf_value_size,
select_value_size]() -> int { select_value_size]() -> int {
auto& local_shard = _local_shards[shard_id]; auto &local_shard = _local_shards[shard_id];
float data_buffer[value_size]; // NOLINT float data_buffer[value_size]; // NOLINT
float* data_buffer_ptr = data_buffer; float *data_buffer_ptr = data_buffer;
auto& keys = task_keys[shard_id]; auto &keys = task_keys[shard_id];
for (size_t i = 0; i < keys.size(); i++) { for (size_t i = 0; i < keys.size(); i++) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
...@@ -790,9 +790,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -790,9 +790,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
if (FLAGS_pserver_create_value_when_push) { if (FLAGS_pserver_create_value_when_push) {
memset(data_buffer, 0, sizeof(float) * data_size); memset(data_buffer, 0, sizeof(float) * data_size);
} else { } else {
auto& feature_value = local_shard[key]; auto &feature_value = local_shard[key];
feature_value.resize(data_size); feature_value.resize(data_size);
float* data_ptr = feature_value.data(); float *data_ptr = feature_value.data();
_value_accesor->Create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy( memcpy(
data_ptr, data_buffer_ptr, data_size * sizeof(float)); data_ptr, data_buffer_ptr, data_size * sizeof(float));
...@@ -807,9 +807,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -807,9 +807,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
data_buffer[mf_idx] = 0.0; data_buffer[mf_idx] = 0.0;
} }
auto offset = keys[i].second; auto offset = keys[i].second;
float* select_data = pull_values + select_value_size * offset; float *select_data = pull_values + select_value_size * offset;
_value_accesor->Select( _value_accesor->Select(
&select_data, (const float**)&data_buffer_ptr, 1); &select_data, (const float **)&data_buffer_ptr, 1);
} }
return 0; return 0;
...@@ -822,8 +822,8 @@ int32_t MemorySparseTable::PullSparse(float* pull_values, ...@@ -822,8 +822,8 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
return 0; return 0;
} }
int32_t MemorySparseTable::PullSparsePtr(char** pull_values, int32_t MemorySparseTable::PullSparsePtr(char **pull_values,
const uint64_t* keys, const uint64_t *keys,
size_t num) { size_t num) {
CostTimer timer("pscore_sparse_select_all"); CostTimer timer("pscore_sparse_select_all");
size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
...@@ -847,20 +847,20 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -847,20 +847,20 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
pull_values, pull_values,
value_size, value_size,
mf_value_size]() -> int { mf_value_size]() -> int {
auto& keys = task_keys[shard_id]; auto &keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto &local_shard = _local_shards[shard_id];
float data_buffer[value_size]; // NOLINT float data_buffer[value_size]; // NOLINT
float* data_buffer_ptr = data_buffer; float *data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
size_t data_size = value_size - mf_value_size; size_t data_size = value_size - mf_value_size;
FixedFeatureValue* ret = NULL; FixedFeatureValue *ret = NULL;
if (itr == local_shard.end()) { if (itr == local_shard.end()) {
// ++missed_keys; // ++missed_keys;
auto& feature_value = local_shard[key]; auto &feature_value = local_shard[key];
feature_value.resize(data_size); feature_value.resize(data_size);
float* data_ptr = feature_value.data(); float *data_ptr = feature_value.data();
_value_accesor->Create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float)); memcpy(data_ptr, data_buffer_ptr, data_size * sizeof(float));
ret = &feature_value; ret = &feature_value;
...@@ -868,7 +868,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -868,7 +868,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
ret = itr.value_ptr(); ret = itr.value_ptr();
} }
int pull_data_idx = keys[i].second; int pull_data_idx = keys[i].second;
pull_values[pull_data_idx] = reinterpret_cast<char*>(ret); pull_values[pull_data_idx] = reinterpret_cast<char *>(ret);
} }
return 0; return 0;
}); });
...@@ -879,8 +879,8 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -879,8 +879,8 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
return 0; return 0;
} }
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, int32_t MemorySparseTable::PushSparse(const uint64_t *keys,
const float* values, const float *values,
size_t num) { size_t num) {
CostTimer timer("pserver_sparse_update_all"); CostTimer timer("pserver_sparse_update_all");
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
...@@ -907,15 +907,15 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -907,15 +907,15 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
update_value_col, update_value_col,
values, values,
&task_keys]() -> int { &task_keys]() -> int {
auto& keys = task_keys[shard_id]; auto &keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto &local_shard = _local_shards[shard_id];
auto& local_shard_new = _local_shards_new[shard_id]; auto &local_shard_new = _local_shards_new[shard_id];
float data_buffer[value_col]; // NOLINT float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float *data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = const float *update_data =
values + push_data_idx * update_value_col; values + push_data_idx * update_value_col;
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
if (itr == local_shard.end()) { if (itr == local_shard.end()) {
...@@ -924,7 +924,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -924,7 +924,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
continue; continue;
} }
auto value_size = value_col - mf_value_col; auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key]; auto &feature_value = local_shard[key];
feature_value.resize(value_size); feature_value.resize(value_size);
_value_accesor->Create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(feature_value.data(), memcpy(feature_value.data(),
...@@ -933,8 +933,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -933,8 +933,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
itr = local_shard.find(key); itr = local_shard.find(key);
} }
auto& feature_value = itr.value(); auto &feature_value = itr.value();
float* value_data = feature_value.data(); float *value_data = feature_value.data();
size_t value_size = feature_value.size(); size_t value_size = feature_value.size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update if (value_size == value_col) { // 已拓展到最大size, 则就地update
...@@ -952,7 +952,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -952,7 +952,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
} }
if (_config.enable_revert()) { if (_config.enable_revert()) {
FixedFeatureValue* feature_value_new = &(local_shard_new[key]); FixedFeatureValue *feature_value_new = &(local_shard_new[key]);
auto new_size = feature_value.size(); auto new_size = feature_value.size();
feature_value_new->resize(new_size); feature_value_new->resize(new_size);
memcpy(feature_value_new->data(), memcpy(feature_value_new->data(),
...@@ -970,8 +970,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -970,8 +970,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
return 0; return 0;
} }
int32_t MemorySparseTable::PushSparse(const uint64_t* keys, int32_t MemorySparseTable::PushSparse(const uint64_t *keys,
const float** values, const float **values,
size_t num) { size_t num) {
std::vector<std::future<int>> tasks(_real_local_shard_num); std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys( std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
...@@ -996,14 +996,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -996,14 +996,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
update_value_col, update_value_col,
values, values,
&task_keys]() -> int { &task_keys]() -> int {
auto& keys = task_keys[shard_id]; auto &keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto &local_shard = _local_shards[shard_id];
float data_buffer[value_col]; // NOLINT float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float *data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first; uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second; uint64_t push_data_idx = keys[i].second;
const float* update_data = values[push_data_idx]; const float *update_data = values[push_data_idx];
auto itr = local_shard.find(key); auto itr = local_shard.find(key);
if (itr == local_shard.end()) { if (itr == local_shard.end()) {
if (FLAGS_pserver_enable_create_feasign_randomly && if (FLAGS_pserver_enable_create_feasign_randomly &&
...@@ -1011,7 +1011,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -1011,7 +1011,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
continue; continue;
} }
auto value_size = value_col - mf_value_col; auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key]; auto &feature_value = local_shard[key];
feature_value.resize(value_size); feature_value.resize(value_size);
_value_accesor->Create(&data_buffer_ptr, 1); _value_accesor->Create(&data_buffer_ptr, 1);
memcpy(feature_value.data(), memcpy(feature_value.data(),
...@@ -1019,8 +1019,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -1019,8 +1019,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
value_size * sizeof(float)); value_size * sizeof(float));
itr = local_shard.find(key); itr = local_shard.find(key);
} }
auto& feature_value = itr.value(); auto &feature_value = itr.value();
float* value_data = feature_value.data(); float *value_data = feature_value.data();
size_t value_size = feature_value.size(); size_t value_size = feature_value.size();
if (value_size == value_col) { // 已拓展到最大size, 则就地update if (value_size == value_col) { // 已拓展到最大size, 则就地update
_value_accesor->Update(&value_data, &update_data, 1); _value_accesor->Update(&value_data, &update_data, 1);
...@@ -1048,12 +1048,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -1048,12 +1048,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
int32_t MemorySparseTable::Flush() { return 0; } int32_t MemorySparseTable::Flush() { return 0; }
int32_t MemorySparseTable::Shrink(const std::string& param) { int32_t MemorySparseTable::Shrink(const std::string &param) {
VLOG(0) << "MemorySparseTable::Shrink"; VLOG(0) << "MemorySparseTable::Shrink";
// TODO(zhaocaibei123): implement with multi-thread // TODO(zhaocaibei123): implement with multi-thread
for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
// Shrink // Shrink
auto& shard = _local_shards[shard_id]; auto &shard = _local_shards[shard_id];
for (auto it = shard.begin(); it != shard.end();) { for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->Shrink(it.value().data())) { if (_value_accesor->Shrink(it.value().data())) {
it = shard.erase(it); it = shard.erase(it);
......
...@@ -23,7 +23,7 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient"); ...@@ -23,7 +23,7 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter &param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto naive_param = param.naive(); auto naive_param = param.naive();
...@@ -41,9 +41,9 @@ void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, ...@@ -41,9 +41,9 @@ void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
} }
} }
void SparseNaiveSGDRule::UpdateValueWork(float* w, void SparseNaiveSGDRule::UpdateValueWork(float *w,
float* sgd, float *sgd,
const float* push_value, const float *push_value,
float scale) { float scale) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
w[i] -= learning_rate_ * push_value[i]; w[i] -= learning_rate_ * push_value[i];
...@@ -51,8 +51,8 @@ void SparseNaiveSGDRule::UpdateValueWork(float* w, ...@@ -51,8 +51,8 @@ void SparseNaiveSGDRule::UpdateValueWork(float* w,
} }
} }
void SparseNaiveSGDRule::InitValueWork(float* value, void SparseNaiveSGDRule::InitValueWork(float *value,
float* sgd, float *sgd,
bool zero_init) { bool zero_init) {
if (zero_init) { if (zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
...@@ -68,7 +68,7 @@ void SparseNaiveSGDRule::InitValueWork(float* value, ...@@ -68,7 +68,7 @@ void SparseNaiveSGDRule::InitValueWork(float* value,
} }
} }
} }
void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter &param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adagrad_param = param.adagrad(); auto adagrad_param = param.adagrad();
...@@ -88,11 +88,11 @@ void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, ...@@ -88,11 +88,11 @@ void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
} }
} }
void SparseAdaGradSGDRule::UpdateValueWork(float* w, void SparseAdaGradSGDRule::UpdateValueWork(float *w,
float* sgd, float *sgd,
const float* grad, const float *grad,
float scale) { float scale) {
float& g2sum = sgd[G2SumIndex()]; float &g2sum = sgd[G2SumIndex()];
double add_g2sum = 0; double add_g2sum = 0;
for (size_t i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
...@@ -106,8 +106,8 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w, ...@@ -106,8 +106,8 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w,
g2sum += add_g2sum / _embedding_dim; g2sum += add_g2sum / _embedding_dim;
} }
void SparseAdaGradSGDRule::InitValueWork(float* value, void SparseAdaGradSGDRule::InitValueWork(float *value,
float* sgd, float *sgd,
bool zero_init) { bool zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
...@@ -125,7 +125,7 @@ void SparseAdaGradSGDRule::InitValueWork(float* value, ...@@ -125,7 +125,7 @@ void SparseAdaGradSGDRule::InitValueWork(float* value,
sgd[G2SumIndex()] = 0; sgd[G2SumIndex()] = 0;
} }
void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter &param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adagrad_param = param.adagrad(); auto adagrad_param = param.adagrad();
...@@ -145,12 +145,12 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, ...@@ -145,12 +145,12 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
} }
} }
void StdAdaGradSGDRule::UpdateValueWork(float* w, void StdAdaGradSGDRule::UpdateValueWork(float *w,
float* sgd, float *sgd,
const float* grad, const float *grad,
float scale) { float scale) {
for (size_t i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
float& g2sum = sgd[G2SumIndex() + i]; float &g2sum = sgd[G2SumIndex() + i];
double scaled_grad = grad[i] / scale; double scaled_grad = grad[i] / scale;
w[i] -= learning_rate_ * scaled_grad * w[i] -= learning_rate_ * scaled_grad *
sqrt(_initial_g2sum / (_initial_g2sum + g2sum)); sqrt(_initial_g2sum / (_initial_g2sum + g2sum));
...@@ -159,8 +159,8 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w, ...@@ -159,8 +159,8 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w,
} }
} }
void StdAdaGradSGDRule::InitValueWork(float* value, void StdAdaGradSGDRule::InitValueWork(float *value,
float* sgd, float *sgd,
bool zero_init) { bool zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
...@@ -178,7 +178,7 @@ void StdAdaGradSGDRule::InitValueWork(float* value, ...@@ -178,7 +178,7 @@ void StdAdaGradSGDRule::InitValueWork(float* value,
} }
} }
void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter &param,
size_t emb_dim) { size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adam_param = param.adam(); auto adam_param = param.adam();
...@@ -199,15 +199,15 @@ void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param, ...@@ -199,15 +199,15 @@ void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
} }
} }
void SparseAdamSGDRule::UpdateValueWork(float* w, void SparseAdamSGDRule::UpdateValueWork(float *w,
float* sgd, float *sgd,
const float* grad, const float *grad,
float scale) { float scale) {
float* gsum = sgd + GSumIndex(); float *gsum = sgd + GSumIndex();
float* g2sum = sgd + G2SumIndex(); float *g2sum = sgd + G2SumIndex();
float* beta1_pow = sgd + Beta1PowIndex(); float *beta1_pow = sgd + Beta1PowIndex();
float* beta2_pow = sgd + Beta2PowIndex(); float *beta2_pow = sgd + Beta2PowIndex();
const float* g = grad; const float *g = grad;
float lr = learning_rate_; float lr = learning_rate_;
float beta1_pow_ = *beta1_pow; float beta1_pow_ = *beta1_pow;
...@@ -227,8 +227,8 @@ void SparseAdamSGDRule::UpdateValueWork(float* w, ...@@ -227,8 +227,8 @@ void SparseAdamSGDRule::UpdateValueWork(float* w,
(*beta2_pow) *= _beta2_decay_rate; (*beta2_pow) *= _beta2_decay_rate;
} }
void SparseAdamSGDRule::InitValueWork(float* value, void SparseAdamSGDRule::InitValueWork(float *value,
float* sgd, float *sgd,
bool zero_init) { bool zero_init) {
for (size_t i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
...@@ -253,7 +253,7 @@ void SparseAdamSGDRule::InitValueWork(float* value, ...@@ -253,7 +253,7 @@ void SparseAdamSGDRule::InitValueWork(float* value,
} }
void SparseSharedAdamSGDRule::LoadConfig( void SparseSharedAdamSGDRule::LoadConfig(
const SparseCommonSGDRuleParameter& param, size_t emb_dim) { const SparseCommonSGDRuleParameter &param, size_t emb_dim) {
_embedding_dim = emb_dim; _embedding_dim = emb_dim;
auto adam_param = param.adam(); auto adam_param = param.adam();
learning_rate_ = adam_param.learning_rate(); learning_rate_ = adam_param.learning_rate();
...@@ -273,15 +273,15 @@ void SparseSharedAdamSGDRule::LoadConfig( ...@@ -273,15 +273,15 @@ void SparseSharedAdamSGDRule::LoadConfig(
} }
} }
void SparseSharedAdamSGDRule::UpdateValueWork(float* w, void SparseSharedAdamSGDRule::UpdateValueWork(float *w,
float* sgd, float *sgd,
const float* grad, const float *grad,
float scale) { float scale) {
float* gsum = sgd + GSumIndex(); float *gsum = sgd + GSumIndex();
float* g2sum = sgd + G2SumIndex(); float *g2sum = sgd + G2SumIndex();
float* beta1_pow = sgd + Beta1PowIndex(); float *beta1_pow = sgd + Beta1PowIndex();
float* beta2_pow = sgd + Beta2PowIndex(); float *beta2_pow = sgd + Beta2PowIndex();
const float* g = grad; const float *g = grad;
float lr = learning_rate_; float lr = learning_rate_;
float beta1_pow_ = *beta1_pow; float beta1_pow_ = *beta1_pow;
...@@ -292,7 +292,7 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w, ...@@ -292,7 +292,7 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_); lr *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
double sum_gsum = 0.0; double sum_gsum = 0.0;
double sum_g2sum = 0.0; double sum_g2sum = 0.0;
for (int i = 0; i < _embedding_dim; i++) { for (size_t i = 0; i < _embedding_dim; i++) {
// Calculation // Calculation
double new_gsum = double new_gsum =
_beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i]; _beta1_decay_rate * gsum_ + (1 - _beta1_decay_rate) * g[i];
...@@ -310,10 +310,10 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w, ...@@ -310,10 +310,10 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
(*beta2_pow) *= _beta2_decay_rate; (*beta2_pow) *= _beta2_decay_rate;
} }
void SparseSharedAdamSGDRule::InitValueWork(float* value, void SparseSharedAdamSGDRule::InitValueWork(float *value,
float* sgd, float *sgd,
bool zero_init) { bool zero_init) {
for (int i = 0; i < _embedding_dim; ++i) { for (size_t i = 0; i < _embedding_dim; ++i) {
if (zero_init) { if (zero_init) {
value[i] = 0.0; value[i] = 0.0;
BoundValue(value[i]); BoundValue(value[i]);
...@@ -327,7 +327,7 @@ void SparseSharedAdamSGDRule::InitValueWork(float* value, ...@@ -327,7 +327,7 @@ void SparseSharedAdamSGDRule::InitValueWork(float* value,
} }
} }
// init rule gsum and g2sum // init rule gsum and g2sum
for (int i = GSumIndex(); i < Beta1PowIndex(); i++) { for (size_t i = GSumIndex(); i < Beta1PowIndex(); i++) {
sgd[i] = 0.0; sgd[i] = 0.0;
} }
// init beta1_pow and beta2_pow // init beta1_pow and beta2_pow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册