提交 97d3e021 编写于 作者: S suweiyue

modify gaan in conv

上级 e47f23b9
......@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat):
# 计算每条边上的注意力分数
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点
# compute attention
# E * (M * D1)
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
......@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate': src_feat['feat_gate']}
def recv_func(message):
# 每条边的终点的特征
# feature of src and dst node on each edge
dst_feat = message['dst_node_feat']
# 每条边的出发点的特征
src_feat = message['src_node_feat']
# 每个中心点自己的特征
# feature of center node
x = fluid.layers.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
# feature of neighbors of center node
z = fluid.layers.sequence_pool(src_feat, 'average')
# 计算 gate
# compute gate
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
......@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
......@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段
# send
message = gw.send(
send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
......@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list=None,
)
# 聚合邻居特征
# recv
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册