提交 d2cefa43 编写于 作者: Z ZHUI

fix some typo problems

上级 920a0acc
...@@ -19,10 +19,11 @@ import os ...@@ -19,10 +19,11 @@ import os
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from pgl.utils.logger import log from pgl.utils.logger import log
from pybloom import BloomFilter
#from pybloom import BloomFilter
class KBloader:
class KGLoader:
""" """
load the FB15K load the FB15K
""" """
...@@ -65,8 +66,9 @@ class KBloader: ...@@ -65,8 +66,9 @@ class KBloader:
def training_data_no_filter(self, train_triple_positive): def training_data_no_filter(self, train_triple_positive):
"""faster, no filter for exists triples""" """faster, no filter for exists triples"""
size = len(train_triple_positive) size = len(train_triple_positive) * self._neg_times
train_triple_negative = train_triple_positive + 0 train_triple_negative = train_triple_positive.repeat(
self._neg_times, axis=0)
replace_head_probability = 0.5 * np.ones(size) replace_head_probability = 0.5 * np.ones(size)
replace_entity_id = np.random.randint(self.entity_total, size=size) replace_entity_id = np.random.randint(self.entity_total, size=size)
random_num = np.random.random(size=size) random_num = np.random.random(size=size)
...@@ -122,7 +124,6 @@ class KBloader: ...@@ -122,7 +124,6 @@ class KBloader:
""" """
n = len(self._triple_train) n = len(self._triple_train)
rand_idx = np.random.permutation(n) rand_idx = np.random.permutation(n)
rand_idx = rand_idx % n
n_triple = len(rand_idx) n_triple = len(rand_idx)
start = 0 start = 0
while start < n_triple: while start < n_triple:
......
...@@ -99,8 +99,10 @@ class Evaluate: ...@@ -99,8 +99,10 @@ class Evaluate:
feed=batch_feed_dict) feed=batch_feed_dict)
yield batch_feed_dict["test_triple"], head_score, tail_score yield batch_feed_dict["test_triple"], head_score, tail_score
n_used_eval_triple += 1 n_used_eval_triple += 1
print('[{:.3f}s] #evaluation triple: {}/{}'.format( if n_used_eval_triple % 500 == 0:
timeit.default_timer() - start, n_used_eval_triple, 5000)) print('[{:.3f}s] #evaluation triple: {}/{}'.format(
timeit.default_timer(
) - start, n_used_eval_triple, self.reader.test_num))
res_reader = mp_reader_mapper( res_reader = mp_reader_mapper(
reader=iterator, reader=iterator,
......
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
""" """
RotatE: RotatE:
"Learning entity and relation embeddings for knowledge graph completion." "RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space."
Lin, Yankai, et al. Sun, Zhiqing, et al.
https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9571/9523 https://arxiv.org/abs/1902.10197
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
from .Model import Model from .Model import Model
......
...@@ -65,12 +65,16 @@ def mp_reader_mapper(reader, func, num_works=4): ...@@ -65,12 +65,16 @@ def mp_reader_mapper(reader, func, num_works=4):
all_process.append(p) all_process.append(p)
data_iter = reader() data_iter = reader()
if not hasattr(data_iter, "__next__"):
__next__ = data_iter.next
else:
__next__ = data_iter.__next__
def next_data(): def next_data():
"""next_data""" """next_data"""
_next = None _next = None
try: try:
_next = data_iter.next() _next = __next__()
except StopIteration: except StopIteration:
# log.debug(traceback.format_exc()) # log.debug(traceback.format_exc())
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册