Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
3eb6d2a6
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3eb6d2a6
编写于
7月 28, 2020
作者:
Y
Yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add edge drop
上级
141fe25b
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
202 addition
and
49 deletion
+202
-49
examples/citation_benchmark/build_model.py
examples/citation_benchmark/build_model.py
+1
-6
examples/citation_benchmark/config/appnp.yaml
examples/citation_benchmark/config/appnp.yaml
+1
-0
examples/citation_benchmark/config/gat.yaml
examples/citation_benchmark/config/gat.yaml
+1
-0
examples/citation_benchmark/config/gcn.yaml
examples/citation_benchmark/config/gcn.yaml
+2
-1
examples/citation_benchmark/config/sgc.yaml
examples/citation_benchmark/config/sgc.yaml
+4
-0
examples/citation_benchmark/model.py
examples/citation_benchmark/model.py
+70
-11
examples/citation_benchmark/train.py
examples/citation_benchmark/train.py
+6
-5
pgl/__init__.py
pgl/__init__.py
+1
-0
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+74
-19
pgl/layers/conv.py
pgl/layers/conv.py
+18
-7
pgl/sample.py
pgl/sample.py
+7
-0
pgl/utils/paddle_helper.py
pgl/utils/paddle_helper.py
+17
-0
未找到文件。
examples/citation_benchmark/build_model.py
浏览文件 @
3eb6d2a6
...
@@ -13,7 +13,7 @@ def build_model(dataset, config, phase, main_prog):
...
@@ -13,7 +13,7 @@ def build_model(dataset, config, phase, main_prog):
GraphModel
=
getattr
(
model
,
config
.
model_name
)
GraphModel
=
getattr
(
model
,
config
.
model_name
)
m
=
GraphModel
(
config
=
config
,
num_class
=
dataset
.
num_classes
)
m
=
GraphModel
(
config
=
config
,
num_class
=
dataset
.
num_classes
)
logits
=
m
.
forward
(
gw
,
gw
.
node_feat
[
"words"
])
logits
=
m
.
forward
(
gw
,
gw
.
node_feat
[
"words"
]
,
phase
)
node_index
=
fluid
.
layers
.
data
(
node_index
=
fluid
.
layers
.
data
(
"node_index"
,
"node_index"
,
...
@@ -33,11 +33,6 @@ def build_model(dataset, config, phase, main_prog):
...
@@ -33,11 +33,6 @@ def build_model(dataset, config, phase, main_prog):
loss
=
fluid
.
layers
.
mean
(
loss
)
loss
=
fluid
.
layers
.
mean
(
loss
)
if
phase
==
"train"
:
if
phase
==
"train"
:
#adam = fluid.optimizer.Adam(
# learning_rate=config.learning_rate,
# regularization=fluid.regularizer.L2DecayRegularizer(
# regularization_coeff=config.weight_decay))
#adam.minimize(loss)
AdamW
(
loss
=
loss
,
AdamW
(
loss
=
loss
,
learning_rate
=
config
.
learning_rate
,
learning_rate
=
config
.
learning_rate
,
weight_decay
=
config
.
weight_decay
,
weight_decay
=
config
.
weight_decay
,
...
...
examples/citation_benchmark/config/appnp.yaml
浏览文件 @
3eb6d2a6
...
@@ -6,3 +6,4 @@ learning_rate: 0.01
...
@@ -6,3 +6,4 @@ learning_rate: 0.01
dropout
:
0.5
dropout
:
0.5
hidden_size
:
64
hidden_size
:
64
weight_decay
:
0.0005
weight_decay
:
0.0005
edge_dropout
:
0.00
examples/citation_benchmark/config/gat.yaml
浏览文件 @
3eb6d2a6
...
@@ -6,3 +6,4 @@ feat_drop: 0.6
...
@@ -6,3 +6,4 @@ feat_drop: 0.6
attn_drop
:
0.6
attn_drop
:
0.6
num_heads
:
8
num_heads
:
8
hidden_size
:
8
hidden_size
:
8
edge_dropout
:
0.1
examples/citation_benchmark/config/gcn.yaml
浏览文件 @
3eb6d2a6
model_name
:
GCN
model_name
:
GCN
num_layers
:
1
num_layers
:
1
dropout
:
0.5
dropout
:
0.5
hidden_size
:
64
hidden_size
:
16
learning_rate
:
0.01
learning_rate
:
0.01
weight_decay
:
0.0005
weight_decay
:
0.0005
edge_dropout
:
0.0
examples/citation_benchmark/config/sgc.yaml
0 → 100644
浏览文件 @
3eb6d2a6
model_name
:
SGC
num_layers
:
2
learning_rate
:
0.2
weight_decay
:
0.000005
examples/citation_benchmark/model.py
浏览文件 @
3eb6d2a6
...
@@ -2,6 +2,12 @@ import pgl
...
@@ -2,6 +2,12 @@ import pgl
import
paddle.fluid.layers
as
L
import
paddle.fluid.layers
as
L
import
pgl.layers.conv
as
conv
import
pgl.layers.conv
as
conv
def
get_norm
(
indegree
):
norm
=
L
.
pow
(
L
.
cast
(
indegree
,
dtype
=
"float32"
)
+
1e-6
,
factor
=-
0.5
)
norm
=
norm
*
L
.
cast
(
indegree
>
0
,
dtype
=
"float32"
)
return
norm
class
GCN
(
object
):
class
GCN
(
object
):
"""Implement of GCN
"""Implement of GCN
"""
"""
...
@@ -10,14 +16,29 @@ class GCN(object):
...
@@ -10,14 +16,29 @@ class GCN(object):
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
64
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
64
)
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
self
.
edge_dropout
=
config
.
get
(
"edge_dropout"
,
0.0
)
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
feature
=
pgl
.
layers
.
gcn
(
graph_wrapper
,
if
phase
==
"train"
:
ngw
=
pgl
.
sample
.
edge_drop
(
graph_wrapper
,
self
.
edge_dropout
)
norm
=
get_norm
(
ngw
.
indegree
())
else
:
ngw
=
graph_wrapper
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
feature
=
pgl
.
layers
.
gcn
(
ngw
,
feature
,
feature
,
self
.
hidden_size
,
self
.
hidden_size
,
activation
=
"relu"
,
activation
=
"relu"
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
,
norm
=
norm
,
name
=
"layer_%s"
%
i
)
name
=
"layer_%s"
%
i
)
feature
=
L
.
dropout
(
feature
=
L
.
dropout
(
...
@@ -25,11 +46,18 @@ class GCN(object):
...
@@ -25,11 +46,18 @@ class GCN(object):
self
.
dropout
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
dropout_implementation
=
'upscale_in_train'
)
feature
=
conv
.
gcn
(
graph_wrapper
,
if
phase
==
"train"
:
ngw
=
pgl
.
sample
.
edge_drop
(
graph_wrapper
,
self
.
edge_dropout
)
norm
=
get_norm
(
ngw
.
indegree
())
else
:
ngw
=
graph_wrapper
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
feature
=
conv
.
gcn
(
ngw
,
feature
,
feature
,
self
.
num_class
,
self
.
num_class
,
activation
=
None
,
activation
=
None
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
,
norm
=
norm
,
name
=
"output"
)
name
=
"output"
)
return
feature
return
feature
...
@@ -43,10 +71,18 @@ class GAT(object):
...
@@ -43,10 +71,18 @@ class GAT(object):
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
8
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
8
)
self
.
feat_dropout
=
config
.
get
(
"feat_drop"
,
0.6
)
self
.
feat_dropout
=
config
.
get
(
"feat_drop"
,
0.6
)
self
.
attn_dropout
=
config
.
get
(
"attn_drop"
,
0.6
)
self
.
attn_dropout
=
config
.
get
(
"attn_drop"
,
0.6
)
self
.
edge_dropout
=
config
.
get
(
"edge_dropout"
,
0.0
)
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
if
phase
==
"train"
:
edge_dropout
=
0
else
:
edge_dropout
=
self
.
edge_dropout
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
feature
=
conv
.
gat
(
graph_wrapper
,
ngw
=
pgl
.
sample
.
edge_drop
(
graph_wrapper
,
edge_dropout
)
feature
=
conv
.
gat
(
ngw
,
feature
,
feature
,
self
.
hidden_size
,
self
.
hidden_size
,
activation
=
"elu"
,
activation
=
"elu"
,
...
@@ -55,7 +91,8 @@ class GAT(object):
...
@@ -55,7 +91,8 @@ class GAT(object):
feat_drop
=
self
.
feat_dropout
,
feat_drop
=
self
.
feat_dropout
,
attn_drop
=
self
.
attn_dropout
)
attn_drop
=
self
.
attn_dropout
)
feature
=
conv
.
gat
(
graph_wrapper
,
ngw
=
pgl
.
sample
.
edge_drop
(
graph_wrapper
,
edge_dropout
)
feature
=
conv
.
gat
(
ngw
,
feature
,
feature
,
self
.
num_class
,
self
.
num_class
,
num_heads
=
1
,
num_heads
=
1
,
...
@@ -75,8 +112,14 @@ class APPNP(object):
...
@@ -75,8 +112,14 @@ class APPNP(object):
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
self
.
alpha
=
config
.
get
(
"alpha"
,
0.1
)
self
.
alpha
=
config
.
get
(
"alpha"
,
0.1
)
self
.
k_hop
=
config
.
get
(
"k_hop"
,
10
)
self
.
k_hop
=
config
.
get
(
"k_hop"
,
10
)
self
.
edge_dropout
=
config
.
get
(
"edge_dropout"
,
0.0
)
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
if
phase
==
"train"
:
edge_dropout
=
0
else
:
edge_dropout
=
self
.
edge_dropout
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
feature
=
L
.
dropout
(
feature
=
L
.
dropout
(
feature
,
feature
,
...
@@ -93,8 +136,24 @@ class APPNP(object):
...
@@ -93,8 +136,24 @@ class APPNP(object):
feature
=
conv
.
appnp
(
graph_wrapper
,
feature
=
conv
.
appnp
(
graph_wrapper
,
feature
=
feature
,
feature
=
feature
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
,
edge_dropout
=
edge_dropout
,
alpha
=
self
.
alpha
,
alpha
=
self
.
alpha
,
k_hop
=
self
.
k_hop
)
k_hop
=
self
.
k_hop
)
return
feature
return
feature
class
SGC
(
object
):
"""Implement of SGC"""
def
__init__
(
self
,
config
,
num_class
):
self
.
num_class
=
num_class
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
feature
=
conv
.
appnp
(
graph_wrapper
,
feature
=
feature
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
],
alpha
=
0
,
k_hop
=
self
.
num_layers
)
feature
.
stop_gradient
=
True
feature
=
L
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
name
=
"output"
)
return
feature
examples/citation_benchmark/train.py
浏览文件 @
3eb6d2a6
...
@@ -63,6 +63,7 @@ def main(args, config):
...
@@ -63,6 +63,7 @@ def main(args, config):
config
=
config
,
config
=
config
,
phase
=
"test"
,
phase
=
"test"
,
main_prog
=
test_program
)
main_prog
=
test_program
)
test_program
=
test_program
.
clone
(
for_test
=
True
)
test_program
=
test_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
...
@@ -86,7 +87,7 @@ def main(args, config):
...
@@ -86,7 +87,7 @@ def main(args, config):
cal_val_acc
=
[]
cal_val_acc
=
[]
cal_test_acc
=
[]
cal_test_acc
=
[]
for
epoch
in
range
(
300
):
for
epoch
in
range
(
args
.
epoch
):
if
epoch
>=
3
:
if
epoch
>=
3
:
t0
=
time
.
time
()
t0
=
time
.
time
()
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
...
@@ -123,11 +124,10 @@ def main(args, config):
...
@@ -123,11 +124,10 @@ def main(args, config):
test_loss
=
test_loss
[
0
]
test_loss
=
test_loss
[
0
]
test_acc
=
test_acc
[
0
]
test_acc
=
test_acc
[
0
]
cal_test_acc
.
append
(
test_acc
)
cal_test_acc
.
append
(
test_acc
)
if
epoch
%
10
==
0
:
log
.
info
(
"Epoch %d "
%
epoch
+
log
.
info
(
"Epoch %d "
%
epoch
+
"Train Loss: %f "
%
train_loss
+
"Train Acc: %f "
%
train_acc
"Train Loss: %f "
%
train_loss
+
"Train Acc: %f "
%
train_acc
+
"Val Loss: %f "
%
val_loss
+
"Val Acc: %f "
%
val_acc
+
"Val Loss: %f "
%
val_loss
+
"Val Acc: %f "
%
val_acc
)
+
" Test Loss: %f "
%
test_loss
+
" Test Acc: %f "
%
test_acc
)
cal_val_acc
=
np
.
array
(
cal_val_acc
)
cal_val_acc
=
np
.
array
(
cal_val_acc
)
log
.
info
(
"Model: %s Best Test Accuracy: %f"
%
(
config
.
model_name
,
log
.
info
(
"Model: %s Best Test Accuracy: %f"
%
(
config
.
model_name
,
...
@@ -140,6 +140,7 @@ if __name__ == '__main__':
...
@@ -140,6 +140,7 @@ if __name__ == '__main__':
"--dataset"
,
type
=
str
,
default
=
"cora"
,
help
=
"dataset (cora, pubmed)"
)
"--dataset"
,
type
=
str
,
default
=
"cora"
,
help
=
"dataset (cora, pubmed)"
)
parser
.
add_argument
(
"--use_cuda"
,
action
=
'store_true'
,
help
=
"use_cuda"
)
parser
.
add_argument
(
"--use_cuda"
,
action
=
'store_true'
,
help
=
"use_cuda"
)
parser
.
add_argument
(
"--conf"
,
type
=
str
,
help
=
"config file for models"
)
parser
.
add_argument
(
"--conf"
,
type
=
str
,
help
=
"config file for models"
)
parser
.
add_argument
(
"--epoch"
,
type
=
int
,
default
=
200
,
help
=
"Epoch"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
config
=
edict
(
yaml
.
load
(
open
(
args
.
conf
),
Loader
=
yaml
.
FullLoader
))
config
=
edict
(
yaml
.
load
(
open
(
args
.
conf
),
Loader
=
yaml
.
FullLoader
))
log
.
info
(
args
)
log
.
info
(
args
)
...
...
pgl/__init__.py
浏览文件 @
3eb6d2a6
...
@@ -22,3 +22,4 @@ from pgl import heter_graph
...
@@ -22,3 +22,4 @@ from pgl import heter_graph
from
pgl
import
heter_graph_wrapper
from
pgl
import
heter_graph_wrapper
from
pgl
import
contrib
from
pgl
import
contrib
from
pgl
import
message_passing
from
pgl
import
message_passing
from
pgl
import
sample
pgl/graph_wrapper.py
浏览文件 @
3eb6d2a6
...
@@ -19,6 +19,7 @@ for PaddlePaddle.
...
@@ -19,6 +19,7 @@ for PaddlePaddle.
import
warnings
import
warnings
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
L
from
pgl.utils
import
op
from
pgl.utils
import
op
from
pgl.utils
import
paddle_helper
from
pgl.utils
import
paddle_helper
...
@@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try
:
try
:
out_dim
=
msg
.
shape
[
-
1
]
out_dim
=
msg
.
shape
[
-
1
]
init_output
=
fluid
.
layers
.
fill_constant
(
init_output
=
L
.
fill_constant
(
shape
=
[
num_nodes
,
out_dim
],
value
=
0
,
dtype
=
msg
.
dtype
)
shape
=
[
num_nodes
,
out_dim
],
value
=
0
,
dtype
=
msg
.
dtype
)
init_output
.
stop_gradient
=
False
init_output
.
stop_gradient
=
False
empty_msg_flag
=
fluid
.
layers
.
cast
(
num_edges
>
0
,
dtype
=
msg
.
dtype
)
empty_msg_flag
=
L
.
cast
(
num_edges
>
0
,
dtype
=
msg
.
dtype
)
msg
=
msg
*
empty_msg_flag
msg
=
msg
*
empty_msg_flag
output
=
paddle_helper
.
scatter_add
(
init_output
,
dst
,
msg
)
output
=
paddle_helper
.
scatter_add
(
init_output
,
dst
,
msg
)
return
output
return
output
...
@@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
"scatter_add is not supported with paddle version <= 1.5"
)
"scatter_add is not supported with paddle version <= 1.5"
)
def
sum_func
(
message
):
def
sum_func
(
message
):
return
fluid
.
layers
.
sequence_pool
(
message
,
"sum"
)
return
L
.
sequence_pool
(
message
,
"sum"
)
reduce_function
=
sum_func
reduce_function
=
sum_func
...
@@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
...
@@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
output
=
reduce_function
(
bucketed_msg
)
output
=
reduce_function
(
bucketed_msg
)
output_dim
=
output
.
shape
[
-
1
]
output_dim
=
output
.
shape
[
-
1
]
empty_msg_flag
=
fluid
.
layers
.
cast
(
num_edges
>
0
,
dtype
=
output
.
dtype
)
empty_msg_flag
=
L
.
cast
(
num_edges
>
0
,
dtype
=
output
.
dtype
)
output
=
output
*
empty_msg_flag
output
=
output
*
empty_msg_flag
init_output
=
fluid
.
layers
.
fill_constant
(
init_output
=
L
.
fill_constant
(
shape
=
[
num_nodes
,
output_dim
],
value
=
0
,
dtype
=
output
.
dtype
)
shape
=
[
num_nodes
,
output_dim
],
value
=
0
,
dtype
=
output
.
dtype
)
init_output
.
stop_gradient
=
True
init_output
.
stop_gradient
=
True
final_output
=
fluid
.
layers
.
scatter
(
init_output
,
uniq_dst
,
output
)
final_output
=
L
.
scatter
(
init_output
,
uniq_dst
,
output
)
return
final_output
return
final_output
...
@@ -104,6 +105,7 @@ class BaseGraphWrapper(object):
...
@@ -104,6 +105,7 @@ class BaseGraphWrapper(object):
self
.
_node_ids
=
None
self
.
_node_ids
=
None
self
.
_graph_lod
=
None
self
.
_graph_lod
=
None
self
.
_num_graph
=
None
self
.
_num_graph
=
None
self
.
_num_edges
=
None
self
.
_data_name_prefix
=
""
self
.
_data_name_prefix
=
""
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper):
...
@@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper):
class
GraphWrapper
(
BaseGraphWrapper
):
class
GraphWrapper
(
BaseGraphWrapper
):
"""Implement a graph wrapper that creates a graph data holders
"""Implement a graph wrapper that creates a graph data holders
that attributes and features in the graph are :code:`
fluid.layers
.data`.
that attributes and features in the graph are :code:`
L
.data`.
And we provide interface :code:`to_feed` to help converting :code:`Graph`
And we provide interface :code:`to_feed` to help converting :code:`Graph`
data into :code:`feed_dict`.
data into :code:`feed_dict`.
...
@@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper):
def
__create_graph_attr_holders
(
self
):
def
__create_graph_attr_holders
(
self
):
"""Create data holders for graph attributes.
"""Create data holders for graph attributes.
"""
"""
self
.
_num_edges
=
fluid
.
layers
.
data
(
self
.
_num_edges
=
L
.
data
(
self
.
_data_name_prefix
+
'/num_edges'
,
self
.
_data_name_prefix
+
'/num_edges'
,
shape
=
[
1
],
shape
=
[
1
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_num_graph
=
fluid
.
layers
.
data
(
self
.
_num_graph
=
L
.
data
(
self
.
_data_name_prefix
+
'/num_graph'
,
self
.
_data_name_prefix
+
'/num_graph'
,
shape
=
[
1
],
shape
=
[
1
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edges_src
=
fluid
.
layers
.
data
(
self
.
_edges_src
=
L
.
data
(
self
.
_data_name_prefix
+
'/edges_src'
,
self
.
_data_name_prefix
+
'/edges_src'
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edges_dst
=
fluid
.
layers
.
data
(
self
.
_edges_dst
=
L
.
data
(
self
.
_data_name_prefix
+
'/edges_dst'
,
self
.
_data_name_prefix
+
'/edges_dst'
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_num_nodes
=
fluid
.
layers
.
data
(
self
.
_num_nodes
=
L
.
data
(
self
.
_data_name_prefix
+
'/num_nodes'
,
self
.
_data_name_prefix
+
'/num_nodes'
,
shape
=
[
1
],
shape
=
[
1
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
'int64'
,
dtype
=
'int64'
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edge_uniq_dst
=
fluid
.
layers
.
data
(
self
.
_edge_uniq_dst
=
L
.
data
(
self
.
_data_name_prefix
+
"/uniq_dst"
,
self
.
_data_name_prefix
+
"/uniq_dst"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_graph_lod
=
fluid
.
layers
.
data
(
self
.
_graph_lod
=
L
.
data
(
self
.
_data_name_prefix
+
"/graph_lod"
,
self
.
_data_name_prefix
+
"/graph_lod"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int32"
,
dtype
=
"int32"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_edge_uniq_dst_count
=
fluid
.
layers
.
data
(
self
.
_edge_uniq_dst_count
=
L
.
data
(
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
self
.
_data_name_prefix
+
"/uniq_dst_count"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int32"
,
dtype
=
"int32"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_node_ids
=
fluid
.
layers
.
data
(
self
.
_node_ids
=
L
.
data
(
self
.
_data_name_prefix
+
"/node_ids"
,
self
.
_data_name_prefix
+
"/node_ids"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
dtype
=
"int64"
,
dtype
=
"int64"
,
stop_gradient
=
True
)
stop_gradient
=
True
)
self
.
_indegree
=
fluid
.
layers
.
data
(
self
.
_indegree
=
L
.
data
(
self
.
_data_name_prefix
+
"/indegree"
,
self
.
_data_name_prefix
+
"/indegree"
,
shape
=
[
None
],
shape
=
[
None
],
append_batch_size
=
False
,
append_batch_size
=
False
,
...
@@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper):
node_feat_dtype
):
node_feat_dtype
):
"""Create data holders for node features.
"""Create data holders for node features.
"""
"""
feat_holder
=
fluid
.
layers
.
data
(
feat_holder
=
L
.
data
(
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
self
.
_data_name_prefix
+
'/node_feat/'
+
node_feat_name
,
shape
=
node_feat_shape
,
shape
=
node_feat_shape
,
append_batch_size
=
False
,
append_batch_size
=
False
,
...
@@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat_dtype
):
edge_feat_dtype
):
"""Create edge holders for edge features.
"""Create edge holders for edge features.
"""
"""
feat_holder
=
fluid
.
layers
.
data
(
feat_holder
=
L
.
data
(
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
self
.
_data_name_prefix
+
'/edge_feat/'
+
edge_feat_name
,
shape
=
edge_feat_shape
,
shape
=
edge_feat_shape
,
append_batch_size
=
False
,
append_batch_size
=
False
,
...
@@ -719,3 +721,56 @@ class GraphWrapper(BaseGraphWrapper):
...
@@ -719,3 +721,56 @@ class GraphWrapper(BaseGraphWrapper):
"""Return the holder list.
"""Return the holder list.
"""
"""
return
self
.
_holder_list
return
self
.
_holder_list
def
get_degree
(
edge
,
num_nodes
):
init_output
=
L
.
fill_constant
(
shape
=
[
num_nodes
],
value
=
0
,
dtype
=
"float32"
)
init_output
.
stop_gradient
=
True
final_output
=
L
.
scatter
(
init_output
,
edge
,
L
.
full_like
(
edge
,
1
,
dtype
=
"float32"
),
overwrite
=
False
)
return
final_output
class
DropEdgeWrapper
(
BaseGraphWrapper
):
"""Implement of Edge Drop """
def
__init__
(
self
,
graph_wrapper
,
dropout
):
super
(
DropEdgeWrapper
,
self
).
__init__
()
# Copy Node's information
for
key
,
value
in
graph_wrapper
.
node_feat
.
items
():
self
.
node_feat_tensor_dict
[
key
]
=
value
self
.
_num_nodes
=
graph_wrapper
.
num_nodes
self
.
_graph_lod
=
graph_wrapper
.
graph_lod
self
.
_num_graph
=
graph_wrapper
.
num_graph
self
.
_node_ids
=
L
.
range
(
0
,
self
.
_num_nodes
,
step
=
1
,
dtype
=
"int32"
)
# Dropout Edges
src
,
dst
=
graph_wrapper
.
edges
u
=
L
.
uniform_random
(
shape
=
L
.
cast
(
L
.
shape
(
src
),
'int64'
),
min
=
0.
,
max
=
1.
)
# Avoid Empty Edges
keeped
=
L
.
cast
(
u
>
dropout
,
dtype
=
"float32"
)
self
.
_num_edges
=
L
.
reduce_sum
(
L
.
cast
(
keeped
,
"int32"
))
keeped
=
keeped
+
L
.
cast
(
self
.
_num_edges
==
0
,
dtype
=
"float32"
)
keeped
=
(
keeped
>
0.5
)
src
=
paddle_helper
.
masked_select
(
src
,
keeped
)
dst
=
paddle_helper
.
masked_select
(
dst
,
keeped
)
src
.
stop_gradient
=
True
dst
.
stop_gradient
=
True
self
.
_edges_src
=
src
self
.
_edges_dst
=
dst
for
key
,
value
in
graph_wrapper
.
edge_feat
.
items
():
self
.
edge_feat_tensor_dict
[
key
]
=
paddle_helper
.
masked_select
(
value
,
keeped
)
self
.
_edge_uniq_dst
,
_
,
uniq_count
=
L
.
unique_with_counts
(
dst
,
dtype
=
"int32"
)
self
.
_edge_uniq_dst
.
stop_gradient
=
True
last
=
L
.
reduce_sum
(
uniq_count
,
keep_dim
=
True
)
uniq_count
=
L
.
cumsum
(
uniq_count
,
exclusive
=
True
)
self
.
_edge_uniq_dst_count
=
L
.
concat
([
uniq_count
,
last
])
self
.
_edge_uniq_dst_count
.
stop_gradient
=
True
self
.
_indegree
=
get_degree
(
self
.
_edges_dst
,
self
.
_num_nodes
)
pgl/layers/conv.py
浏览文件 @
3eb6d2a6
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
"""This package implements common layers to help building
"""This package implements common layers to help building
graph neural networks.
graph neural networks.
"""
"""
import
pgl
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
pgl.utils
import
paddle_helper
from
pgl.utils
import
paddle_helper
from
pgl
import
message_passing
from
pgl
import
message_passing
...
@@ -404,7 +405,14 @@ def gen_conv(gw,
...
@@ -404,7 +405,14 @@ def gen_conv(gw,
return
output
return
output
def
appnp
(
gw
,
feature
,
norm
=
None
,
alpha
=
0.2
,
k_hop
=
10
):
def
get_norm
(
indegree
):
"""Get Laplacian Normalization"""
norm
=
fluid
.
layers
.
pow
(
fluid
.
layers
.
cast
(
indegree
,
dtype
=
"float32"
)
+
1e-6
,
factor
=-
0.5
)
norm
=
norm
*
fluid
.
layers
.
cast
(
indegree
>
0
,
dtype
=
"float32"
)
return
norm
def
appnp
(
gw
,
feature
,
edge_dropout
=
0
,
alpha
=
0.2
,
k_hop
=
10
):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
meet Personalized PageRank" (ICLR 2019).
...
@@ -413,8 +421,7 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10):
...
@@ -413,8 +421,7 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10):
feature: A tensor with shape (num_nodes, feature_size).
feature: A tensor with shape (num_nodes, feature_size).
norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
edge_dropout: Edge dropout rate.
be tensor with shape (num_nodes,) and dtype float32.
k_hop: K Steps for Propagation
k_hop: K Steps for Propagation
...
@@ -427,16 +434,20 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10):
...
@@ -427,16 +434,20 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10):
return
feature
return
feature
h0
=
feature
h0
=
feature
ngw
=
gw
norm
=
get_norm
(
ngw
.
indegree
())
for
i
in
range
(
k_hop
):
for
i
in
range
(
k_hop
):
if
norm
is
not
None
:
if
edge_dropout
>
1e-5
:
ngw
=
pgl
.
sample
.
edge_drop
(
gw
,
edge_dropout
)
norm
=
get_norm
(
ngw
.
indegree
())
feature
=
feature
*
norm
feature
=
feature
*
norm
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
feature
=
gw
.
recv
(
msg
,
"sum"
)
feature
=
gw
.
recv
(
msg
,
"sum"
)
if
norm
is
not
None
:
feature
=
feature
*
norm
feature
=
feature
*
norm
feature
=
feature
*
(
1
-
alpha
)
+
h0
*
alpha
feature
=
feature
*
(
1
-
alpha
)
+
h0
*
alpha
...
...
pgl/sample.py
浏览文件 @
3eb6d2a6
...
@@ -516,3 +516,10 @@ def graph_saint_random_walk_sample(graph,
...
@@ -516,3 +516,10 @@ def graph_saint_random_walk_sample(graph,
nodes
=
sample_nodes
,
eid
=
eids
,
with_node_feat
=
True
,
with_edge_feat
=
True
)
nodes
=
sample_nodes
,
eid
=
eids
,
with_node_feat
=
True
,
with_edge_feat
=
True
)
subgraph
.
node_feat
[
"index"
]
=
np
.
array
(
sample_nodes
,
dtype
=
"int64"
)
subgraph
.
node_feat
[
"index"
]
=
np
.
array
(
sample_nodes
,
dtype
=
"int64"
)
return
subgraph
return
subgraph
def
edge_drop
(
graph_wrapper
,
dropout_rate
):
if
dropout_rate
<
1e-5
:
return
graph_wrapper
else
:
return
pgl
.
graph_wrapper
.
DropEdgeWrapper
(
graph_wrapper
,
dropout_rate
)
pgl/utils/paddle_helper.py
浏览文件 @
3eb6d2a6
...
@@ -250,3 +250,20 @@ def scatter_max(input, index, updates):
...
@@ -250,3 +250,20 @@ def scatter_max(input, index, updates):
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'max'
)
output
=
fluid
.
layers
.
scatter
(
input
,
index
,
updates
,
mode
=
'max'
)
return
output
return
output
def
masked_select
(
input
,
mask
):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index
=
fluid
.
layers
.
where
(
mask
)
return
fluid
.
layers
.
gather
(
input
,
index
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录