diff --git a/pgl/layers/conv.py b/pgl/layers/conv.py index 68a1d733ed1d297e7a20daa1fb7c14828ff8722b..9bdfa352f107ae2067e8a31c62d9d3a6aa4795b3 100644 --- a/pgl/layers/conv.py +++ b/pgl/layers/conv.py @@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o """Implementation of GaAN""" def send_func(src_feat, dst_feat, edge_feat): - # 计算每条边上的注意力分数 - # E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点 + # compute attention + # E * (M * D1) feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] # E * M * D1 old = feat_query @@ -281,16 +281,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o 'feat_gate': src_feat['feat_gate']} def recv_func(message): - # 每条边的终点的特征 + # feature of src and dst node on each edge dst_feat = message['dst_node_feat'] - # 每条边的出发点的特征 src_feat = message['src_node_feat'] - # 每个中心点自己的特征 + # feature of center node x = fluid.layers.sequence_pool(dst_feat, 'average') - # 每个中心点的邻居的特征的平均值 + # feature of neighbors of center node z = fluid.layers.sequence_pool(src_feat, 'average') - # 计算 gate + # compute gate feat_gate = message['feat_gate'] g_max = fluid.layers.sequence_pool(feat_gate, 'max') g = fluid.layers.concat([x, g_max, z], axis=1) @@ -318,10 +317,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o return output - # feature N * D - - # 计算每个点自己需要发送出去的内容 - # 投影后的特征向量 # N * (D1 * M) feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_key')) @@ -335,8 +330,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_gate')) - # send 阶段 - + # send message = gw.send( send_func, nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value), @@ -344,7 +338,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o efeat_list=None, ) - # 聚合邻居特征 + # recv output = gw.recv(message, recv_func) output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, param_attr=fluid.ParamAttr(name=name + '_project_output'))