Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
a4f898a5
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a4f898a5
编写于
8月 18, 2020
作者:
Y
Yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fixed pgl for pslib; fixed example in citation_network
上级
c820ca0a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
13 addition
and
11 deletion
+13
-11
examples/citation_benchmark/model.py
examples/citation_benchmark/model.py
+6
-8
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+7
-3
未找到文件。
examples/citation_benchmark/model.py
浏览文件 @
a4f898a5
...
...
@@ -72,9 +72,9 @@ class GAT(object):
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
if
phase
==
"train"
:
edge_dropout
=
0
else
:
edge_dropout
=
self
.
edge_dropout
else
:
edge_dropout
=
0
for
i
in
range
(
self
.
num_layers
):
ngw
=
pgl
.
sample
.
edge_drop
(
graph_wrapper
,
edge_dropout
)
...
...
@@ -113,9 +113,9 @@ class APPNP(object):
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
if
phase
==
"train"
:
edge_dropout
=
0
else
:
edge_dropout
=
self
.
edge_dropout
else
:
edge_dropout
=
0
for
i
in
range
(
self
.
num_layers
):
feature
=
L
.
dropout
(
...
...
@@ -169,9 +169,9 @@ class GCNII(object):
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
if
phase
==
"train"
:
edge_dropout
=
0
else
:
edge_dropout
=
self
.
edge_dropout
else
:
edge_dropout
=
0
for
i
in
range
(
self
.
num_layers
):
feature
=
L
.
fc
(
feature
,
self
.
hidden_size
,
act
=
"relu"
,
name
=
"lin%s"
%
i
)
...
...
@@ -191,5 +191,3 @@ class GCNII(object):
feature
=
L
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
name
=
"output"
)
return
feature
pgl/graph_wrapper.py
浏览文件 @
a4f898a5
...
...
@@ -774,7 +774,8 @@ class BatchGraphWrapper(BaseGraphWrapper):
num_edges (int32 or int64): Shape [ num_graph ].
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
edges (int32 or int64): Shape [ total_num_edges_in_the_graphs, 2 ]
or Tuple with (src, dst).
node_feats: A dictionary for node features. Each value should be tensor
with shape [ total_num_nodes_in_the_graphs, feature_size]
...
...
@@ -835,8 +836,11 @@ class BatchGraphWrapper(BaseGraphWrapper):
def
__build_edges
(
self
,
edges
,
node_shift
,
edge_lod
):
""" Merge subgraph edges.
"""
src
=
edges
[:,
0
]
dst
=
edges
[:,
1
]
if
len
(
edges
)
==
2
:
src
,
dst
=
edges
else
:
src
=
edges
[:,
0
]
dst
=
edges
[:,
1
]
src
=
L
.
reshape
(
src
,
[
-
1
])
dst
=
L
.
reshape
(
dst
,
[
-
1
])
src
=
paddle_helper
.
ensure_dtype
(
src
,
dtype
=
"int32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录