提交 352e017c 编写于 作者: W Webbley

[fix]: fix bug of GATNE

上级 969c1880
...@@ -114,29 +114,29 @@ class GATNE(object): ...@@ -114,29 +114,29 @@ class GATNE(object):
node_type_embed = fl.gather(node_type_embed, self.train_inputs) node_type_embed = fl.gather(node_type_embed, self.train_inputs)
# M_r # M_r
tn_initializer = fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size))
trans_weights = fl.create_parameter( trans_weights = fl.create_parameter(
shape=[ shape=[
self.edge_type_count, self.embedding_u_size, self.edge_type_count, self.embedding_u_size,
self.embedding_size // self.att_head self.embedding_size // self.att_head
], ],
attr=fluid.initializer.TruncatedNormalInitializer( default_initializer=tn_initializer,
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32', dtype='float32',
name='trans_w') name='trans_w')
# W_r # W_r
trans_weights_s1 = fl.create_parameter( trans_weights_s1 = fl.create_parameter(
shape=[self.edge_type_count, self.embedding_u_size, self.dim_a], shape=[self.edge_type_count, self.embedding_u_size, self.dim_a],
attr=fluid.initializer.TruncatedNormalInitializer( default_initializer=tn_initializer,
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32', dtype='float32',
name='trans_w_s1') name='trans_w_s1')
# w_r # w_r
trans_weights_s2 = fl.create_parameter( trans_weights_s2 = fl.create_parameter(
shape=[self.edge_type_count, self.dim_a, self.att_head], shape=[self.edge_type_count, self.dim_a, self.att_head],
attr=fluid.initializer.TruncatedNormalInitializer( default_initializer=tn_initializer,
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32', dtype='float32',
name='trans_w_s2') name='trans_w_s2')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册