Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
d2cefa43
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d2cefa43
编写于
4月 22, 2020
作者:
Z
ZHUI
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some typo problems
上级
920a0acc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
18 addition
and
11 deletion
+18
-11
examples/kg/data_loader.py
examples/kg/data_loader.py
+6
-5
examples/kg/evalutate.py
examples/kg/evalutate.py
+4
-2
examples/kg/model/RotatE.py
examples/kg/model/RotatE.py
+3
-3
examples/kg/mp_mapper.py
examples/kg/mp_mapper.py
+5
-1
未找到文件。
examples/kg/data_loader.py
浏览文件 @
d2cefa43
...
...
@@ -19,10 +19,11 @@ import os
import
numpy
as
np
from
collections
import
defaultdict
from
pgl.utils.logger
import
log
from
pybloom
import
BloomFilter
#from pybloom import BloomFilter
class
KBloader
:
class
KGLoader
:
"""
load the FB15K
"""
...
...
@@ -65,8 +66,9 @@ class KBloader:
def
training_data_no_filter
(
self
,
train_triple_positive
):
"""faster, no filter for exists triples"""
size
=
len
(
train_triple_positive
)
train_triple_negative
=
train_triple_positive
+
0
size
=
len
(
train_triple_positive
)
*
self
.
_neg_times
train_triple_negative
=
train_triple_positive
.
repeat
(
self
.
_neg_times
,
axis
=
0
)
replace_head_probability
=
0.5
*
np
.
ones
(
size
)
replace_entity_id
=
np
.
random
.
randint
(
self
.
entity_total
,
size
=
size
)
random_num
=
np
.
random
.
random
(
size
=
size
)
...
...
@@ -122,7 +124,6 @@ class KBloader:
"""
n
=
len
(
self
.
_triple_train
)
rand_idx
=
np
.
random
.
permutation
(
n
)
rand_idx
=
rand_idx
%
n
n_triple
=
len
(
rand_idx
)
start
=
0
while
start
<
n_triple
:
...
...
examples/kg/evalutate.py
浏览文件 @
d2cefa43
...
...
@@ -99,8 +99,10 @@ class Evaluate:
feed
=
batch_feed_dict
)
yield
batch_feed_dict
[
"test_triple"
],
head_score
,
tail_score
n_used_eval_triple
+=
1
print
(
'[{:.3f}s] #evaluation triple: {}/{}'
.
format
(
timeit
.
default_timer
()
-
start
,
n_used_eval_triple
,
5000
))
if
n_used_eval_triple
%
500
==
0
:
print
(
'[{:.3f}s] #evaluation triple: {}/{}'
.
format
(
timeit
.
default_timer
(
)
-
start
,
n_used_eval_triple
,
self
.
reader
.
test_num
))
res_reader
=
mp_reader_mapper
(
reader
=
iterator
,
...
...
examples/kg/model/RotatE.py
浏览文件 @
d2cefa43
...
...
@@ -13,9 +13,9 @@
# limitations under the License.
"""
RotatE:
"
Learning entity and relation embeddings for knowledge graph completion
."
Lin, Yankai
, et al.
https://
www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9571/9523
"
RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space
."
Sun, Zhiqing
, et al.
https://
arxiv.org/abs/1902.10197
"""
import
paddle.fluid
as
fluid
from
.Model
import
Model
...
...
examples/kg/mp_mapper.py
浏览文件 @
d2cefa43
...
...
@@ -65,12 +65,16 @@ def mp_reader_mapper(reader, func, num_works=4):
all_process
.
append
(
p
)
data_iter
=
reader
()
if
not
hasattr
(
data_iter
,
"__next__"
):
__next__
=
data_iter
.
next
else
:
__next__
=
data_iter
.
__next__
def
next_data
():
"""next_data"""
_next
=
None
try
:
_next
=
data_iter
.
next
()
_next
=
__next__
()
except
StopIteration
:
# log.debug(traceback.format_exc())
pass
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录