提交 d2cefa43 编写于 作者: Z ZHUI

fix some typo problems

上级 920a0acc
......@@ -19,10 +19,11 @@ import os
import numpy as np
from collections import defaultdict
from pgl.utils.logger import log
from pybloom import BloomFilter
#from pybloom import BloomFilter
class KBloader:
class KGLoader:
"""
load the FB15K
"""
......@@ -65,8 +66,9 @@ class KBloader:
def training_data_no_filter(self, train_triple_positive):
"""faster, no filter for exists triples"""
size = len(train_triple_positive)
train_triple_negative = train_triple_positive + 0
size = len(train_triple_positive) * self._neg_times
train_triple_negative = train_triple_positive.repeat(
self._neg_times, axis=0)
replace_head_probability = 0.5 * np.ones(size)
replace_entity_id = np.random.randint(self.entity_total, size=size)
random_num = np.random.random(size=size)
......@@ -122,7 +124,6 @@ class KBloader:
"""
n = len(self._triple_train)
rand_idx = np.random.permutation(n)
rand_idx = rand_idx % n
n_triple = len(rand_idx)
start = 0
while start < n_triple:
......
......@@ -99,8 +99,10 @@ class Evaluate:
feed=batch_feed_dict)
yield batch_feed_dict["test_triple"], head_score, tail_score
n_used_eval_triple += 1
print('[{:.3f}s] #evaluation triple: {}/{}'.format(
timeit.default_timer() - start, n_used_eval_triple, 5000))
if n_used_eval_triple % 500 == 0:
print('[{:.3f}s] #evaluation triple: {}/{}'.format(
timeit.default_timer(
) - start, n_used_eval_triple, self.reader.test_num))
res_reader = mp_reader_mapper(
reader=iterator,
......
......@@ -13,9 +13,9 @@
# limitations under the License.
"""
RotatE:
"Learning entity and relation embeddings for knowledge graph completion."
Lin, Yankai, et al.
https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9571/9523
"RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space."
Sun, Zhiqing, et al.
https://arxiv.org/abs/1902.10197
"""
import paddle.fluid as fluid
from .Model import Model
......
......@@ -65,12 +65,16 @@ def mp_reader_mapper(reader, func, num_works=4):
all_process.append(p)
data_iter = reader()
if not hasattr(data_iter, "__next__"):
__next__ = data_iter.next
else:
__next__ = data_iter.__next__
def next_data():
"""next_data"""
_next = None
try:
_next = data_iter.next()
_next = __next__()
except StopIteration:
# log.debug(traceback.format_exc())
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册