diff --git a/tranE.py b/tranE.py index f76a4deafaa215f942ef5175d29e109ee9aeda59..c3c703c5f64e58e5819880f88221dea73fcf012d 100644 --- a/tranE.py +++ b/tranE.py @@ -1,15 +1,17 @@ from random import uniform, sample from numpy import * +from copy import deepcopy class TransE: - def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.01, dim = 10): - self.margin=margin + def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True): + self.margin = margin self.learingRate = learingRate self.dim = dim#向量维度 - self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量。 + self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)。 self.relationList = relationList#理由同上 self.tripleList = tripleList#理由同上 self.loss = 0 + self.L1 = L1 def initialize(self): ''' @@ -43,21 +45,21 @@ class TransE: def transE(self, cI = 20): print("训练开始") for cycleIndex in range(cI): - if cycleIndex%10000==0: - print("第%d次循环"%cycleIndex) - print(self.loss) - self.loss = 0 - self.writeRelationVector("c:\\relationVector.txt") - self.writeEntilyVector("c:\\entityVector.txt") - Sbatch = self.getSample() + Sbatch = self.getSample(150) Tbatch = []#元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))} for sbatch in Sbatch: tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch)) if(tripletWithCorruptedTriplet not in Tbatch): Tbatch.append(tripletWithCorruptedTriplet) self.update(Tbatch) + if cycleIndex % 100 == 0: + print("第%d次循环"%cycleIndex) + print(self.loss) + self.writeRelationVector("c:\\relationVector.txt") + self.writeEntilyVector("c:\\entityVector.txt") + self.loss = 0 - def getSample(self, size = 500): + def getSample(self, size): return sample(self.tripleList, size) def getCorruptedTriplet(self, triplet): @@ -72,59 +74,84 @@ class TransE: entityTemp = sample(self.entityList.keys(), 1)[0] if entityTemp != triplet[0]: break - corruptedTriplet = (entityTemp,triplet[1], triplet[2]) + corruptedTriplet = (entityTemp, triplet[1], triplet[2]) else:#大于等于0,打坏三元组的第二项 while True: entityTemp = sample(self.entityList.keys(), 1)[0] if entityTemp != triplet[1]: break - corruptedTriplet = (triplet[0],entityTemp, triplet[2]) + corruptedTriplet = (triplet[0], entityTemp, triplet[2]) return corruptedTriplet def update(self, Tbatch): - i = 0 - while i < len(Tbatch): - tripletWithCorruptedTriplet = Tbatch[i] - headEntityVector = array(self.entityList[tripletWithCorruptedTriplet[0][0]])#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple - tailEntityVector = array(self.entityList[tripletWithCorruptedTriplet[0][1]]) - relationVector = array(self.relationList[tripletWithCorruptedTriplet[0][2]]) - headEntityVectorWithCorruptedTriplet = array(self.entityList[tripletWithCorruptedTriplet[1][0]]) - tailEntityVectorWithCorruptedTriplet = array(self.entityList[tripletWithCorruptedTriplet[1][1]]) - - distTriplet = distance(headEntityVector,tailEntityVector , relationVector) - distCorruptedTriplet = distance(headEntityVectorWithCorruptedTriplet,tailEntityVectorWithCorruptedTriplet , relationVector) + copyEntityList = deepcopy(self.entityList) + copyRelationList = deepcopy(self.relationList) + + for tripletWithCorruptedTriplet in Tbatch: + headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple + tailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]] + relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]] + headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]] + tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]] + + headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]#tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tuple + tailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]] + relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]] + headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]] + tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]] + + if self.L1: + distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch) + distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch) + else: + distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch, relationVectorBeforeBatch) + distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch, tailEntityVectorWithCorruptedTripletBeforeBatch , relationVectorBeforeBatch) eg = self.margin + distTriplet - distCorruptedTriplet if eg > 0: #[function]+ 是一个取正值的函数 self.loss += eg - tempPositive = 2 * self.learingRate * (tailEntityVector - headEntityVector - relationVector) - tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTriplet - headEntityVectorWithCorruptedTriplet - relationVector) - - temp1 = headEntityVector + tempPositive - temp2 = tailEntityVector - tempPositive - temp3 = relationVector + tempPositive - tempNegtative - temp4 = headEntityVectorWithCorruptedTriplet - tempNegtative - temp5 = tailEntityVectorWithCorruptedTriplet + tempNegtative - - headEntityVector = temp1 - tailEntityVector = temp2 - relationVector = temp3 - headEntityVectorWithCorruptedTriplet = temp4 - tailEntityVectorWithCorruptedTriplet = temp5 + if self.L1: + tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch) + tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch) + tempPositiveL1 = [] + tempNegtativeL1 = [] + for i in range(self.dim):#不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)? + if tempPositive[i] >= 0: + tempPositiveL1.append(1) + else: + tempPositiveL1.append(-1) + if tempNegtative[i] >= 0: + tempNegtativeL1.append(1) + else: + tempNegtativeL1.append(-1) + tempPositive = array(tempPositiveL1) + tempNegtative = array(tempNegtativeL1) + + else: + tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch) + tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch) + + headEntityVector = headEntityVector + tempPositive + tailEntityVector = tailEntityVector - tempPositive + relationVector = relationVector + tempPositive - tempNegtative + headEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtative + tailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative #只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了 - self.entityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector.tolist()) - self.entityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector.tolist()) - self.relationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector.tolist()) - self.entityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet.tolist()) - self.entityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet.tolist()) - i += 1 - + copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector) + copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector) + copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector) + copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet) + copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet) + + self.entityList = copyEntityList + self.relationList = copyRelationList + def writeEntilyVector(self, dir): print("写入实体") entityVectorFile = open(dir, 'w') for entity in self.entityList.keys(): entityVectorFile.write(entity+"\t") - entityVectorFile.write(str(self.entityList[entity])) + entityVectorFile.write(str(self.entityList[entity].tolist())) entityVectorFile.write("\n") entityVectorFile.close() @@ -133,24 +160,27 @@ class TransE: relationVectorFile = open(dir, 'w') for relation in self.relationList.keys(): relationVectorFile.write(relation + "\t") - relationVectorFile.write(str(self.relationList[relation])) + relationVectorFile.write(str(self.relationList[relation].tolist())) relationVectorFile.write("\n") relationVectorFile.close() def init(dim): - return uniform(-6/(dim**0.5),6/(dim**0.5)) + return uniform(-6/(dim**0.5), 6/(dim**0.5)) -def distance(h, t, r): +def distanceL1(h, t ,r): s = h + r - t - narray = array(s) - narray2 = narray*narray - sum = narray2.sum() + sum = fabs(s).sum() return sum +def distanceL2(h, t, r): + s = h + r - t + sum = (s*s).sum() + return sum + def norm(list): ''' 归一化 - :param 向量: + :param 向量 :return: 向量的平方和的开方后的向量 ''' var = linalg.norm(list) @@ -158,7 +188,7 @@ def norm(list): while i < len(list): list[i] = list[i]/var i += 1 - return list + return array(list) def openDetailsAndId(dir,sp="\t"): idNum = 0 @@ -192,9 +222,10 @@ if __name__ == '__main__': dirTrain = "C:\\data\\train.txt" tripleNum, tripleList = openTrain(dirTrain) print("打开TransE") - transE = TransE(entityList,relationList,tripleList,dim = 30) + transE = TransE(entityList,relationList,tripleList, margin=1, dim = 100) print("TranE初始化") transE.initialize() - transE.transE(300000) + transE.transE(15000) transE.writeRelationVector("c:\\relationVector.txt") transE.writeEntilyVector("c:\\entityVector.txt") +