Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wdafggrt
transe
提交
02f9c820
T
transe
项目概览
wdafggrt
/
transe
与 Fork 源项目一致
从无法访问的项目Fork
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
transe
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
02f9c820
编写于
4月 06, 2016
作者:
W
wuxiyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
commit
上级
eb195a19
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
477 addition
and
0 deletion
+477
-0
pca.py
pca.py
+101
-0
reTranE.py
reTranE.py
+23
-0
test.py
test.py
+152
-0
tranE.py
tranE.py
+201
-0
未找到文件。
pca.py
0 → 100644
浏览文件 @
02f9c820
from
numpy
import
*
import
matplotlib.pyplot
as
plt
import
pylab
from
sklearn.decomposition
import
PCA
def
loadData
(
str
):
fr
=
open
(
str
)
sArr
=
[
line
.
strip
().
split
(
"
\t
"
)
for
line
in
fr
.
readlines
()]
datArr
=
[[
float
(
s
)
for
s
in
line
[
1
][
1
:
-
1
].
split
(
", "
)]
for
line
in
sArr
]
matA
=
mat
(
datArr
)
print
(
matA
.
shape
)
nameArr
=
[
line
[
0
]
for
line
in
sArr
]
return
matA
,
nameArr
def
getEig
(
inputM
):
covM
=
cov
(
inputM
,
rowvar
=
0
)
s
,
V
=
linalg
.
eig
(
covM
)
return
s
,
V
def
judge
(
s
):
s
.
sort
()
s
=
s
[::
-
1
]
bili
=
[]
i
=
0
sum1
=
0.0
sum2
=
s
[
0
]
while
i
<
len
(
s
)
-
1
:
sum1
=
sum1
+
s
[
i
]
sum2
=
sum2
+
s
[
i
+
1
]
bili
.
append
(
sum1
/
sum2
)
i
+=
1
plt
.
plot
(
range
(
len
(
bili
)),
bili
,
'b*'
)
plt
.
plot
(
range
(
len
(
bili
)),
bili
,
'r'
)
for
xy
in
zip
(
range
(
len
(
bili
)),
bili
):
plt
.
annotate
(
xy
[
1
],
xy
=
xy
,
xytext
=
(
-
20
,
10
),
textcoords
=
'offset points'
)
plt
.
xlabel
(
"eigenvector"
)
plt
.
ylabel
(
"eigenvalue"
)
plt
.
title
(
'fangchabili'
)
plt
.
legend
()
plt
.
show
()
return
bili
def
getbaifenbi
(
bili
,
num
):
i
=
1
for
b
in
bili
:
if
b
>
num
:
break
i
+=
1
return
i
def
pca
(
inputM
,
k
):
covM
=
cov
(
inputM
,
rowvar
=
0
)
s
,
V
=
linalg
.
eig
(
covM
)
paixu
=
argsort
(
s
)
paixuk
=
paixu
[:
-
(
k
+
1
):
-
1
]
kwei
=
V
[:,
paixuk
]
outputM
=
inputM
*
kwei
chonggou
=
(
outputM
*
kwei
.
T
)
return
outputM
,
chonggou
def
plotV
(
a
,
labels
):
fig
=
plt
.
figure
()
ax
=
fig
.
add_subplot
(
111
)
print
(
"aaa"
)
font
=
{
'fontname'
:
'Tahoma'
,
'fontsize'
:
0.5
,
'verticalalignment'
:
'top'
,
'horizontalalignment'
:
'center'
}
ax
.
scatter
(
a
[:,
0
],
a
[:,
1
],
marker
=
' '
)
ax
.
set_xlim
(
-
0.8
,
0.8
)
ax
.
set_ylim
(
-
0.8
,
0.8
)
i
=
0
for
label
,
x
,
y
in
zip
(
labels
,
a
[:,
0
],
a
[:,
1
]):
i
+=
1
s
=
random
.
uniform
(
0
,
100
)
if
i
<
14951
:
if
s
>
3.1
:
continue
else
:
if
s
>
6.7
:
continue
ax
.
annotate
(
label
,
xy
=
(
x
,
y
),
xytext
=
None
,
ha
=
'right'
,
va
=
'bottom'
,
**
font
)
#,textcoords = 'offset points',bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
# #arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))
plt
.
title
(
'TransE pca2dim'
)
plt
.
xlabel
(
'X'
)
plt
.
ylabel
(
'Y'
)
print
(
"ddd"
)
plt
.
savefig
(
'plot_with_labels'
,
dpi
=
3000
,
bbox_inches
=
'tight'
,
orientation
=
'landscape'
,
papertype
=
'a0'
)
if
__name__
==
'__main__'
:
dirEntity
=
"c:
\\
entityVector.txt"
dirRelation
=
"c:
\\
relationVector.txt"
matEntity
,
nameEntity
=
loadData
(
dirEntity
)
matRelation
,
nameRelation
=
loadData
(
dirRelation
)
mat
=
row_stack
((
matEntity
,
matRelation
))
print
(
mat
.
shape
)
nameEntity
.
extend
(
nameRelation
)
#s,V = getEig(mat)
#bili= judge(s)
#k = getbaifenbi(bili, 0.9)
k
=
2
a
,
b
=
pca
(
mat
,
k
)
plotV
(
a
,
nameEntity
)
reTranE.py
0 → 100644
浏览文件 @
02f9c820
from
tranE
import
*
def
loadData
(
str
):
fr
=
open
(
str
)
sArr
=
[
line
.
strip
().
split
(
"
\t
"
)
for
line
in
fr
.
readlines
()]
datArr
=
[[
float
(
s
)
for
s
in
line
[
1
][
1
:
-
1
].
split
(
", "
)]
for
line
in
sArr
]
nameArr
=
[
line
[
0
]
for
line
in
sArr
]
dic
=
{}
for
name
,
vec
in
zip
(
nameArr
,
datArr
):
dic
[
name
]
=
vec
return
dic
if
__name__
==
'__main__'
:
dirEntityVector
=
"c:
\\
entityVector.txt"
entityList
=
loadData
(
dirEntityVector
)
dirRelationVector
=
"c:
\\
relationVector.txt"
relationList
=
loadData
(
dirRelationVector
)
dirTrain
=
"C:
\\
data
\\
train.txt"
tripleNum
,
tripleList
=
openTrain
(
dirTrain
)
transE
=
TransE
(
entityList
,
relationList
,
tripleList
,
learingRate
=
0.001
,
dim
=
30
)
transE
.
transE
(
100000
)
transE
.
writeRelationVector
(
"c:
\\
relationVector.txt"
)
transE
.
writeEntilyVector
(
"c:
\\
entityVector.txt"
)
\ No newline at end of file
test.py
0 → 100644
浏览文件 @
02f9c820
from
numpy
import
*
import
operator
class
Test
:
def
__init__
(
self
,
entityList
,
entityVectorList
,
relationList
,
relationVectorList
,
tripleListTrain
,
tripleListTest
,
label
=
"head"
,
isFit
=
False
):
self
.
entityList
=
{}
self
.
relationList
=
{}
for
name
,
vec
in
zip
(
entityList
,
entityVectorList
):
self
.
entityList
[
name
]
=
vec
for
name
,
vec
in
zip
(
relationList
,
relationVectorList
):
self
.
relationList
[
name
]
=
vec
self
.
tripleListTrain
=
tripleListTrain
self
.
tripleListTest
=
tripleListTest
self
.
rank
=
[]
self
.
label
=
label
self
.
isFit
=
isFit
def
writeRank
(
self
,
dir
):
print
(
"写入"
)
file
=
open
(
dir
,
'w'
)
for
r
in
self
.
rank
:
file
.
write
(
str
(
r
[
0
])
+
"
\t
"
)
file
.
write
(
str
(
r
[
1
])
+
"
\t
"
)
file
.
write
(
str
(
r
[
2
])
+
"
\t
"
)
file
.
write
(
str
(
r
[
3
])
+
"
\n
"
)
file
.
close
()
def
getRank
(
self
):
cou
=
0
for
triplet
in
self
.
tripleListTest
:
rankList
=
{}
for
entityTemp
in
self
.
entityList
.
keys
():
if
self
.
label
==
"head"
:
corruptedTriplet
=
(
entityTemp
,
triplet
[
1
],
triplet
[
2
])
if
self
.
isFit
and
(
corruptedTriplet
in
self
.
tripleListTrain
):
continue
rankList
[
entityTemp
]
=
distance
(
self
.
entityList
[
entityTemp
],
self
.
entityList
[
triplet
[
1
]],
self
.
relationList
[
triplet
[
2
]])
else
:
#
corruptedTriplet
=
(
triplet
[
0
],
entityTemp
,
triplet
[
2
])
if
self
.
isFit
and
(
corruptedTriplet
in
self
.
tripleListTrain
):
continue
rankList
[
entityTemp
]
=
distance
(
self
.
entityList
[
triplet
[
0
]],
self
.
entityList
[
entityTemp
],
self
.
relationList
[
triplet
[
2
]])
nameRank
=
sorted
(
rankList
.
items
(),
key
=
operator
.
itemgetter
(
1
))
if
self
.
label
==
'head'
:
numTri
=
0
else
:
numTri
=
1
x
=
1
for
i
in
nameRank
:
if
i
[
0
]
==
triplet
[
numTri
]:
break
x
+=
1
self
.
rank
.
append
((
triplet
,
triplet
[
numTri
],
nameRank
[
0
][
0
],
x
))
print
(
x
)
cou
+=
1
if
cou
%
10000
==
0
:
print
(
cou
)
def
getRelationRank
(
self
):
cou
=
0
self
.
rank
=
[]
for
triplet
in
self
.
tripleListTest
:
rankList
=
{}
for
relationTemp
in
self
.
relationList
.
keys
():
corruptedTriplet
=
(
triplet
[
0
],
triplet
[
1
],
relationTemp
)
if
self
.
isFit
and
(
corruptedTriplet
in
self
.
tripleListTrain
):
continue
rankList
[
relationTemp
]
=
distance
(
self
.
entityList
[
triplet
[
0
]],
self
.
entityList
[
triplet
[
1
]],
self
.
relationList
[
relationTemp
])
nameRank
=
sorted
(
rankList
.
items
(),
key
=
operator
.
itemgetter
(
1
))
x
=
1
for
i
in
nameRank
:
if
i
[
0
]
==
triplet
[
2
]:
break
x
+=
1
self
.
rank
.
append
((
triplet
,
triplet
[
2
],
nameRank
[
0
][
0
],
x
))
print
(
x
)
cou
+=
1
if
cou
%
10000
==
0
:
print
(
cou
)
def
getMeanRank
(
self
):
num
=
0
for
r
in
self
.
rank
:
num
+=
r
[
3
]
return
num
/
len
(
self
.
rank
)
def
distance
(
h
,
t
,
r
):
h
=
array
(
h
)
t
=
array
(
t
)
r
=
array
(
r
)
s
=
h
+
r
-
t
return
linalg
.
norm
(
s
)
def
openD
(
dir
,
sp
=
"
\t
"
):
#triple = (head, tail, relation)
num
=
0
list
=
[]
with
open
(
dir
)
as
file
:
lines
=
file
.
readlines
()
for
line
in
lines
:
triple
=
line
.
strip
().
split
(
sp
)
if
(
len
(
triple
)
<
3
):
continue
list
.
append
(
tuple
(
triple
))
num
+=
1
print
(
num
)
return
num
,
list
def
loadData
(
str
):
fr
=
open
(
str
)
sArr
=
[
line
.
strip
().
split
(
"
\t
"
)
for
line
in
fr
.
readlines
()]
datArr
=
[[
float
(
s
)
for
s
in
line
[
1
][
1
:
-
1
].
split
(
", "
)]
for
line
in
sArr
]
nameArr
=
[
line
[
0
]
for
line
in
sArr
]
return
datArr
,
nameArr
if
__name__
==
'__main__'
:
dirTrain
=
"C:
\\
data
\\
train.txt"
tripleNumTrain
,
tripleListTrain
=
openD
(
dirTrain
)
dirTest
=
"C:
\\
data
\\
test.txt"
tripleNumTest
,
tripleListTest
=
openD
(
dirTest
)
dirEntityVector
=
"c:
\\
entityVector.txt"
entityVectorList
,
entityList
=
loadData
(
dirEntityVector
)
dirRelationVector
=
"c:
\\
relationVector.txt"
relationVectorList
,
relationList
=
loadData
(
dirRelationVector
)
print
(
"kaishitest"
)
testHeadRaw
=
Test
(
entityList
,
entityVectorList
,
relationList
,
relationVectorList
,
tripleListTrain
,
tripleListTest
)
testHeadRaw
.
getRank
()
print
(
testHeadRaw
.
getMeanRank
())
testHeadRaw
.
writeRank
(
"c:
\\
"
+
"testHeadRaw"
+
".txt"
)
testHeadRaw
.
getRelationRank
()
print
(
testHeadRaw
.
getMeanRank
())
testHeadRaw
.
writeRank
(
"c:
\\
"
+
"testRelationRaw"
+
".txt"
)
testTailRaw
=
Test
(
entityList
,
entityVectorList
,
relationList
,
relationVectorList
,
tripleListTrain
,
tripleListTest
,
label
=
"tail"
)
testTailRaw
.
getRank
()
print
(
testTailRaw
.
getMeanRank
())
testTailRaw
.
writeRank
(
"c:
\\
"
+
"testTailRaw"
+
".txt"
)
testHeadFit
=
Test
(
entityList
,
entityVectorList
,
relationList
,
relationVectorList
,
tripleListTrain
,
tripleListTest
,
isFit
=
True
)
testHeadFit
.
getRank
()
print
(
testHeadFit
.
getMeanRank
())
testHeadFit
.
writeRank
(
"c:
\\
"
+
"testHeadFit"
+
".txt"
)
testHeadFit
.
getRelationRank
()
print
(
testHeadFit
.
getMeanRank
())
testHeadFit
.
writeRank
(
"c:
\\
"
+
"testRelationFit"
+
".txt"
)
testTailFit
=
Test
(
entityList
,
entityVectorList
,
relationList
,
relationVectorList
,
tripleListTrain
,
tripleListTest
,
isFit
=
True
,
label
=
"tail"
)
testTailFit
.
getRank
()
print
(
testTailFit
.
getMeanRank
())
testTailFit
.
writeRank
(
"c:
\\
"
+
"testTailFit"
+
".txt"
)
\ No newline at end of file
tranE.py
0 → 100644
浏览文件 @
02f9c820
from
random
import
uniform
,
sample
from
numpy
import
*
class
TransE
:
def
__init__
(
self
,
entityList
,
relationList
,
tripleList
,
margin
=
1
,
learingRate
=
0.01
,
dim
=
10
,
lambd
=
0.4
):
self
.
margin
=
margin
self
.
learingRate
=
learingRate
self
.
dim
=
dim
#向量维度
self
.
entityList
=
entityList
#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量。
self
.
relationList
=
relationList
#理由同上
self
.
tripleList
=
tripleList
#理由同上
self
.
lambd
=
lambd
self
.
loss
=
0
def
initialize
(
self
):
'''
初始化向量
'''
entityVectorList
=
{}
relationVectorList
=
{}
for
entity
in
self
.
entityList
:
n
=
0
entityVector
=
[]
while
n
<
self
.
dim
:
ram
=
init
(
self
.
dim
)
#初始化的范围
entityVector
.
append
(
ram
)
n
+=
1
entityVector
=
norm
(
entityVector
)
#归一化
entityVectorList
[
entity
]
=
entityVector
print
(
"entityVector初始化完成,数量是%d"
%
len
(
entityVectorList
))
for
relation
in
self
.
relationList
:
n
=
0
relationVector
=
[]
while
n
<
self
.
dim
:
ram
=
init
(
self
.
dim
)
#初始化的范围
relationVector
.
append
(
ram
)
n
+=
1
relationVector
=
norm
(
relationVector
)
#归一化
relationVectorList
[
relation
]
=
relationVector
print
(
"relationVectorList初始化完成,数量是%d"
%
len
(
relationVectorList
))
self
.
entityList
=
entityVectorList
self
.
relationList
=
relationVectorList
def
transE
(
self
,
cI
=
20
):
print
(
"训练开始"
)
for
cycleIndex
in
range
(
cI
):
if
cycleIndex
%
10000
==
0
:
print
(
"第%d次循环"
%
cycleIndex
)
print
(
self
.
loss
)
self
.
loss
=
0
self
.
writeRelationVector
(
"c:
\\
relationVector.txt"
)
self
.
writeEntilyVector
(
"c:
\\
entityVector.txt"
)
Sbatch
=
self
.
getSample
()
Tbatch
=
[]
#元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
for
sbatch
in
Sbatch
:
tripletWithCorruptedTriplet
=
(
sbatch
,
self
.
getCorruptedTriplet
(
sbatch
))
if
(
tripletWithCorruptedTriplet
not
in
Tbatch
):
Tbatch
.
append
(
tripletWithCorruptedTriplet
)
self
.
update
(
Tbatch
)
def
getSample
(
self
,
size
=
500
):
return
sample
(
self
.
tripleList
,
size
)
def
getCorruptedTriplet
(
self
,
triplet
):
'''
training triplets with either the head or tail replaced by a random entity (but not both at the same time)
:param triplet:
:return corruptedTriplet:
'''
i
=
uniform
(
-
1
,
1
)
if
i
<
0
:
#小于0,打坏三元组的第一项
while
True
:
entityTemp
=
sample
(
self
.
entityList
.
keys
(),
1
)[
0
]
if
entityTemp
!=
triplet
[
0
]:
break
corruptedTriplet
=
(
entityTemp
,
triplet
[
1
],
triplet
[
2
])
else
:
#大于等于0,打坏三元组的第二项
while
True
:
entityTemp
=
sample
(
self
.
entityList
.
keys
(),
1
)[
0
]
if
entityTemp
!=
triplet
[
1
]:
break
corruptedTriplet
=
(
triplet
[
0
],
entityTemp
,
triplet
[
2
])
return
corruptedTriplet
def
update
(
self
,
Tbatch
):
i
=
0
while
i
<
len
(
Tbatch
):
tripletWithCorruptedTriplet
=
Tbatch
[
i
]
headEntityVector
=
array
(
self
.
entityList
[
tripletWithCorruptedTriplet
[
0
][
0
]])
#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVector
=
array
(
self
.
entityList
[
tripletWithCorruptedTriplet
[
0
][
1
]])
relationVector
=
array
(
self
.
relationList
[
tripletWithCorruptedTriplet
[
0
][
2
]])
headEntityVectorWithCorruptedTriplet
=
array
(
self
.
entityList
[
tripletWithCorruptedTriplet
[
1
][
0
]])
tailEntityVectorWithCorruptedTriplet
=
array
(
self
.
entityList
[
tripletWithCorruptedTriplet
[
1
][
1
]])
distTriplet
=
distance
(
headEntityVector
,
tailEntityVector
,
relationVector
)
distCorruptedTriplet
=
distance
(
headEntityVectorWithCorruptedTriplet
,
tailEntityVectorWithCorruptedTriplet
,
relationVector
)
eg
=
self
.
margin
+
distTriplet
-
distCorruptedTriplet
if
eg
>
0
:
#[function]+ 是一个取正值的函数
self
.
loss
+=
eg
tempPositive
=
2
*
self
.
learingRate
*
(
tailEntityVector
-
headEntityVector
-
relationVector
)
tempNegtative
=
2
*
self
.
learingRate
*
(
tailEntityVectorWithCorruptedTriplet
-
headEntityVectorWithCorruptedTriplet
-
relationVector
)
temp1
=
headEntityVector
+
tempPositive
temp2
=
tailEntityVector
-
tempPositive
temp3
=
relationVector
+
tempPositive
-
tempNegtative
temp4
=
headEntityVectorWithCorruptedTriplet
-
tempNegtative
temp5
=
tailEntityVectorWithCorruptedTriplet
+
tempNegtative
headEntityVector
=
temp1
tailEntityVector
=
temp2
relationVector
=
temp3
headEntityVectorWithCorruptedTriplet
=
temp4
tailEntityVectorWithCorruptedTriplet
=
temp5
#只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了
self
.
entityList
[
tripletWithCorruptedTriplet
[
0
][
0
]]
=
norm
(
headEntityVector
.
tolist
())
self
.
entityList
[
tripletWithCorruptedTriplet
[
0
][
1
]]
=
norm
(
tailEntityVector
.
tolist
())
self
.
relationList
[
tripletWithCorruptedTriplet
[
0
][
2
]]
=
norm
(
relationVector
.
tolist
())
self
.
entityList
[
tripletWithCorruptedTriplet
[
1
][
0
]]
=
norm
(
headEntityVectorWithCorruptedTriplet
.
tolist
())
self
.
entityList
[
tripletWithCorruptedTriplet
[
1
][
1
]]
=
norm
(
tailEntityVectorWithCorruptedTriplet
.
tolist
())
i
+=
1
def
writeEntilyVector
(
self
,
dir
):
print
(
"写入实体"
)
entityVectorFile
=
open
(
dir
,
'w'
)
for
entity
in
self
.
entityList
.
keys
():
entityVectorFile
.
write
(
entity
+
"
\t
"
)
entityVectorFile
.
write
(
str
(
self
.
entityList
[
entity
]))
entityVectorFile
.
write
(
"
\n
"
)
entityVectorFile
.
close
()
def
writeRelationVector
(
self
,
dir
):
print
(
"写入关系"
)
relationVectorFile
=
open
(
dir
,
'w'
)
for
relation
in
self
.
relationList
.
keys
():
relationVectorFile
.
write
(
relation
+
"
\t
"
)
relationVectorFile
.
write
(
str
(
self
.
relationList
[
relation
]))
relationVectorFile
.
write
(
"
\n
"
)
relationVectorFile
.
close
()
def
init
(
dim
):
return
uniform
(
-
6
/
(
dim
**
0.5
),
6
/
(
dim
**
0.5
))
def
distance
(
h
,
t
,
r
):
s
=
h
+
r
-
t
narray
=
array
(
s
)
narray2
=
narray
*
narray
sum
=
narray2
.
sum
()
return
sum
def
norm
(
list
):
'''
归一化
:param 向量:
:return: 向量的平方和的开方后的向量
'''
var
=
linalg
.
norm
(
list
)
i
=
0
while
i
<
len
(
list
):
list
[
i
]
=
list
[
i
]
/
var
i
+=
1
return
list
def
openDetailsAndId
(
dir
,
sp
=
"
\t
"
):
idNum
=
0
list
=
[]
with
open
(
dir
)
as
file
:
lines
=
file
.
readlines
()
for
line
in
lines
:
DetailsAndId
=
line
.
strip
().
split
(
sp
)
list
.
append
(
DetailsAndId
[
0
])
idNum
+=
1
return
idNum
,
list
def
openTrain
(
dir
,
sp
=
"
\t
"
):
num
=
0
list
=
[]
with
open
(
dir
)
as
file
:
lines
=
file
.
readlines
()
for
line
in
lines
:
triple
=
line
.
strip
().
split
(
sp
)
if
(
len
(
triple
)
<
3
):
continue
list
.
append
(
tuple
(
triple
))
num
+=
1
return
num
,
list
if
__name__
==
'__main__'
:
dirEntity
=
"C:
\\
data
\\
entity2id.txt"
entityIdNum
,
entityList
=
openDetailsAndId
(
dirEntity
)
dirRelation
=
"C:
\\
data
\\
relation2id.txt"
relationIdNum
,
relationList
=
openDetailsAndId
(
dirRelation
)
dirTrain
=
"C:
\\
data
\\
train.txt"
tripleNum
,
tripleList
=
openTrain
(
dirTrain
)
print
(
"打开TransE"
)
transE
=
TransE
(
entityList
,
relationList
,
tripleList
,
dim
=
30
)
print
(
"TranE初始化"
)
transE
.
initialize
()
transE
.
transE
(
300000
)
transE
.
writeRelationVector
(
"c:
\\
relationVector.txt"
)
transE
.
writeEntilyVector
(
"c:
\\
entityVector.txt"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录