Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
9913672a
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9913672a
编写于
9月 25, 2020
作者:
S
sys1874
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unimp_large
上级
8be7e76a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
588 addition
and
0 deletion
+588
-0
ogb_examples/nodeproppred/unimp/main_arxiv_large.py
ogb_examples/nodeproppred/unimp/main_arxiv_large.py
+196
-0
ogb_examples/nodeproppred/unimp/model_large.py
ogb_examples/nodeproppred/unimp/model_large.py
+147
-0
ogb_examples/nodeproppred/unimp/module/model_unimp_large.py
ogb_examples/nodeproppred/unimp/module/model_unimp_large.py
+245
-0
未找到文件。
ogb_examples/nodeproppred/unimp/main_arxiv_large.py
0 → 100644
浏览文件 @
9913672a
import
math
import
torch
import
paddle
import
pgl
import
numpy
as
np
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
pgl.contrib.ogb.nodeproppred.dataset_pgl
import
PglNodePropPredDataset
from
ogb.nodeproppred
import
Evaluator
from
utils
import
to_undirected
,
add_self_loop
,
linear_warmup_decay
from
model_large
import
Arxiv_baseline_model
,
Arxiv_label_embedding_model
from
optimization
import
optimization
import
argparse
from
tqdm
import
tqdm
evaluator
=
Evaluator
(
name
=
'ogbn-arxiv'
)
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
## model_base_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
80
,
type
=
int
)
model_group
.
add_argument
(
'--num_heads'
,
default
=
5
,
type
=
int
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0.1
,
type
=
float
)
## embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.65
,
type
=
float
)
## train_arg
train_group
=
parser
.
add_argument_group
(
'train_arg'
)
train_group
.
add_argument
(
'--runs'
,
default
=
10
,
type
=
int
)
train_group
.
add_argument
(
'--epochs'
,
default
=
2000
,
type
=
int
)
train_group
.
add_argument
(
'--lr'
,
default
=
0.001
,
type
=
float
)
train_group
.
add_argument
(
'--place'
,
default
=-
1
,
type
=
int
)
train_group
.
add_argument
(
'--log_file'
,
default
=
'result_arxiv.txt'
,
type
=
str
)
return
parser
.
parse_args
()
def
optimizer_func
(
lr
=
0.01
):
return
F
.
optimizer
.
AdamOptimizer
(
learning_rate
=
lr
,
regularization
=
F
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.0005
))
def
eval_test
(
parser
,
program
,
model
,
test_exe
,
graph
,
y_true
,
split_idx
):
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
if
parser
.
use_label_e
:
feed_dict
[
'label'
]
=
y_true
feed_dict
[
'label_idx'
]
=
split_idx
[
'train'
]
feed_dict
[
'attn_drop'
]
=-
1
avg_cost_np
=
test_exe
.
run
(
program
=
program
,
feed
=
feed_dict
,
fetch_list
=
[
model
.
out_feat
])
y_pred
=
avg_cost_np
[
0
].
argmax
(
axis
=-
1
)
y_pred
=
np
.
expand_dims
(
y_pred
,
1
)
train_acc
=
evaluator
.
eval
({
'y_true'
:
y_true
[
split_idx
[
'train'
]],
'y_pred'
:
y_pred
[
split_idx
[
'train'
]],
})[
'acc'
]
val_acc
=
evaluator
.
eval
({
'y_true'
:
y_true
[
split_idx
[
'valid'
]],
'y_pred'
:
y_pred
[
split_idx
[
'valid'
]],
})[
'acc'
]
test_acc
=
evaluator
.
eval
({
'y_true'
:
y_true
[
split_idx
[
'test'
]],
'y_pred'
:
y_pred
[
split_idx
[
'test'
]],
})[
'acc'
]
return
train_acc
,
val_acc
,
test_acc
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
exe
.
run
(
start_program
)
max_acc
=
0
max_step
=
0
max_val_acc
=
0
max_cor_acc
=
0
max_cor_step
=
0
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
train_idx_temp
=
split_idx
[
'train'
]
np
.
random
.
shuffle
(
train_idx_temp
)
label_idx
=
train_idx_temp
[
:
int
(
parser
.
label_rate
*
len
(
train_idx_temp
))]
unlabel_idx
=
train_idx_temp
[
int
(
parser
.
label_rate
*
len
(
train_idx_temp
)):
]
feed_dict
[
'label'
]
=
label
feed_dict
[
'label_idx'
]
=
label_idx
feed_dict
[
'train_idx'
]
=
unlabel_idx
feed_dict
[
'attn_drop'
]
=
parser
.
attn_dropout
else
:
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
feed_dict
[
'label'
]
=
label
feed_dict
[
'train_idx'
]
=
split_idx
[
'train'
]
loss
=
exe
.
run
(
main_program
,
feed
=
feed_dict
,
fetch_list
=
[
model
.
avg_cost
])
loss
=
loss
[
0
]
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
max_val_acc
=
max
(
valid_acc
,
max_val_acc
)
if
max_val_acc
==
valid_acc
:
max_cor_acc
=
test_acc
max_cor_step
=
epoch_id
if
max_acc
==
result
[
2
]:
max_step
=
epoch_id
result_t
=
(
f
'Run:
{
run_id
:
02
d
}
, '
f
'Epoch:
{
epoch_id
:
02
d
}
, '
f
'Loss:
{
loss
[
0
]:.
4
f
}
, '
f
'Train:
{
100
*
train_acc
:.
2
f
}
%, '
f
'Valid:
{
100
*
valid_acc
:.
2
f
}
%, '
f
'Test:
{
100
*
test_acc
:.
2
f
}
%
\n
'
f
'max_val:
{
100
*
max_val_acc
:.
2
f
}
%, '
f
'max_val_Test:
{
100
*
max_cor_acc
:.
2
f
}
%, '
f
'max_val_step:
{
max_cor_step
}
\n
'
)
if
(
epoch_id
+
1
)
%
100
==
0
:
print
(
result_t
)
wf
.
write
(
result_t
)
wf
.
write
(
'
\n
'
)
wf
.
flush
()
return
max_cor_acc
if
__name__
==
'__main__'
:
parser
=
get_config
()
print
(
'===========args=============='
)
print
(
parser
)
print
(
'============================='
)
startup_prog
=
F
.
default_startup_program
()
train_prog
=
F
.
default_main_program
()
place
=
F
.
CPUPlace
()
if
parser
.
place
<
0
else
F
.
CUDAPlace
(
parser
.
place
)
dataset
=
PglNodePropPredDataset
(
name
=
"ogbn-arxiv"
)
split_idx
=
dataset
.
get_idx_split
()
graph
,
label
=
dataset
[
0
]
print
(
label
.
shape
)
graph
=
to_undirected
(
graph
)
graph
=
add_self_loop
(
graph
)
with
F
.
unique_name
.
guard
():
with
F
.
program_guard
(
train_prog
,
startup_prog
):
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
name
=
"arxiv"
,
node_feat
=
graph
.
node_feat_info
(),
place
=
place
)
if
parser
.
use_label_e
:
model
=
Arxiv_label_embedding_model
(
gw
,
parser
.
hidden_size
,
parser
.
num_heads
,
parser
.
dropout
,
parser
.
num_layers
)
else
:
model
=
Arxiv_baseline_model
(
gw
,
parser
.
hidden_size
,
parser
.
num_heads
,
parser
.
dropout
,
parser
.
num_layers
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
model
.
train_program
()
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
adam_optimizer
=
F
.
optimizer
.
RecomputeOptimizer
(
adam_optimizer
)
adam_optimizer
.
_set_checkpoints
(
model
.
checkpoints
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
wf
=
open
(
parser
.
log_file
,
'w'
,
encoding
=
'utf-8'
)
total_test_acc
=
0.0
for
run_i
in
range
(
parser
.
runs
):
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
close
()
# Runned 10 times
# Val Accs: [74.64, 74.74, 74.71, 74.83, 74.82, 74.77, 74.75, 74.86, 74.6, 74.76]
# Test Accs: [73.79, 73.82, 74.0, 73.85, 74.02, 73.67, 73.65, 73.87, 73.66, 73.6]
# Average val accuracy: 74.74799999999999 ± 0.0775628777186617
# Average test accuracy: 73.793 ± 0.13957435294494433
# params: 1162515
\ No newline at end of file
ogb_examples/nodeproppred/unimp/model_large.py
0 → 100644
浏览文件 @
9913672a
'''build label embedding model
'''
import
math
import
pgl
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
pgl.utils
import
paddle_helper
from
module.transformer_gat_pgl
import
transformer_gat_pgl
from
module.model_unimp_large
import
graph_transformer
,
linear
,
attn_appnp
class
Arxiv_baseline_model
():
def
__init__
(
self
,
gw
,
hidden_size
,
num_heads
,
dropout
,
num_layers
):
'''Arxiv_baseline_model
'''
self
.
gw
=
gw
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
num_layers
=
num_layers
self
.
out_size
=
40
self
.
embed_size
=
128
self
.
checkpoints
=
[]
self
.
build_model
()
def
embed_input
(
self
,
feature
):
lay_norm_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
1
))
lay_norm_bias
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
0
))
feature
=
L
.
layer_norm
(
feature
,
name
=
'layer_norm_feature_input'
,
param_attr
=
lay_norm_attr
,
bias_attr
=
lay_norm_bias
)
return
feature
def
build_model
(
self
):
feature_batch
=
self
.
embed_input
(
self
.
gw
.
node_feat
[
'feat'
])
feature_batch
=
L
.
dropout
(
feature_batch
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
for
i
in
range
(
self
.
num_layers
-
1
):
feature_batch
=
graph_transformer
(
str
(
i
),
self
.
gw
,
feature_batch
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
concat
=
True
,
skip_feat
=
True
,
layer_norm
=
True
,
relu
=
True
,
gate
=
True
)
if
self
.
dropout
>
0
:
feature_batch
=
L
.
dropout
(
feature_batch
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
self
.
checkpoints
.
append
(
feature_batch
)
feature_batch
=
graph_transformer
(
str
(
self
.
num_layers
-
1
),
self
.
gw
,
feature_batch
,
hidden_size
=
self
.
out_size
,
num_heads
=
self
.
num_heads
,
concat
=
False
,
skip_feat
=
True
,
layer_norm
=
False
,
relu
=
False
,
gate
=
True
)
self
.
checkpoints
.
append
(
feature_batch
)
self
.
out_feat
=
feature_batch
def
train_program
(
self
,):
label
=
F
.
data
(
name
=
"label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
train_idx
=
F
.
data
(
name
=
'train_idx'
,
shape
=
[
None
],
dtype
=
"int64"
)
prediction
=
L
.
gather
(
self
.
out_feat
,
train_idx
,
overwrite
=
False
)
label
=
L
.
gather
(
label
,
train_idx
,
overwrite
=
False
)
cost
=
L
.
softmax_with_cross_entropy
(
logits
=
prediction
,
label
=
label
)
avg_cost
=
L
.
mean
(
cost
)
self
.
avg_cost
=
avg_cost
class
Arxiv_label_embedding_model
():
def
__init__
(
self
,
gw
,
hidden_size
,
num_heads
,
dropout
,
num_layers
):
'''Arxiv_label_embedding_model
'''
self
.
gw
=
gw
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
num_layers
=
num_layers
self
.
out_size
=
40
self
.
embed_size
=
128
self
.
checkpoints
=
[]
self
.
build_model
()
def
label_embed_input
(
self
,
feature
):
label
=
F
.
data
(
name
=
"label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
label_idx
=
F
.
data
(
name
=
'label_idx'
,
shape
=
[
None
],
dtype
=
"int64"
)
label
=
L
.
reshape
(
label
,
shape
=
[
-
1
])
label
=
L
.
gather
(
label
,
label_idx
,
overwrite
=
False
)
lay_norm_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
1
))
lay_norm_bias
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
0
))
feature
=
L
.
layer_norm
(
feature
,
name
=
'layer_norm_feature_input1'
,
param_attr
=
lay_norm_attr
,
bias_attr
=
lay_norm_bias
)
embed_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
))
embed
=
F
.
embedding
(
input
=
label
,
size
=
(
self
.
out_size
,
self
.
embed_size
),
param_attr
=
embed_attr
)
lay_norm_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
1
))
lay_norm_bias
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
0
))
embed
=
L
.
layer_norm
(
embed
,
name
=
'layer_norm_feature_input2'
,
param_attr
=
lay_norm_attr
,
bias_attr
=
lay_norm_bias
)
embed
=
L
.
relu
(
embed
)
feature_label
=
L
.
gather
(
feature
,
label_idx
,
overwrite
=
False
)
feature_label
=
feature_label
+
embed
feature
=
L
.
scatter
(
feature
,
label_idx
,
feature_label
,
overwrite
=
True
)
return
feature
def
build_model
(
self
):
label_feature
=
self
.
label_embed_input
(
self
.
gw
.
node_feat
[
'feat'
])
feature_batch
=
L
.
dropout
(
label_feature
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
for
i
in
range
(
self
.
num_layers
-
1
):
feature_batch
,
_
,
cks
=
graph_transformer
(
str
(
i
),
self
.
gw
,
feature_batch
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
,
attn_drop
=
True
,
concat
=
True
,
skip_feat
=
True
,
layer_norm
=
True
,
relu
=
True
,
gate
=
True
)
if
self
.
dropout
>
0
:
feature_batch
=
L
.
dropout
(
feature_batch
,
dropout_prob
=
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
self
.
checkpoints
=
self
.
checkpoints
+
cks
feature_batch
,
attn
,
cks
=
graph_transformer
(
str
(
self
.
num_layers
-
1
),
self
.
gw
,
feature_batch
,
hidden_size
=
self
.
out_size
,
num_heads
=
self
.
num_heads
+
1
,
concat
=
False
,
skip_feat
=
True
,
layer_norm
=
False
,
relu
=
False
,
gate
=
True
)
self
.
checkpoints
.
append
(
feature_batch
)
feature_batch
=
attn_appnp
(
self
.
gw
,
feature_batch
,
attn
,
alpha
=
0.2
,
k_hop
=
10
)
self
.
checkpoints
.
append
(
feature_batch
)
self
.
out_feat
=
feature_batch
def
train_program
(
self
,):
label
=
F
.
data
(
name
=
"label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
)
train_idx
=
F
.
data
(
name
=
'train_idx'
,
shape
=
[
None
],
dtype
=
"int64"
)
prediction
=
L
.
gather
(
self
.
out_feat
,
train_idx
,
overwrite
=
False
)
label
=
L
.
gather
(
label
,
train_idx
,
overwrite
=
False
)
cost
=
L
.
softmax_with_cross_entropy
(
logits
=
prediction
,
label
=
label
)
avg_cost
=
L
.
mean
(
cost
)
self
.
avg_cost
=
avg_cost
ogb_examples/nodeproppred/unimp/module/model_unimp_large.py
0 → 100644
浏览文件 @
9913672a
import
pgl
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
pgl.utils
import
paddle_helper
from
pgl
import
message_passing
import
math
def
graph_transformer
(
name
,
gw
,
feature
,
hidden_size
,
num_heads
=
4
,
attn_drop
=
False
,
edge_feature
=
None
,
concat
=
True
,
skip_feat
=
True
,
gate
=
False
,
layer_norm
=
True
,
relu
=
True
,
is_test
=
False
):
"""Implementation of graph Transformer from UniMP
This is an implementation of the paper Unified Massage Passing Model for Semi-Supervised Classification
(https://arxiv.org/abs/2009.03509).
Args:
name: Granph Transformer layer names.
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for graph transformer.
num_heads: The head number in graph transformer.
attn_drop: Dropout rate for attention.
edge_feature: A tensor with shape (num_edges, feature_size).
concat: Reshape the output (num_nodes, num_heads, hidden_size) by concat (num_nodes, hidden_size * num_heads) or mean (num_nodes, hidden_size)
skip_feat: Whether use skip connect
gate: Whether add skip_feat and output up with gate weight
layer_norm: Whether use layer_norm for output
relu: Whether use relu activation for output
is_test: Whether in test phrase.
Return:
A tensor with shape (num_nodes, hidden_size * num_heads) or (num_nodes, hidden_size)
"""
def
send_attention
(
src_feat
,
dst_feat
,
edge_feat
):
if
edge_feat
is
None
or
not
edge_feat
:
output
=
src_feat
[
"k_h"
]
*
dst_feat
[
"q_h"
]
output
=
L
.
reduce_sum
(
output
,
-
1
)
output
=
output
/
(
hidden_size
**
0.5
)
# alpha = paddle_helper.sequence_softmax(output)
return
{
"alpha"
:
output
,
"v"
:
src_feat
[
"v_h"
]}
# batch x h batch x h x feat
else
:
edge_feat
=
edge_feat
[
"edge"
]
edge_feat
=
L
.
reshape
(
edge_feat
,
[
-
1
,
num_heads
,
hidden_size
])
output
=
(
src_feat
[
"k_h"
]
+
edge_feat
)
*
dst_feat
[
"q_h"
]
output
=
L
.
reduce_sum
(
output
,
-
1
)
output
=
output
/
(
hidden_size
**
0.5
)
# alpha = paddle_helper.sequence_softmax(output)
return
{
"alpha"
:
output
,
"v"
:
(
src_feat
[
"v_h"
]
+
edge_feat
)}
# batch x h batch x h x feat
class
Reduce_attention
():
def
__init__
(
self
,):
self
.
alpha
=
None
def
__call__
(
self
,
msg
):
alpha
=
msg
[
"alpha"
]
# lod-tensor (batch_size, num_heads)
if
attn_drop
:
old_h
=
alpha
dropout
=
F
.
data
(
name
=
'attn_drop'
,
shape
=
[
1
],
dtype
=
"int64"
)
u
=
L
.
uniform_random
(
shape
=
L
.
cast
(
L
.
shape
(
alpha
)[:
1
],
'int64'
),
min
=
0.
,
max
=
1.
)
keeped
=
L
.
cast
(
u
>
dropout
,
dtype
=
"float32"
)
self_attn_mask
=
L
.
scale
(
x
=
keeped
,
scale
=
10000.0
,
bias
=-
1.0
,
bias_after_scale
=
False
)
n_head_self_attn_mask
=
L
.
stack
(
x
=
[
self_attn_mask
]
*
num_heads
,
axis
=
1
)
n_head_self_attn_mask
.
stop_gradient
=
True
alpha
=
n_head_self_attn_mask
+
alpha
alpha
=
L
.
lod_reset
(
alpha
,
old_h
)
h
=
msg
[
"v"
]
alpha
=
paddle_helper
.
sequence_softmax
(
alpha
)
self
.
alpha
=
alpha
old_h
=
h
h
=
h
*
alpha
h
=
L
.
lod_reset
(
h
,
old_h
)
h
=
L
.
sequence_pool
(
h
,
"sum"
)
if
concat
:
h
=
L
.
reshape
(
h
,
[
-
1
,
num_heads
*
hidden_size
])
else
:
h
=
L
.
reduce_mean
(
h
,
dim
=
1
)
return
h
reduce_attention
=
Reduce_attention
()
q
=
linear
(
feature
,
hidden_size
*
num_heads
,
name
=
name
+
'_q_weight'
,
init_type
=
'gcn'
)
k
=
linear
(
feature
,
hidden_size
*
num_heads
,
name
=
name
+
'_k_weight'
,
init_type
=
'gcn'
)
v
=
linear
(
feature
,
hidden_size
*
num_heads
,
name
=
name
+
'_v_weight'
,
init_type
=
'gcn'
)
reshape_q
=
L
.
reshape
(
q
,
[
-
1
,
num_heads
,
hidden_size
])
reshape_k
=
L
.
reshape
(
k
,
[
-
1
,
num_heads
,
hidden_size
])
reshape_v
=
L
.
reshape
(
v
,
[
-
1
,
num_heads
,
hidden_size
])
msg
=
gw
.
send
(
send_attention
,
nfeat_list
=
[(
"q_h"
,
reshape_q
),
(
"k_h"
,
reshape_k
),
(
"v_h"
,
reshape_v
)],
efeat_list
=
edge_feature
)
out_feat
=
gw
.
recv
(
msg
,
reduce_attention
)
checkpoints
=
[
out_feat
]
if
skip_feat
:
if
concat
:
out_feat
,
cks
=
appnp
(
gw
,
out_feat
,
k_hop
=
1
)
# out_feat, cks = appnp(gw, out_feat, k_hop=3)
checkpoints
.
append
(
out_feat
)
# The UniMP-xxlarge will come soon.
# out_feat, cks = appnp(gw, out_feat, k_hop=6)
# out_feat, cks = appnp(gw, out_feat, k_hop=9)
# checkpoints = checkpoints + cks
skip_feature
=
linear
(
feature
,
hidden_size
*
num_heads
,
name
=
name
+
'_skip_weight'
,
init_type
=
'lin'
)
else
:
skip_feature
=
linear
(
feature
,
hidden_size
,
name
=
name
+
'_skip_weight'
,
init_type
=
'lin'
)
if
gate
:
temp_output
=
L
.
concat
([
skip_feature
,
out_feat
,
out_feat
-
skip_feature
],
axis
=-
1
)
gate_f
=
L
.
sigmoid
(
linear
(
temp_output
,
1
,
name
=
name
+
'_gate_weight'
,
init_type
=
'lin'
))
out_feat
=
skip_feature
*
gate_f
+
out_feat
*
(
1
-
gate_f
)
else
:
out_feat
=
skip_feature
+
out_feat
if
layer_norm
:
lay_norm_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
1
))
lay_norm_bias
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
value
=
0
))
out_feat
=
L
.
layer_norm
(
out_feat
,
name
=
name
+
'_layer_norm'
,
param_attr
=
lay_norm_attr
,
bias_attr
=
lay_norm_bias
)
if
relu
:
out_feat
=
L
.
relu
(
out_feat
)
return
out_feat
,
reduce_attention
.
alpha
,
checkpoints
def
appnp
(
gw
,
feature
,
alpha
=
0.2
,
k_hop
=
10
):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
edge_dropout: Edge dropout rate.
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def
send_src_copy
(
src_feat
,
dst_feat
,
edge_feat
):
feature
=
src_feat
[
"h"
]
return
feature
def
get_norm
(
indegree
):
float_degree
=
L
.
cast
(
indegree
,
dtype
=
"float32"
)
float_degree
=
L
.
clamp
(
float_degree
,
min
=
1.0
)
norm
=
L
.
pow
(
float_degree
,
factor
=-
0.5
)
return
norm
cks
=
[]
h0
=
feature
ngw
=
gw
norm
=
get_norm
(
ngw
.
indegree
())
for
i
in
range
(
k_hop
):
feature
=
feature
*
norm
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
feature
=
gw
.
recv
(
msg
,
"sum"
)
feature
=
feature
*
norm
feature
=
feature
*
(
1
-
alpha
)
+
h0
*
alpha
if
(
i
+
1
)
%
3
==
0
:
cks
.
append
(
feature
)
return
feature
,
cks
def
attn_appnp
(
gw
,
feature
,
attn
,
alpha
=
0.2
,
k_hop
=
10
):
"""Attention based APPNP to Make model output deeper
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
attn: Using the attntion as transition matrix for APPNP
feature: A tensor with shape (num_nodes, feature_size).
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def
send_src_copy
(
src_feat
,
dst_feat
,
edge_feat
):
feature
=
src_feat
[
"h"
]
return
feature
h0
=
feature
attn
=
L
.
reduce_mean
(
attn
,
1
)
for
i
in
range
(
k_hop
):
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
msg
=
msg
*
attn
feature
=
gw
.
recv
(
msg
,
"sum"
)
feature
=
feature
*
(
1
-
alpha
)
+
h0
*
alpha
return
feature
def
linear
(
input
,
hidden_size
,
name
,
with_bias
=
True
,
init_type
=
'gcn'
):
"""fluid.layers.fc with different init_type
"""
if
init_type
==
'gcn'
:
fc_w_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
XavierInitializer
())
fc_bias_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
ConstantInitializer
(
0.0
))
else
:
fan_in
=
input
.
shape
[
-
1
]
bias_bound
=
1.0
/
math
.
sqrt
(
fan_in
)
fc_bias_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
UniformInitializer
(
low
=-
bias_bound
,
high
=
bias_bound
))
negative_slope
=
math
.
sqrt
(
5
)
gain
=
math
.
sqrt
(
2.0
/
(
1
+
negative_slope
**
2
))
std
=
gain
/
math
.
sqrt
(
fan_in
)
weight_bound
=
math
.
sqrt
(
3.0
)
*
std
fc_w_attr
=
F
.
ParamAttr
(
initializer
=
F
.
initializer
.
UniformInitializer
(
low
=-
weight_bound
,
high
=
weight_bound
))
if
not
with_bias
:
fc_bias_attr
=
False
output
=
L
.
fc
(
input
,
hidden_size
,
param_attr
=
fc_w_attr
,
name
=
name
,
bias_attr
=
fc_bias_attr
)
return
output
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录