From 352e017c490c7ad978119bf66bb12283d2e75bf9 Mon Sep 17 00:00:00 2001 From: Webbley Date: Thu, 24 Sep 2020 10:50:16 +0800 Subject: [PATCH] [fix]: fix bug of GATNE --- examples/GATNE/model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/GATNE/model.py b/examples/GATNE/model.py index 18f83c8..492aa3d 100644 --- a/examples/GATNE/model.py +++ b/examples/GATNE/model.py @@ -114,29 +114,29 @@ class GATNE(object): node_type_embed = fl.gather(node_type_embed, self.train_inputs) # M_r + tn_initializer = fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)) + trans_weights = fl.create_parameter( shape=[ self.edge_type_count, self.embedding_u_size, self.embedding_size // self.att_head ], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w') # W_r trans_weights_s1 = fl.create_parameter( shape=[self.edge_type_count, self.embedding_u_size, self.dim_a], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s1') # w_r trans_weights_s2 = fl.create_parameter( shape=[self.edge_type_count, self.dim_a, self.att_head], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s2') -- GitLab