diff --git a/models/match/dssm/data/preprocess.py b/models/match/dssm/data/preprocess.py index 5985cb901579286a3c4279cd4fb70be540ef4900..ad7ad4b634c2b9ac0c3a8130aba579761f41f675 100644 --- a/models/match/dssm/data/preprocess.py +++ b/models/match/dssm/data/preprocess.py @@ -16,14 +16,9 @@ for line in lines: text = line[0].split(" ") + line[1].split(" ") for word in text: if word in word_dict: - word_dict[word] = word_dict[word] + 1 + continue else: - word_dict[word] = 1 - -word_list = word_dict.items() -word_list = sorted(word_dict.items(), key=lambda item: item[1], reverse=True) -word_list_ids = range(1, len(word_list) + 1) -word_dict = dict(zip([x[0] for x in word_list], word_list_ids)) + word_dict[word] = len(word_dict) + 1 f = open("./zhidao", "r") lines = f.readlines() @@ -74,12 +69,11 @@ for query in test_query: if query not in neg_dict: continue for neg in neg_dict[query]: - test_set.append([query, pos, 0]) + test_set.append([query, neg, 0]) random.shuffle(test_set) #训练集中的query,pos,neg转化为词袋 f = open("train.txt", "w") -f = open("train.txt", "w") for line in train_set: query = line[0].strip().split(" ") pos = line[1].strip().split(" ") diff --git a/models/match/readme.md b/models/match/readme.md index 9e6a5748c0e8e3111ae9f9531625a53dd58c1da6..440ad9796605aefe94c2f822daebb73062bfed24 100755 --- a/models/match/readme.md +++ b/models/match/readme.md @@ -51,8 +51,8 @@ python -m paddlerec.run -m models/contentunderstanding/match-pyramid/config.yaml ### 模型效果 (测试) -| 数据集 | 模型 | auc | map | +| 数据集 | 模型 | 正逆序比 | map | | :------------------: | :--------------------: | :---------: |:---------: | -| zhidao | DSSM | 0.55 | -- | +| zhidao | DSSM | 2.25 | -- | | Letor07 | match-pyramid | -- | 0.42 | -| zhidao | multiview-simnet | 0.59 | -- | +| zhidao | multiview-simnet | 1.72 | -- |