Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
97d3e021
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97d3e021
编写于
7月 02, 2020
作者:
S
suweiyue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify gaan in conv
上级
e47f23b9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
8 addition
and
14 deletion
+8
-14
pgl/layers/conv.py
pgl/layers/conv.py
+8
-14
未找到文件。
pgl/layers/conv.py
浏览文件 @
97d3e021
...
...
@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN"""
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
#
计算每条边上的注意力分数
# E * (M * D1)
, 每个 dst 点都查询它的全部邻边的 src 点
#
compute attention
# E * (M * D1)
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
old
=
feat_query
...
...
@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate'
:
src_feat
[
'feat_gate'
]}
def
recv_func
(
message
):
#
每条边的终点的特征
#
feature of src and dst node on each edge
dst_feat
=
message
[
'dst_node_feat'
]
# 每条边的出发点的特征
src_feat
=
message
[
'src_node_feat'
]
#
每个中心点自己的特征
#
feature of center node
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
#
每个中心点的邻居的特征的平均值
#
feature of neighbors of center node
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
#
计算
gate
#
compute
gate
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
...
...
@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return
output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
...
...
@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send 阶段
# send
message
=
gw
.
send
(
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
...
...
@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list
=
None
,
)
#
聚合邻居特征
#
recv
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录