diff --git a/examples/GATNE/model.py b/examples/GATNE/model.py index 18f83c89a31324256f20ae118372828fe8be955d..492aa3d97e07df5b4335adc3069df650ff320870 100644 --- a/examples/GATNE/model.py +++ b/examples/GATNE/model.py @@ -114,29 +114,29 @@ class GATNE(object): node_type_embed = fl.gather(node_type_embed, self.train_inputs) # M_r + tn_initializer = fluid.initializer.TruncatedNormalInitializer( + loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)) + trans_weights = fl.create_parameter( shape=[ self.edge_type_count, self.embedding_u_size, self.embedding_size // self.att_head ], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w') # W_r trans_weights_s1 = fl.create_parameter( shape=[self.edge_type_count, self.embedding_u_size, self.dim_a], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s1') # w_r trans_weights_s2 = fl.create_parameter( shape=[self.edge_type_count, self.dim_a, self.att_head], - attr=fluid.initializer.TruncatedNormalInitializer( - loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)), + default_initializer=tn_initializer, dtype='float32', name='trans_w_s2')