未验证 提交 d4b3bfab 编写于 作者: W wangzhen38 提交者: GitHub

[code_style fix] graph_brpc_server cpplint (#49462)

上级 36c6c589
......@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <string>
#include <thread> // NOLINT
#include <utility>
......@@ -125,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_);
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
(reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
return 0;
}
......@@ -142,14 +143,16 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list;
if (request.params_size() == 3) {
size_t weight_list_size = request.params(2).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(2).c_str());
const bool *is_weighted_buffer =
reinterpret_cast<const bool *>(request.params(2).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size);
}
......@@ -161,7 +164,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
// weight_list_size);
// }
((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list);
(reinterpret_cast<GraphTable *>(table))
->add_graph_node(idx_, node_ids, is_weighted_list);
return 0;
}
int32_t GraphBrpcService::remove_graph_node(Table *table,
......@@ -176,12 +180,13 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"remove_graph_node request requires at least 2 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(idx_, node_ids);
(reinterpret_cast<GraphTable *>(table))->remove_graph_node(idx_, node_ids);
return 0;
}
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
......@@ -338,7 +343,7 @@ int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
GraphBrpcServer *p_server = reinterpret_cast<GraphBrpcServer *>(_server);
std::thread t_stop([p_server]() {
p_server->Stop();
LOG(INFO) << "Server Stoped";
......@@ -375,14 +380,14 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments");
return 0;
}
int type_id = *(int *)(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str());
int type_id = std::stoi(request.params(0).c_str());
int idx = std::stoi(request.params(1).c_str());
int start = std::stoi(request.params(2).c_str());
int size = std::stoi(request.params(3).c_str());
int step = std::stoi(request.params(4).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->pull_graph_list(
type_id, idx, start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size);
......@@ -401,14 +406,16 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); // NOLINT
const int sample_size =
*reinterpret_cast<const int *>(request.params(2).c_str());
const bool need_weight =
*reinterpret_cast<const bool *>(request.params(3).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors(
idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
......@@ -425,18 +432,18 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str());
int type_id = std::stoi(request.params(0).c_str());
int idx_ = std::stoi(request.params(1).c_str());
size_t size = std::stoull(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (((GraphTable *)table)
->random_sample_nodes(type_id, idx_, size, buffer, actual_size) ==
0) {
if (reinterpret_cast<GraphTable *>(table)->random_sample_nodes(
type_id, idx_, size, buffer, actual_size) == 0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
} else {
cntl->response_attachment().append(NULL, 0);
}
return 0;
}
......@@ -453,9 +460,10 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names =
......@@ -464,7 +472,8 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num));
((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature);
(reinterpret_cast<GraphTable *>(table))
->get_node_feat(idx_, node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
......@@ -492,11 +501,12 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
int sample_size = *(int *)(request.params(2).c_str());
bool need_weight = *(bool *)(request.params(3).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
int sample_size = std::stoi(request.params(2).c_str());
bool need_weight = std::stoi(request.params(3).c_str());
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
......@@ -504,8 +514,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> local_query_idx;
size_t rank = GetRank();
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int server_index = (reinterpret_cast<GraphTable *>(table))
->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
......@@ -514,10 +524,10 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) {
auto pos = server2request[rank];
std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]);
request2server[static_cast<int>(request2server.size()) - 1]);
server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] =
request2server.size() - 1;
server2request[request2server[static_cast<int>(request2server.size()) -
1]] = request2server.size() - 1;
}
size_t request_call_num = request2server.size();
std::vector<std::shared_ptr<char>> local_buffers;
......@@ -526,8 +536,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index =
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]);
int server_index = (reinterpret_cast<GraphTable *>(table))
->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
......@@ -550,7 +560,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
request_call_num](void *done) {
local_fut.get();
std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num);
size_t fail_num = 0;
......@@ -610,17 +620,19 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int));
closure->request(request_idx)
->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
->add_params(
reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
->add_params(reinterpret_cast<char *>(&sample_size), sizeof(int));
closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool));
PsService_Stub rpc_stub(
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index));
->add_params(reinterpret_cast<char *>(&need_weight), sizeof(bool));
PsService_Stub rpc_stub((reinterpret_cast<GraphBrpcServer *>(GetServer())
->GetCmdChannel(server_index)));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
......@@ -630,7 +642,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure);
}
if (server2request[rank] != -1) {
((GraphTable *)table)
(reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors(idx_,
node_id_buckets.back().data(),
sample_size,
......@@ -655,10 +667,11 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments");
return 0;
}
int idx_ = *(int *)(request.params(0).c_str());
int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
// std::vector<std::string> feature_names =
......@@ -675,7 +688,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer);
const size_t feat_len = *reinterpret_cast<const size_t *>(buffer);
buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat;
......@@ -683,7 +696,8 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
}
}
((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features);
(reinterpret_cast<GraphTable *>(table))
->set_node_feat(idx_, node_ids, feature_names, features);
return 0;
}
......
......@@ -174,7 +174,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(),
......@@ -188,8 +188,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
return 0;
}
bool CtrDoubleAccessor::NeedExtendMF(float* value) {
auto show = ((double*)(value + CtrDoubleFeatureValue::ShowIndex()))[0];
auto click = ((double*)(value + CtrDoubleFeatureValue::ClickIndex()))[0];
auto show = (reinterpret_cast<double*>(
value + CtrDoubleFeatureValue::ShowIndex()))[0];
auto click = (reinterpret_cast<double*>(
value + CtrDoubleFeatureValue::ClickIndex()))[0];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff();
......@@ -204,10 +206,11 @@ int32_t CtrDoubleAccessor::Select(float** select_values,
for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]);
select_value[CtrDoublePullValue::ShowIndex()] =
(float)*(double*)(value + CtrDoubleFeatureValue::ShowIndex());
select_value[CtrDoublePullValue::ShowIndex()] = static_cast<float>(
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()));
select_value[CtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + CtrDoubleFeatureValue::ClickIndex());
static_cast<float>(*reinterpret_cast<double*>(
value + CtrDoubleFeatureValue::ClickIndex()));
select_value[CtrDoublePullValue::EmbedWIndex()] =
value[CtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + CtrDoublePullValue::EmbedxWIndex(),
......@@ -254,15 +257,17 @@ int32_t CtrDoubleAccessor::Update(float** update_values,
float push_show = push_value[CtrDoublePushValue::ShowIndex()];
float push_click = push_value[CtrDoublePushValue::ClickIndex()];
float slot = push_value[CtrDoublePushValue::SlotIndex()];
*(double*)(update_value + CtrDoubleFeatureValue::ShowIndex()) +=
(double)push_show;
*(double*)(update_value + CtrDoubleFeatureValue::ClickIndex()) +=
(double)push_click;
*reinterpret_cast<double*>(update_value +
CtrDoubleFeatureValue::ShowIndex()) +=
static_cast<double>(push_show);
*reinterpret_cast<double*>(update_value +
CtrDoubleFeatureValue::ClickIndex()) +=
static_cast<double>(push_click);
update_value[CtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[CtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
if (!_show_scale) {
......@@ -315,9 +320,11 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
thread_local std::ostringstream os;
os.clear();
os.str("");
os << v[0] << " " << v[1] << " " << (float)((double*)(v + 2))[0] << " "
<< (float)((double*)(v + 4))[0] << " " << v[6] << " " << v[7] << " "
<< v[8];
os << v[0] << " " << v[1] << " "
<< static_cast<const float>((reinterpret_cast<const double*>(v + 2))[0])
<< " "
<< static_cast<const float>((reinterpret_cast<const double*>(v + 4))[0])
<< " " << v[6] << " " << v[7] << " " << v[8];
auto show = CtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = CtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click);
......@@ -331,7 +338,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
}
int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim();
float data_buff[_accessor_info.dim + 2];
float data_buff[_accessor_info.dim + 2]; // NOLINT
float* data_buff_ptr = data_buff;
_embedx_sgd_rule->InitValue(
data_buff_ptr + CtrDoubleFeatureValue::EmbedxWIndex(),
......@@ -350,8 +357,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3];
*reinterpret_cast<double*>(value + show_index) =
static_cast<double>(data_buff_ptr[2]);
*reinterpret_cast<double*>(value + click_index) =
static_cast<double>(data_buff_ptr[3]);
// copy others
value[CtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[CtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
......@@ -362,8 +371,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2];
*(double*)(value + click_index) = (double)data_buff_ptr[3];
*reinterpret_cast<double*>(value + show_index) =
static_cast<double>(data_buff_ptr[2]);
*reinterpret_cast<double*>(value + click_index) =
static_cast<double>(data_buff_ptr[3]);
// copy embed_w..embedx_w
memcpy(value + embed_w_index,
data_buff_ptr + 4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册