提交 4adeed38 编写于 作者: W wuxiyu

add L1 distance and fix some bug(althought not convergent)

上级 a404b3ed
from random import uniform, sample
from numpy import *
from copy import deepcopy
class TransE:
def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.01, dim = 10):
self.margin=margin
def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):
self.margin = margin
self.learingRate = learingRate
self.dim = dim#向量维度
self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量。
self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)
self.relationList = relationList#理由同上
self.tripleList = tripleList#理由同上
self.loss = 0
self.L1 = L1
def initialize(self):
'''
......@@ -43,21 +45,21 @@ class TransE:
def transE(self, cI = 20):
print("训练开始")
for cycleIndex in range(cI):
if cycleIndex%10000==0:
print("第%d次循环"%cycleIndex)
print(self.loss)
self.loss = 0
self.writeRelationVector("c:\\relationVector.txt")
self.writeEntilyVector("c:\\entityVector.txt")
Sbatch = self.getSample()
Sbatch = self.getSample(150)
Tbatch = []#元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}
for sbatch in Sbatch:
tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))
if(tripletWithCorruptedTriplet not in Tbatch):
Tbatch.append(tripletWithCorruptedTriplet)
self.update(Tbatch)
if cycleIndex % 100 == 0:
print("第%d次循环"%cycleIndex)
print(self.loss)
self.writeRelationVector("c:\\relationVector.txt")
self.writeEntilyVector("c:\\entityVector.txt")
self.loss = 0
def getSample(self, size = 500):
def getSample(self, size):
return sample(self.tripleList, size)
def getCorruptedTriplet(self, triplet):
......@@ -72,59 +74,84 @@ class TransE:
entityTemp = sample(self.entityList.keys(), 1)[0]
if entityTemp != triplet[0]:
break
corruptedTriplet = (entityTemp,triplet[1], triplet[2])
corruptedTriplet = (entityTemp, triplet[1], triplet[2])
else:#大于等于0,打坏三元组的第二项
while True:
entityTemp = sample(self.entityList.keys(), 1)[0]
if entityTemp != triplet[1]:
break
corruptedTriplet = (triplet[0],entityTemp, triplet[2])
corruptedTriplet = (triplet[0], entityTemp, triplet[2])
return corruptedTriplet
def update(self, Tbatch):
i = 0
while i < len(Tbatch):
tripletWithCorruptedTriplet = Tbatch[i]
headEntityVector = array(self.entityList[tripletWithCorruptedTriplet[0][0]])#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVector = array(self.entityList[tripletWithCorruptedTriplet[0][1]])
relationVector = array(self.relationList[tripletWithCorruptedTriplet[0][2]])
headEntityVectorWithCorruptedTriplet = array(self.entityList[tripletWithCorruptedTriplet[1][0]])
tailEntityVectorWithCorruptedTriplet = array(self.entityList[tripletWithCorruptedTriplet[1][1]])
distTriplet = distance(headEntityVector,tailEntityVector , relationVector)
distCorruptedTriplet = distance(headEntityVectorWithCorruptedTriplet,tailEntityVectorWithCorruptedTriplet , relationVector)
copyEntityList = deepcopy(self.entityList)
copyRelationList = deepcopy(self.relationList)
for tripletWithCorruptedTriplet in Tbatch:
headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]
relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]
headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple
tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]
relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]
headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]
tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]
if self.L1:
distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)
else:
distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch)
distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch)
eg = self.margin + distTriplet - distCorruptedTriplet
if eg > 0: #[function]+ 是一个取正值的函数
self.loss += eg
tempPositive = 2 * self.learingRate * (tailEntityVector - headEntityVector - relationVector)
tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTriplet - headEntityVectorWithCorruptedTriplet - relationVector)
temp1 = headEntityVector + tempPositive
temp2 = tailEntityVector - tempPositive
temp3 = relationVector + tempPositive - tempNegtative
temp4 = headEntityVectorWithCorruptedTriplet - tempNegtative
temp5 = tailEntityVectorWithCorruptedTriplet + tempNegtative
headEntityVector = temp1
tailEntityVector = temp2
relationVector = temp3
headEntityVectorWithCorruptedTriplet = temp4
tailEntityVectorWithCorruptedTriplet = temp5
if self.L1:
tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
tempPositiveL1 = []
tempNegtativeL1 = []
for i in range(self.dim):#不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)?
if tempPositive[i] >= 0:
tempPositiveL1.append(1)
else:
tempPositiveL1.append(-1)
if tempNegtative[i] >= 0:
tempNegtativeL1.append(1)
else:
tempNegtativeL1.append(-1)
tempPositive = array(tempPositiveL1)
tempNegtative = array(tempNegtativeL1)
else:
tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)
tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)
headEntityVector = headEntityVector + tempPositive
tailEntityVector = tailEntityVector - tempPositive
relationVector = relationVector + tempPositive - tempNegtative
headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative
tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative
#只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了
self.entityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector.tolist())
self.entityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector.tolist())
self.relationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector.tolist())
self.entityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet.tolist())
self.entityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet.tolist())
i += 1
copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)
copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)
copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)
copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)
copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)
self.entityList = copyEntityList
self.relationList = copyRelationList
def writeEntilyVector(self, dir):
print("写入实体")
entityVectorFile = open(dir, 'w')
for entity in self.entityList.keys():
entityVectorFile.write(entity+"\t")
entityVectorFile.write(str(self.entityList[entity]))
entityVectorFile.write(str(self.entityList[entity].tolist()))
entityVectorFile.write("\n")
entityVectorFile.close()
......@@ -133,24 +160,27 @@ class TransE:
relationVectorFile = open(dir, 'w')
for relation in self.relationList.keys():
relationVectorFile.write(relation + "\t")
relationVectorFile.write(str(self.relationList[relation]))
relationVectorFile.write(str(self.relationList[relation].tolist()))
relationVectorFile.write("\n")
relationVectorFile.close()
def init(dim):
return uniform(-6/(dim**0.5),6/(dim**0.5))
return uniform(-6/(dim**0.5), 6/(dim**0.5))
def distance(h, t, r):
def distanceL1(h, t ,r):
s = h + r - t
narray = array(s)
narray2 = narray*narray
sum = narray2.sum()
sum = fabs(s).sum()
return sum
def distanceL2(h, t, r):
s = h + r - t
sum = (s*s).sum()
return sum
def norm(list):
'''
归一化
:param 向量:
:param 向量
:return: 向量的平方和的开方后的向量
'''
var = linalg.norm(list)
......@@ -158,7 +188,7 @@ def norm(list):
while i < len(list):
list[i] = list[i]/var
i += 1
return list
return array(list)
def openDetailsAndId(dir,sp="\t"):
idNum = 0
......@@ -192,9 +222,10 @@ if __name__ == '__main__':
dirTrain = "C:\\data\\train.txt"
tripleNum, tripleList = openTrain(dirTrain)
print("打开TransE")
transE = TransE(entityList,relationList,tripleList,dim = 30)
transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100)
print("TranE初始化")
transE.initialize()
transE.transE(300000)
transE.transE(15000)
transE.writeRelationVector("c:\\relationVector.txt")
transE.writeEntilyVector("c:\\entityVector.txt")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册