Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
fc6e0397
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fc6e0397
编写于
5月 26, 2020
作者:
W
wangwenjin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify GaAN
上级
855f0f99
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
381 addition
and
94 deletion
+381
-94
.gitignore
.gitignore
+2
-0
examples/GaAN/conv.py
examples/GaAN/conv.py
+354
-0
examples/GaAN/main.sh
examples/GaAN/main.sh
+1
-0
examples/GaAN/model.py
examples/GaAN/model.py
+6
-4
examples/GaAN/train.py
examples/GaAN/train.py
+14
-8
pgl/layers/conv.py
pgl/layers/conv.py
+4
-82
未找到文件。
.gitignore
浏览文件 @
fc6e0397
...
...
@@ -2,6 +2,8 @@
/examples/GaAN/dataset/
/examples/GaAN/log/
/examples/GaAN/__pycache__/
/examples/GaAN/params/
/DoorGod
# Virtualenv
/.venv/
/venv/
...
...
examples/GaAN/conv.py
0 → 100644
浏览文件 @
fc6e0397
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package implements common layers to help building
graph neural networks.
"""
import
paddle.fluid
as
fluid
from
pgl
import
graph_wrapper
from
pgl.utils
import
paddle_helper
__all__
=
[
'gcn'
,
'gat'
,
'gin'
,
'gaan'
]
def
gcn
(
gw
,
feature
,
hidden_size
,
activation
,
name
,
norm
=
None
):
"""Implementation of graph convolutional neural networks (GCN)
This is an implementation of the paper SEMI-SUPERVISED CLASSIFICATION
WITH GRAPH CONVOLUTIONAL NETWORKS (https://arxiv.org/pdf/1609.02907.pdf).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gcn.
activation: The activation for the output.
name: Gcn layer names.
norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
be tensor with shape (num_nodes,) and dtype float32.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def
send_src_copy
(
src_feat
,
dst_feat
,
edge_feat
):
return
src_feat
[
"h"
]
size
=
feature
.
shape
[
-
1
]
if
size
>
hidden_size
:
feature
=
fluid
.
layers
.
fc
(
feature
,
size
=
hidden_size
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
))
if
norm
is
not
None
:
feature
=
feature
*
norm
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
if
size
>
hidden_size
:
output
=
gw
.
recv
(
msg
,
"sum"
)
else
:
output
=
gw
.
recv
(
msg
,
"sum"
)
output
=
fluid
.
layers
.
fc
(
output
,
size
=
hidden_size
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
))
if
norm
is
not
None
:
output
=
output
*
norm
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
hidden_size
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
name
+
'_bias'
)
output
=
fluid
.
layers
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
return
output
def
gat
(
gw
,
feature
,
hidden_size
,
activation
,
name
,
num_heads
=
8
,
feat_drop
=
0.6
,
attn_drop
=
0.6
,
is_test
=
False
):
"""Implementation of graph attention networks (GAT)
This is an implementation of the paper GRAPH ATTENTION NETWORKS
(https://arxiv.org/abs/1710.10903).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
hidden_size: The hidden size for gat.
activation: The activation for the output.
name: Gat layer names.
num_heads: The head number in gat.
feat_drop: Dropout rate for feature.
attn_drop: Dropout rate for attention.
is_test: Whether in test phrase.
Return:
A tensor with shape (num_nodes, hidden_size * num_heads)
"""
def
send_attention
(
src_feat
,
dst_feat
,
edge_feat
):
output
=
src_feat
[
"left_a"
]
+
dst_feat
[
"right_a"
]
output
=
fluid
.
layers
.
leaky_relu
(
output
,
alpha
=
0.2
)
# (num_edges, num_heads)
return
{
"alpha"
:
output
,
"h"
:
src_feat
[
"h"
]}
def
reduce_attention
(
msg
):
alpha
=
msg
[
"alpha"
]
# lod-tensor (batch_size, seq_len, num_heads)
h
=
msg
[
"h"
]
alpha
=
paddle_helper
.
sequence_softmax
(
alpha
)
old_h
=
h
h
=
fluid
.
layers
.
reshape
(
h
,
[
-
1
,
num_heads
,
hidden_size
])
alpha
=
fluid
.
layers
.
reshape
(
alpha
,
[
-
1
,
num_heads
,
1
])
if
attn_drop
>
1e-15
:
alpha
=
fluid
.
layers
.
dropout
(
alpha
,
dropout_prob
=
attn_drop
,
is_test
=
is_test
,
dropout_implementation
=
"upscale_in_train"
)
h
=
h
*
alpha
h
=
fluid
.
layers
.
reshape
(
h
,
[
-
1
,
num_heads
*
hidden_size
])
h
=
fluid
.
layers
.
lod_reset
(
h
,
old_h
)
return
fluid
.
layers
.
sequence_pool
(
h
,
"sum"
)
if
feat_drop
>
1e-15
:
feature
=
fluid
.
layers
.
dropout
(
feature
,
dropout_prob
=
feat_drop
,
is_test
=
is_test
,
dropout_implementation
=
'upscale_in_train'
)
ft
=
fluid
.
layers
.
fc
(
feature
,
hidden_size
*
num_heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_weight'
))
left_a
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_heads
,
hidden_size
],
dtype
=
'float32'
,
name
=
name
+
'_gat_l_A'
)
right_a
=
fluid
.
layers
.
create_parameter
(
shape
=
[
num_heads
,
hidden_size
],
dtype
=
'float32'
,
name
=
name
+
'_gat_r_A'
)
reshape_ft
=
fluid
.
layers
.
reshape
(
ft
,
[
-
1
,
num_heads
,
hidden_size
])
left_a_value
=
fluid
.
layers
.
reduce_sum
(
reshape_ft
*
left_a
,
-
1
)
right_a_value
=
fluid
.
layers
.
reduce_sum
(
reshape_ft
*
right_a
,
-
1
)
msg
=
gw
.
send
(
send_attention
,
nfeat_list
=
[(
"h"
,
ft
),
(
"left_a"
,
left_a_value
),
(
"right_a"
,
right_a_value
)])
output
=
gw
.
recv
(
msg
,
reduce_attention
)
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
hidden_size
*
num_heads
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
name
+
'_bias'
)
bias
.
stop_gradient
=
True
output
=
fluid
.
layers
.
elementwise_add
(
output
,
bias
,
act
=
activation
)
return
output
def
gin
(
gw
,
feature
,
hidden_size
,
activation
,
name
,
init_eps
=
0.0
,
train_eps
=
False
):
"""Implementation of Graph Isomorphism Network (GIN) layer.
This is an implementation of the paper How Powerful are Graph Neural Networks?
(https://arxiv.org/pdf/1810.00826.pdf).
In their implementation, all MLPs have 2 layers. Batch normalization is applied
on every hidden layer.
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
name: GIN layer names.
hidden_size: The hidden size for gin.
activation: The activation for the output.
init_eps: float, optional
Initial :math:`\epsilon` value, default is 0.
train_eps: bool, optional
if True, :math:`\epsilon` will be a learnable parameter.
Return:
A tensor with shape (num_nodes, hidden_size).
"""
def
send_src_copy
(
src_feat
,
dst_feat
,
edge_feat
):
return
src_feat
[
"h"
]
epsilon
=
fluid
.
layers
.
create_parameter
(
shape
=
[
1
,
1
],
dtype
=
"float32"
,
attr
=
fluid
.
ParamAttr
(
name
=
"%s_eps"
%
name
),
default_initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
init_eps
))
if
not
train_eps
:
epsilon
.
stop_gradient
=
True
msg
=
gw
.
send
(
send_src_copy
,
nfeat_list
=
[(
"h"
,
feature
)])
output
=
gw
.
recv
(
msg
,
"sum"
)
+
feature
*
(
epsilon
+
1.0
)
output
=
fluid
.
layers
.
fc
(
output
,
size
=
hidden_size
,
act
=
None
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"%s_w_0"
%
name
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"%s_b_0"
%
name
))
output
=
fluid
.
layers
.
layer_norm
(
output
,
begin_norm_axis
=
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"norm_scale_%s"
%
(
name
),
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"norm_bias_%s"
%
(
name
),
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
)
if
activation
is
not
None
:
output
=
getattr
(
fluid
.
layers
,
activation
)(
output
)
output
=
fluid
.
layers
.
fc
(
output
,
size
=
hidden_size
,
act
=
activation
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"%s_w_1"
%
name
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"%s_b_1"
%
name
))
return
output
def
gaan
(
gw
,
feature
,
hidden_size_a
,
hidden_size_v
,
hidden_size_m
,
hidden_size_o
,
heads
,
name
):
"""Implementation of GaAN"""
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
# 计算每条边上的注意力分数
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
old
=
feat_query
feat_query
=
fluid
.
layers
.
reshape
(
feat_query
,
[
-
1
,
heads
,
hidden_size_a
])
feat_key
=
fluid
.
layers
.
reshape
(
feat_key
,
[
-
1
,
heads
,
hidden_size_a
])
# E * M
alpha
=
fluid
.
layers
.
reduce_sum
(
feat_key
*
feat_query
,
dim
=-
1
)
return
{
'dst_node_feat'
:
dst_feat
[
'node_feat'
],
'src_node_feat'
:
src_feat
[
'node_feat'
],
'feat_value'
:
src_feat
[
'feat_value'
],
'alpha'
:
alpha
,
'feat_gate'
:
src_feat
[
'feat_gate'
]}
def
recv_func
(
message
):
# 每条边的终点的特征
dst_feat
=
message
[
'dst_node_feat'
]
# 每条边的出发点的特征
src_feat
=
message
[
'src_node_feat'
]
# 每个中心点自己的特征
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
# 每个中心点的邻居的特征的平均值
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
# 计算 gate
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
fluid
.
layers
.
fc
(
g
,
heads
,
bias_attr
=
False
,
act
=
"sigmoid"
)
# softmax
alpha
=
message
[
'alpha'
]
alpha
=
paddle_helper
.
sequence_softmax
(
alpha
)
# E * M
feat_value
=
message
[
'feat_value'
]
# E * (M * D2)
old
=
feat_value
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# E * M * D2
feat_value
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
alpha
,
axis
=
0
)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
*
hidden_size_v
])
# E * (M * D2)
feat_value
=
fluid
.
layers
.
lod_reset
(
feat_value
,
old
)
feat_value
=
fluid
.
layers
.
sequence_pool
(
feat_value
,
'sum'
)
# N * (M * D2)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
heads
,
hidden_size_v
])
# N * M * D2
output
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
g
,
axis
=
0
)
output
=
fluid
.
layers
.
reshape
(
output
,
[
-
1
,
heads
*
hidden_size_v
])
# N * (M * D2)
output
=
fluid
.
layers
.
concat
([
x
,
output
],
axis
=
1
)
return
output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
# N * (D2 * M)
feat_value
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_v
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_value'
))
# N * (D1 * M)
feat_query
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_query'
))
# N * Dm
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send 阶段
message
=
gw
.
send
(
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
(
'feat_query'
,
feat_query
),
(
'feat_gate'
,
feat_gate
)],
efeat_list
=
None
,
)
# 聚合邻居特征
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
output
=
fluid
.
layers
.
leaky_relu
(
output
,
alpha
=
0.1
)
output
=
fluid
.
layers
.
dropout
(
output
,
dropout_prob
=
0.1
)
return
output
examples/GaAN/main.sh
0 → 100644
浏览文件 @
fc6e0397
python3 train.py
--epochs
100
--lr
1e-2
--rc
0
--batch_size
1024
--gpu_id
4
--exp_id
39
\ No newline at end of file
examples/GaAN/model.py
浏览文件 @
fc6e0397
from
paddle
import
fluid
from
pgl.utils
import
paddle_helper
from
pgl.layers
import
GaAN
# from pgl.layers import gaan
from
conv
import
gaan
class
GaANModel
(
object
):
def
__init__
(
self
,
num_class
,
num_layers
,
hidden_size_a
=
24
,
...
...
@@ -19,9 +21,9 @@ class GaANModel(object):
def
forward
(
self
,
gw
):
feature
=
gw
.
node_feat
[
'node_feat'
]
for
i
in
range
(
self
.
num_layers
):
feature
=
GaAN
(
gw
,
feature
,
self
.
hidden_size_a
,
self
.
hidden_size_v
,
self
.
hidden_size_m
,
self
.
hidden_size_o
,
self
.
heads
,
self
.
name
+
'_'
+
str
(
i
))
feature
=
gaan
(
gw
,
feature
,
self
.
hidden_size_a
,
self
.
hidden_size_v
,
self
.
hidden_size_m
,
self
.
hidden_size_o
,
self
.
heads
,
self
.
name
+
'_'
+
str
(
i
))
pred
=
fluid
.
layers
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
name
=
self
.
name
+
"_pred_output"
)
...
...
examples/GaAN/train.py
浏览文件 @
fc6e0397
...
...
@@ -119,14 +119,6 @@ if __name__ == "__main__":
val_program
=
train_program
.
clone
(
for_test
=
True
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
# adam = fluid.optimizer.Adam(
# learning_rate=1e-2,
# regularization=fluid.regularizer.L2DecayRegularizer(
# regularization_coeff=0.0005))
# lr = fluid.layers.natural_exp_decay(learning_rate=args.lr,
# decay_steps=1000,
# decay_rate=0.5,
# )
lr
=
args
.
lr
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr
,
...
...
@@ -171,6 +163,7 @@ if __name__ == "__main__":
start
=
time
.
time
()
print
(
"Training Begin"
.
center
(
50
,
"="
))
log_text
.
add_record
(
0
,
"Training Begin"
.
center
(
50
,
"="
))
best_valid
=
-
1.0
for
epoch
in
range
(
args
.
epochs
):
start_e
=
time
.
time
()
# print("Train Epoch {}".format(epoch).center(50, "="))
...
...
@@ -197,9 +190,22 @@ if __name__ == "__main__":
log_valid_rocauc_epoch
.
add_record
(
epoch
,
valid_rocauc
)
log_time
.
add_record
(
epoch
,
end_e
-
start_e
)
if
valid_rocauc
>
best_valid
:
print
(
"Update: new {}, old {}"
.
format
(
valid_rocauc
,
best_valid
))
best_valid
=
valid_rocauc
# fluid.io.save_vars(executor=exe, dirname='./vars/'+str(args.exp_id), main_program=val_program,
# predicate=lambda x: True)
fluid
.
io
.
save_params
(
executor
=
exe
,
dirname
=
'./params/'
+
str
(
args
.
exp_id
),
main_program
=
val_program
)
print
(
"Test Stage"
.
center
(
50
,
"="
))
log_text
.
add_record
(
args
.
epochs
+
1
,
"Test Stage"
.
center
(
50
,
"="
))
# fluid.io.load_vars(executor=exe, dirname='./params/'+str(args.exp_id), main_program=val_program,
# predicate=lambda x: True)
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
'./params/'
+
str
(
args
.
exp_id
),
main_program
=
val_program
)
test_loss
,
test_rocauc
=
valid_epoch
(
test_iter
,
program
=
val_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
evaluator
=
evaluator
,
epoch
=
epoch
)
...
...
pgl/layers/conv.py
浏览文件 @
fc6e0397
...
...
@@ -18,7 +18,7 @@ import paddle.fluid as fluid
from
pgl
import
graph_wrapper
from
pgl.utils
import
paddle_helper
__all__
=
[
'gcn'
,
'gat'
,
'gin'
,
'
GaAN
'
]
__all__
=
[
'gcn'
,
'gat'
,
'gin'
,
'
gaan
'
]
def
gcn
(
gw
,
feature
,
hidden_size
,
activation
,
name
,
norm
=
None
):
...
...
@@ -259,87 +259,9 @@ def gin(gw,
return
output
# def GaAN(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads,
# name):
# """
# This is an implementation of the paper GaAN: Gated Attention Networks for Learning
# on Large and Spatiotemporal Graphs(https://arxiv.org/abs/1803.07294)
# """
# # send function
# def send_func(src_feat, dst_feat, edge_feat):
# print("heads: {}, hidden_size_a: {}".format(heads, hidden_size_a))
# feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
# feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
# alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
# return {'dst_node_feat': dst_feat['node_feat'],
# 'src_node_feat': src_feat['node_feat'],
# 'feat_value': src_feat['feat_value'],
# 'alpha': alpha,
# 'feat_gate': src_feat['feat_gate']}
# # recv function
# def recv_func(message):
# dst_feat = message['dst_node_feat'] # feature of dst nodes on each edge
# src_feat = message['src_node_feat'] # feature of src nodes on each edge
# x = fluid.layers.sequence_pool(dst_feat, 'average') # feature of center nodes
# z = fluid.layers.sequence_pool(src_feat, 'average') # mean feature of neighbors
# # compute gate
# feat_gate = message['feat_gate']
# g_max = fluid.layers.sequence_pool(feat_gate, 'max')
# g = fluid.layers.concat([x, g_max, z], axis=1)
# g = fluid.layers.fc(g, heads, bias_attr=False, act='sigmoid')
# # softmax of attention coefficient
# alpha = message['alpha']
# alpha = paddle_helper.sequence_softmax(alpha)
# feat_value = message['feat_value']
# old = feat_value
# feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v])
# feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
# feat_value = fluid.layers.reshape(feat_value, [-1, heads * hidden_size_v])
# feat_value = fluid.layers.lod_reset(feat_value, old)
# feat_value = fluid.layers.sequence_pool(feat_value, 'sum')
# feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v])
# output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
# output = fluid.layers.reshape(output, [-1, heads*hidden_size_v])
# output = fluid.layers.concat([x, output], axis=1)
# return output
# # project the feature of nodes into new vector spaces
# feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
# param_attr=fluid.ParamAttr(name=name + '_project_key'))
# feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
# param_attr=fluid.ParamAttr(name=name + '_project_value'))
# feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
# param_attr=fluid.ParamAttr(name=name + '_project_query'))
# feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
# param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# # send stage
# msg = gw.send(send_func, nfeat_list=[('node_feat', feature),
# ('feat_key', feat_key), ('feat_value', feat_value),
# ('feat_query', feat_query), ('feat_gate', feat_gate)],
# efeat_list=None,
# )
# # recv stage
# output = gw.recv(msg, recv_func)
# # output
# output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
# param_attr=fluid.ParamAttr(name=name+'_project_output'))
# outout = fluid.layers.leaky_relu(output, alpha=0.1)
# output = fluid.layers.dropout(output, dropout_prob=0.1)
# return output
def
GaAN
(
gw
,
feature
,
hidden_size_a
,
hidden_size_v
,
hidden_size_m
,
hidden_size_o
,
heads
,
name
):
def
gaan
(
gw
,
feature
,
hidden_size_a
,
hidden_size_v
,
hidden_size_m
,
hidden_size_o
,
heads
,
name
):
"""Implementation of GaAN"""
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
# 计算每条边上的注意力分数
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录