Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
b809f488
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b809f488
编写于
7月 30, 2020
作者:
Y
Yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
reproduce paper results
上级
3eb6d2a6
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
127 addition
and
172 deletion
+127
-172
examples/citation_benchmark/build_model.py
examples/citation_benchmark/build_model.py
+6
-5
examples/citation_benchmark/config/appnp.yaml
examples/citation_benchmark/config/appnp.yaml
+1
-1
examples/citation_benchmark/config/sgc.yaml
examples/citation_benchmark/config/sgc.yaml
+1
-0
examples/citation_benchmark/model.py
examples/citation_benchmark/model.py
+6
-9
examples/citation_benchmark/optimization.py
examples/citation_benchmark/optimization.py
+0
-55
examples/citation_benchmark/train.py
examples/citation_benchmark/train.py
+47
-44
pgl/graph_wrapper.py
pgl/graph_wrapper.py
+6
-1
pgl/layers/conv.py
pgl/layers/conv.py
+56
-55
pgl/sample.py
pgl/sample.py
+4
-2
未找到文件。
examples/citation_benchmark/build_model.py
浏览文件 @
b809f488
...
...
@@ -4,7 +4,6 @@ from pgl import data_loader
import
paddle.fluid
as
fluid
import
numpy
as
np
import
time
from
optimization
import
AdamW
def
build_model
(
dataset
,
config
,
phase
,
main_prog
):
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
...
...
@@ -15,6 +14,7 @@ def build_model(dataset, config, phase, main_prog):
m
=
GraphModel
(
config
=
config
,
num_class
=
dataset
.
num_classes
)
logits
=
m
.
forward
(
gw
,
gw
.
node_feat
[
"words"
],
phase
)
# Take the last
node_index
=
fluid
.
layers
.
data
(
"node_index"
,
shape
=
[
None
,
1
],
...
...
@@ -33,10 +33,11 @@ def build_model(dataset, config, phase, main_prog):
loss
=
fluid
.
layers
.
mean
(
loss
)
if
phase
==
"train"
:
AdamW
(
loss
=
loss
,
learning_rate
=
config
.
learning_rate
,
weight_decay
=
config
.
weight_decay
,
train_program
=
main_prog
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
config
.
learning_rate
,
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
config
.
weight_decay
))
adam
.
minimize
(
loss
)
return
gw
,
loss
,
acc
examples/citation_benchmark/config/appnp.yaml
浏览文件 @
b809f488
...
...
@@ -6,4 +6,4 @@ learning_rate: 0.01
dropout
:
0.5
hidden_size
:
64
weight_decay
:
0.0005
edge_dropout
:
0.0
0
edge_dropout
:
0.0
examples/citation_benchmark/config/sgc.yaml
浏览文件 @
b809f488
...
...
@@ -2,3 +2,4 @@ model_name: SGC
num_layers
:
2
learning_rate
:
0.2
weight_decay
:
0.000005
feature_pre_normalize
:
False
examples/citation_benchmark/model.py
浏览文件 @
b809f488
...
...
@@ -3,8 +3,9 @@ import paddle.fluid.layers as L
import
pgl.layers.conv
as
conv
def
get_norm
(
indegree
):
norm
=
L
.
pow
(
L
.
cast
(
indegree
,
dtype
=
"float32"
)
+
1e-6
,
factor
=-
0.5
)
norm
=
norm
*
L
.
cast
(
indegree
>
0
,
dtype
=
"float32"
)
float_degree
=
L
.
cast
(
indegree
,
dtype
=
"float32"
)
float_degree
=
L
.
clamp
(
float_degree
,
min
=
1.0
)
norm
=
L
.
pow
(
float_degree
,
factor
=-
0.5
)
return
norm
...
...
@@ -29,10 +30,6 @@ class GCN(object):
ngw
=
graph_wrapper
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
feature
=
pgl
.
layers
.
gcn
(
ngw
,
feature
,
...
...
@@ -41,7 +38,7 @@ class GCN(object):
norm
=
norm
,
name
=
"layer_%s"
%
i
)
feature
=
L
.
dropout
(
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
...
...
@@ -150,10 +147,10 @@ class SGC(object):
def
forward
(
self
,
graph_wrapper
,
feature
,
phase
):
feature
=
conv
.
appnp
(
graph_wrapper
,
feature
=
feature
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
]
,
edge_dropout
=
0
,
alpha
=
0
,
k_hop
=
self
.
num_layers
)
feature
.
stop_gradient
=
True
feature
=
L
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
name
=
"output"
)
feature
=
L
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
bias_attr
=
False
,
name
=
"output"
)
return
feature
examples/citation_benchmark/optimization.py
已删除
100644 → 0
浏览文件 @
3eb6d2a6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
def
AdamW
(
loss
,
learning_rate
,
train_program
,
weight_decay
):
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
)
def
exclude_from_weight_decay
(
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
param_list
=
dict
()
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
.
name
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
learning_rate
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
examples/citation_benchmark/train.py
浏览文件 @
b809f488
...
...
@@ -22,33 +22,35 @@ import argparse
from
build_model
import
build_model
import
yaml
from
easydict
import
EasyDict
as
edict
import
tqdm
def
normalize
(
feat
):
return
feat
/
np
.
maximum
(
np
.
sum
(
feat
,
-
1
,
keepdims
=
True
),
1
)
def
load
(
name
):
def
load
(
name
,
normalized_feature
=
True
):
if
name
==
'cora'
:
dataset
=
data_loader
.
CoraDataset
()
elif
name
==
"pubmed"
:
dataset
=
data_loader
.
CitationDataset
(
"pubmed"
,
symmetry_edges
=
Fals
e
)
dataset
=
data_loader
.
CitationDataset
(
"pubmed"
,
symmetry_edges
=
Tru
e
)
elif
name
==
"citeseer"
:
dataset
=
data_loader
.
CitationDataset
(
"citeseer"
,
symmetry_edges
=
Fals
e
)
dataset
=
data_loader
.
CitationDataset
(
"citeseer"
,
symmetry_edges
=
Tru
e
)
else
:
raise
ValueError
(
name
+
" dataset doesn't exists"
)
return
dataset
def
main
(
args
,
config
):
dataset
=
load
(
args
.
dataset
)
indegree
=
dataset
.
graph
.
indegree
()
norm
=
np
.
zeros_like
(
indegree
,
dtype
=
"float32"
)
norm
[
indegree
>
0
]
=
np
.
power
(
indegree
[
indegree
>
0
]
,
-
0.5
)
norm
=
np
.
maximum
(
indegree
.
astype
(
"float32"
),
1
)
norm
=
np
.
power
(
norm
,
-
0.5
)
dataset
.
graph
.
node_feat
[
"norm"
]
=
np
.
expand_dims
(
norm
,
-
1
)
dataset
.
graph
.
node_feat
[
"words"
]
=
normalize
(
dataset
.
graph
.
node_feat
[
"words"
])
return
dataset
def
main
(
args
,
config
):
dataset
=
load
(
args
.
dataset
,
args
.
feature_pre_normalize
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_cuda
else
fluid
.
CPUPlace
()
train_program
=
fluid
.
default_main_program
()
startup_program
=
fluid
.
default_startup_program
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
gw
,
loss
,
acc
=
build_model
(
dataset
,
...
...
@@ -67,13 +69,11 @@ def main(args, config):
test_program
=
test_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
train_index
=
dataset
.
train_index
train_label
=
np
.
expand_dims
(
dataset
.
y
[
train_index
],
-
1
)
train_index
=
np
.
expand_dims
(
train_index
,
-
1
)
log
.
info
(
"Number of Train %s"
%
len
(
train_index
))
val_index
=
dataset
.
val_index
val_label
=
np
.
expand_dims
(
dataset
.
y
[
val_index
],
-
1
)
...
...
@@ -84,54 +84,55 @@ def main(args, config):
test_index
=
np
.
expand_dims
(
test_index
,
-
1
)
dur
=
[]
cal_val_acc
=
[]
cal_test_acc
=
[]
# Feed data
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
best_test
=
[]
for
epoch
in
range
(
args
.
epoch
):
if
epoch
>=
3
:
t0
=
time
.
time
()
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
feed_dict
[
"node_index"
]
=
np
.
array
(
train_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
train_label
,
dtype
=
"int64"
)
train_loss
,
train_acc
=
exe
.
run
(
train_program
,
for
run
in
range
(
args
.
runs
):
exe
.
run
(
startup_program
)
cal_val_acc
=
[]
cal_test_acc
=
[]
cal_val_loss
=
[]
cal_test_loss
=
[]
for
epoch
in
tqdm
.
tqdm
(
range
(
args
.
epoch
)):
feed_dict
[
"node_index"
]
=
np
.
array
(
train_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
train_label
,
dtype
=
"int64"
)
train_loss
,
train_acc
=
exe
.
run
(
train_program
,
feed
=
feed_dict
,
fetch_list
=
[
loss
,
acc
],
return_numpy
=
True
)
if
epoch
>=
3
:
time_per_epoch
=
1.0
*
(
time
.
time
()
-
t0
)
dur
.
append
(
time_per_epoch
)
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
feed_dict
[
"node_index"
]
=
np
.
array
(
val_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
val_label
,
dtype
=
"int64"
)
val_loss
,
val_acc
=
exe
.
run
(
test_program
,
feed_dict
[
"node_index"
]
=
np
.
array
(
val_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
val_label
,
dtype
=
"int64"
)
val_loss
,
val_acc
=
exe
.
run
(
test_program
,
feed
=
feed_dict
,
fetch_list
=
[
v_loss
,
v_acc
],
return_numpy
=
True
)
val_loss
=
val_loss
[
0
]
val_acc
=
val_acc
[
0
]
cal_val_acc
.
append
(
val_acc
)
cal_val_acc
.
append
(
val_acc
[
0
])
cal_val_loss
.
append
(
val_loss
[
0
])
feed_dict
[
"node_index"
]
=
np
.
array
(
test_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
test_label
,
dtype
=
"int64"
)
test_loss
,
test_acc
=
exe
.
run
(
test_program
,
feed_dict
[
"node_index"
]
=
np
.
array
(
test_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
test_label
,
dtype
=
"int64"
)
test_loss
,
test_acc
=
exe
.
run
(
test_program
,
feed
=
feed_dict
,
fetch_list
=
[
v_loss
,
v_acc
],
return_numpy
=
True
)
test_loss
=
test_loss
[
0
]
test_acc
=
test_acc
[
0
]
cal_test_acc
.
append
(
test_acc
)
cal_test_acc
.
append
(
test_acc
[
0
])
cal_test_loss
.
append
(
test_loss
[
0
])
log
.
info
(
"Epoch %d "
%
epoch
+
"Train Loss: %f "
%
train_loss
+
"Train Acc: %f "
%
train_acc
+
"Val Loss: %f "
%
val_loss
+
"Val Acc: %f "
%
val_acc
)
cal_val_acc
=
np
.
array
(
cal_val_acc
)
log
.
info
(
"Model: %s Best Test Accuracy: %f"
%
(
config
.
model_name
,
cal_test_acc
[
np
.
argmax
(
cal_val_acc
)]))
log
.
info
(
"Runs %s: Model: %s Best Test Accuracy: %f"
%
(
run
,
config
.
model_name
,
cal_test_acc
[
np
.
argmin
(
cal_val_loss
)]))
best_test
.
append
(
cal_test_acc
[
np
.
argmin
(
cal_val_loss
)])
log
.
info
(
"Best Test Accuracy: %f ( stddev: %f )"
%
(
np
.
mean
(
best_test
),
np
.
std
(
best_test
)))
if
__name__
==
'__main__'
:
...
...
@@ -141,6 +142,8 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--use_cuda"
,
action
=
'store_true'
,
help
=
"use_cuda"
)
parser
.
add_argument
(
"--conf"
,
type
=
str
,
help
=
"config file for models"
)
parser
.
add_argument
(
"--epoch"
,
type
=
int
,
default
=
200
,
help
=
"Epoch"
)
parser
.
add_argument
(
"--runs"
,
type
=
int
,
default
=
5
,
help
=
"runs"
)
parser
.
add_argument
(
"--feature_pre_normalize"
,
type
=
bool
,
default
=
True
,
help
=
"pre_normalize feature"
)
args
=
parser
.
parse_args
()
config
=
edict
(
yaml
.
load
(
open
(
args
.
conf
),
Loader
=
yaml
.
FullLoader
))
log
.
info
(
args
)
...
...
pgl/graph_wrapper.py
浏览文件 @
b809f488
...
...
@@ -735,7 +735,7 @@ def get_degree(edge, num_nodes):
class
DropEdgeWrapper
(
BaseGraphWrapper
):
"""Implement of Edge Drop """
def
__init__
(
self
,
graph_wrapper
,
dropout
):
def
__init__
(
self
,
graph_wrapper
,
dropout
,
keep_self_loop
=
True
):
super
(
DropEdgeWrapper
,
self
).
__init__
()
# Copy Node's information
...
...
@@ -750,12 +750,17 @@ class DropEdgeWrapper(BaseGraphWrapper):
# Dropout Edges
src
,
dst
=
graph_wrapper
.
edges
u
=
L
.
uniform_random
(
shape
=
L
.
cast
(
L
.
shape
(
src
),
'int64'
),
min
=
0.
,
max
=
1.
)
# Avoid Empty Edges
keeped
=
L
.
cast
(
u
>
dropout
,
dtype
=
"float32"
)
self
.
_num_edges
=
L
.
reduce_sum
(
L
.
cast
(
keeped
,
"int32"
))
keeped
=
keeped
+
L
.
cast
(
self
.
_num_edges
==
0
,
dtype
=
"float32"
)
if
keep_self_loop
:
self_loop
=
L
.
cast
(
src
==
dst
,
dtype
=
"float32"
)
keeped
=
keeped
+
self_loop
keeped
=
(
keeped
>
0.5
)
src
=
paddle_helper
.
masked_select
(
src
,
keeped
)
dst
=
paddle_helper
.
masked_select
(
dst
,
keeped
)
...
...
pgl/layers/conv.py
浏览文件 @
b809f488
...
...
@@ -16,6 +16,7 @@ graph neural networks.
"""
import
pgl
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
L
from
pgl.utils
import
paddle_helper
from
pgl
import
message_passing
...
...
@@ -51,7 +52,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
size
=
feature
.
shape
[
-
1
]
if
size
>
hidden_size
:
feature
=
fluid
.
layers
.
fc
(
feature
,
feature
=
L
.
fc
(
feature
,
size
=
hidden_size
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
))
...
...
@@ -65,7 +66,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
output
=
gw
.
recv
(
msg
,
"sum"
)
else
:
output
=
gw
.
recv
(
msg
,
"sum"
)
output
=
fluid
.
layers
.
fc
(
output
,
output
=
L
.
fc
(
output
,
size
=
hidden_size
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
))
...
...
@@ -73,12 +74,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
if
norm
is
not
None
:
output
=
output
*
norm
bias
=
fluid
.
layers
.
create_parameter
(
bias
=
L
.
create_parameter
(
shape
=
[
hidden_size
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
name
+
'_bias'
)
output
=
fluid
.
layers
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
output
=
L
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
return
output
...
...
@@ -121,7 +122,7 @@ def gat(gw,
def
send_attention
(
src_feat
,
dst_feat
,
edge_feat
):
output
=
src_feat
[
"left_a"
]
+
dst_feat
[
"right_a"
]
output
=
fluid
.
layers
.
leaky_relu
(
output
=
L
.
leaky_relu
(
output
,
alpha
=
0.2
)
# (num_edges, num_heads)
return
{
"alpha"
:
output
,
"h"
:
src_feat
[
"h"
]}
...
...
@@ -130,54 +131,54 @@ def gat(gw,
h
=
msg
[
"h"
]
alpha
=
paddle_helper
.
sequence_softmax
(
alpha
)
old_h
=
h
h
=
fluid
.
layers
.
reshape
(
h
,
[
-
1
,
num_heads
,
hidden_size
])
alpha
=
fluid
.
layers
.
reshape
(
alpha
,
[
-
1
,
num_heads
,
1
])
h
=
L
.
reshape
(
h
,
[
-
1
,
num_heads
,
hidden_size
])
alpha
=
L
.
reshape
(
alpha
,
[
-
1
,
num_heads
,
1
])
if
attn_drop
>
1e-15
:
alpha
=
fluid
.
layers
.
dropout
(
alpha
=
L
.
dropout
(
alpha
,
dropout_prob
=
attn_drop
,
is_test
=
is_test
,
dropout_implementation
=
"upscale_in_train"
)
h
=
h
*
alpha
h
=
fluid
.
layers
.
reshape
(
h
,
[
-
1
,
num_heads
*
hidden_size
])
h
=
fluid
.
layers
.
lod_reset
(
h
,
old_h
)
return
fluid
.
layers
.
sequence_pool
(
h
,
"sum"
)
h
=
L
.
reshape
(
h
,
[
-
1
,
num_heads
*
hidden_size
])
h
=
L
.
lod_reset
(
h
,
old_h
)
return
L
.
sequence_pool
(
h
,
"sum"
)
if
feat_drop
>
1e-15
:
feature
=
fluid
.
layers
.
dropout
(
feature
=
L
.
dropout
(
feature
,
dropout_prob
=
feat_drop
,
is_test
=
is_test
,
dropout_implementation
=
'upscale_in_train'
)
ft
=
fluid
.
layers
.
fc
(
feature
,
ft
=
L
.
fc
(
feature
,
hidden_size
*
num_heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weight'
))
left_a
=
fluid
.
layers
.
create_parameter
(
left_a
=
L
.
create_parameter
(
shape
=
[
num_heads
,
hidden_size
],
dtype
=
'float32'
,
name
=
name
+
'_gat_l_A'
)
right_a
=
fluid
.
layers
.
create_parameter
(
right_a
=
L
.
create_parameter
(
shape
=
[
num_heads
,
hidden_size
],
dtype
=
'float32'
,
name
=
name
+
'_gat_r_A'
)
reshape_ft
=
fluid
.
layers
.
reshape
(
ft
,
[
-
1
,
num_heads
,
hidden_size
])
left_a_value
=
fluid
.
layers
.
reduce_sum
(
reshape_ft
*
left_a
,
-
1
)
right_a_value
=
fluid
.
layers
.
reduce_sum
(
reshape_ft
*
right_a
,
-
1
)
reshape_ft
=
L
.
reshape
(
ft
,
[
-
1
,
num_heads
,
hidden_size
])
left_a_value
=
L
.
reduce_sum
(
reshape_ft
*
left_a
,
-
1
)
right_a_value
=
L
.
reduce_sum
(
reshape_ft
*
right_a
,
-
1
)
msg
=
gw
.
send
(
send_attention
,
nfeat_list
=
[(
"h"
,
ft
),
(
"left_a"
,
left_a_value
),
(
"right_a"
,
right_a_value
)])
output
=
gw
.
recv
(
msg
,
reduce_attention
)
bias
=
fluid
.
layers
.
create_parameter
(
bias
=
L
.
create_parameter
(
shape
=
[
hidden_size
*
num_heads
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
name
+
'_bias'
)
bias
.
stop_gradient
=
True
output
=
fluid
.
layers
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
output
=
L
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
return
output
...
...
@@ -220,7 +221,7 @@ def gin(gw,
def
send_src_copy
(
src_feat
,
dst_feat
,
edge_feat
):
return
src_feat
[
"h"
]
epsilon
=
fluid
.
layers
.
create_parameter
(
epsilon
=
L
.
create_parameter
(
shape
=
[
1
,
1
],
dtype
=
"float32"
,
attr
=
fluid
.
ParamAttr
(
name
=
"%s_eps"
%
name
),
...
...
@@ -233,13 +234,13 @@ def gin(gw,
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
output
=
gw
.
recv
(
msg
,
"sum"
)
+
feature
*
(
epsilon
+
1.0
)
output
=
fluid
.
layers
.
fc
(
output
,
output
=
L
.
fc
(
output
,
size
=
hidden_size
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"%s_w_0"
%
name
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"%s_b_0"
%
name
))
output
=
fluid
.
layers
.
layer_norm
(
output
=
L
.
layer_norm
(
output
,
begin_norm_axis
=
1
,
param_attr
=
fluid
.
ParamAttr
(
...
...
@@ -250,9 +251,9 @@ def gin(gw,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
)
if
activation
is
not
None
:
output
=
getattr
(
fluid
.
layers
,
activation
)(
output
)
output
=
getattr
(
L
,
activation
)(
output
)
output
=
fluid
.
layers
.
fc
(
output
,
output
=
L
.
fc
(
output
,
size
=
hidden_size
,
act
=
activation
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"%s_w_1"
%
name
),
...
...
@@ -270,10 +271,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
old
=
feat_query
feat_query
=
fluid
.
layers
.
reshape
(
feat_query
,
[
-
1
,
heads
,
hidden_size_a
])
feat_key
=
fluid
.
layers
.
reshape
(
feat_key
,
[
-
1
,
heads
,
hidden_size_a
])
feat_query
=
L
.
reshape
(
feat_query
,
[
-
1
,
heads
,
hidden_size_a
])
feat_key
=
L
.
reshape
(
feat_key
,
[
-
1
,
heads
,
hidden_size_a
])
# E * M
alpha
=
fluid
.
layers
.
reduce_sum
(
feat_key
*
feat_query
,
dim
=-
1
)
alpha
=
L
.
reduce_sum
(
feat_key
*
feat_query
,
dim
=-
1
)
return
{
'dst_node_feat'
:
dst_feat
[
'node_feat'
],
'src_node_feat'
:
src_feat
[
'node_feat'
],
...
...
@@ -287,15 +288,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 每条边的出发点的特征
src_feat
=
message
[
'src_node_feat'
]
# 每个中心点自己的特征
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
x
=
L
.
sequence_pool
(
dst_feat
,
'average'
)
# 每个中心点的邻居的特征的平均值
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
z
=
L
.
sequence_pool
(
src_feat
,
'average'
)
# 计算 gate
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
fluid
.
layers
.
fc
(
g
,
heads
,
bias_attr
=
False
,
act
=
"sigmoid"
)
g_max
=
L
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
L
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
L
.
fc
(
g
,
heads
,
bias_attr
=
False
,
act
=
"sigmoid"
)
# softmax
alpha
=
message
[
'alpha'
]
...
...
@@ -303,19 +304,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_value
=
message
[
'feat_value'
]
# E * (M * D2)
old
=
feat_value
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# E * M * D2
feat_value
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
alpha
,
axis
=
0
)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
*
hidden_size_v
])
# E * (M * D2)
feat_value
=
fluid
.
layers
.
lod_reset
(
feat_value
,
old
)
feat_value
=
L
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# E * M * D2
feat_value
=
L
.
elementwise_mul
(
feat_value
,
alpha
,
axis
=
0
)
feat_value
=
L
.
reshape
(
feat_value
,
[
-
1
,
heads
*
hidden_size_v
])
# E * (M * D2)
feat_value
=
L
.
lod_reset
(
feat_value
,
old
)
feat_value
=
fluid
.
layers
.
sequence_pool
(
feat_value
,
'sum'
)
# N * (M * D2)
feat_value
=
L
.
sequence_pool
(
feat_value
,
'sum'
)
# N * (M * D2)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# N * M * D2
feat_value
=
L
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# N * M * D2
output
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
g
,
axis
=
0
)
output
=
fluid
.
layers
.
reshape
(
output
,
[
-
1
,
heads
*
hidden_size_v
])
# N * (M * D2)
output
=
L
.
elementwise_mul
(
feat_value
,
g
,
axis
=
0
)
output
=
L
.
reshape
(
output
,
[
-
1
,
heads
*
hidden_size_v
])
# N * (M * D2)
output
=
fluid
.
layers
.
concat
([
x
,
output
],
axis
=
1
)
output
=
L
.
concat
([
x
,
output
],
axis
=
1
)
return
output
...
...
@@ -324,16 +325,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
feat_key
=
L
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
# N * (D2 * M)
feat_value
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_v
*
heads
,
bias_attr
=
False
,
feat_value
=
L
.
fc
(
feature
,
hidden_size_v
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_value'
))
# N * (D1 * M)
feat_query
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
feat_query
=
L
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_query'
))
# N * Dm
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
feat_gate
=
L
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send 阶段
...
...
@@ -347,10 +348,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 聚合邻居特征
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
output
=
L
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
output
=
fluid
.
layers
.
leaky_relu
(
output
,
alpha
=
0.1
)
output
=
fluid
.
layers
.
dropout
(
output
,
dropout_prob
=
0.1
)
output
=
L
.
leaky_relu
(
output
,
alpha
=
0.1
)
output
=
L
.
dropout
(
output
,
dropout_prob
=
0.1
)
return
output
...
...
@@ -377,7 +378,7 @@ def gen_conv(gw,
"""
if
beta
==
"dynamic"
:
beta
=
fluid
.
layers
.
create_parameter
(
beta
=
L
.
create_parameter
(
shape
=
[
1
],
dtype
=
'float32'
,
default_initializer
=
...
...
@@ -392,13 +393,13 @@ def gen_conv(gw,
output
=
message_passing
.
msg_norm
(
feature
,
output
,
name
)
output
=
feature
+
output
output
=
fluid
.
layers
.
fc
(
output
,
output
=
L
.
fc
(
output
,
feature
.
shape
[
-
1
],
bias_attr
=
False
,
act
=
"relu"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weight1'
))
output
=
fluid
.
layers
.
fc
(
output
,
output
=
L
.
fc
(
output
,
feature
.
shape
[
-
1
],
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weight2'
))
...
...
@@ -407,9 +408,9 @@ def gen_conv(gw,
def
get_norm
(
indegree
):
"""Get Laplacian Normalization"""
norm
=
fluid
.
layers
.
pow
(
fluid
.
layers
.
cast
(
indegree
,
dtype
=
"float32"
)
+
1e-6
,
factor
=-
0.5
)
norm
=
norm
*
fluid
.
layers
.
cast
(
indegree
>
0
,
dtype
=
"float32"
)
float_degree
=
L
.
cast
(
indegree
,
dtype
=
"float32"
)
float_degree
=
L
.
clamp
(
float_degree
,
min
=
1.0
)
norm
=
L
.
pow
(
float_degree
,
factor
=-
0.5
)
return
norm
def
appnp
(
gw
,
feature
,
edge_dropout
=
0
,
alpha
=
0.2
,
k_hop
=
10
):
...
...
pgl/sample.py
浏览文件 @
b809f488
...
...
@@ -518,8 +518,10 @@ def graph_saint_random_walk_sample(graph,
return
subgraph
def
edge_drop
(
graph_wrapper
,
dropout_rate
):
def
edge_drop
(
graph_wrapper
,
dropout_rate
,
keep_self_loop
=
True
):
if
dropout_rate
<
1e-5
:
return
graph_wrapper
else
:
return
pgl
.
graph_wrapper
.
DropEdgeWrapper
(
graph_wrapper
,
dropout_rate
)
return
pgl
.
graph_wrapper
.
DropEdgeWrapper
(
graph_wrapper
,
dropout_rate
,
keep_self_loop
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录