未验证 提交 5c158603 编写于 作者: W Weiyue Su 提交者: GitHub

Merge pull request #100 from WenjinW/master

modify gaan in conv.py
python3 train.py --epochs 100 --lr 1e-2 --rc 0 --batch_size 1024 --gpu_id 0 --exp_id 0 python3 train.py --epochs 100 --lr 1e-2 --rc 0 --batch_size 1024 --exp_id 0
\ No newline at end of file \ No newline at end of file
...@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN""" """Implementation of GaAN"""
def send_func(src_feat, dst_feat, edge_feat): def send_func(src_feat, dst_feat, edge_feat):
# 计算每条边上的注意力分数 # compute attention
# E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点 # E * (M * D1)
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1 # E * M * D1
old = feat_query old = feat_query
...@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate': src_feat['feat_gate']} 'feat_gate': src_feat['feat_gate']}
def recv_func(message): def recv_func(message):
# 每条边的终点的特征 # feature of src and dst node on each edge
dst_feat = message['dst_node_feat'] dst_feat = message['dst_node_feat']
# 每条边的出发点的特征
src_feat = message['src_node_feat'] src_feat = message['src_node_feat']
# 每个中心点自己的特征 # feature of center node
x = fluid.layers.sequence_pool(dst_feat, 'average') x = fluid.layers.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值 # feature of neighbors of center node
z = fluid.layers.sequence_pool(src_feat, 'average') z = fluid.layers.sequence_pool(src_feat, 'average')
# 计算 gate # compute gate
feat_gate = message['feat_gate'] feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max') g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1) g = fluid.layers.concat([x, g_max, z], axis=1)
...@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return output return output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M) # N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key')) param_attr=fluid.ParamAttr(name=name + '_project_key'))
...@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False, feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate')) param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段 # send
message = gw.send( message = gw.send(
send_func, send_func,
nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value), nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
...@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list=None, efeat_list=None,
) )
# 聚合邻居特征 # recv
output = gw.recv(message, recv_func) output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output')) param_attr=fluid.ParamAttr(name=name + '_project_output'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册