未验证 提交 d4b3bfab 编写于 作者: W wangzhen38 提交者: GitHub

[code_style fix] graph_brpc_server cpplint (#49462)

上级 36c6c589
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h" #include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
...@@ -125,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table, ...@@ -125,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str()); int type_id = std::stoi(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str()); int idx_ = std::stoi(request.params(1).c_str());
((GraphTable *)table)->clear_nodes(type_id, idx_); (reinterpret_cast<GraphTable *>(table))->clear_nodes(type_id, idx_);
return 0; return 0;
} }
...@@ -142,14 +143,16 @@ int32_t GraphBrpcService::add_graph_node(Table *table, ...@@ -142,14 +143,16 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(int64_t); size_t node_num = request.params(1).size() / sizeof(int64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<bool> is_weighted_list; std::vector<bool> is_weighted_list;
if (request.params_size() == 3) { if (request.params_size() == 3) {
size_t weight_list_size = request.params(2).size() / sizeof(bool); size_t weight_list_size = request.params(2).size() / sizeof(bool);
bool *is_weighted_buffer = (bool *)(request.params(2).c_str()); const bool *is_weighted_buffer =
reinterpret_cast<const bool *>(request.params(2).c_str());
is_weighted_list = std::vector<bool>(is_weighted_buffer, is_weighted_list = std::vector<bool>(is_weighted_buffer,
is_weighted_buffer + weight_list_size); is_weighted_buffer + weight_list_size);
} }
...@@ -161,7 +164,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table, ...@@ -161,7 +164,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
// weight_list_size); // weight_list_size);
// } // }
((GraphTable *)table)->add_graph_node(idx_, node_ids, is_weighted_list); (reinterpret_cast<GraphTable *>(table))
->add_graph_node(idx_, node_ids, is_weighted_list);
return 0; return 0;
} }
int32_t GraphBrpcService::remove_graph_node(Table *table, int32_t GraphBrpcService::remove_graph_node(Table *table,
...@@ -176,12 +180,13 @@ int32_t GraphBrpcService::remove_graph_node(Table *table, ...@@ -176,12 +180,13 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"remove_graph_node request requires at least 2 arguments"); "remove_graph_node request requires at least 2 arguments");
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t); size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<uint64_t> node_ids(node_data, node_data + node_num);
((GraphTable *)table)->remove_graph_node(idx_, node_ids); (reinterpret_cast<GraphTable *>(table))->remove_graph_node(idx_, node_ids);
return 0; return 0;
} }
int32_t GraphBrpcServer::Port() { return _server.listen_address().port; } int32_t GraphBrpcServer::Port() { return _server.listen_address().port; }
...@@ -338,7 +343,7 @@ int32_t GraphBrpcService::StopServer(Table *table, ...@@ -338,7 +343,7 @@ int32_t GraphBrpcService::StopServer(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
GraphBrpcServer *p_server = (GraphBrpcServer *)_server; GraphBrpcServer *p_server = reinterpret_cast<GraphBrpcServer *>(_server);
std::thread t_stop([p_server]() { std::thread t_stop([p_server]() {
p_server->Stop(); p_server->Stop();
LOG(INFO) << "Server Stoped"; LOG(INFO) << "Server Stoped";
...@@ -375,14 +380,14 @@ int32_t GraphBrpcService::pull_graph_list(Table *table, ...@@ -375,14 +380,14 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response, -1, "pull_graph_list request requires at least 5 arguments"); response, -1, "pull_graph_list request requires at least 5 arguments");
return 0; return 0;
} }
int type_id = *(int *)(request.params(0).c_str()); int type_id = std::stoi(request.params(0).c_str());
int idx = *(int *)(request.params(1).c_str()); int idx = std::stoi(request.params(1).c_str());
int start = *(int *)(request.params(2).c_str()); int start = std::stoi(request.params(2).c_str());
int size = *(int *)(request.params(3).c_str()); int size = std::stoi(request.params(3).c_str());
int step = *(int *)(request.params(4).c_str()); int step = std::stoi(request.params(4).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
((GraphTable *)table) (reinterpret_cast<GraphTable *>(table))
->pull_graph_list( ->pull_graph_list(
type_id, idx, start, size, buffer, actual_size, false, step); type_id, idx, start, size, buffer, actual_size, false, step);
cntl->response_attachment().append(buffer.get(), actual_size); cntl->response_attachment().append(buffer.get(), actual_size);
...@@ -401,14 +406,16 @@ int32_t GraphBrpcService::graph_random_sample_neighbors( ...@@ -401,14 +406,16 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"); "graph_random_sample_neighbors request requires at least 3 arguments");
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t); size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); // NOLINT
int sample_size = *(int *)(request.params(2).c_str()); const int sample_size =
bool need_weight = *(bool *)(request.params(3).c_str()); *reinterpret_cast<const int *>(request.params(2).c_str());
const bool need_weight =
*reinterpret_cast<const bool *>(request.params(3).c_str());
std::vector<std::shared_ptr<char>> buffers(node_num); std::vector<std::shared_ptr<char>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0); std::vector<int> actual_sizes(node_num, 0);
((GraphTable *)table) (reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors( ->random_sample_neighbors(
idx_, node_data, sample_size, buffers, actual_sizes, need_weight); idx_, node_data, sample_size, buffers, actual_sizes, need_weight);
...@@ -425,18 +432,18 @@ int32_t GraphBrpcService::graph_random_sample_nodes( ...@@ -425,18 +432,18 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
int type_id = *(int *)(request.params(0).c_str()); int type_id = std::stoi(request.params(0).c_str());
int idx_ = *(int *)(request.params(1).c_str()); int idx_ = std::stoi(request.params(1).c_str());
size_t size = *(uint64_t *)(request.params(2).c_str()); size_t size = std::stoull(request.params(2).c_str());
// size_t size = *(int64_t *)(request.params(0).c_str()); // size_t size = *(int64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer; std::unique_ptr<char[]> buffer;
int actual_size; int actual_size;
if (((GraphTable *)table) if (reinterpret_cast<GraphTable *>(table)->random_sample_nodes(
->random_sample_nodes(type_id, idx_, size, buffer, actual_size) == type_id, idx_, size, buffer, actual_size) == 0) {
0) {
cntl->response_attachment().append(buffer.get(), actual_size); cntl->response_attachment().append(buffer.get(), actual_size);
} else } else {
cntl->response_attachment().append(NULL, 0); cntl->response_attachment().append(NULL, 0);
}
return 0; return 0;
} }
...@@ -453,9 +460,10 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -453,9 +460,10 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 3 arguments"); "graph_get_node_feat request requires at least 3 arguments");
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t); size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<uint64_t> node_ids(node_data, node_data + node_num);
std::vector<std::string> feature_names = std::vector<std::string> feature_names =
...@@ -464,7 +472,8 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table, ...@@ -464,7 +472,8 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
std::vector<std::vector<std::string>> feature( std::vector<std::vector<std::string>> feature(
feature_names.size(), std::vector<std::string>(node_num)); feature_names.size(), std::vector<std::string>(node_num));
((GraphTable *)table)->get_node_feat(idx_, node_ids, feature_names, feature); (reinterpret_cast<GraphTable *>(table))
->get_node_feat(idx_, node_ids, feature_names, feature);
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
...@@ -492,11 +501,12 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -492,11 +501,12 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t); size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); const uint64_t *node_data =
int sample_size = *(int *)(request.params(2).c_str()); reinterpret_cast<const uint64_t *>(request.params(1).c_str());
bool need_weight = *(bool *)(request.params(3).c_str()); int sample_size = std::stoi(request.params(2).c_str());
bool need_weight = std::stoi(request.params(3).c_str());
std::vector<int> request2server; std::vector<int> request2server;
std::vector<int> server2request(server_size, -1); std::vector<int> server2request(server_size, -1);
...@@ -504,8 +514,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -504,8 +514,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<int> local_query_idx; std::vector<int> local_query_idx;
size_t rank = GetRank(); size_t rank = GetRank();
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) { for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index = (reinterpret_cast<GraphTable *>(table))
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ->get_server_index_by_id(node_data[query_idx]);
if (server2request[server_index] == -1) { if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size(); server2request[server_index] = request2server.size();
request2server.push_back(server_index); request2server.push_back(server_index);
...@@ -514,10 +524,10 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -514,10 +524,10 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if (server2request[rank] != -1) { if (server2request[rank] != -1) {
auto pos = server2request[rank]; auto pos = server2request[rank];
std::swap(request2server[pos], std::swap(request2server[pos],
request2server[(int)request2server.size() - 1]); request2server[static_cast<int>(request2server.size()) - 1]);
server2request[request2server[pos]] = pos; server2request[request2server[pos]] = pos;
server2request[request2server[(int)request2server.size() - 1]] = server2request[request2server[static_cast<int>(request2server.size()) -
request2server.size() - 1; 1]] = request2server.size() - 1;
} }
size_t request_call_num = request2server.size(); size_t request_call_num = request2server.size();
std::vector<std::shared_ptr<char>> local_buffers; std::vector<std::shared_ptr<char>> local_buffers;
...@@ -526,8 +536,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -526,8 +536,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num); std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num); std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (size_t query_idx = 0; query_idx < node_num; ++query_idx) { for (size_t query_idx = 0; query_idx < node_num; ++query_idx) {
int server_index = int server_index = (reinterpret_cast<GraphTable *>(table))
((GraphTable *)table)->get_server_index_by_id(node_data[query_idx]); ->get_server_index_by_id(node_data[query_idx]);
int request_idx = server2request[server_index]; int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_data[query_idx]); node_id_buckets[request_idx].push_back(node_data[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx); query_idx_buckets[request_idx].push_back(query_idx);
...@@ -550,7 +560,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -550,7 +560,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
request_call_num](void *done) { request_call_num](void *done) {
local_fut.get(); local_fut.get();
std::vector<int> actual_size; std::vector<int> actual_size;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res( std::vector<std::unique_ptr<butil::IOBufBytesIterator>> res(
remote_call_num); remote_call_num);
size_t fail_num = 0; size_t fail_num = 0;
...@@ -610,17 +620,19 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -610,17 +620,19 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure->request(request_idx)->set_client_id(rank); closure->request(request_idx)->set_client_id(rank);
size_t node_num = node_id_buckets[request_idx].size(); size_t node_num = node_id_buckets[request_idx].size();
closure->request(request_idx)->add_params((char *)&idx_, sizeof(int)); closure->request(request_idx)
->add_params(reinterpret_cast<char *>(&idx_), sizeof(int));
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(), ->add_params(
sizeof(uint64_t) * node_num); reinterpret_cast<char *>(node_id_buckets[request_idx].data()),
sizeof(uint64_t) * node_num);
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int)); ->add_params(reinterpret_cast<char *>(&sample_size), sizeof(int));
closure->request(request_idx) closure->request(request_idx)
->add_params((char *)&need_weight, sizeof(bool)); ->add_params(reinterpret_cast<char *>(&need_weight), sizeof(bool));
PsService_Stub rpc_stub( PsService_Stub rpc_stub((reinterpret_cast<GraphBrpcServer *>(GetServer())
((GraphBrpcServer *)GetServer())->GetCmdChannel(server_index)); ->GetCmdChannel(server_index)));
// GraphPsService_Stub rpc_stub = // GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index)); // getServiceStub(GetCmdChannel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms()); closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
...@@ -630,7 +642,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( ...@@ -630,7 +642,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure); closure);
} }
if (server2request[rank] != -1) { if (server2request[rank] != -1) {
((GraphTable *)table) (reinterpret_cast<GraphTable *>(table))
->random_sample_neighbors(idx_, ->random_sample_neighbors(idx_,
node_id_buckets.back().data(), node_id_buckets.back().data(),
sample_size, sample_size,
...@@ -655,10 +667,11 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -655,10 +667,11 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"); "graph_set_node_feat request requires at least 3 arguments");
return 0; return 0;
} }
int idx_ = *(int *)(request.params(0).c_str()); int idx_ = std::stoi(request.params(0).c_str());
size_t node_num = request.params(1).size() / sizeof(uint64_t); size_t node_num = request.params(1).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(1).c_str()); const uint64_t *node_data =
reinterpret_cast<const uint64_t *>(request.params(1).c_str());
std::vector<uint64_t> node_ids(node_data, node_data + node_num); std::vector<uint64_t> node_ids(node_data, node_data + node_num);
// std::vector<std::string> feature_names = // std::vector<std::string> feature_names =
...@@ -675,7 +688,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -675,7 +688,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) { for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) { for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
size_t feat_len = *(size_t *)(buffer); const size_t feat_len = *reinterpret_cast<const size_t *>(buffer);
buffer += sizeof(size_t); buffer += sizeof(size_t);
auto feat = std::string(buffer, feat_len); auto feat = std::string(buffer, feat_len);
features[feat_idx][node_idx] = feat; features[feat_idx][node_idx] = feat;
...@@ -683,7 +696,8 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table, ...@@ -683,7 +696,8 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
} }
} }
((GraphTable *)table)->set_node_feat(idx_, node_ids, feature_names, features); (reinterpret_cast<GraphTable *>(table))
->set_node_feat(idx_, node_ids, feature_names, features);
return 0; return 0;
} }
......
...@@ -174,7 +174,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) { ...@@ -174,7 +174,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0; value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0; value[CtrDoubleFeatureValue::DeltaScoreIndex()] = 0;
*reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0; *reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()) = 0;
*(double*)(value + CtrDoubleFeatureValue::ClickIndex()) = 0; *reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ClickIndex()) = 0;
value[CtrDoubleFeatureValue::SlotIndex()] = -1; value[CtrDoubleFeatureValue::SlotIndex()] = -1;
bool zero_init = _config.ctr_accessor_param().zero_init(); bool zero_init = _config.ctr_accessor_param().zero_init();
_embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(), _embed_sgd_rule->InitValue(value + CtrDoubleFeatureValue::EmbedWIndex(),
...@@ -188,8 +188,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) { ...@@ -188,8 +188,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
return 0; return 0;
} }
bool CtrDoubleAccessor::NeedExtendMF(float* value) { bool CtrDoubleAccessor::NeedExtendMF(float* value) {
auto show = ((double*)(value + CtrDoubleFeatureValue::ShowIndex()))[0]; auto show = (reinterpret_cast<double*>(
auto click = ((double*)(value + CtrDoubleFeatureValue::ClickIndex()))[0]; value + CtrDoubleFeatureValue::ShowIndex()))[0];
auto click = (reinterpret_cast<double*>(
value + CtrDoubleFeatureValue::ClickIndex()))[0];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() // float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() + auto score = (show - click) * _config.ctr_accessor_param().nonclk_coeff() +
click * _config.ctr_accessor_param().click_coeff(); click * _config.ctr_accessor_param().click_coeff();
...@@ -204,10 +206,11 @@ int32_t CtrDoubleAccessor::Select(float** select_values, ...@@ -204,10 +206,11 @@ int32_t CtrDoubleAccessor::Select(float** select_values,
for (size_t value_item = 0; value_item < num; ++value_item) { for (size_t value_item = 0; value_item < num; ++value_item) {
float* select_value = select_values[value_item]; float* select_value = select_values[value_item];
float* value = const_cast<float*>(values[value_item]); float* value = const_cast<float*>(values[value_item]);
select_value[CtrDoublePullValue::ShowIndex()] = select_value[CtrDoublePullValue::ShowIndex()] = static_cast<float>(
(float)*(double*)(value + CtrDoubleFeatureValue::ShowIndex()); *reinterpret_cast<double*>(value + CtrDoubleFeatureValue::ShowIndex()));
select_value[CtrDoublePullValue::ClickIndex()] = select_value[CtrDoublePullValue::ClickIndex()] =
(float)*(double*)(value + CtrDoubleFeatureValue::ClickIndex()); static_cast<float>(*reinterpret_cast<double*>(
value + CtrDoubleFeatureValue::ClickIndex()));
select_value[CtrDoublePullValue::EmbedWIndex()] = select_value[CtrDoublePullValue::EmbedWIndex()] =
value[CtrDoubleFeatureValue::EmbedWIndex()]; value[CtrDoubleFeatureValue::EmbedWIndex()];
memcpy(select_value + CtrDoublePullValue::EmbedxWIndex(), memcpy(select_value + CtrDoublePullValue::EmbedxWIndex(),
...@@ -254,15 +257,17 @@ int32_t CtrDoubleAccessor::Update(float** update_values, ...@@ -254,15 +257,17 @@ int32_t CtrDoubleAccessor::Update(float** update_values,
float push_show = push_value[CtrDoublePushValue::ShowIndex()]; float push_show = push_value[CtrDoublePushValue::ShowIndex()];
float push_click = push_value[CtrDoublePushValue::ClickIndex()]; float push_click = push_value[CtrDoublePushValue::ClickIndex()];
float slot = push_value[CtrDoublePushValue::SlotIndex()]; float slot = push_value[CtrDoublePushValue::SlotIndex()];
*(double*)(update_value + CtrDoubleFeatureValue::ShowIndex()) += *reinterpret_cast<double*>(update_value +
(double)push_show; CtrDoubleFeatureValue::ShowIndex()) +=
*(double*)(update_value + CtrDoubleFeatureValue::ClickIndex()) += static_cast<double>(push_show);
(double)push_click; *reinterpret_cast<double*>(update_value +
CtrDoubleFeatureValue::ClickIndex()) +=
static_cast<double>(push_click);
update_value[CtrDoubleFeatureValue::SlotIndex()] = slot; update_value[CtrDoubleFeatureValue::SlotIndex()] = slot;
update_value[CtrDoubleFeatureValue::DeltaScoreIndex()] += update_value[CtrDoubleFeatureValue::DeltaScoreIndex()] +=
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
push_click * _config.ctr_accessor_param().click_coeff(); push_click * _config.ctr_accessor_param().click_coeff();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() + // (push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff(); // push_click * _config.ctr_accessor_param().click_coeff();
update_value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0; update_value[CtrDoubleFeatureValue::UnseenDaysIndex()] = 0;
if (!_show_scale) { if (!_show_scale) {
...@@ -315,9 +320,11 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) { ...@@ -315,9 +320,11 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
thread_local std::ostringstream os; thread_local std::ostringstream os;
os.clear(); os.clear();
os.str(""); os.str("");
os << v[0] << " " << v[1] << " " << (float)((double*)(v + 2))[0] << " " os << v[0] << " " << v[1] << " "
<< (float)((double*)(v + 4))[0] << " " << v[6] << " " << v[7] << " " << static_cast<const float>((reinterpret_cast<const double*>(v + 2))[0])
<< v[8]; << " "
<< static_cast<const float>((reinterpret_cast<const double*>(v + 4))[0])
<< " " << v[6] << " " << v[7] << " " << v[8];
auto show = CtrDoubleFeatureValue::Show(const_cast<float*>(v)); auto show = CtrDoubleFeatureValue::Show(const_cast<float*>(v));
auto click = CtrDoubleFeatureValue::Click(const_cast<float*>(v)); auto click = CtrDoubleFeatureValue::Click(const_cast<float*>(v));
auto score = ShowClickScore(show, click); auto score = ShowClickScore(show, click);
...@@ -331,7 +338,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) { ...@@ -331,7 +338,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
} }
int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) { int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
int embedx_dim = _config.embedx_dim(); int embedx_dim = _config.embedx_dim();
float data_buff[_accessor_info.dim + 2]; float data_buff[_accessor_info.dim + 2]; // NOLINT
float* data_buff_ptr = data_buff; float* data_buff_ptr = data_buff;
_embedx_sgd_rule->InitValue( _embedx_sgd_rule->InitValue(
data_buff_ptr + CtrDoubleFeatureValue::EmbedxWIndex(), data_buff_ptr + CtrDoubleFeatureValue::EmbedxWIndex(),
...@@ -350,8 +357,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) { ...@@ -350,8 +357,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score // copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float)); memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click // copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2]; *reinterpret_cast<double*>(value + show_index) =
*(double*)(value + click_index) = (double)data_buff_ptr[3]; static_cast<double>(data_buff_ptr[2]);
*reinterpret_cast<double*>(value + click_index) =
static_cast<double>(data_buff_ptr[3]);
// copy others // copy others
value[CtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4]; value[CtrDoubleFeatureValue::EmbedWIndex()] = data_buff_ptr[4];
value[CtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5]; value[CtrDoubleFeatureValue::EmbedG2SumIndex()] = data_buff_ptr[5];
...@@ -362,8 +371,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) { ...@@ -362,8 +371,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score // copy unseen_days..delta_score
memcpy(value, data_buff_ptr, show_index * sizeof(float)); memcpy(value, data_buff_ptr, show_index * sizeof(float));
// copy show & click // copy show & click
*(double*)(value + show_index) = (double)data_buff_ptr[2]; *reinterpret_cast<double*>(value + show_index) =
*(double*)(value + click_index) = (double)data_buff_ptr[3]; static_cast<double>(data_buff_ptr[2]);
*reinterpret_cast<double*>(value + click_index) =
static_cast<double>(data_buff_ptr[3]);
// copy embed_w..embedx_w // copy embed_w..embedx_w
memcpy(value + embed_w_index, memcpy(value + embed_w_index,
data_buff_ptr + 4, data_buff_ptr + 4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册