Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
4e5c920a
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e5c920a
编写于
5月 19, 2020
作者:
S
suweiyue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
1. dataset with neg_type, 2. never ignore edges
上级
08da20a6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
5 deletion
+13
-5
examples/erniesage/dataset/graph_reader.py
examples/erniesage/dataset/graph_reader.py
+9
-3
examples/erniesage/train.py
examples/erniesage/train.py
+4
-2
未找到文件。
examples/erniesage/dataset/graph_reader.py
浏览文件 @
4e5c920a
...
...
@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class
GraphGenerator
(
BaseDataGenerator
):
def
__init__
(
self
,
graph_wrappers
,
data
,
batch_size
,
samples
,
num_workers
,
feed_name_list
,
use_pyreader
,
phase
,
graph_data_path
,
shuffle
=
True
,
buf_size
=
1000
):
phase
,
graph_data_path
,
shuffle
=
True
,
buf_size
=
1000
,
neg_type
=
"batch_neg"
):
super
(
GraphGenerator
,
self
).
__init__
(
buf_size
=
buf_size
,
...
...
@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self
.
phase
=
phase
self
.
load_graph
(
graph_data_path
)
self
.
num_layers
=
len
(
graph_wrappers
)
self
.
neg_type
=
neg_type
def
load_graph
(
self
,
graph_data_path
):
self
.
graph
=
pgl
.
graph
.
MemmapGraph
(
graph_data_path
)
...
...
@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src
=
np
.
array
(
batch_src
,
dtype
=
"int64"
)
batch_dst
=
np
.
array
(
batch_dst
,
dtype
=
"int64"
)
sampled_batch_neg
=
alias_sample
(
batch_dst
.
shape
,
self
.
alias
,
self
.
events
)
if
neg_type
==
"batch_neg"
:
neg_shape
=
[
1
]
else
:
neg_shape
=
batch_dst
.
shape
sampled_batch_neg
=
alias_sample
(
neg_shape
,
self
.
alias
,
self
.
events
)
if
len
(
batch_neg
)
>
0
:
batch_neg
=
np
.
concatenate
([
batch_neg
,
sampled_batch_neg
],
0
)
...
...
@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg
=
sampled_batch_neg
if
self
.
phase
==
"train"
:
ignore_edges
=
np
.
concatenate
([
np
.
stack
([
batch_src
,
batch_dst
],
1
),
np
.
stack
([
batch_dst
,
batch_src
],
1
)],
0
)
#ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges
=
set
()
else
:
ignore_edges
=
set
()
...
...
examples/erniesage/train.py
浏览文件 @
4e5c920a
...
...
@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS_NUM"
,
"1"
))
log
.
info
(
"trainer_id: %s, trainer_count: %s."
%
(
trainer_id
,
trainer_count
))
edges
=
np
.
load
(
os
.
path
.
join
(
graph_path
,
"edges.npy"
),
allow_pickle
=
True
)
bidirectional_
edges
=
np
.
load
(
os
.
path
.
join
(
graph_path
,
"edges.npy"
),
allow_pickle
=
True
)
# edges is bidirectional.
edges
=
bidirectional_edges
[
0
::
2
]
train_usr
=
edges
[
trainer_id
::
trainer_count
,
0
]
train_ad
=
edges
[
trainer_id
::
trainer_count
,
1
]
returns
=
{
...
...
@@ -73,7 +74,8 @@ def main(config):
use_pyreader
=
config
.
use_pyreader
,
phase
=
"train"
,
graph_data_path
=
config
.
graph_path
,
shuffle
=
True
)
shuffle
=
True
,
neg_type
=
config
.
neg_type
)
log
.
info
(
"build graph reader done."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录