Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
5c158603
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5c158603
编写于
7月 02, 2020
作者:
W
Weiyue Su
提交者:
GitHub
7月 02, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #100 from WenjinW/master
modify gaan in conv.py
上级
299328a0
97d3e021
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
9 addition
and
15 deletion
+9
-15
examples/GaAN/main.sh
examples/GaAN/main.sh
+1
-1
pgl/layers/conv.py
pgl/layers/conv.py
+8
-14
未找到文件。
examples/GaAN/main.sh
浏览文件 @
5c158603
python3 train.py
--epochs
100
--lr
1e-2
--rc
0
--batch_size
1024
--gpu_id
0
--exp_id
0
python3 train.py
--epochs
100
--lr
1e-2
--rc
0
--batch_size
1024
--exp_id
0
\ No newline at end of file
\ No newline at end of file
pgl/layers/conv.py
浏览文件 @
5c158603
...
@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN"""
"""Implementation of GaAN"""
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
#
计算每条边上的注意力分数
#
compute attention
# E * (M * D1)
, 每个 dst 点都查询它的全部邻边的 src 点
# E * (M * D1)
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
# E * M * D1
old
=
feat_query
old
=
feat_query
...
@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate'
:
src_feat
[
'feat_gate'
]}
'feat_gate'
:
src_feat
[
'feat_gate'
]}
def
recv_func
(
message
):
def
recv_func
(
message
):
#
每条边的终点的特征
#
feature of src and dst node on each edge
dst_feat
=
message
[
'dst_node_feat'
]
dst_feat
=
message
[
'dst_node_feat'
]
# 每条边的出发点的特征
src_feat
=
message
[
'src_node_feat'
]
src_feat
=
message
[
'src_node_feat'
]
#
每个中心点自己的特征
#
feature of center node
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
#
每个中心点的邻居的特征的平均值
#
feature of neighbors of center node
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
#
计算
gate
#
compute
gate
feat_gate
=
message
[
'feat_gate'
]
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
...
@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return
output
return
output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
# N * (D1 * M)
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
...
@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send 阶段
# send
message
=
gw
.
send
(
message
=
gw
.
send
(
send_func
,
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
...
@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list
=
None
,
efeat_list
=
None
,
)
)
#
聚合邻居特征
#
recv
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录