Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d4b3bfab
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d4b3bfab
编写于
1月 09, 2023
作者:
W
wangzhen38
提交者:
GitHub
1月 09, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[code_style fix] graph_brpc_server cpplint (#49462)
上级
36c6c589
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
100 addition
and
75 deletion
+100
-75
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
+70
-56
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
+30
-19
未找到文件。
paddle/fluid/distributed/ps/service/graph_brpc_server.cc
浏览文件 @
d4b3bfab
...
...
@@ -14,6 +14,7 @@
#include "paddle/fluid/distributed/ps/service/graph_brpc_server.h"
#include <string>
#include <thread> // NOLINT
#include <utility>
...
...
@@ -125,9 +126,9 @@ int32_t GraphBrpcService::clear_nodes(Table *table,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)
(
request
.
params
(
1
).
c_str
());
(
(
GraphTable
*
)
table
)
->
clear_nodes
(
type_id
,
idx_
);
int
type_id
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
1
).
c_str
());
(
reinterpret_cast
<
GraphTable
*>
(
table
)
)
->
clear_nodes
(
type_id
,
idx_
);
return
0
;
}
...
...
@@ -142,14 +143,16 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
int64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
const
uint64_t
*
node_data
=
reinterpret_cast
<
const
uint64_t
*>
(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
bool
>
is_weighted_list
;
if
(
request
.
params_size
()
==
3
)
{
size_t
weight_list_size
=
request
.
params
(
2
).
size
()
/
sizeof
(
bool
);
bool
*
is_weighted_buffer
=
(
bool
*
)(
request
.
params
(
2
).
c_str
());
const
bool
*
is_weighted_buffer
=
reinterpret_cast
<
const
bool
*>
(
request
.
params
(
2
).
c_str
());
is_weighted_list
=
std
::
vector
<
bool
>
(
is_weighted_buffer
,
is_weighted_buffer
+
weight_list_size
);
}
...
...
@@ -161,7 +164,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
// weight_list_size);
// }
((
GraphTable
*
)
table
)
->
add_graph_node
(
idx_
,
node_ids
,
is_weighted_list
);
(
reinterpret_cast
<
GraphTable
*>
(
table
))
->
add_graph_node
(
idx_
,
node_ids
,
is_weighted_list
);
return
0
;
}
int32_t
GraphBrpcService
::
remove_graph_node
(
Table
*
table
,
...
...
@@ -176,12 +180,13 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
"remove_graph_node request requires at least 2 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
const
uint64_t
*
node_data
=
reinterpret_cast
<
const
uint64_t
*>
(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
(
(
GraphTable
*
)
table
)
->
remove_graph_node
(
idx_
,
node_ids
);
(
reinterpret_cast
<
GraphTable
*>
(
table
)
)
->
remove_graph_node
(
idx_
,
node_ids
);
return
0
;
}
int32_t
GraphBrpcServer
::
Port
()
{
return
_server
.
listen_address
().
port
;
}
...
...
@@ -338,7 +343,7 @@ int32_t GraphBrpcService::StopServer(Table *table,
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
GraphBrpcServer
*
p_server
=
(
GraphBrpcServer
*
)
_server
;
GraphBrpcServer
*
p_server
=
reinterpret_cast
<
GraphBrpcServer
*>
(
_server
)
;
std
::
thread
t_stop
([
p_server
]()
{
p_server
->
Stop
();
LOG
(
INFO
)
<<
"Server Stoped"
;
...
...
@@ -375,14 +380,14 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
response
,
-
1
,
"pull_graph_list request requires at least 5 arguments"
);
return
0
;
}
int
type_id
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx
=
*
(
int
*
)
(
request
.
params
(
1
).
c_str
());
int
start
=
*
(
int
*
)
(
request
.
params
(
2
).
c_str
());
int
size
=
*
(
int
*
)
(
request
.
params
(
3
).
c_str
());
int
step
=
*
(
int
*
)
(
request
.
params
(
4
).
c_str
());
int
type_id
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
int
idx
=
std
::
stoi
(
request
.
params
(
1
).
c_str
());
int
start
=
std
::
stoi
(
request
.
params
(
2
).
c_str
());
int
size
=
std
::
stoi
(
request
.
params
(
3
).
c_str
());
int
step
=
std
::
stoi
(
request
.
params
(
4
).
c_str
());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
(
(
GraphTable
*
)
table
)
(
reinterpret_cast
<
GraphTable
*>
(
table
)
)
->
pull_graph_list
(
type_id
,
idx
,
start
,
size
,
buffer
,
actual_size
,
false
,
step
);
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
...
...
@@ -401,14 +406,16 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
"graph_random_sample_neighbors request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
3
).
c_str
());
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
// NOLINT
const
int
sample_size
=
*
reinterpret_cast
<
const
int
*>
(
request
.
params
(
2
).
c_str
());
const
bool
need_weight
=
*
reinterpret_cast
<
const
bool
*>
(
request
.
params
(
3
).
c_str
());
std
::
vector
<
std
::
shared_ptr
<
char
>>
buffers
(
node_num
);
std
::
vector
<
int
>
actual_sizes
(
node_num
,
0
);
(
(
GraphTable
*
)
table
)
(
reinterpret_cast
<
GraphTable
*>
(
table
)
)
->
random_sample_neighbors
(
idx_
,
node_data
,
sample_size
,
buffers
,
actual_sizes
,
need_weight
);
...
...
@@ -425,18 +432,18 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
const
PsRequestMessage
&
request
,
PsResponseMessage
&
response
,
brpc
::
Controller
*
cntl
)
{
int
type_id
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
*
(
int
*
)
(
request
.
params
(
1
).
c_str
());
size_t
size
=
*
(
uint64_t
*
)
(
request
.
params
(
2
).
c_str
());
int
type_id
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
1
).
c_str
());
size_t
size
=
std
::
stoull
(
request
.
params
(
2
).
c_str
());
// size_t size = *(int64_t *)(request.params(0).c_str());
std
::
unique_ptr
<
char
[]
>
buffer
;
int
actual_size
;
if
(((
GraphTable
*
)
table
)
->
random_sample_nodes
(
type_id
,
idx_
,
size
,
buffer
,
actual_size
)
==
0
)
{
if
(
reinterpret_cast
<
GraphTable
*>
(
table
)
->
random_sample_nodes
(
type_id
,
idx_
,
size
,
buffer
,
actual_size
)
==
0
)
{
cntl
->
response_attachment
().
append
(
buffer
.
get
(),
actual_size
);
}
else
}
else
{
cntl
->
response_attachment
().
append
(
NULL
,
0
);
}
return
0
;
}
...
...
@@ -453,9 +460,10 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
"graph_get_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
const
uint64_t
*
node_data
=
reinterpret_cast
<
const
uint64_t
*>
(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
std
::
vector
<
std
::
string
>
feature_names
=
...
...
@@ -464,7 +472,8 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
std
::
vector
<
std
::
vector
<
std
::
string
>>
feature
(
feature_names
.
size
(),
std
::
vector
<
std
::
string
>
(
node_num
));
((
GraphTable
*
)
table
)
->
get_node_feat
(
idx_
,
node_ids
,
feature_names
,
feature
);
(
reinterpret_cast
<
GraphTable
*>
(
table
))
->
get_node_feat
(
idx_
,
node_ids
,
feature_names
,
feature
);
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
...
...
@@ -492,11 +501,12 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
int
sample_size
=
*
(
int
*
)(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
*
(
bool
*
)(
request
.
params
(
3
).
c_str
());
const
uint64_t
*
node_data
=
reinterpret_cast
<
const
uint64_t
*>
(
request
.
params
(
1
).
c_str
());
int
sample_size
=
std
::
stoi
(
request
.
params
(
2
).
c_str
());
bool
need_weight
=
std
::
stoi
(
request
.
params
(
3
).
c_str
());
std
::
vector
<
int
>
request2server
;
std
::
vector
<
int
>
server2request
(
server_size
,
-
1
);
...
...
@@ -504,8 +514,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std
::
vector
<
int
>
local_query_idx
;
size_t
rank
=
GetRank
();
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
server_index
=
(
reinterpret_cast
<
GraphTable
*>
(
table
))
->
get_server_index_by_id
(
node_data
[
query_idx
]);
if
(
server2request
[
server_index
]
==
-
1
)
{
server2request
[
server_index
]
=
request2server
.
size
();
request2server
.
push_back
(
server_index
);
...
...
@@ -514,10 +524,10 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
if
(
server2request
[
rank
]
!=
-
1
)
{
auto
pos
=
server2request
[
rank
];
std
::
swap
(
request2server
[
pos
],
request2server
[
(
int
)
request2server
.
size
(
)
-
1
]);
request2server
[
static_cast
<
int
>
(
request2server
.
size
()
)
-
1
]);
server2request
[
request2server
[
pos
]]
=
pos
;
server2request
[
request2server
[
(
int
)
request2server
.
size
()
-
1
]]
=
request2server
.
size
()
-
1
;
server2request
[
request2server
[
static_cast
<
int
>
(
request2server
.
size
())
-
1
]]
=
request2server
.
size
()
-
1
;
}
size_t
request_call_num
=
request2server
.
size
();
std
::
vector
<
std
::
shared_ptr
<
char
>>
local_buffers
;
...
...
@@ -526,8 +536,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
std
::
vector
<
std
::
vector
<
uint64_t
>>
node_id_buckets
(
request_call_num
);
std
::
vector
<
std
::
vector
<
int
>>
query_idx_buckets
(
request_call_num
);
for
(
size_t
query_idx
=
0
;
query_idx
<
node_num
;
++
query_idx
)
{
int
server_index
=
((
GraphTable
*
)
table
)
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
server_index
=
(
reinterpret_cast
<
GraphTable
*>
(
table
))
->
get_server_index_by_id
(
node_data
[
query_idx
]);
int
request_idx
=
server2request
[
server_index
];
node_id_buckets
[
request_idx
].
push_back
(
node_data
[
query_idx
]);
query_idx_buckets
[
request_idx
].
push_back
(
query_idx
);
...
...
@@ -550,7 +560,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
request_call_num
](
void
*
done
)
{
local_fut
.
get
();
std
::
vector
<
int
>
actual_size
;
auto
*
closure
=
(
DownpourBrpcClosure
*
)
done
;
auto
*
closure
=
reinterpret_cast
<
DownpourBrpcClosure
*>
(
done
)
;
std
::
vector
<
std
::
unique_ptr
<
butil
::
IOBufBytesIterator
>>
res
(
remote_call_num
);
size_t
fail_num
=
0
;
...
...
@@ -610,17 +620,19 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure
->
request
(
request_idx
)
->
set_client_id
(
rank
);
size_t
node_num
=
node_id_buckets
[
request_idx
].
size
();
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
&
idx_
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
(
reinterpret_cast
<
char
*>
(
&
idx_
),
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
((
char
*
)
node_id_buckets
[
request_idx
].
data
(),
sizeof
(
uint64_t
)
*
node_num
);
->
add_params
(
reinterpret_cast
<
char
*>
(
node_id_buckets
[
request_idx
].
data
()),
sizeof
(
uint64_t
)
*
node_num
);
closure
->
request
(
request_idx
)
->
add_params
(
(
char
*
)
&
sample_size
,
sizeof
(
int
));
->
add_params
(
reinterpret_cast
<
char
*>
(
&
sample_size
)
,
sizeof
(
int
));
closure
->
request
(
request_idx
)
->
add_params
(
(
char
*
)
&
need_weight
,
sizeof
(
bool
));
PsService_Stub
rpc_stub
(
((
GraphBrpcServer
*
)
GetServer
())
->
GetCmdChannel
(
server_index
));
->
add_params
(
reinterpret_cast
<
char
*>
(
&
need_weight
)
,
sizeof
(
bool
));
PsService_Stub
rpc_stub
(
(
reinterpret_cast
<
GraphBrpcServer
*>
(
GetServer
())
->
GetCmdChannel
(
server_index
)
));
// GraphPsService_Stub rpc_stub =
// getServiceStub(GetCmdChannel(server_index));
closure
->
cntl
(
request_idx
)
->
set_log_id
(
butil
::
gettimeofday_ms
());
...
...
@@ -630,7 +642,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
closure
);
}
if
(
server2request
[
rank
]
!=
-
1
)
{
(
(
GraphTable
*
)
table
)
(
reinterpret_cast
<
GraphTable
*>
(
table
)
)
->
random_sample_neighbors
(
idx_
,
node_id_buckets
.
back
().
data
(),
sample_size
,
...
...
@@ -655,10 +667,11 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
"graph_set_node_feat request requires at least 3 arguments"
);
return
0
;
}
int
idx_
=
*
(
int
*
)
(
request
.
params
(
0
).
c_str
());
int
idx_
=
std
::
stoi
(
request
.
params
(
0
).
c_str
());
size_t
node_num
=
request
.
params
(
1
).
size
()
/
sizeof
(
uint64_t
);
uint64_t
*
node_data
=
(
uint64_t
*
)(
request
.
params
(
1
).
c_str
());
const
uint64_t
*
node_data
=
reinterpret_cast
<
const
uint64_t
*>
(
request
.
params
(
1
).
c_str
());
std
::
vector
<
uint64_t
>
node_ids
(
node_data
,
node_data
+
node_num
);
// std::vector<std::string> feature_names =
...
...
@@ -675,7 +688,7 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
for
(
size_t
node_idx
=
0
;
node_idx
<
node_num
;
++
node_idx
)
{
size_t
feat_len
=
*
(
size_t
*
)
(
buffer
);
const
size_t
feat_len
=
*
reinterpret_cast
<
const
size_t
*>
(
buffer
);
buffer
+=
sizeof
(
size_t
);
auto
feat
=
std
::
string
(
buffer
,
feat_len
);
features
[
feat_idx
][
node_idx
]
=
feat
;
...
...
@@ -683,7 +696,8 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
}
}
((
GraphTable
*
)
table
)
->
set_node_feat
(
idx_
,
node_ids
,
feature_names
,
features
);
(
reinterpret_cast
<
GraphTable
*>
(
table
))
->
set_node_feat
(
idx_
,
node_ids
,
feature_names
,
features
);
return
0
;
}
...
...
paddle/fluid/distributed/ps/table/ctr_double_accessor.cc
浏览文件 @
d4b3bfab
...
...
@@ -174,7 +174,7 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
value
[
CtrDoubleFeatureValue
::
UnseenDaysIndex
()]
=
0
;
value
[
CtrDoubleFeatureValue
::
DeltaScoreIndex
()]
=
0
;
*
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ShowIndex
())
=
0
;
*
(
double
*
)
(
value
+
CtrDoubleFeatureValue
::
ClickIndex
())
=
0
;
*
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ClickIndex
())
=
0
;
value
[
CtrDoubleFeatureValue
::
SlotIndex
()]
=
-
1
;
bool
zero_init
=
_config
.
ctr_accessor_param
().
zero_init
();
_embed_sgd_rule
->
InitValue
(
value
+
CtrDoubleFeatureValue
::
EmbedWIndex
(),
...
...
@@ -188,8 +188,10 @@ int32_t CtrDoubleAccessor::Create(float** values, size_t num) {
return
0
;
}
bool
CtrDoubleAccessor
::
NeedExtendMF
(
float
*
value
)
{
auto
show
=
((
double
*
)(
value
+
CtrDoubleFeatureValue
::
ShowIndex
()))[
0
];
auto
click
=
((
double
*
)(
value
+
CtrDoubleFeatureValue
::
ClickIndex
()))[
0
];
auto
show
=
(
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ShowIndex
()))[
0
];
auto
click
=
(
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ClickIndex
()))[
0
];
// float score = (show - click) * _config.ctr_accessor_param().nonclk_coeff()
auto
score
=
(
show
-
click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
click
*
_config
.
ctr_accessor_param
().
click_coeff
();
...
...
@@ -204,10 +206,11 @@ int32_t CtrDoubleAccessor::Select(float** select_values,
for
(
size_t
value_item
=
0
;
value_item
<
num
;
++
value_item
)
{
float
*
select_value
=
select_values
[
value_item
];
float
*
value
=
const_cast
<
float
*>
(
values
[
value_item
]);
select_value
[
CtrDoublePullValue
::
ShowIndex
()]
=
(
float
)
*
(
double
*
)(
value
+
CtrDoubleFeatureValue
::
ShowIndex
(
));
select_value
[
CtrDoublePullValue
::
ShowIndex
()]
=
static_cast
<
float
>
(
*
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ShowIndex
()
));
select_value
[
CtrDoublePullValue
::
ClickIndex
()]
=
(
float
)
*
(
double
*
)(
value
+
CtrDoubleFeatureValue
::
ClickIndex
());
static_cast
<
float
>
(
*
reinterpret_cast
<
double
*>
(
value
+
CtrDoubleFeatureValue
::
ClickIndex
()));
select_value
[
CtrDoublePullValue
::
EmbedWIndex
()]
=
value
[
CtrDoubleFeatureValue
::
EmbedWIndex
()];
memcpy
(
select_value
+
CtrDoublePullValue
::
EmbedxWIndex
(),
...
...
@@ -254,15 +257,17 @@ int32_t CtrDoubleAccessor::Update(float** update_values,
float
push_show
=
push_value
[
CtrDoublePushValue
::
ShowIndex
()];
float
push_click
=
push_value
[
CtrDoublePushValue
::
ClickIndex
()];
float
slot
=
push_value
[
CtrDoublePushValue
::
SlotIndex
()];
*
(
double
*
)(
update_value
+
CtrDoubleFeatureValue
::
ShowIndex
())
+=
(
double
)
push_show
;
*
(
double
*
)(
update_value
+
CtrDoubleFeatureValue
::
ClickIndex
())
+=
(
double
)
push_click
;
*
reinterpret_cast
<
double
*>
(
update_value
+
CtrDoubleFeatureValue
::
ShowIndex
())
+=
static_cast
<
double
>
(
push_show
);
*
reinterpret_cast
<
double
*>
(
update_value
+
CtrDoubleFeatureValue
::
ClickIndex
())
+=
static_cast
<
double
>
(
push_click
);
update_value
[
CtrDoubleFeatureValue
::
SlotIndex
()]
=
slot
;
update_value
[
CtrDoubleFeatureValue
::
DeltaScoreIndex
()]
+=
(
push_show
-
push_click
)
*
_config
.
ctr_accessor_param
().
nonclk_coeff
()
+
push_click
*
_config
.
ctr_accessor_param
().
click_coeff
();
//(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
//
(push_show - push_click) * _config.ctr_accessor_param().nonclk_coeff() +
// push_click * _config.ctr_accessor_param().click_coeff();
update_value
[
CtrDoubleFeatureValue
::
UnseenDaysIndex
()]
=
0
;
if
(
!
_show_scale
)
{
...
...
@@ -315,9 +320,11 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
thread_local
std
::
ostringstream
os
;
os
.
clear
();
os
.
str
(
""
);
os
<<
v
[
0
]
<<
" "
<<
v
[
1
]
<<
" "
<<
(
float
)((
double
*
)(
v
+
2
))[
0
]
<<
" "
<<
(
float
)((
double
*
)(
v
+
4
))[
0
]
<<
" "
<<
v
[
6
]
<<
" "
<<
v
[
7
]
<<
" "
<<
v
[
8
];
os
<<
v
[
0
]
<<
" "
<<
v
[
1
]
<<
" "
<<
static_cast
<
const
float
>
((
reinterpret_cast
<
const
double
*>
(
v
+
2
))[
0
])
<<
" "
<<
static_cast
<
const
float
>
((
reinterpret_cast
<
const
double
*>
(
v
+
4
))[
0
])
<<
" "
<<
v
[
6
]
<<
" "
<<
v
[
7
]
<<
" "
<<
v
[
8
];
auto
show
=
CtrDoubleFeatureValue
::
Show
(
const_cast
<
float
*>
(
v
));
auto
click
=
CtrDoubleFeatureValue
::
Click
(
const_cast
<
float
*>
(
v
));
auto
score
=
ShowClickScore
(
show
,
click
);
...
...
@@ -331,7 +338,7 @@ std::string CtrDoubleAccessor::ParseToString(const float* v, int param_size) {
}
int
CtrDoubleAccessor
::
ParseFromString
(
const
std
::
string
&
str
,
float
*
value
)
{
int
embedx_dim
=
_config
.
embedx_dim
();
float
data_buff
[
_accessor_info
.
dim
+
2
];
float
data_buff
[
_accessor_info
.
dim
+
2
];
// NOLINT
float
*
data_buff_ptr
=
data_buff
;
_embedx_sgd_rule
->
InitValue
(
data_buff_ptr
+
CtrDoubleFeatureValue
::
EmbedxWIndex
(),
...
...
@@ -350,8 +357,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score
memcpy
(
value
,
data_buff_ptr
,
show_index
*
sizeof
(
float
));
// copy show & click
*
(
double
*
)(
value
+
show_index
)
=
(
double
)
data_buff_ptr
[
2
];
*
(
double
*
)(
value
+
click_index
)
=
(
double
)
data_buff_ptr
[
3
];
*
reinterpret_cast
<
double
*>
(
value
+
show_index
)
=
static_cast
<
double
>
(
data_buff_ptr
[
2
]);
*
reinterpret_cast
<
double
*>
(
value
+
click_index
)
=
static_cast
<
double
>
(
data_buff_ptr
[
3
]);
// copy others
value
[
CtrDoubleFeatureValue
::
EmbedWIndex
()]
=
data_buff_ptr
[
4
];
value
[
CtrDoubleFeatureValue
::
EmbedG2SumIndex
()]
=
data_buff_ptr
[
5
];
...
...
@@ -362,8 +371,10 @@ int CtrDoubleAccessor::ParseFromString(const std::string& str, float* value) {
// copy unseen_days..delta_score
memcpy
(
value
,
data_buff_ptr
,
show_index
*
sizeof
(
float
));
// copy show & click
*
(
double
*
)(
value
+
show_index
)
=
(
double
)
data_buff_ptr
[
2
];
*
(
double
*
)(
value
+
click_index
)
=
(
double
)
data_buff_ptr
[
3
];
*
reinterpret_cast
<
double
*>
(
value
+
show_index
)
=
static_cast
<
double
>
(
data_buff_ptr
[
2
]);
*
reinterpret_cast
<
double
*>
(
value
+
click_index
)
=
static_cast
<
double
>
(
data_buff_ptr
[
3
]);
// copy embed_w..embedx_w
memcpy
(
value
+
embed_w_index
,
data_buff_ptr
+
4
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录