Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
eeb267da
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
eeb267da
编写于
5月 19, 2020
作者:
W
Weiyue Su
提交者:
GitHub
5月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #79 from WeiyueSu/erniesage
Erniesage
上级
e68b8b25
7cb9cea8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
18 addition
and
10 deletion
+18
-10
examples/erniesage/config/erniesage_v2_cpu.yaml
examples/erniesage/config/erniesage_v2_cpu.yaml
+1
-1
examples/erniesage/config/erniesage_v2_gpu.yaml
examples/erniesage/config/erniesage_v2_gpu.yaml
+1
-1
examples/erniesage/dataset/graph_reader.py
examples/erniesage/dataset/graph_reader.py
+9
-3
examples/erniesage/models/base.py
examples/erniesage/models/base.py
+2
-2
examples/erniesage/models/erniesage_v2.py
examples/erniesage/models/erniesage_v2.py
+1
-1
examples/erniesage/train.py
examples/erniesage/train.py
+4
-2
未找到文件。
examples/erniesage/config/erniesage_v2_cpu.yaml
浏览文件 @
eeb267da
...
@@ -31,7 +31,7 @@ final_fc: true
...
@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm
:
true
final_l2_norm
:
true
loss_type
:
"
hinge"
loss_type
:
"
hinge"
margin
:
0.3
margin
:
0.3
neg_type
:
"
random
_neg"
neg_type
:
"
batch
_neg"
# infer config ------
# infer config ------
infer_model
:
"
./output/last"
infer_model
:
"
./output/last"
...
...
examples/erniesage/config/erniesage_v2_gpu.yaml
浏览文件 @
eeb267da
...
@@ -31,7 +31,7 @@ final_fc: true
...
@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm
:
true
final_l2_norm
:
true
loss_type
:
"
hinge"
loss_type
:
"
hinge"
margin
:
0.3
margin
:
0.3
neg_type
:
"
random
_neg"
neg_type
:
"
batch
_neg"
# infer config ------
# infer config ------
infer_model
:
"
./output/last"
infer_model
:
"
./output/last"
...
...
examples/erniesage/dataset/graph_reader.py
浏览文件 @
eeb267da
...
@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
...
@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class
GraphGenerator
(
BaseDataGenerator
):
class
GraphGenerator
(
BaseDataGenerator
):
def
__init__
(
self
,
graph_wrappers
,
data
,
batch_size
,
samples
,
def
__init__
(
self
,
graph_wrappers
,
data
,
batch_size
,
samples
,
num_workers
,
feed_name_list
,
use_pyreader
,
num_workers
,
feed_name_list
,
use_pyreader
,
phase
,
graph_data_path
,
shuffle
=
True
,
buf_size
=
1000
):
phase
,
graph_data_path
,
shuffle
=
True
,
buf_size
=
1000
,
neg_type
=
"batch_neg"
):
super
(
GraphGenerator
,
self
).
__init__
(
super
(
GraphGenerator
,
self
).
__init__
(
buf_size
=
buf_size
,
buf_size
=
buf_size
,
...
@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
...
@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self
.
phase
=
phase
self
.
phase
=
phase
self
.
load_graph
(
graph_data_path
)
self
.
load_graph
(
graph_data_path
)
self
.
num_layers
=
len
(
graph_wrappers
)
self
.
num_layers
=
len
(
graph_wrappers
)
self
.
neg_type
=
neg_type
def
load_graph
(
self
,
graph_data_path
):
def
load_graph
(
self
,
graph_data_path
):
self
.
graph
=
pgl
.
graph
.
MemmapGraph
(
graph_data_path
)
self
.
graph
=
pgl
.
graph
.
MemmapGraph
(
graph_data_path
)
...
@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
...
@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src
=
np
.
array
(
batch_src
,
dtype
=
"int64"
)
batch_src
=
np
.
array
(
batch_src
,
dtype
=
"int64"
)
batch_dst
=
np
.
array
(
batch_dst
,
dtype
=
"int64"
)
batch_dst
=
np
.
array
(
batch_dst
,
dtype
=
"int64"
)
sampled_batch_neg
=
alias_sample
(
batch_dst
.
shape
,
self
.
alias
,
self
.
events
)
if
self
.
neg_type
==
"batch_neg"
:
neg_shape
=
[
1
]
else
:
neg_shape
=
batch_dst
.
shape
sampled_batch_neg
=
alias_sample
(
neg_shape
,
self
.
alias
,
self
.
events
)
if
len
(
batch_neg
)
>
0
:
if
len
(
batch_neg
)
>
0
:
batch_neg
=
np
.
concatenate
([
batch_neg
,
sampled_batch_neg
],
0
)
batch_neg
=
np
.
concatenate
([
batch_neg
,
sampled_batch_neg
],
0
)
...
@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
...
@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg
=
sampled_batch_neg
batch_neg
=
sampled_batch_neg
if
self
.
phase
==
"train"
:
if
self
.
phase
==
"train"
:
ignore_edges
=
np
.
concatenate
([
np
.
stack
([
batch_src
,
batch_dst
],
1
),
np
.
stack
([
batch_dst
,
batch_src
],
1
)],
0
)
#ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges
=
set
()
else
:
else
:
ignore_edges
=
set
()
ignore_edges
=
set
()
...
...
examples/erniesage/models/base.py
浏览文件 @
eeb267da
...
@@ -191,12 +191,12 @@ def all_gather(X):
...
@@ -191,12 +191,12 @@ def all_gather(X):
for
i
in
range
(
trainer_num
):
for
i
in
range
(
trainer_num
):
copy_X
=
X
*
1
copy_X
=
X
*
1
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
=
L
.
collective
.
_broadcast
(
copy_X
,
i
,
True
)
copy_X
.
stop_gradient
s
=
True
copy_X
.
stop_gradient
=
True
Xs
.
append
(
copy_X
)
Xs
.
append
(
copy_X
)
if
len
(
Xs
)
>
1
:
if
len
(
Xs
)
>
1
:
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
=
L
.
concat
(
Xs
,
0
)
Xs
.
stop_gradient
s
=
True
Xs
.
stop_gradient
=
True
else
:
else
:
Xs
=
Xs
[
0
]
Xs
=
Xs
[
0
]
return
Xs
return
Xs
...
...
examples/erniesage/models/erniesage_v2.py
浏览文件 @
eeb267da
...
@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
...
@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
src_position_ids
=
L
.
expand
(
src_position_ids
,
[
src_batch
,
1
,
1
])
# [B, slot_seqlen * num_b, 1]
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
zero
=
L
.
fill_constant
([
1
],
dtype
=
'int64'
,
value
=
0
)
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
input_mask
=
L
.
cast
(
L
.
equal
(
src_ids
,
zero
),
"int32"
)
# assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
)
# [B, 1, 1]
src_pad_len
=
L
.
reduce_sum
(
input_mask
,
1
,
keep_dim
=
True
)
# [B, 1, 1]
dst_position_ids
=
L
.
reshape
(
dst_position_ids
=
L
.
reshape
(
L
.
range
(
L
.
range
(
...
...
examples/erniesage/train.py
浏览文件 @
eeb267da
...
@@ -32,8 +32,9 @@ class TrainData(object):
...
@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"1"
))
trainer_count
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"1"
))
log
.
info
(
"trainer_id: %s, trainer_count: %s."
%
(
trainer_id
,
trainer_count
))
log
.
info
(
"trainer_id: %s, trainer_count: %s."
%
(
trainer_id
,
trainer_count
))
edges
=
np
.
load
(
os
.
path
.
join
(
graph_path
,
"edges.npy"
),
allow_pickle
=
True
)
bidirectional_
edges
=
np
.
load
(
os
.
path
.
join
(
graph_path
,
"edges.npy"
),
allow_pickle
=
True
)
# edges is bidirectional.
# edges is bidirectional.
edges
=
bidirectional_edges
[
0
::
2
]
train_usr
=
edges
[
trainer_id
::
trainer_count
,
0
]
train_usr
=
edges
[
trainer_id
::
trainer_count
,
0
]
train_ad
=
edges
[
trainer_id
::
trainer_count
,
1
]
train_ad
=
edges
[
trainer_id
::
trainer_count
,
1
]
returns
=
{
returns
=
{
...
@@ -73,7 +74,8 @@ def main(config):
...
@@ -73,7 +74,8 @@ def main(config):
use_pyreader
=
config
.
use_pyreader
,
use_pyreader
=
config
.
use_pyreader
,
phase
=
"train"
,
phase
=
"train"
,
graph_data_path
=
config
.
graph_path
,
graph_data_path
=
config
.
graph_path
,
shuffle
=
True
)
shuffle
=
True
,
neg_type
=
config
.
neg_type
)
log
.
info
(
"build graph reader done."
)
log
.
info
(
"build graph reader done."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录