Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
be273ea9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be273ea9
编写于
10月 19, 2022
作者:
L
Li-fAngyU
提交者:
GitHub
10月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix build warning: [Wsign-compare] on linux (#46644)
上级
ddf317ed
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
244 addition
and
232 deletion
+244
-232
paddle/fluid/distributed/ps/table/common_graph_table.cc
paddle/fluid/distributed/ps/table/common_graph_table.cc
+85
-73
paddle/fluid/distributed/ps/table/memory_dense_table.cc
paddle/fluid/distributed/ps/table/memory_dense_table.cc
+22
-22
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
+92
-92
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
+45
-45
未找到文件。
paddle/fluid/distributed/ps/table/common_graph_table.cc
浏览文件 @
be273ea9
...
@@ -78,7 +78,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
...
@@ -78,7 +78,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
paddle
::
framework
::
GpuPsFeaInfo
x
;
paddle
::
framework
::
GpuPsFeaInfo
x
;
std
::
vector
<
uint64_t
>
feature_ids
;
std
::
vector
<
uint64_t
>
feature_ids
;
for
(
size_t
j
=
0
;
j
<
bags
[
i
].
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
bags
[
i
].
size
();
j
++
)
{
// TODO use FEATURE_TABLE instead
// TODO
(danleifeng):
use FEATURE_TABLE instead
Node
*
v
=
find_node
(
1
,
bags
[
i
][
j
]);
Node
*
v
=
find_node
(
1
,
bags
[
i
][
j
]);
node_id
=
bags
[
i
][
j
];
node_id
=
bags
[
i
][
j
];
if
(
v
==
NULL
)
{
if
(
v
==
NULL
)
{
...
@@ -109,7 +109,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
...
@@ -109,7 +109,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
}));
}));
}
}
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
paddle
::
framework
::
GpuPsCommGraphFea
res
;
paddle
::
framework
::
GpuPsCommGraphFea
res
;
uint64_t
tot_len
=
0
;
uint64_t
tot_len
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
...
@@ -120,7 +120,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
...
@@ -120,7 +120,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
res
.
init_on_cpu
(
tot_len
,
(
unsigned
int
)
node_ids
.
size
(),
slot_num
);
res
.
init_on_cpu
(
tot_len
,
(
unsigned
int
)
node_ids
.
size
(),
slot_num
);
unsigned
int
offset
=
0
,
ind
=
0
;
unsigned
int
offset
=
0
,
ind
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
node_id_array
[
i
].
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
node_id_array
[
i
].
size
();
j
++
)
{
res
.
node_list
[
ind
]
=
node_id_array
[
i
][
j
];
res
.
node_list
[
ind
]
=
node_id_array
[
i
][
j
];
res
.
fea_info_list
[
ind
]
=
node_fea_info_array
[
i
][
j
];
res
.
fea_info_list
[
ind
]
=
node_fea_info_array
[
i
][
j
];
res
.
fea_info_list
[
ind
++
].
feature_offset
+=
offset
;
res
.
fea_info_list
[
ind
++
].
feature_offset
+=
offset
;
...
@@ -177,7 +177,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
...
@@ -177,7 +177,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
}));
}));
}
}
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
int64_t
tot_len
=
0
;
int64_t
tot_len
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
...
@@ -188,7 +188,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
...
@@ -188,7 +188,7 @@ paddle::framework::GpuPsCommGraph GraphTable::make_gpu_ps_graph(
res
.
init_on_cpu
(
tot_len
,
ids
.
size
());
res
.
init_on_cpu
(
tot_len
,
ids
.
size
());
int64_t
offset
=
0
,
ind
=
0
;
int64_t
offset
=
0
,
ind
=
0
;
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
task_pool_size_
;
i
++
)
{
for
(
int
j
=
0
;
j
<
(
int
)
node_array
[
i
].
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
node_array
[
i
].
size
();
j
++
)
{
res
.
node_list
[
ind
]
=
node_array
[
i
][
j
];
res
.
node_list
[
ind
]
=
node_array
[
i
][
j
];
res
.
node_info_list
[
ind
]
=
info_array
[
i
][
j
];
res
.
node_info_list
[
ind
]
=
info_array
[
i
][
j
];
res
.
node_info_list
[
ind
++
].
neighbor_offset
+=
offset
;
res
.
node_info_list
[
ind
++
].
neighbor_offset
+=
offset
;
...
@@ -213,7 +213,7 @@ int32_t GraphTable::add_node_to_ssd(
...
@@ -213,7 +213,7 @@ int32_t GraphTable::add_node_to_ssd(
ch
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
str
)
==
0
)
{
uint64_t
*
stored_data
=
((
uint64_t
*
)
str
.
c_str
());
uint64_t
*
stored_data
=
((
uint64_t
*
)
str
.
c_str
());
// NOLINT
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
char
*
new_data
=
new
char
[
n
*
sizeof
(
uint64_t
)
+
len
];
char
*
new_data
=
new
char
[
n
*
sizeof
(
uint64_t
)
+
len
];
memcpy
(
new_data
,
stored_data
,
n
*
sizeof
(
uint64_t
));
memcpy
(
new_data
,
stored_data
,
n
*
sizeof
(
uint64_t
));
...
@@ -221,14 +221,14 @@ int32_t GraphTable::add_node_to_ssd(
...
@@ -221,14 +221,14 @@ int32_t GraphTable::add_node_to_ssd(
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
ch
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
(
char
*
)
new_data
,
(
char
*
)
new_data
,
// NOLINT
n
*
sizeof
(
uint64_t
)
+
len
);
n
*
sizeof
(
uint64_t
)
+
len
);
delete
[]
new_data
;
delete
[]
new_data
;
}
else
{
}
else
{
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
_db
->
put
(
src_id
%
shard_num
%
task_pool_size_
,
ch
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
(
char
*
)
data
,
(
char
*
)
data
,
// NOLINT
len
);
len
);
}
}
}
}
...
@@ -254,7 +254,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
...
@@ -254,7 +254,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
ch
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
str
)
==
0
)
{
uint64_t
*
data
=
((
uint64_t
*
)
str
.
c_str
());
uint64_t
*
data
=
((
uint64_t
*
)
str
.
c_str
());
// NOLINT
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
int
n
=
str
.
size
()
/
sizeof
(
uint64_t
);
std
::
unordered_map
<
int
,
int
>
m
;
std
::
unordered_map
<
int
,
int
>
m
;
// std::vector<uint64_t> res;
// std::vector<uint64_t> res;
...
@@ -281,7 +281,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
...
@@ -281,7 +281,7 @@ char *GraphTable::random_sample_neighbor_from_ssd(
// res.push_back(data[pos]);
// res.push_back(data[pos]);
}
}
for
(
int
i
=
0
;
i
<
actual_size
;
i
+=
8
)
{
for
(
int
i
=
0
;
i
<
actual_size
;
i
+=
8
)
{
VLOG
(
2
)
<<
"sampled an neighbor "
<<
*
(
uint64_t
*
)
&
buff
[
i
];
VLOG
(
2
)
<<
"sampled an neighbor "
<<
*
(
uint64_t
*
)
&
buff
[
i
];
// NOLINT
}
}
return
buff
;
return
buff
;
}
}
...
@@ -310,8 +310,8 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
...
@@ -310,8 +310,8 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
std
::
string
str
;
std
::
string
str
;
if
(
_db
->
get
(
i
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
if
(
_db
->
get
(
i
,
ch
,
sizeof
(
int
)
*
2
+
sizeof
(
uint64_t
),
str
)
==
0
)
{
count
[
i
]
+=
(
int64_t
)
str
.
size
();
count
[
i
]
+=
(
int64_t
)
str
.
size
();
for
(
size_t
j
=
0
;
j
<
(
int
)
str
.
size
();
j
+=
sizeof
(
uint64_t
))
{
for
(
size_t
j
=
0
;
j
<
str
.
size
();
j
+=
sizeof
(
uint64_t
))
{
uint64_t
id
=
*
(
uint64_t
*
)(
str
.
c_str
()
+
j
);
uint64_t
id
=
*
(
uint64_t
*
)(
str
.
c_str
()
+
j
);
// NOLINT
add_comm_edge
(
idx
,
v
,
id
);
add_comm_edge
(
idx
,
v
,
id
);
}
}
}
}
...
@@ -321,7 +321,7 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
...
@@ -321,7 +321,7 @@ int64_t GraphTable::load_graph_to_memory_from_ssd(int idx,
}
}
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
int64_t
tot
=
0
;
int64_t
tot
=
0
;
for
(
auto
x
:
count
)
tot
+=
x
;
for
(
auto
x
:
count
)
tot
+=
x
;
return
tot
;
return
tot
;
...
@@ -354,9 +354,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
...
@@ -354,9 +354,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
iters
.
push_back
(
_db
->
get_iterator
(
i
));
iters
.
push_back
(
_db
->
get_iterator
(
i
));
iters
[
i
]
->
SeekToFirst
();
iters
[
i
]
->
SeekToFirst
();
}
}
in
t
next
=
0
;
size_
t
next
=
0
;
while
(
iters
.
size
())
{
while
(
iters
.
size
())
{
if
(
next
>=
(
int
)
iters
.
size
())
{
if
(
next
>=
iters
.
size
())
{
next
=
0
;
next
=
0
;
}
}
if
(
!
iters
[
next
]
->
Valid
())
{
if
(
!
iters
[
next
]
->
Valid
())
{
...
@@ -364,15 +364,16 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
...
@@ -364,15 +364,16 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
continue
;
continue
;
}
}
std
::
string
key
=
iters
[
next
]
->
key
().
ToString
();
std
::
string
key
=
iters
[
next
]
->
key
().
ToString
();
int
type_idx
=
*
(
int
*
)
key
.
c_str
();
int
type_idx
=
*
(
int
*
)
key
.
c_str
();
// NOLINT
int
temp_idx
=
*
(
int
*
)(
key
.
c_str
()
+
sizeof
(
int
));
int
temp_idx
=
*
(
int
*
)(
key
.
c_str
()
+
sizeof
(
int
));
// NOLINT
if
(
type_idx
!=
0
||
temp_idx
!=
idx
)
{
if
(
type_idx
!=
0
||
temp_idx
!=
idx
)
{
iters
[
next
]
->
Next
();
iters
[
next
]
->
Next
();
next
++
;
next
++
;
continue
;
continue
;
}
}
std
::
string
value
=
iters
[
next
]
->
value
().
ToString
();
std
::
string
value
=
iters
[
next
]
->
value
().
ToString
();
std
::
uint64_t
i_key
=
*
(
uint64_t
*
)(
key
.
c_str
()
+
sizeof
(
int
)
*
2
);
std
::
uint64_t
i_key
=
*
(
uint64_t
*
)(
key
.
c_str
()
+
sizeof
(
int
)
*
2
);
// NOLINT
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
if
(
memory_remaining
[
i
]
<
(
int64_t
)
value
.
size
())
{
if
(
memory_remaining
[
i
]
<
(
int64_t
)
value
.
size
())
{
score
[
i
]
=
-
100000.0
;
score
[
i
]
=
-
100000.0
;
...
@@ -380,8 +381,8 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
...
@@ -380,8 +381,8 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
score
[
i
]
=
0
;
score
[
i
]
=
0
;
}
}
}
}
for
(
size_t
j
=
0
;
j
<
(
int
)
value
.
size
();
j
+=
sizeof
(
uint64_t
))
{
for
(
size_t
j
=
0
;
j
<
value
.
size
();
j
+=
sizeof
(
uint64_t
))
{
uint64_t
v
=
*
((
uint64_t
*
)(
value
.
c_str
()
+
j
));
uint64_t
v
=
*
((
uint64_t
*
)(
value
.
c_str
()
+
j
));
// NOLINT
int
index
=
-
1
;
int
index
=
-
1
;
if
(
id_map
.
find
(
v
)
!=
id_map
.
end
())
{
if
(
id_map
.
find
(
v
)
!=
id_map
.
end
())
{
index
=
id_map
[
v
];
index
=
id_map
[
v
];
...
@@ -398,9 +399,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
...
@@ -398,9 +399,9 @@ void GraphTable::make_partitions(int idx, int64_t byte_size, int device_len) {
int
index
=
0
;
int
index
=
0
;
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
for
(
int
i
=
0
;
i
<
part_len
;
i
++
)
{
base
=
gb_size_by_discount
-
memory_remaining
[
i
]
+
value
.
size
();
base
=
gb_size_by_discount
-
memory_remaining
[
i
]
+
value
.
size
();
if
(
has_weight
)
if
(
has_weight
)
{
weight_base
=
weight_cost
[
i
]
+
w
*
weight_param
;
weight_base
=
weight_cost
[
i
]
+
w
*
weight_param
;
else
{
}
else
{
weight_base
=
0
;
weight_base
=
0
;
}
}
score
[
i
]
-=
a
*
y
*
std
::
pow
(
1.0
*
base
,
y
-
1
)
+
weight_base
;
score
[
i
]
-=
a
*
y
*
std
::
pow
(
1.0
*
base
,
y
-
1
)
+
weight_base
;
...
@@ -434,7 +435,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
...
@@ -434,7 +435,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
int
part_len
=
partitions
[
idx
].
size
();
int
part_len
=
partitions
[
idx
].
size
();
if
(
part_len
==
0
)
return
;
if
(
part_len
==
0
)
return
;
if
(
file_path
==
""
)
file_path
=
"."
;
if
(
file_path
==
""
)
file_path
=
"."
;
if
(
file_path
[
(
int
)
file_path
.
size
()
-
1
]
!=
'/'
)
{
if
(
file_path
[
file_path
.
size
()
-
1
]
!=
'/'
)
{
file_path
+=
"/"
;
file_path
+=
"/"
;
}
}
std
::
vector
<
std
::
future
<
int
>>
tasks
;
std
::
vector
<
std
::
future
<
int
>>
tasks
;
...
@@ -459,7 +460,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
...
@@ -459,7 +460,7 @@ void GraphTable::export_partition_files(int idx, std::string file_path) {
}));
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
}
}
void
GraphTable
::
clear_graph
(
int
idx
)
{
void
GraphTable
::
clear_graph
(
int
idx
)
{
for
(
auto
p
:
edge_shards
[
idx
])
{
for
(
auto
p
:
edge_shards
[
idx
])
{
...
@@ -472,7 +473,7 @@ void GraphTable::clear_graph(int idx) {
...
@@ -472,7 +473,7 @@ void GraphTable::clear_graph(int idx) {
}
}
}
}
int32_t
GraphTable
::
load_next_partition
(
int
idx
)
{
int32_t
GraphTable
::
load_next_partition
(
int
idx
)
{
if
(
next_partition
>=
(
int
)
partitions
[
idx
].
size
(
))
{
if
(
next_partition
>=
static_cast
<
int
>
(
partitions
[
idx
].
size
()
))
{
VLOG
(
0
)
<<
"partition iteration is done"
;
VLOG
(
0
)
<<
"partition iteration is done"
;
return
-
1
;
return
-
1
;
}
}
...
@@ -518,8 +519,8 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path,
...
@@ -518,8 +519,8 @@ int32_t GraphTable::load_edges_to_ssd(const std::string &path,
add_node_to_ssd
(
0
,
add_node_to_ssd
(
0
,
idx
,
idx
,
src_id
,
src_id
,
(
char
*
)
dist_data
.
data
(),
(
char
*
)
dist_data
.
data
(),
// NOLINT
(
int
)
(
dist_data
.
size
()
*
sizeof
(
uint64_t
)));
static_cast
<
int
>
(
dist_data
.
size
()
*
sizeof
(
uint64_t
)));
}
}
}
}
VLOG
(
0
)
<<
"total memory cost = "
<<
total_memory_cost
<<
" bytes"
;
VLOG
(
0
)
<<
"total memory cost = "
<<
total_memory_cost
<<
" bytes"
;
...
@@ -537,14 +538,14 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) {
...
@@ -537,14 +538,14 @@ int32_t GraphTable::dump_edges_to_ssd(int idx) {
std
::
vector
<
Node
*>
&
v
=
shards
[
i
]
->
get_bucket
();
std
::
vector
<
Node
*>
&
v
=
shards
[
i
]
->
get_bucket
();
for
(
size_t
j
=
0
;
j
<
v
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
v
.
size
();
j
++
)
{
std
::
vector
<
uint64_t
>
s
;
std
::
vector
<
uint64_t
>
s
;
for
(
size_t
k
=
0
;
k
<
(
int
)
v
[
j
]
->
get_neighbor_size
();
k
++
)
{
for
(
size_t
k
=
0
;
k
<
v
[
j
]
->
get_neighbor_size
();
k
++
)
{
s
.
push_back
(
v
[
j
]
->
get_neighbor_id
(
k
));
s
.
push_back
(
v
[
j
]
->
get_neighbor_id
(
k
));
}
}
cost
+=
v
[
j
]
->
get_neighbor_size
()
*
sizeof
(
uint64_t
);
cost
+=
v
[
j
]
->
get_neighbor_size
()
*
sizeof
(
uint64_t
);
add_node_to_ssd
(
0
,
add_node_to_ssd
(
0
,
idx
,
idx
,
v
[
j
]
->
get_id
(),
v
[
j
]
->
get_id
(),
(
char
*
)
s
.
data
(),
(
char
*
)
s
.
data
(),
// NOLINT
s
.
size
()
*
sizeof
(
uint64_t
));
s
.
size
()
*
sizeof
(
uint64_t
));
}
}
return
cost
;
return
cost
;
...
@@ -901,7 +902,8 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
...
@@ -901,7 +902,8 @@ void BasicBfsGraphSampler::init(size_t gpu_num, GraphTable *graph_table,
std
::
vector
<
Node
*>
GraphShard
::
get_batch
(
int
start
,
int
end
,
int
step
)
{
std
::
vector
<
Node
*>
GraphShard
::
get_batch
(
int
start
,
int
end
,
int
step
)
{
if
(
start
<
0
)
start
=
0
;
if
(
start
<
0
)
start
=
0
;
std
::
vector
<
Node
*>
res
;
std
::
vector
<
Node
*>
res
;
for
(
int
pos
=
start
;
pos
<
std
::
min
(
end
,
(
int
)
bucket
.
size
());
pos
+=
step
)
{
for
(
int
pos
=
start
;
pos
<
std
::
min
(
end
,
(
int
)
bucket
.
size
());
// NOLINT
pos
+=
step
)
{
res
.
push_back
(
bucket
[
pos
]);
res
.
push_back
(
bucket
[
pos
]);
}
}
return
res
;
return
res
;
...
@@ -990,7 +992,7 @@ void GraphShard::delete_node(uint64_t id) {
...
@@ -990,7 +992,7 @@ void GraphShard::delete_node(uint64_t id) {
if
(
iter
==
node_location
.
end
())
return
;
if
(
iter
==
node_location
.
end
())
return
;
int
pos
=
iter
->
second
;
int
pos
=
iter
->
second
;
delete
bucket
[
pos
];
delete
bucket
[
pos
];
if
(
pos
!=
(
int
)
bucket
.
size
(
)
-
1
)
{
if
(
pos
!=
static_cast
<
int
>
(
bucket
.
size
()
)
-
1
)
{
bucket
[
pos
]
=
bucket
.
back
();
bucket
[
pos
]
=
bucket
.
back
();
node_location
[
bucket
.
back
()
->
get_id
()]
=
pos
;
node_location
[
bucket
.
back
()
->
get_id
()]
=
pos
;
}
}
...
@@ -1002,7 +1004,7 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
...
@@ -1002,7 +1004,7 @@ GraphNode *GraphShard::add_graph_node(uint64_t id) {
node_location
[
id
]
=
bucket
.
size
();
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
new
GraphNode
(
id
));
bucket
.
push_back
(
new
GraphNode
(
id
));
}
}
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
// NOLINT
}
}
GraphNode
*
GraphShard
::
add_graph_node
(
Node
*
node
)
{
GraphNode
*
GraphShard
::
add_graph_node
(
Node
*
node
)
{
...
@@ -1011,17 +1013,17 @@ GraphNode *GraphShard::add_graph_node(Node *node) {
...
@@ -1011,17 +1013,17 @@ GraphNode *GraphShard::add_graph_node(Node *node) {
node_location
[
id
]
=
bucket
.
size
();
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
node
);
bucket
.
push_back
(
node
);
}
}
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
return
(
GraphNode
*
)
bucket
[
node_location
[
id
]];
// NOLINT
}
}
FeatureNode
*
GraphShard
::
add_feature_node
(
uint64_t
id
,
bool
is_overlap
)
{
FeatureNode
*
GraphShard
::
add_feature_node
(
uint64_t
id
,
bool
is_overlap
)
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
if
(
node_location
.
find
(
id
)
==
node_location
.
end
())
{
node_location
[
id
]
=
bucket
.
size
();
node_location
[
id
]
=
bucket
.
size
();
bucket
.
push_back
(
new
FeatureNode
(
id
));
bucket
.
push_back
(
new
FeatureNode
(
id
));
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
// NOLINT
}
}
if
(
is_overlap
)
{
if
(
is_overlap
)
{
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
return
(
FeatureNode
*
)
bucket
[
node_location
[
id
]];
// NOLINT
}
}
return
NULL
;
return
NULL
;
...
@@ -1037,14 +1039,14 @@ Node *GraphShard::find_node(uint64_t id) {
...
@@ -1037,14 +1039,14 @@ Node *GraphShard::find_node(uint64_t id) {
}
}
GraphTable
::~
GraphTable
()
{
GraphTable
::~
GraphTable
()
{
for
(
int
i
=
0
;
i
<
(
int
)
edge_shards
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
edge_shards
.
size
();
i
++
)
{
for
(
auto
p
:
edge_shards
[
i
])
{
for
(
auto
p
:
edge_shards
[
i
])
{
delete
p
;
delete
p
;
}
}
edge_shards
[
i
].
clear
();
edge_shards
[
i
].
clear
();
}
}
for
(
int
i
=
0
;
i
<
(
int
)
feature_shards
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
feature_shards
.
size
();
i
++
)
{
for
(
auto
p
:
feature_shards
[
i
])
{
for
(
auto
p
:
feature_shards
[
i
])
{
delete
p
;
delete
p
;
}
}
...
@@ -1070,7 +1072,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string ¶m) {
...
@@ -1070,7 +1072,7 @@ int32_t GraphTable::Load(const std::string &path, const std::string ¶m) {
std
::
string
GraphTable
::
get_inverse_etype
(
std
::
string
&
etype
)
{
std
::
string
GraphTable
::
get_inverse_etype
(
std
::
string
&
etype
)
{
auto
etype_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
etype
,
"2"
);
auto
etype_split
=
paddle
::
string
::
split_string
<
std
::
string
>
(
etype
,
"2"
);
std
::
string
res
;
std
::
string
res
;
if
(
(
int
)
etype_split
.
size
()
==
3
)
{
if
(
etype_split
.
size
()
==
3
)
{
res
=
etype_split
[
2
]
+
"2"
+
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
res
=
etype_split
[
2
]
+
"2"
+
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
}
else
{
}
else
{
res
=
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
res
=
etype_split
[
1
]
+
"2"
+
etype_split
[
0
];
...
@@ -1099,7 +1101,8 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
...
@@ -1099,7 +1101,8 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
std
::
string
etype_path
=
epath
+
"/"
+
etypes
[
i
];
std
::
string
etype_path
=
epath
+
"/"
+
etypes
[
i
];
auto
etype_path_list
=
paddle
::
framework
::
localfs_list
(
etype_path
);
auto
etype_path_list
=
paddle
::
framework
::
localfs_list
(
etype_path
);
std
::
string
etype_path_str
;
std
::
string
etype_path_str
;
if
(
part_num
>
0
&&
part_num
<
(
int
)
etype_path_list
.
size
())
{
if
(
part_num
>
0
&&
part_num
<
(
int
)
etype_path_list
.
size
())
{
// NOLINT
std
::
vector
<
std
::
string
>
sub_etype_path_list
(
std
::
vector
<
std
::
string
>
sub_etype_path_list
(
etype_path_list
.
begin
(),
etype_path_list
.
begin
()
+
part_num
);
etype_path_list
.
begin
(),
etype_path_list
.
begin
()
+
part_num
);
etype_path_str
=
etype_path_str
=
...
@@ -1116,7 +1119,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
...
@@ -1116,7 +1119,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
}
else
{
}
else
{
auto
npath_list
=
paddle
::
framework
::
localfs_list
(
npath
);
auto
npath_list
=
paddle
::
framework
::
localfs_list
(
npath
);
std
::
string
npath_str
;
std
::
string
npath_str
;
if
(
part_num
>
0
&&
part_num
<
(
int
)
npath_list
.
size
())
{
if
(
part_num
>
0
&&
part_num
<
(
int
)
npath_list
.
size
())
{
// NOLINT
std
::
vector
<
std
::
string
>
sub_npath_list
(
std
::
vector
<
std
::
string
>
sub_npath_list
(
npath_list
.
begin
(),
npath_list
.
begin
()
+
part_num
);
npath_list
.
begin
(),
npath_list
.
begin
()
+
part_num
);
npath_str
=
paddle
::
string
::
join_strings
(
sub_npath_list
,
delim
);
npath_str
=
paddle
::
string
::
join_strings
(
sub_npath_list
,
delim
);
...
@@ -1140,7 +1143,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
...
@@ -1140,7 +1143,7 @@ int32_t GraphTable::load_node_and_edge_file(std::string etype,
return
0
;
return
0
;
}));
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
tasks
[
i
].
get
();
return
0
;
return
0
;
}
}
...
@@ -1154,13 +1157,14 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
...
@@ -1154,13 +1157,14 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res
.
clear
();
res
.
clear
();
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
index
<
(
int
)
ranges
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shards
.
size
()
&&
index
<
(
int
)
ranges
.
size
();
// NOLINT
i
++
)
{
end
=
total_size
+
shards
[
i
]
->
get_size
();
end
=
total_size
+
shards
[
i
]
->
get_size
();
start
=
total_size
;
start
=
total_size
;
while
(
start
<
end
&&
index
<
(
int
)
ranges
.
size
(
))
{
while
(
start
<
end
&&
index
<
static_cast
<
int
>
(
ranges
.
size
()
))
{
if
(
ranges
[
index
].
second
<=
start
)
if
(
ranges
[
index
].
second
<=
start
)
{
index
++
;
index
++
;
else
if
(
ranges
[
index
].
first
>=
end
)
{
}
else
if
(
ranges
[
index
].
first
>=
end
)
{
break
;
break
;
}
else
{
}
else
{
int
first
=
std
::
max
(
ranges
[
index
].
first
,
start
);
int
first
=
std
::
max
(
ranges
[
index
].
first
,
start
);
...
@@ -1178,7 +1182,8 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
...
@@ -1178,7 +1182,8 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
res
.
reserve
(
res
.
size
()
+
num
);
res
.
reserve
(
res
.
size
()
+
num
);
for
(
auto
&
id
:
keys
)
{
for
(
auto
&
id
:
keys
)
{
res
.
push_back
(
id
);
res
.
push_back
(
id
);
std
::
swap
(
res
[
rand
()
%
res
.
size
()],
res
[(
int
)
res
.
size
()
-
1
]);
std
::
swap
(
res
[
rand
()
%
res
.
size
()],
res
[(
int
)
res
.
size
()
-
1
]);
// NOLINT
}
}
mutex
.
unlock
();
mutex
.
unlock
();
...
@@ -1291,7 +1296,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
...
@@ -1291,7 +1296,7 @@ std::pair<uint64_t, uint64_t> GraphTable::parse_node_file(
return
{
local_count
,
local_valid_count
};
return
{
local_count
,
local_valid_count
};
}
}
//
TODO
opt load all node_types in once reading
//
// TODO(danleifeng):
opt load all node_types in once reading
int32_t
GraphTable
::
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
)
{
int32_t
GraphTable
::
load_nodes
(
const
std
::
string
&
path
,
std
::
string
node_type
)
{
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
auto
paths
=
paddle
::
string
::
split_string
<
std
::
string
>
(
path
,
";"
);
uint64_t
count
=
0
;
uint64_t
count
=
0
;
...
@@ -1308,7 +1313,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
...
@@ -1308,7 +1313,7 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
return
parse_node_file
(
paths
[
i
]);
return
parse_node_file
(
paths
[
i
]);
}));
}));
}
}
for
(
int
i
=
0
;
i
<
(
int
)
tasks
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
tasks
.
size
();
i
++
)
{
auto
res
=
tasks
[
i
].
get
();
auto
res
=
tasks
[
i
].
get
();
count
+=
res
.
first
;
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
valid_count
+=
res
.
second
;
...
@@ -1434,13 +1439,13 @@ int32_t GraphTable::load_edges(const std::string &path,
...
@@ -1434,13 +1439,13 @@ int32_t GraphTable::load_edges(const std::string &path,
VLOG
(
0
)
<<
"Begin GraphTable::load_edges() edge_type["
<<
edge_type
<<
"]"
;
VLOG
(
0
)
<<
"Begin GraphTable::load_edges() edge_type["
<<
edge_type
<<
"]"
;
if
(
FLAGS_graph_load_in_parallel
)
{
if
(
FLAGS_graph_load_in_parallel
)
{
std
::
vector
<
std
::
future
<
std
::
pair
<
uint64_t
,
uint64_t
>>>
tasks
;
std
::
vector
<
std
::
future
<
std
::
pair
<
uint64_t
,
uint64_t
>>>
tasks
;
for
(
in
t
i
=
0
;
i
<
paths
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
paths
.
size
();
i
++
)
{
tasks
.
push_back
(
load_node_edge_task_pool
->
enqueue
(
tasks
.
push_back
(
load_node_edge_task_pool
->
enqueue
(
[
&
,
i
,
idx
,
this
]()
->
std
::
pair
<
uint64_t
,
uint64_t
>
{
[
&
,
i
,
idx
,
this
]()
->
std
::
pair
<
uint64_t
,
uint64_t
>
{
return
parse_edge_file
(
paths
[
i
],
idx
,
reverse_edge
);
return
parse_edge_file
(
paths
[
i
],
idx
,
reverse_edge
);
}));
}));
}
}
for
(
int
j
=
0
;
j
<
(
int
)
tasks
.
size
();
j
++
)
{
for
(
size_t
j
=
0
;
j
<
tasks
.
size
();
j
++
)
{
auto
res
=
tasks
[
j
].
get
();
auto
res
=
tasks
[
j
].
get
();
count
+=
res
.
first
;
count
+=
res
.
first
;
valid_count
+=
res
.
second
;
valid_count
+=
res
.
second
;
...
@@ -1543,7 +1548,7 @@ int32_t GraphTable::random_sample_nodes(int type_id,
...
@@ -1543,7 +1548,7 @@ int32_t GraphTable::random_sample_nodes(int type_id,
int
&
actual_size
)
{
int
&
actual_size
)
{
int
total_size
=
0
;
int
total_size
=
0
;
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
auto
&
shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
for
(
int
i
=
0
;
i
<
(
int
)
shards
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shards
.
size
();
i
++
)
{
total_size
+=
shards
[
i
]
->
get_size
();
total_size
+=
shards
[
i
]
->
get_size
();
}
}
if
(
sample_size
>
total_size
)
sample_size
=
total_size
;
if
(
sample_size
>
total_size
)
sample_size
=
total_size
;
...
@@ -1554,9 +1559,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
...
@@ -1554,9 +1559,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
int
remain
=
sample_size
,
last_pos
=
-
1
,
num
;
int
remain
=
sample_size
,
last_pos
=
-
1
,
num
;
std
::
set
<
int
>
separator_set
;
std
::
set
<
int
>
separator_set
;
for
(
int
i
=
0
;
i
<
range_num
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
range_num
-
1
;
i
++
)
{
while
(
separator_set
.
find
(
num
=
rand
()
%
(
sample_size
-
1
))
!=
unsigned
int
seed
=
time
(
0
);
separator_set
.
end
())
while
(
separator_set
.
find
(
num
=
rand_r
(
&
seed
)
%
(
sample_size
-
1
))
!=
;
separator_set
.
end
())
{
continue
;
}
separator_set
.
insert
(
num
);
separator_set
.
insert
(
num
);
}
}
for
(
auto
p
:
separator_set
)
{
for
(
auto
p
:
separator_set
)
{
...
@@ -1567,8 +1574,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
...
@@ -1567,8 +1574,11 @@ int32_t GraphTable::random_sample_nodes(int type_id,
remain
=
total_size
-
sample_size
+
range_num
;
remain
=
total_size
-
sample_size
+
range_num
;
separator_set
.
clear
();
separator_set
.
clear
();
for
(
int
i
=
0
;
i
<
range_num
;
i
++
)
{
for
(
int
i
=
0
;
i
<
range_num
;
i
++
)
{
while
(
separator_set
.
find
(
num
=
rand
()
%
remain
)
!=
separator_set
.
end
())
unsigned
int
seed
=
time
(
0
);
;
while
(
separator_set
.
find
(
num
=
rand_r
(
&
seed
)
%
remain
)
!=
separator_set
.
end
())
{
continue
;
}
separator_set
.
insert
(
num
);
separator_set
.
insert
(
num
);
}
}
int
used
=
0
,
index
=
0
;
int
used
=
0
,
index
=
0
;
...
@@ -1580,12 +1590,13 @@ int32_t GraphTable::random_sample_nodes(int type_id,
...
@@ -1580,12 +1590,13 @@ int32_t GraphTable::random_sample_nodes(int type_id,
used
+=
ranges_len
[
index
++
];
used
+=
ranges_len
[
index
++
];
}
}
std
::
vector
<
std
::
pair
<
int
,
int
>>
first_half
,
second_half
;
std
::
vector
<
std
::
pair
<
int
,
int
>>
first_half
,
second_half
;
int
start_index
=
rand
()
%
total_size
;
unsigned
int
seed
=
time
(
0
);
int
start_index
=
rand_r
(
&
seed
)
%
total_size
;
for
(
size_t
i
=
0
;
i
<
ranges_len
.
size
()
&&
i
<
ranges_pos
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
ranges_len
.
size
()
&&
i
<
ranges_pos
.
size
();
i
++
)
{
if
(
ranges_pos
[
i
]
+
ranges_len
[
i
]
-
1
+
start_index
<
total_size
)
if
(
ranges_pos
[
i
]
+
ranges_len
[
i
]
-
1
+
start_index
<
total_size
)
{
first_half
.
push_back
({
ranges_pos
[
i
]
+
start_index
,
first_half
.
push_back
({
ranges_pos
[
i
]
+
start_index
,
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
});
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
});
else
if
(
ranges_pos
[
i
]
+
start_index
>=
total_size
)
{
}
else
if
(
ranges_pos
[
i
]
+
start_index
>=
total_size
)
{
second_half
.
push_back
(
second_half
.
push_back
(
{
ranges_pos
[
i
]
+
start_index
-
total_size
,
{
ranges_pos
[
i
]
+
start_index
-
total_size
,
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
-
total_size
});
ranges_pos
[
i
]
+
ranges_len
[
i
]
+
start_index
-
total_size
});
...
@@ -1623,7 +1634,7 @@ int32_t GraphTable::random_sample_neighbors(
...
@@ -1623,7 +1634,7 @@ int32_t GraphTable::random_sample_neighbors(
id_list
[
index
].
emplace_back
(
idx
,
node_ids
[
idy
],
sample_size
,
need_weight
);
id_list
[
index
].
emplace_back
(
idx
,
node_ids
[
idy
],
sample_size
,
need_weight
);
}
}
for
(
int
i
=
0
;
i
<
(
int
)
seq_id
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
seq_id
.
size
();
i
++
)
{
if
(
seq_id
[
i
].
size
()
==
0
)
continue
;
if
(
seq_id
[
i
].
size
()
==
0
)
continue
;
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
tasks
.
push_back
(
_shards_task_pool
[
i
]
->
enqueue
([
&
,
i
,
this
]()
->
int
{
uint64_t
node_id
;
uint64_t
node_id
;
...
@@ -1633,12 +1644,12 @@ int32_t GraphTable::random_sample_neighbors(
...
@@ -1633,12 +1644,12 @@ int32_t GraphTable::random_sample_neighbors(
response
=
response
=
scaled_lru
->
query
(
i
,
id_list
[
i
].
data
(),
id_list
[
i
].
size
(),
r
);
scaled_lru
->
query
(
i
,
id_list
[
i
].
data
(),
id_list
[
i
].
size
(),
r
);
}
}
in
t
index
=
0
;
size_
t
index
=
0
;
std
::
vector
<
SampleResult
>
sample_res
;
std
::
vector
<
SampleResult
>
sample_res
;
std
::
vector
<
SampleKey
>
sample_keys
;
std
::
vector
<
SampleKey
>
sample_keys
;
auto
&
rng
=
_shards_task_rng_pool
[
i
];
auto
&
rng
=
_shards_task_rng_pool
[
i
];
for
(
size_t
k
=
0
;
k
<
id_list
[
i
].
size
();
k
++
)
{
for
(
size_t
k
=
0
;
k
<
id_list
[
i
].
size
();
k
++
)
{
if
(
index
<
(
int
)
r
.
size
()
&&
if
(
index
<
r
.
size
()
&&
r
[
index
].
first
.
node_key
==
id_list
[
i
][
k
].
node_key
)
{
r
[
index
].
first
.
node_key
==
id_list
[
i
][
k
].
node_key
)
{
int
idy
=
seq_id
[
i
][
k
];
int
idy
=
seq_id
[
i
][
k
];
actual_sizes
[
idy
]
=
r
[
index
].
second
.
actual_size
;
actual_sizes
[
idy
]
=
r
[
index
].
second
.
actual_size
;
...
@@ -1722,7 +1733,7 @@ int32_t GraphTable::get_node_feat(int idx,
...
@@ -1722,7 +1733,7 @@ int32_t GraphTable::get_node_feat(int idx,
if
(
node
==
nullptr
)
{
if
(
node
==
nullptr
)
{
return
0
;
return
0
;
}
}
for
(
int
feat_idx
=
0
;
feat_idx
<
(
int
)
feature_names
.
size
();
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
...
@@ -1755,7 +1766,7 @@ int32_t GraphTable::set_node_feat(
...
@@ -1755,7 +1766,7 @@ int32_t GraphTable::set_node_feat(
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
size_t
index
=
node_id
%
this
->
shard_num
-
this
->
shard_start
;
auto
node
=
feature_shards
[
idx
][
index
]
->
add_feature_node
(
node_id
);
auto
node
=
feature_shards
[
idx
][
index
]
->
add_feature_node
(
node_id
);
node
->
set_feature_size
(
this
->
feat_name
[
idx
].
size
());
node
->
set_feature_size
(
this
->
feat_name
[
idx
].
size
());
for
(
int
feat_idx
=
0
;
feat_idx
<
(
int
)
feature_names
.
size
();
for
(
size_t
feat_idx
=
0
;
feat_idx
<
feature_names
.
size
();
++
feat_idx
)
{
++
feat_idx
)
{
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
const
std
::
string
&
feature_name
=
feature_names
[
feat_idx
];
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
if
(
feat_id_map
[
idx
].
find
(
feature_name
)
!=
feat_id_map
[
idx
].
end
())
{
...
@@ -1893,8 +1904,8 @@ int GraphTable::get_all_id(int type_id,
...
@@ -1893,8 +1904,8 @@ int GraphTable::get_all_id(int type_id,
MergeShardVector
shard_merge
(
output
,
slice_num
);
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
in
t
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
size_
t
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
in
t
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
for
(
size_
t
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
...
@@ -1917,8 +1928,8 @@ int GraphTable::get_all_neighbor_id(
...
@@ -1917,8 +1928,8 @@ int GraphTable::get_all_neighbor_id(
MergeShardVector
shard_merge
(
output
,
slice_num
);
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
auto
&
search_shards
=
type_id
==
0
?
edge_shards
:
feature_shards
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
in
t
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
size_
t
idx
=
0
;
idx
<
search_shards
.
size
();
idx
++
)
{
for
(
in
t
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
for
(
size_
t
j
=
0
;
j
<
search_shards
[
idx
].
size
();
j
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
j
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
[
&
search_shards
,
idx
,
j
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
...
@@ -1970,7 +1981,7 @@ int GraphTable::get_all_neighbor_id(
...
@@ -1970,7 +1981,7 @@ int GraphTable::get_all_neighbor_id(
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
VLOG
(
3
)
<<
"begin task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
VLOG
(
3
)
<<
"begin task, task_pool_size_["
<<
task_pool_size_
<<
"]"
;
for
(
in
t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
...
@@ -1996,7 +2007,7 @@ int GraphTable::get_all_feature_ids(
...
@@ -1996,7 +2007,7 @@ int GraphTable::get_all_feature_ids(
MergeShardVector
shard_merge
(
output
,
slice_num
);
MergeShardVector
shard_merge
(
output
,
slice_num
);
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
auto
&
search_shards
=
type_id
==
0
?
edge_shards
[
idx
]
:
feature_shards
[
idx
];
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
std
::
vector
<
std
::
future
<
size_t
>>
tasks
;
for
(
in
t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
search_shards
.
size
();
i
++
)
{
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
tasks
.
push_back
(
_shards_task_pool
[
i
%
task_pool_size_
]
->
enqueue
(
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
[
&
search_shards
,
i
,
slice_num
,
&
shard_merge
]()
->
size_t
{
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
std
::
vector
<
std
::
vector
<
uint64_t
>>
shard_keys
;
...
@@ -2139,7 +2150,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
...
@@ -2139,7 +2150,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
if
(
use_cache
)
{
if
(
use_cache
)
{
cache_size_limit
=
graph
.
cache_size_limit
();
cache_size_limit
=
graph
.
cache_size_limit
();
cache_ttl
=
graph
.
cache_ttl
();
cache_ttl
=
graph
.
cache_ttl
();
make_neighbor_sample_cache
((
size_t
)
cache_size_limit
,
(
size_t
)
cache_ttl
);
make_neighbor_sample_cache
((
size_t
)
cache_size_limit
,
// NOLINT
(
size_t
)
cache_ttl
);
// NOLINT
}
}
_shards_task_pool
.
resize
(
task_pool_size_
);
_shards_task_pool
.
resize
(
task_pool_size_
);
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
...
@@ -2205,14 +2217,14 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
...
@@ -2205,14 +2217,14 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
#ifdef PADDLE_WITH_HETERPS
#ifdef PADDLE_WITH_HETERPS
partitions
.
resize
(
id_to_edge
.
size
());
partitions
.
resize
(
id_to_edge
.
size
());
#endif
#endif
for
(
int
k
=
0
;
k
<
(
int
)
edge_shards
.
size
();
k
++
)
{
for
(
size_t
k
=
0
;
k
<
edge_shards
.
size
();
k
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
edge_shards
[
k
].
push_back
(
new
GraphShard
());
edge_shards
[
k
].
push_back
(
new
GraphShard
());
}
}
}
}
node_weight
[
1
].
resize
(
id_to_feature
.
size
());
node_weight
[
1
].
resize
(
id_to_feature
.
size
());
feature_shards
.
resize
(
id_to_feature
.
size
());
feature_shards
.
resize
(
id_to_feature
.
size
());
for
(
int
k
=
0
;
k
<
(
int
)
feature_shards
.
size
();
k
++
)
{
for
(
size_t
k
=
0
;
k
<
feature_shards
.
size
();
k
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shard_num_per_server
;
i
++
)
{
feature_shards
[
k
].
push_back
(
new
GraphShard
());
feature_shards
[
k
].
push_back
(
new
GraphShard
());
}
}
...
...
paddle/fluid/distributed/ps/table/memory_dense_table.cc
浏览文件 @
be273ea9
...
@@ -21,8 +21,8 @@ namespace distributed {
...
@@ -21,8 +21,8 @@ namespace distributed {
int
FLAGS_pslib_table_save_max_retry_dense
=
3
;
int
FLAGS_pslib_table_save_max_retry_dense
=
3
;
void
MemoryDenseTable
::
CreateInitializer
(
const
std
::
string
&
attr
,
void
MemoryDenseTable
::
CreateInitializer
(
const
std
::
string
&
attr
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
auto
slices
=
string
::
split_string
<
std
::
string
>
(
attr
,
"&"
);
auto
slices
=
string
::
split_string
<
std
::
string
>
(
attr
,
"&"
);
if
(
slices
[
0
]
==
"gaussian_random"
)
{
if
(
slices
[
0
]
==
"gaussian_random"
)
{
...
@@ -60,14 +60,14 @@ int32_t MemoryDenseTable::InitializeValue() {
...
@@ -60,14 +60,14 @@ int32_t MemoryDenseTable::InitializeValue() {
values_
.
resize
(
size
);
values_
.
resize
(
size
);
total_dim_
=
0
;
total_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
varname
=
common
.
params
()[
x
];
auto
&
varname
=
common
.
params
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
if
(
varname
==
"Param"
)
{
if
(
varname
==
"Param"
)
{
param_dim_
=
dim
;
param_dim_
=
dim
;
param_idx_
=
x
;
param_idx_
=
x
;
}
}
auto
&
initializer
=
common
.
initializers
()[
x
];
auto
&
initializer
=
common
.
initializers
()[
x
];
total_dim_
+=
dim
;
total_dim_
+=
dim
;
CreateInitializer
(
initializer
,
varname
);
CreateInitializer
(
initializer
,
varname
);
...
@@ -81,7 +81,7 @@ int32_t MemoryDenseTable::InitializeValue() {
...
@@ -81,7 +81,7 @@ int32_t MemoryDenseTable::InitializeValue() {
fixed_len_params_dim_
=
0
;
fixed_len_params_dim_
=
0
;
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
for
(
int
x
=
0
;
x
<
size
;
++
x
)
{
auto
&
dim
=
common
.
dims
()[
x
];
auto
&
dim
=
common
.
dims
()[
x
];
if
(
static_cast
<
int
>
(
dim
)
!=
param_dim_
)
{
if
(
static_cast
<
int
>
(
dim
)
!=
param_dim_
)
{
fixed_len_params_dim_
+=
dim
;
fixed_len_params_dim_
+=
dim
;
}
else
{
}
else
{
...
@@ -124,19 +124,19 @@ int32_t MemoryDenseTable::InitializeOptimizer() {
...
@@ -124,19 +124,19 @@ int32_t MemoryDenseTable::InitializeOptimizer() {
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
SetGlobalLR
(
float
*
lr
)
{
int32_t
MemoryDenseTable
::
SetGlobalLR
(
float
*
lr
)
{
_global_lr
=
lr
;
_global_lr
=
lr
;
optimizer_
->
SetGlobalLR
(
_global_lr
);
optimizer_
->
SetGlobalLR
(
_global_lr
);
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
Pull
(
TableContext
&
context
)
{
int32_t
MemoryDenseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
CHECK
(
context
.
value_type
==
Dense
);
float
*
pull_values
=
context
.
pull_context
.
values
;
float
*
pull_values
=
context
.
pull_context
.
values
;
return
PullDense
(
pull_values
,
context
.
num
);
return
PullDense
(
pull_values
,
context
.
num
);
}
}
int32_t
MemoryDenseTable
::
Push
(
TableContext
&
context
)
{
int32_t
MemoryDenseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Dense
);
CHECK
(
context
.
value_type
==
Dense
);
if
(
context
.
push_context
.
values
!=
nullptr
)
{
if
(
context
.
push_context
.
values
!=
nullptr
)
{
if
(
!
context
.
push_context
.
is_param
)
{
if
(
!
context
.
push_context
.
is_param
)
{
...
@@ -148,13 +148,13 @@ int32_t MemoryDenseTable::Push(TableContext& context) {
...
@@ -148,13 +148,13 @@ int32_t MemoryDenseTable::Push(TableContext& context) {
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
PullDense
(
float
*
pull_values
,
size_t
num
)
{
int32_t
MemoryDenseTable
::
PullDense
(
float
*
pull_values
,
size_t
num
)
{
std
::
copy
(
std
::
copy
(
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
pull_values
);
values_
[
param_idx_
].
begin
(),
values_
[
param_idx_
].
end
(),
pull_values
);
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
PushDenseParam
(
const
float
*
values
,
size_t
num
)
{
int32_t
MemoryDenseTable
::
PushDenseParam
(
const
float
*
values
,
size_t
num
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
num
,
num
,
param_dim_
,
param_dim_
,
...
@@ -171,7 +171,7 @@ int32_t MemoryDenseTable::Pour() {
...
@@ -171,7 +171,7 @@ int32_t MemoryDenseTable::Pour() {
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
PushDense
(
const
float
*
values
,
size_t
num
)
{
int32_t
MemoryDenseTable
::
PushDense
(
const
float
*
values
,
size_t
num
)
{
if
(
sync
)
{
if
(
sync
)
{
std
::
future
<
int
>
task
=
std
::
future
<
int
>
task
=
_shards_task_pool
[
0
]
->
enqueue
([
this
,
&
values
]()
->
int
{
_shards_task_pool
[
0
]
->
enqueue
([
this
,
&
values
]()
->
int
{
...
@@ -185,7 +185,7 @@ int32_t MemoryDenseTable::PushDense(const float* values, size_t num) {
...
@@ -185,7 +185,7 @@ int32_t MemoryDenseTable::PushDense(const float* values, size_t num) {
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
_PushDense
(
const
float
*
values
,
size_t
num
)
{
int32_t
MemoryDenseTable
::
_PushDense
(
const
float
*
values
,
size_t
num
)
{
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
num
,
num
,
param_dim_
,
param_dim_
,
...
@@ -212,8 +212,8 @@ int32_t MemoryDenseTable::_PushDense(const float* values, size_t num) {
...
@@ -212,8 +212,8 @@ int32_t MemoryDenseTable::_PushDense(const float* values, size_t num) {
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
Load
(
const
std
::
string
&
path
,
int32_t
MemoryDenseTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
const
std
::
string
&
param
)
{
if
(
param_dim_
<=
0
)
{
if
(
param_dim_
<=
0
)
{
return
0
;
return
0
;
}
}
...
@@ -249,7 +249,7 @@ int32_t MemoryDenseTable::Load(const std::string& path,
...
@@ -249,7 +249,7 @@ int32_t MemoryDenseTable::Load(const std::string& path,
try
{
try
{
int
dim_idx
=
0
;
int
dim_idx
=
0
;
float
data_buffer
[
5
];
float
data_buffer
[
5
];
float
*
data_buff_ptr
=
data_buffer
;
float
*
data_buff_ptr
=
data_buffer
;
std
::
string
line_data
;
std
::
string
line_data
;
auto
common
=
_config
.
common
();
auto
common
=
_config
.
common
();
...
@@ -319,8 +319,8 @@ int32_t MemoryDenseTable::Load(const std::string& path,
...
@@ -319,8 +319,8 @@ int32_t MemoryDenseTable::Load(const std::string& path,
return
0
;
return
0
;
}
}
int32_t
MemoryDenseTable
::
Save
(
const
std
::
string
&
path
,
int32_t
MemoryDenseTable
::
Save
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
const
std
::
string
&
param
)
{
int
save_param
=
atoi
(
param
.
c_str
());
int
save_param
=
atoi
(
param
.
c_str
());
uint32_t
feasign_size
;
uint32_t
feasign_size
;
VLOG
(
0
)
<<
"MemoryDenseTable::save path "
<<
path
;
VLOG
(
0
)
<<
"MemoryDenseTable::save path "
<<
path
;
...
@@ -353,7 +353,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
...
@@ -353,7 +353,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
os
.
clear
();
os
.
clear
();
os
.
str
(
""
);
os
.
str
(
""
);
os
<<
values_
[
param_col_ids_
[
0
]][
y
]
<<
" 0"
;
os
<<
values_
[
param_col_ids_
[
0
]][
y
]
<<
" 0"
;
for
(
in
t
x
=
2
;
x
<
param_col_ids_
.
size
();
++
x
)
{
for
(
size_
t
x
=
2
;
x
<
param_col_ids_
.
size
();
++
x
)
{
os
<<
" "
;
os
<<
" "
;
os
<<
values_
[
param_col_ids_
[
x
]][
y
];
os
<<
values_
[
param_col_ids_
[
x
]][
y
];
}
}
...
@@ -365,7 +365,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
...
@@ -365,7 +365,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
os
.
clear
();
os
.
clear
();
os
.
str
(
""
);
os
.
str
(
""
);
os
<<
values_
[
param_col_ids_
[
0
]][
y
];
os
<<
values_
[
param_col_ids_
[
0
]][
y
];
for
(
in
t
x
=
1
;
x
<
param_col_ids_
.
size
();
++
x
)
{
for
(
size_
t
x
=
1
;
x
<
param_col_ids_
.
size
();
++
x
)
{
os
<<
" "
;
os
<<
" "
;
os
<<
values_
[
param_col_ids_
[
x
]][
y
];
os
<<
values_
[
param_col_ids_
[
x
]][
y
];
}
}
...
@@ -383,7 +383,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
...
@@ -383,7 +383,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
auto
write_channel
=
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
auto
&
t
:
result_buffer_param
)
{
for
(
auto
&
t
:
result_buffer_param
)
{
if
(
0
!=
write_channel
->
write_line
(
t
))
{
if
(
0
!=
write_channel
->
write_line
(
t
))
{
++
retry_num
;
++
retry_num
;
is_write_failed
=
true
;
is_write_failed
=
true
;
...
...
paddle/fluid/distributed/ps/table/memory_sparse_table.cc
浏览文件 @
be273ea9
...
@@ -41,12 +41,12 @@ namespace paddle {
...
@@ -41,12 +41,12 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
int32_t
MemorySparseTable
::
Initialize
()
{
int32_t
MemorySparseTable
::
Initialize
()
{
auto
&
profiler
=
CostProfiler
::
instance
();
auto
&
profiler
=
CostProfiler
::
instance
();
profiler
.
register_profiler
(
"pserver_sparse_update_all"
);
profiler
.
register_profiler
(
"pserver_sparse_update_all"
);
profiler
.
register_profiler
(
"pserver_sparse_select_all"
);
profiler
.
register_profiler
(
"pserver_sparse_select_all"
);
InitializeValue
();
InitializeValue
();
_shards_task_pool
.
resize
(
_task_pool_size
);
_shards_task_pool
.
resize
(
_task_pool_size
);
for
(
in
t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
_shards_task_pool
.
size
();
++
i
)
{
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
_shards_task_pool
[
i
].
reset
(
new
::
ThreadPool
(
1
));
}
}
VLOG
(
0
)
<<
"initalize MemorySparseTable succ"
;
VLOG
(
0
)
<<
"initalize MemorySparseTable succ"
;
...
@@ -102,8 +102,8 @@ int32_t MemorySparseTable::InitializeValue() {
...
@@ -102,8 +102,8 @@ int32_t MemorySparseTable::InitializeValue() {
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
Load
(
const
std
::
string
&
path
,
int32_t
MemorySparseTable
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
param
)
{
const
std
::
string
&
param
)
{
std
::
string
table_path
=
TableDir
(
path
);
std
::
string
table_path
=
TableDir
(
path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
auto
file_list
=
_afs_client
.
list
(
table_path
);
...
@@ -157,13 +157,13 @@ int32_t MemorySparseTable::Load(const std::string& path,
...
@@ -157,13 +157,13 @@ int32_t MemorySparseTable::Load(const std::string& path,
err_no
=
0
;
err_no
=
0
;
std
::
string
line_data
;
std
::
string
line_data
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
char
*
end
=
NULL
;
char
*
end
=
NULL
;
auto
&
shard
=
_local_shards
[
i
];
auto
&
shard
=
_local_shards
[
i
];
try
{
try
{
while
(
read_channel
->
read_line
(
line_data
)
==
0
&&
while
(
read_channel
->
read_line
(
line_data
)
==
0
&&
line_data
.
size
()
>
1
)
{
line_data
.
size
()
>
1
)
{
uint64_t
key
=
std
::
strtoul
(
line_data
.
data
(),
&
end
,
10
);
uint64_t
key
=
std
::
strtoul
(
line_data
.
data
(),
&
end
,
10
);
auto
&
value
=
shard
[
key
];
auto
&
value
=
shard
[
key
];
value
.
resize
(
feature_value_size
);
value
.
resize
(
feature_value_size
);
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
value
.
resize
(
parse_size
);
value
.
resize
(
parse_size
);
...
@@ -200,7 +200,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
...
@@ -200,7 +200,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
LoadPatch
(
const
std
::
vector
<
std
::
string
>
&
file_list
,
int32_t
MemorySparseTable
::
LoadPatch
(
const
std
::
vector
<
std
::
string
>
&
file_list
,
int
load_param
)
{
int
load_param
)
{
if
(
!
_config
.
enable_revert
())
{
if
(
!
_config
.
enable_revert
())
{
LOG
(
INFO
)
<<
"MemorySparseTable should be enabled revert."
;
LOG
(
INFO
)
<<
"MemorySparseTable should be enabled revert."
;
...
@@ -213,7 +213,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
...
@@ -213,7 +213,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
int
o_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
int
o_start_idx
=
_shard_idx
*
_avg_local_shard_num
;
int
o_end_idx
=
o_start_idx
+
_real_local_shard_num
;
int
o_end_idx
=
o_start_idx
+
_real_local_shard_num
;
if
(
start_idx
>=
file_list
.
size
(
))
{
if
(
start_idx
>=
static_cast
<
int
>
(
file_list
.
size
()
))
{
return
0
;
return
0
;
}
}
size_t
feature_value_size
=
size_t
feature_value_size
=
...
@@ -224,7 +224,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
...
@@ -224,7 +224,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
omp_set_num_threads
(
thread_num
);
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for schedule(dynamic)
for
(
size_
t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
for
(
in
t
i
=
start_idx
;
i
<
end_idx
;
++
i
)
{
FsChannelConfig
channel_config
;
FsChannelConfig
channel_config
;
channel_config
.
path
=
file_list
[
i
];
channel_config
.
path
=
file_list
[
i
];
channel_config
.
converter
=
_value_accesor
->
Converter
(
load_param
).
converter
;
channel_config
.
converter
=
_value_accesor
->
Converter
(
load_param
).
converter
;
...
@@ -239,11 +239,11 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
...
@@ -239,11 +239,11 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
err_no
=
0
;
err_no
=
0
;
std
::
string
line_data
;
std
::
string
line_data
;
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
auto
read_channel
=
_afs_client
.
open_r
(
channel_config
,
0
,
&
err_no
);
char
*
end
=
NULL
;
char
*
end
=
NULL
;
int
m_local_shard_id
=
i
%
_m_avg_local_shard_num
;
int
m_local_shard_id
=
i
%
_m_avg_local_shard_num
;
std
::
unordered_set
<
size_t
>
global_shard_idx
;
std
::
unordered_set
<
size_t
>
global_shard_idx
;
std
::
string
global_shard_idx_str
;
std
::
string
global_shard_idx_str
;
for
(
size_
t
j
=
o_start_idx
;
j
<
o_end_idx
;
++
j
)
{
for
(
in
t
j
=
o_start_idx
;
j
<
o_end_idx
;
++
j
)
{
if
((
j
%
_avg_local_shard_num
)
%
_m_real_local_shard_num
==
if
((
j
%
_avg_local_shard_num
)
%
_m_real_local_shard_num
==
m_local_shard_id
)
{
m_local_shard_id
)
{
global_shard_idx
.
insert
(
j
);
global_shard_idx
.
insert
(
j
);
...
@@ -267,9 +267,9 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
...
@@ -267,9 +267,9 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
continue
;
continue
;
}
}
size_t
local_shard_idx
=
*
index_iter
%
_avg_local_shard_num
;
size_t
local_shard_idx
=
*
index_iter
%
_avg_local_shard_num
;
auto
&
shard
=
_local_shards
[
local_shard_idx
];
auto
&
shard
=
_local_shards
[
local_shard_idx
];
auto
&
value
=
shard
[
key
];
auto
&
value
=
shard
[
key
];
value
.
resize
(
feature_value_size
);
value
.
resize
(
feature_value_size
);
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
int
parse_size
=
_value_accesor
->
ParseFromString
(
++
end
,
value
.
data
());
value
.
resize
(
parse_size
);
value
.
resize
(
parse_size
);
...
@@ -300,7 +300,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
...
@@ -300,7 +300,7 @@ int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
}
}
void
MemorySparseTable
::
Revert
()
{
void
MemorySparseTable
::
Revert
()
{
for
(
size_
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
for
(
in
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
_local_shards_new
[
i
].
clear
();
_local_shards_new
[
i
].
clear
();
}
}
}
}
...
@@ -309,8 +309,8 @@ void MemorySparseTable::CheckSavePrePatchDone() {
...
@@ -309,8 +309,8 @@ void MemorySparseTable::CheckSavePrePatchDone() {
_save_patch_model_thread
.
join
();
_save_patch_model_thread
.
join
();
}
}
int32_t
MemorySparseTable
::
Save
(
const
std
::
string
&
dirname
,
int32_t
MemorySparseTable
::
Save
(
const
std
::
string
&
dirname
,
const
std
::
string
&
param
)
{
const
std
::
string
&
param
)
{
if
(
_real_local_shard_num
==
0
)
{
if
(
_real_local_shard_num
==
0
)
{
_local_show_threshold
=
-
1
;
_local_show_threshold
=
-
1
;
return
0
;
return
0
;
...
@@ -368,7 +368,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
...
@@ -368,7 +368,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
int
feasign_size
=
0
;
int
feasign_size
=
0
;
int
retry_num
=
0
;
int
retry_num
=
0
;
int
err_no
=
0
;
int
err_no
=
0
;
auto
&
shard
=
_local_shards
[
i
];
auto
&
shard
=
_local_shards
[
i
];
do
{
do
{
err_no
=
0
;
err_no
=
0
;
feasign_size
=
0
;
feasign_size
=
0
;
...
@@ -426,7 +426,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
...
@@ -426,7 +426,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
SavePatch
(
const
std
::
string
&
path
,
int
save_param
)
{
int32_t
MemorySparseTable
::
SavePatch
(
const
std
::
string
&
path
,
int
save_param
)
{
if
(
!
_config
.
enable_revert
())
{
if
(
!
_config
.
enable_revert
())
{
LOG
(
INFO
)
<<
"MemorySparseTable should be enabled revert."
;
LOG
(
INFO
)
<<
"MemorySparseTable should be enabled revert."
;
return
0
;
return
0
;
...
@@ -441,7 +441,7 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
...
@@ -441,7 +441,7 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
omp_set_num_threads
(
thread_num
);
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for schedule(dynamic)
for
(
size_
t
i
=
0
;
i
<
_m_real_local_shard_num
;
++
i
)
{
for
(
in
t
i
=
0
;
i
<
_m_real_local_shard_num
;
++
i
)
{
FsChannelConfig
channel_config
;
FsChannelConfig
channel_config
;
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d"
,
channel_config
.
path
=
paddle
::
string
::
format_string
(
"%s/part-%03d-%05d"
,
table_path
.
c_str
(),
table_path
.
c_str
(),
...
@@ -463,9 +463,9 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
...
@@ -463,9 +463,9 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
auto
write_channel
=
auto
write_channel
=
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
_afs_client
.
open_w
(
channel_config
,
1024
*
1024
*
40
,
&
err_no
);
for
(
size_
t
j
=
0
;
j
<
_real_local_shard_num
;
++
j
)
{
for
(
in
t
j
=
0
;
j
<
_real_local_shard_num
;
++
j
)
{
if
(
j
%
_m_real_local_shard_num
==
i
)
{
if
(
j
%
_m_real_local_shard_num
==
i
)
{
auto
&
shard
=
_local_shards_patch_model
[
j
];
auto
&
shard
=
_local_shards_patch_model
[
j
];
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();
++
it
)
{
if
(
_value_accesor
->
Save
(
it
.
value
().
data
(),
save_param
))
{
if
(
_value_accesor
->
Save
(
it
.
value
().
data
(),
save_param
))
{
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
std
::
string
format_value
=
_value_accesor
->
ParseToString
(
...
@@ -515,14 +515,14 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
...
@@ -515,14 +515,14 @@ int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
}
}
int64_t
MemorySparseTable
::
CacheShuffle
(
int64_t
MemorySparseTable
::
CacheShuffle
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
const
std
::
string
&
param
,
const
std
::
string
&
param
,
double
cache_threshold
,
double
cache_threshold
,
std
::
function
<
std
::
future
<
int32_t
>
(
std
::
function
<
std
::
future
<
int32_t
>
(
int
msg_type
,
int
to_pserver_id
,
std
::
string
&
msg
)
>
send_msg_func
,
int
msg_type
,
int
to_pserver_id
,
std
::
string
&
msg
)
>
send_msg_func
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
shuffled_channel
,
&
shuffled_channel
,
const
std
::
vector
<
Table
*>&
table_ptrs
)
{
const
std
::
vector
<
Table
*>
&
table_ptrs
)
{
LOG
(
INFO
)
<<
"cache shuffle with cache threshold: "
<<
cache_threshold
;
LOG
(
INFO
)
<<
"cache shuffle with cache threshold: "
<<
cache_threshold
;
int
save_param
=
atoi
(
param
.
c_str
());
// batch_model:0 xbox:1
int
save_param
=
atoi
(
param
.
c_str
());
// batch_model:0 xbox:1
if
(
!
_config
.
enable_sparse_table_cache
()
||
cache_threshold
<
0
)
{
if
(
!
_config
.
enable_sparse_table_cache
()
||
cache_threshold
<
0
)
{
...
@@ -546,22 +546,22 @@ int64_t MemorySparseTable::CacheShuffle(
...
@@ -546,22 +546,22 @@ int64_t MemorySparseTable::CacheShuffle(
int
feasign_size
=
0
;
int
feasign_size
=
0
;
std
::
vector
<
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>>
std
::
vector
<
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>>
tmp_channels
;
tmp_channels
;
for
(
size_
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
for
(
in
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
tmp_channels
.
push_back
(
tmp_channels
.
push_back
(
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
());
paddle
::
framework
::
MakeChannel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
());
}
}
omp_set_num_threads
(
thread_num
);
omp_set_num_threads
(
thread_num
);
#pragma omp parallel for schedule(dynamic)
#pragma omp parallel for schedule(dynamic)
for
(
size_
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
for
(
in
t
i
=
0
;
i
<
_real_local_shard_num
;
++
i
)
{
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
writer
=
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
writer
=
writers
[
i
];
writers
[
i
];
writer
.
Reset
(
tmp_channels
[
i
].
get
());
writer
.
Reset
(
tmp_channels
[
i
].
get
());
for
(
size_t
idx
=
0
;
idx
<
table_ptrs
.
size
();
idx
++
)
{
for
(
size_t
idx
=
0
;
idx
<
table_ptrs
.
size
();
idx
++
)
{
Table
*
table_ptr
=
table_ptrs
[
idx
];
Table
*
table_ptr
=
table_ptrs
[
idx
];
auto
value_accesor
=
table_ptr
->
ValueAccesor
();
auto
value_accesor
=
table_ptr
->
ValueAccesor
();
shard_type
*
shard_ptr
=
static_cast
<
shard_type
*>
(
table_ptr
->
GetShard
(
i
));
shard_type
*
shard_ptr
=
static_cast
<
shard_type
*>
(
table_ptr
->
GetShard
(
i
));
for
(
auto
it
=
shard_ptr
->
begin
();
it
!=
shard_ptr
->
end
();
++
it
)
{
for
(
auto
it
=
shard_ptr
->
begin
();
it
!=
shard_ptr
->
end
();
++
it
)
{
if
(
value_accesor
->
SaveCache
(
if
(
value_accesor
->
SaveCache
(
...
@@ -581,14 +581,14 @@ int64_t MemorySparseTable::CacheShuffle(
...
@@ -581,14 +581,14 @@ int64_t MemorySparseTable::CacheShuffle(
// size: " << feasign_size << " and start sparse cache data shuffle real local
// size: " << feasign_size << " and start sparse cache data shuffle real local
// shard num: " << _real_local_shard_num;
// shard num: " << _real_local_shard_num;
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
local_datas
;
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
local_datas
;
for
(
size_
t
idx_shard
=
0
;
idx_shard
<
_real_local_shard_num
;
++
idx_shard
)
{
for
(
in
t
idx_shard
=
0
;
idx_shard
<
_real_local_shard_num
;
++
idx_shard
)
{
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
writer
=
paddle
::
framework
::
ChannelWriter
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
writer
=
writers
[
idx_shard
];
writers
[
idx_shard
];
auto
channel
=
writer
.
channel
();
auto
channel
=
writer
.
channel
();
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
data
=
datas
[
idx_shard
];
std
::
vector
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
data
=
datas
[
idx_shard
];
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
shuffle_node_num
);
std
::
vector
<
paddle
::
framework
::
BinaryArchive
>
ars
(
shuffle_node_num
);
while
(
channel
->
Read
(
data
))
{
while
(
channel
->
Read
(
data
))
{
for
(
auto
&
t
:
data
)
{
for
(
auto
&
t
:
data
)
{
auto
pserver_id
=
auto
pserver_id
=
paddle
::
distributed
::
local_random_engine
()()
%
shuffle_node_num
;
paddle
::
distributed
::
local_random_engine
()()
%
shuffle_node_num
;
if
(
pserver_id
!=
_shard_idx
)
{
if
(
pserver_id
!=
_shard_idx
)
{
...
@@ -604,9 +604,9 @@ int64_t MemorySparseTable::CacheShuffle(
...
@@ -604,9 +604,9 @@ int64_t MemorySparseTable::CacheShuffle(
send_index
[
i
]
=
i
;
send_index
[
i
]
=
i
;
}
}
std
::
random_shuffle
(
send_index
.
begin
(),
send_index
.
end
());
std
::
random_shuffle
(
send_index
.
begin
(),
send_index
.
end
());
for
(
auto
index
=
0u
;
index
<
shuffle_node_num
;
++
index
)
{
for
(
int
index
=
0
;
index
<
shuffle_node_num
;
++
index
)
{
int
i
=
send_index
[
index
];
int
i
=
send_index
[
index
];
if
(
i
==
_shard_idx
)
{
if
(
i
==
static_cast
<
int
>
(
_shard_idx
)
)
{
continue
;
continue
;
}
}
if
(
ars
[
i
].
Length
()
==
0
)
{
if
(
ars
[
i
].
Length
()
==
0
)
{
...
@@ -617,7 +617,7 @@ int64_t MemorySparseTable::CacheShuffle(
...
@@ -617,7 +617,7 @@ int64_t MemorySparseTable::CacheShuffle(
total_status
.
push_back
(
std
::
move
(
ret
));
total_status
.
push_back
(
std
::
move
(
ret
));
send_data_size
[
i
]
+=
ars
[
i
].
Length
();
send_data_size
[
i
]
+=
ars
[
i
].
Length
();
}
}
for
(
auto
&
t
:
total_status
)
{
for
(
auto
&
t
:
total_status
)
{
t
.
wait
();
t
.
wait
();
}
}
ars
.
clear
();
ars
.
clear
();
...
@@ -630,10 +630,10 @@ int64_t MemorySparseTable::CacheShuffle(
...
@@ -630,10 +630,10 @@ int64_t MemorySparseTable::CacheShuffle(
}
}
int32_t
MemorySparseTable
::
SaveCache
(
int32_t
MemorySparseTable
::
SaveCache
(
const
std
::
string
&
path
,
const
std
::
string
&
path
,
const
std
::
string
&
param
,
const
std
::
string
&
param
,
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
&
paddle
::
framework
::
Channel
<
std
::
pair
<
uint64_t
,
std
::
string
>>
shuffled_channel
)
{
&
shuffled_channel
)
{
if
(
_shard_idx
>=
_config
.
sparse_table_cache_file_num
())
{
if
(
_shard_idx
>=
_config
.
sparse_table_cache_file_num
())
{
return
0
;
return
0
;
}
}
...
@@ -656,7 +656,7 @@ int32_t MemorySparseTable::SaveCache(
...
@@ -656,7 +656,7 @@ int32_t MemorySparseTable::SaveCache(
bool
is_write_failed
=
false
;
bool
is_write_failed
=
false
;
shuffled_channel
->
Close
();
shuffled_channel
->
Close
();
while
(
shuffled_channel
->
Read
(
data
))
{
while
(
shuffled_channel
->
Read
(
data
))
{
for
(
auto
&
t
:
data
)
{
for
(
auto
&
t
:
data
)
{
++
feasign_size
;
++
feasign_size
;
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
if
(
0
!=
write_channel
->
write_line
(
paddle
::
string
::
format_string
(
"%lu %s"
,
t
.
first
,
t
.
second
.
c_str
())))
{
"%lu %s"
,
t
.
first
,
t
.
second
.
c_str
())))
{
...
@@ -695,7 +695,7 @@ int64_t MemorySparseTable::LocalMFSize() {
...
@@ -695,7 +695,7 @@ int64_t MemorySparseTable::LocalMFSize() {
tasks
[
shard_id
]
=
tasks
[
shard_id
]
=
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
_shards_task_pool
[
shard_id
%
_shards_task_pool
.
size
()]
->
enqueue
(
[
this
,
shard_id
,
&
size_arr
]()
->
int
{
[
this
,
shard_id
,
&
size_arr
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
for
(
auto
it
=
local_shard
.
begin
();
it
!=
local_shard
.
end
();
for
(
auto
it
=
local_shard
.
begin
();
it
!=
local_shard
.
end
();
++
it
)
{
++
it
)
{
if
(
_value_accesor
->
HasMF
(
it
.
value
().
size
()))
{
if
(
_value_accesor
->
HasMF
(
it
.
value
().
size
()))
{
...
@@ -720,20 +720,20 @@ std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
...
@@ -720,20 +720,20 @@ std::pair<int64_t, int64_t> MemorySparseTable::PrintTableStat() {
return
{
feasign_size
,
mf_size
};
return
{
feasign_size
,
mf_size
};
}
}
int32_t
MemorySparseTable
::
Pull
(
TableContext
&
context
)
{
int32_t
MemorySparseTable
::
Pull
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
CHECK
(
context
.
value_type
==
Sparse
);
if
(
context
.
use_ptr
)
{
if
(
context
.
use_ptr
)
{
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
char
**
pull_values
=
context
.
pull_context
.
ptr_values
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
const
uint64_t
*
keys
=
context
.
pull_context
.
keys
;
return
PullSparsePtr
(
pull_values
,
keys
,
context
.
num
);
return
PullSparsePtr
(
pull_values
,
keys
,
context
.
num
);
}
else
{
}
else
{
float
*
pull_values
=
context
.
pull_context
.
values
;
float
*
pull_values
=
context
.
pull_context
.
values
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
const
PullSparseValue
&
pull_value
=
context
.
pull_context
.
pull_value
;
return
PullSparse
(
pull_values
,
pull_value
);
return
PullSparse
(
pull_values
,
pull_value
);
}
}
}
}
int32_t
MemorySparseTable
::
Push
(
TableContext
&
context
)
{
int32_t
MemorySparseTable
::
Push
(
TableContext
&
context
)
{
CHECK
(
context
.
value_type
==
Sparse
);
CHECK
(
context
.
value_type
==
Sparse
);
if
(
!
context
.
use_ptr
)
{
if
(
!
context
.
use_ptr
)
{
return
PushSparse
(
return
PushSparse
(
...
@@ -745,8 +745,8 @@ int32_t MemorySparseTable::Push(TableContext& context) {
...
@@ -745,8 +745,8 @@ int32_t MemorySparseTable::Push(TableContext& context) {
}
}
}
}
int32_t
MemorySparseTable
::
PullSparse
(
float
*
pull_values
,
int32_t
MemorySparseTable
::
PullSparse
(
float
*
pull_values
,
const
PullSparseValue
&
pull_value
)
{
const
PullSparseValue
&
pull_value
)
{
CostTimer
timer
(
"pserver_sparse_select_all"
);
CostTimer
timer
(
"pserver_sparse_select_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
...
@@ -776,11 +776,11 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
...
@@ -776,11 +776,11 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
pull_values
,
pull_values
,
mf_value_size
,
mf_value_size
,
select_value_size
]()
->
int
{
select_value_size
]()
->
int
{
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
// NOLINT
float
data_buffer
[
value_size
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
float
*
data_buffer_ptr
=
data_buffer
;
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
i
++
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
auto
itr
=
local_shard
.
find
(
key
);
...
@@ -790,9 +790,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
...
@@ -790,9 +790,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
if
(
FLAGS_pserver_create_value_when_push
)
{
if
(
FLAGS_pserver_create_value_when_push
)
{
memset
(
data_buffer
,
0
,
sizeof
(
float
)
*
data_size
);
memset
(
data_buffer
,
0
,
sizeof
(
float
)
*
data_size
);
}
else
{
}
else
{
auto
&
feature_value
=
local_shard
[
key
];
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
feature_value
.
data
();
float
*
data_ptr
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
...
@@ -807,9 +807,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
...
@@ -807,9 +807,9 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
data_buffer
[
mf_idx
]
=
0.0
;
data_buffer
[
mf_idx
]
=
0.0
;
}
}
auto
offset
=
keys
[
i
].
second
;
auto
offset
=
keys
[
i
].
second
;
float
*
select_data
=
pull_values
+
select_value_size
*
offset
;
float
*
select_data
=
pull_values
+
select_value_size
*
offset
;
_value_accesor
->
Select
(
_value_accesor
->
Select
(
&
select_data
,
(
const
float
**
)
&
data_buffer_ptr
,
1
);
&
select_data
,
(
const
float
**
)
&
data_buffer_ptr
,
1
);
}
}
return
0
;
return
0
;
...
@@ -822,8 +822,8 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
...
@@ -822,8 +822,8 @@ int32_t MemorySparseTable::PullSparse(float* pull_values,
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
PullSparsePtr
(
char
**
pull_values
,
int32_t
MemorySparseTable
::
PullSparsePtr
(
char
**
pull_values
,
const
uint64_t
*
keys
,
const
uint64_t
*
keys
,
size_t
num
)
{
size_t
num
)
{
CostTimer
timer
(
"pscore_sparse_select_all"
);
CostTimer
timer
(
"pscore_sparse_select_all"
);
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
size_t
value_size
=
_value_accesor
->
GetAccessorInfo
().
size
/
sizeof
(
float
);
...
@@ -847,20 +847,20 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
...
@@ -847,20 +847,20 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
pull_values
,
pull_values
,
value_size
,
value_size
,
mf_value_size
]()
->
int
{
mf_value_size
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_size
];
// NOLINT
float
data_buffer
[
value_size
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
key
=
keys
[
i
].
first
;
auto
itr
=
local_shard
.
find
(
key
);
auto
itr
=
local_shard
.
find
(
key
);
size_t
data_size
=
value_size
-
mf_value_size
;
size_t
data_size
=
value_size
-
mf_value_size
;
FixedFeatureValue
*
ret
=
NULL
;
FixedFeatureValue
*
ret
=
NULL
;
if
(
itr
==
local_shard
.
end
())
{
if
(
itr
==
local_shard
.
end
())
{
// ++missed_keys;
// ++missed_keys;
auto
&
feature_value
=
local_shard
[
key
];
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
data_size
);
feature_value
.
resize
(
data_size
);
float
*
data_ptr
=
feature_value
.
data
();
float
*
data_ptr
=
feature_value
.
data
();
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
memcpy
(
data_ptr
,
data_buffer_ptr
,
data_size
*
sizeof
(
float
));
ret
=
&
feature_value
;
ret
=
&
feature_value
;
...
@@ -868,7 +868,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
...
@@ -868,7 +868,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
ret
=
itr
.
value_ptr
();
ret
=
itr
.
value_ptr
();
}
}
int
pull_data_idx
=
keys
[
i
].
second
;
int
pull_data_idx
=
keys
[
i
].
second
;
pull_values
[
pull_data_idx
]
=
reinterpret_cast
<
char
*>
(
ret
);
pull_values
[
pull_data_idx
]
=
reinterpret_cast
<
char
*>
(
ret
);
}
}
return
0
;
return
0
;
});
});
...
@@ -879,8 +879,8 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
...
@@ -879,8 +879,8 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
*
values
,
const
float
*
values
,
size_t
num
)
{
size_t
num
)
{
CostTimer
timer
(
"pserver_sparse_update_all"
);
CostTimer
timer
(
"pserver_sparse_update_all"
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
...
@@ -907,15 +907,15 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -907,15 +907,15 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
update_value_col
,
update_value_col
,
values
,
values
,
&
task_keys
]()
->
int
{
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard_new
=
_local_shards_new
[
shard_id
];
auto
&
local_shard_new
=
_local_shards_new
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
const
float
*
update_data
=
values
+
push_data_idx
*
update_value_col
;
values
+
push_data_idx
*
update_value_col
;
auto
itr
=
local_shard
.
find
(
key
);
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
itr
==
local_shard
.
end
())
{
...
@@ -924,7 +924,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -924,7 +924,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
continue
;
continue
;
}
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
feature_value
.
data
(),
memcpy
(
feature_value
.
data
(),
...
@@ -933,8 +933,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -933,8 +933,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
itr
=
local_shard
.
find
(
key
);
itr
=
local_shard
.
find
(
key
);
}
}
auto
&
feature_value
=
itr
.
value
();
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
float
*
value_data
=
feature_value
.
data
();
size_t
value_size
=
feature_value
.
size
();
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
...
@@ -952,7 +952,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -952,7 +952,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
memcpy
(
value_data
,
data_buffer_ptr
,
value_size
*
sizeof
(
float
));
}
}
if
(
_config
.
enable_revert
())
{
if
(
_config
.
enable_revert
())
{
FixedFeatureValue
*
feature_value_new
=
&
(
local_shard_new
[
key
]);
FixedFeatureValue
*
feature_value_new
=
&
(
local_shard_new
[
key
]);
auto
new_size
=
feature_value
.
size
();
auto
new_size
=
feature_value
.
size
();
feature_value_new
->
resize
(
new_size
);
feature_value_new
->
resize
(
new_size
);
memcpy
(
feature_value_new
->
data
(),
memcpy
(
feature_value_new
->
data
(),
...
@@ -970,8 +970,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -970,8 +970,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
return
0
;
return
0
;
}
}
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
int32_t
MemorySparseTable
::
PushSparse
(
const
uint64_t
*
keys
,
const
float
**
values
,
const
float
**
values
,
size_t
num
)
{
size_t
num
)
{
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
future
<
int
>>
tasks
(
_real_local_shard_num
);
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
std
::
vector
<
std
::
vector
<
std
::
pair
<
uint64_t
,
int
>>>
task_keys
(
...
@@ -996,14 +996,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -996,14 +996,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
update_value_col
,
update_value_col
,
values
,
values
,
&
task_keys
]()
->
int
{
&
task_keys
]()
->
int
{
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
keys
=
task_keys
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
auto
&
local_shard
=
_local_shards
[
shard_id
];
float
data_buffer
[
value_col
];
// NOLINT
float
data_buffer
[
value_col
];
// NOLINT
float
*
data_buffer_ptr
=
data_buffer
;
float
*
data_buffer_ptr
=
data_buffer
;
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
key
=
keys
[
i
].
first
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
uint64_t
push_data_idx
=
keys
[
i
].
second
;
const
float
*
update_data
=
values
[
push_data_idx
];
const
float
*
update_data
=
values
[
push_data_idx
];
auto
itr
=
local_shard
.
find
(
key
);
auto
itr
=
local_shard
.
find
(
key
);
if
(
itr
==
local_shard
.
end
())
{
if
(
itr
==
local_shard
.
end
())
{
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
if
(
FLAGS_pserver_enable_create_feasign_randomly
&&
...
@@ -1011,7 +1011,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -1011,7 +1011,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
continue
;
continue
;
}
}
auto
value_size
=
value_col
-
mf_value_col
;
auto
value_size
=
value_col
-
mf_value_col
;
auto
&
feature_value
=
local_shard
[
key
];
auto
&
feature_value
=
local_shard
[
key
];
feature_value
.
resize
(
value_size
);
feature_value
.
resize
(
value_size
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
_value_accesor
->
Create
(
&
data_buffer_ptr
,
1
);
memcpy
(
feature_value
.
data
(),
memcpy
(
feature_value
.
data
(),
...
@@ -1019,8 +1019,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -1019,8 +1019,8 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
value_size
*
sizeof
(
float
));
value_size
*
sizeof
(
float
));
itr
=
local_shard
.
find
(
key
);
itr
=
local_shard
.
find
(
key
);
}
}
auto
&
feature_value
=
itr
.
value
();
auto
&
feature_value
=
itr
.
value
();
float
*
value_data
=
feature_value
.
data
();
float
*
value_data
=
feature_value
.
data
();
size_t
value_size
=
feature_value
.
size
();
size_t
value_size
=
feature_value
.
size
();
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
if
(
value_size
==
value_col
)
{
// 已拓展到最大size, 则就地update
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
_value_accesor
->
Update
(
&
value_data
,
&
update_data
,
1
);
...
@@ -1048,12 +1048,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
...
@@ -1048,12 +1048,12 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
int32_t
MemorySparseTable
::
Flush
()
{
return
0
;
}
int32_t
MemorySparseTable
::
Flush
()
{
return
0
;
}
int32_t
MemorySparseTable
::
Shrink
(
const
std
::
string
&
param
)
{
int32_t
MemorySparseTable
::
Shrink
(
const
std
::
string
&
param
)
{
VLOG
(
0
)
<<
"MemorySparseTable::Shrink"
;
VLOG
(
0
)
<<
"MemorySparseTable::Shrink"
;
// TODO(zhaocaibei123): implement with multi-thread
// TODO(zhaocaibei123): implement with multi-thread
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
for
(
int
shard_id
=
0
;
shard_id
<
_real_local_shard_num
;
++
shard_id
)
{
// Shrink
// Shrink
auto
&
shard
=
_local_shards
[
shard_id
];
auto
&
shard
=
_local_shards
[
shard_id
];
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();)
{
for
(
auto
it
=
shard
.
begin
();
it
!=
shard
.
end
();)
{
if
(
_value_accesor
->
Shrink
(
it
.
value
().
data
()))
{
if
(
_value_accesor
->
Shrink
(
it
.
value
().
data
()))
{
it
=
shard
.
erase
(
it
);
it
=
shard
.
erase
(
it
);
...
...
paddle/fluid/distributed/ps/table/sparse_sgd_rule.cc
浏览文件 @
be273ea9
...
@@ -23,7 +23,7 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
...
@@ -23,7 +23,7 @@ DEFINE_bool(enable_show_scale_gradient, true, "enable show scale gradient");
namespace
paddle
{
namespace
paddle
{
namespace
distributed
{
namespace
distributed
{
void
SparseNaiveSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
void
SparseNaiveSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_embedding_dim
=
emb_dim
;
auto
naive_param
=
param
.
naive
();
auto
naive_param
=
param
.
naive
();
...
@@ -41,9 +41,9 @@ void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
...
@@ -41,9 +41,9 @@ void SparseNaiveSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
}
}
}
}
void
SparseNaiveSGDRule
::
UpdateValueWork
(
float
*
w
,
void
SparseNaiveSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
float
*
sgd
,
const
float
*
push_value
,
const
float
*
push_value
,
float
scale
)
{
float
scale
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
w
[
i
]
-=
learning_rate_
*
push_value
[
i
];
w
[
i
]
-=
learning_rate_
*
push_value
[
i
];
...
@@ -51,8 +51,8 @@ void SparseNaiveSGDRule::UpdateValueWork(float* w,
...
@@ -51,8 +51,8 @@ void SparseNaiveSGDRule::UpdateValueWork(float* w,
}
}
}
}
void
SparseNaiveSGDRule
::
InitValueWork
(
float
*
value
,
void
SparseNaiveSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
float
*
sgd
,
bool
zero_init
)
{
bool
zero_init
)
{
if
(
zero_init
)
{
if
(
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
...
@@ -68,7 +68,7 @@ void SparseNaiveSGDRule::InitValueWork(float* value,
...
@@ -68,7 +68,7 @@ void SparseNaiveSGDRule::InitValueWork(float* value,
}
}
}
}
}
}
void
SparseAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
void
SparseAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_embedding_dim
=
emb_dim
;
auto
adagrad_param
=
param
.
adagrad
();
auto
adagrad_param
=
param
.
adagrad
();
...
@@ -88,11 +88,11 @@ void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
...
@@ -88,11 +88,11 @@ void SparseAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
}
}
}
}
void
SparseAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
void
SparseAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
float
*
sgd
,
const
float
*
grad
,
const
float
*
grad
,
float
scale
)
{
float
scale
)
{
float
&
g2sum
=
sgd
[
G2SumIndex
()];
float
&
g2sum
=
sgd
[
G2SumIndex
()];
double
add_g2sum
=
0
;
double
add_g2sum
=
0
;
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
...
@@ -106,8 +106,8 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w,
...
@@ -106,8 +106,8 @@ void SparseAdaGradSGDRule::UpdateValueWork(float* w,
g2sum
+=
add_g2sum
/
_embedding_dim
;
g2sum
+=
add_g2sum
/
_embedding_dim
;
}
}
void
SparseAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
void
SparseAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
float
*
sgd
,
bool
zero_init
)
{
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
if
(
zero_init
)
{
...
@@ -125,7 +125,7 @@ void SparseAdaGradSGDRule::InitValueWork(float* value,
...
@@ -125,7 +125,7 @@ void SparseAdaGradSGDRule::InitValueWork(float* value,
sgd
[
G2SumIndex
()]
=
0
;
sgd
[
G2SumIndex
()]
=
0
;
}
}
void
StdAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
void
StdAdaGradSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_embedding_dim
=
emb_dim
;
auto
adagrad_param
=
param
.
adagrad
();
auto
adagrad_param
=
param
.
adagrad
();
...
@@ -145,12 +145,12 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
...
@@ -145,12 +145,12 @@ void StdAdaGradSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
}
}
}
}
void
StdAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
void
StdAdaGradSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
float
*
sgd
,
const
float
*
grad
,
const
float
*
grad
,
float
scale
)
{
float
scale
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
float
&
g2sum
=
sgd
[
G2SumIndex
()
+
i
];
float
&
g2sum
=
sgd
[
G2SumIndex
()
+
i
];
double
scaled_grad
=
grad
[
i
]
/
scale
;
double
scaled_grad
=
grad
[
i
]
/
scale
;
w
[
i
]
-=
learning_rate_
*
scaled_grad
*
w
[
i
]
-=
learning_rate_
*
scaled_grad
*
sqrt
(
_initial_g2sum
/
(
_initial_g2sum
+
g2sum
));
sqrt
(
_initial_g2sum
/
(
_initial_g2sum
+
g2sum
));
...
@@ -159,8 +159,8 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w,
...
@@ -159,8 +159,8 @@ void StdAdaGradSGDRule::UpdateValueWork(float* w,
}
}
}
}
void
StdAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
void
StdAdaGradSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
float
*
sgd
,
bool
zero_init
)
{
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
if
(
zero_init
)
{
...
@@ -178,7 +178,7 @@ void StdAdaGradSGDRule::InitValueWork(float* value,
...
@@ -178,7 +178,7 @@ void StdAdaGradSGDRule::InitValueWork(float* value,
}
}
}
}
void
SparseAdamSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
void
SparseAdamSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_embedding_dim
=
emb_dim
;
auto
adam_param
=
param
.
adam
();
auto
adam_param
=
param
.
adam
();
...
@@ -199,15 +199,15 @@ void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
...
@@ -199,15 +199,15 @@ void SparseAdamSGDRule::LoadConfig(const SparseCommonSGDRuleParameter& param,
}
}
}
}
void
SparseAdamSGDRule
::
UpdateValueWork
(
float
*
w
,
void
SparseAdamSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
float
*
sgd
,
const
float
*
grad
,
const
float
*
grad
,
float
scale
)
{
float
scale
)
{
float
*
gsum
=
sgd
+
GSumIndex
();
float
*
gsum
=
sgd
+
GSumIndex
();
float
*
g2sum
=
sgd
+
G2SumIndex
();
float
*
g2sum
=
sgd
+
G2SumIndex
();
float
*
beta1_pow
=
sgd
+
Beta1PowIndex
();
float
*
beta1_pow
=
sgd
+
Beta1PowIndex
();
float
*
beta2_pow
=
sgd
+
Beta2PowIndex
();
float
*
beta2_pow
=
sgd
+
Beta2PowIndex
();
const
float
*
g
=
grad
;
const
float
*
g
=
grad
;
float
lr
=
learning_rate_
;
float
lr
=
learning_rate_
;
float
beta1_pow_
=
*
beta1_pow
;
float
beta1_pow_
=
*
beta1_pow
;
...
@@ -227,8 +227,8 @@ void SparseAdamSGDRule::UpdateValueWork(float* w,
...
@@ -227,8 +227,8 @@ void SparseAdamSGDRule::UpdateValueWork(float* w,
(
*
beta2_pow
)
*=
_beta2_decay_rate
;
(
*
beta2_pow
)
*=
_beta2_decay_rate
;
}
}
void
SparseAdamSGDRule
::
InitValueWork
(
float
*
value
,
void
SparseAdamSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
float
*
sgd
,
bool
zero_init
)
{
bool
zero_init
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
if
(
zero_init
)
{
...
@@ -253,7 +253,7 @@ void SparseAdamSGDRule::InitValueWork(float* value,
...
@@ -253,7 +253,7 @@ void SparseAdamSGDRule::InitValueWork(float* value,
}
}
void
SparseSharedAdamSGDRule
::
LoadConfig
(
void
SparseSharedAdamSGDRule
::
LoadConfig
(
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
const
SparseCommonSGDRuleParameter
&
param
,
size_t
emb_dim
)
{
_embedding_dim
=
emb_dim
;
_embedding_dim
=
emb_dim
;
auto
adam_param
=
param
.
adam
();
auto
adam_param
=
param
.
adam
();
learning_rate_
=
adam_param
.
learning_rate
();
learning_rate_
=
adam_param
.
learning_rate
();
...
@@ -273,15 +273,15 @@ void SparseSharedAdamSGDRule::LoadConfig(
...
@@ -273,15 +273,15 @@ void SparseSharedAdamSGDRule::LoadConfig(
}
}
}
}
void
SparseSharedAdamSGDRule
::
UpdateValueWork
(
float
*
w
,
void
SparseSharedAdamSGDRule
::
UpdateValueWork
(
float
*
w
,
float
*
sgd
,
float
*
sgd
,
const
float
*
grad
,
const
float
*
grad
,
float
scale
)
{
float
scale
)
{
float
*
gsum
=
sgd
+
GSumIndex
();
float
*
gsum
=
sgd
+
GSumIndex
();
float
*
g2sum
=
sgd
+
G2SumIndex
();
float
*
g2sum
=
sgd
+
G2SumIndex
();
float
*
beta1_pow
=
sgd
+
Beta1PowIndex
();
float
*
beta1_pow
=
sgd
+
Beta1PowIndex
();
float
*
beta2_pow
=
sgd
+
Beta2PowIndex
();
float
*
beta2_pow
=
sgd
+
Beta2PowIndex
();
const
float
*
g
=
grad
;
const
float
*
g
=
grad
;
float
lr
=
learning_rate_
;
float
lr
=
learning_rate_
;
float
beta1_pow_
=
*
beta1_pow
;
float
beta1_pow_
=
*
beta1_pow
;
...
@@ -292,7 +292,7 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
...
@@ -292,7 +292,7 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
lr
*=
sqrt
(
1
-
beta2_pow_
)
/
(
1
-
beta1_pow_
);
lr
*=
sqrt
(
1
-
beta2_pow_
)
/
(
1
-
beta1_pow_
);
double
sum_gsum
=
0.0
;
double
sum_gsum
=
0.0
;
double
sum_g2sum
=
0.0
;
double
sum_g2sum
=
0.0
;
for
(
in
t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
for
(
size_
t
i
=
0
;
i
<
_embedding_dim
;
i
++
)
{
// Calculation
// Calculation
double
new_gsum
=
double
new_gsum
=
_beta1_decay_rate
*
gsum_
+
(
1
-
_beta1_decay_rate
)
*
g
[
i
];
_beta1_decay_rate
*
gsum_
+
(
1
-
_beta1_decay_rate
)
*
g
[
i
];
...
@@ -310,10 +310,10 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
...
@@ -310,10 +310,10 @@ void SparseSharedAdamSGDRule::UpdateValueWork(float* w,
(
*
beta2_pow
)
*=
_beta2_decay_rate
;
(
*
beta2_pow
)
*=
_beta2_decay_rate
;
}
}
void
SparseSharedAdamSGDRule
::
InitValueWork
(
float
*
value
,
void
SparseSharedAdamSGDRule
::
InitValueWork
(
float
*
value
,
float
*
sgd
,
float
*
sgd
,
bool
zero_init
)
{
bool
zero_init
)
{
for
(
in
t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
for
(
size_
t
i
=
0
;
i
<
_embedding_dim
;
++
i
)
{
if
(
zero_init
)
{
if
(
zero_init
)
{
value
[
i
]
=
0.0
;
value
[
i
]
=
0.0
;
BoundValue
(
value
[
i
]);
BoundValue
(
value
[
i
]);
...
@@ -327,7 +327,7 @@ void SparseSharedAdamSGDRule::InitValueWork(float* value,
...
@@ -327,7 +327,7 @@ void SparseSharedAdamSGDRule::InitValueWork(float* value,
}
}
}
}
// init rule gsum and g2sum
// init rule gsum and g2sum
for
(
in
t
i
=
GSumIndex
();
i
<
Beta1PowIndex
();
i
++
)
{
for
(
size_
t
i
=
GSumIndex
();
i
<
Beta1PowIndex
();
i
++
)
{
sgd
[
i
]
=
0.0
;
sgd
[
i
]
=
0.0
;
}
}
// init beta1_pow and beta2_pow
// init beta1_pow and beta2_pow
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录