提交 352e017c 编写于 作者: W Webbley

[fix]: fix bug of GATNE

上级 969c1880
......@@ -114,29 +114,29 @@ class GATNE(object):
node_type_embed = fl.gather(node_type_embed, self.train_inputs)
# M_r
tn_initializer = fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size))
trans_weights = fl.create_parameter(
shape=[
self.edge_type_count, self.embedding_u_size,
self.embedding_size // self.att_head
],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
default_initializer=tn_initializer,
dtype='float32',
name='trans_w')
# W_r
trans_weights_s1 = fl.create_parameter(
shape=[self.edge_type_count, self.embedding_u_size, self.dim_a],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
default_initializer=tn_initializer,
dtype='float32',
name='trans_w_s1')
# w_r
trans_weights_s2 = fl.create_parameter(
shape=[self.edge_type_count, self.dim_a, self.att_head],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
default_initializer=tn_initializer,
dtype='float32',
name='trans_w_s2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册