preprocess.py 3.6 KB
Newer Older
Y
yinhaofeng 已提交
1
#encoding=utf-8
Y
yinhaofeng 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yinhaofeng 已提交
15 16 17

import os
import sys
Y
yinhaofeng 已提交
18
import jieba
Y
yinhaofeng 已提交
19 20 21
import numpy as np
import random

Y
yinhaofeng 已提交
22
f = open("./raw_data.txt", "r")
Y
yinhaofeng 已提交
23 24 25 26 27 28 29
lines = f.readlines()
f.close()

#建立字典
word_dict = {}
for line in lines:
    line = line.strip().split("\t")
Y
yinhaofeng 已提交
30 31
    text = line[0].strip("") + line[1].strip("")
    text = jieba.cut(text)
Y
yinhaofeng 已提交
32 33
    for word in text:
        if word in word_dict:
Y
change  
yinhaofeng 已提交
34
            continue
Y
yinhaofeng 已提交
35
        else:
Y
change  
yinhaofeng 已提交
36
            word_dict[word] = len(word_dict) + 1
Y
yinhaofeng 已提交
37

Y
yinhaofeng 已提交
38
f = open("./raw_data.txt", "r")
Y
yinhaofeng 已提交
39 40 41 42 43 44 45
lines = f.readlines()
f.close()

lines = [line.strip().split("\t") for line in lines]

#建立以query为key,以负例为value的字典
neg_dict = {}
Y
change  
yinhaofeng 已提交
46
for line in lines:
Y
yinhaofeng 已提交
47 48 49 50 51 52 53 54
    if line[2] == "0":
        if line[0] in neg_dict:
            neg_dict[line[0]].append(line[1])
        else:
            neg_dict[line[0]] = [line[1]]

#建立以query为key,以正例为value的字典
pos_dict = {}
Y
change  
yinhaofeng 已提交
55
for line in lines:
Y
yinhaofeng 已提交
56 57 58 59 60 61
    if line[2] == "1":
        if line[0] in pos_dict:
            pos_dict[line[0]].append(line[1])
        else:
            pos_dict[line[0]] = [line[1]]

Y
change  
yinhaofeng 已提交
62 63
#划分训练集和测试集
query_list = list(pos_dict.keys())
Y
yinhaofeng 已提交
64
print(len(query_list))
Y
change  
yinhaofeng 已提交
65
random.shuffle(query_list)
Y
yinhaofeng 已提交
66 67
train_query = query_list[:11600]
test_query = query_list[11600:]
Y
change  
yinhaofeng 已提交
68 69 70 71

#获得训练集
train_set = []
for query in train_query:
Y
yinhaofeng 已提交
72 73 74 75
    for pos in pos_dict[query]:
        if query not in neg_dict:
            continue
        for neg in neg_dict[query]:
Y
change  
yinhaofeng 已提交
76 77
            train_set.append([query, pos, neg])
random.shuffle(train_set)
Y
yinhaofeng 已提交
78

Y
change  
yinhaofeng 已提交
79 80 81 82 83 84 85 86 87 88
#获得测试集
test_set = []
for query in test_query:
    for pos in pos_dict[query]:
        test_set.append([query, pos, 1])
    if query not in neg_dict:
        continue
    for neg in neg_dict[query]:
        test_set.append([query, neg, 0])
random.shuffle(test_set)
Y
yinhaofeng 已提交
89 90 91

#训练集中的query,pos,neg转化格式
f = open("train.txt", "w")
Y
change  
yinhaofeng 已提交
92
for line in train_set:
Y
yinhaofeng 已提交
93 94 95
    query = jieba.cut(line[0].strip())
    pos = jieba.cut(line[1].strip())
    neg = jieba.cut(line[2].strip())
Y
yinhaofeng 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    query_list = []
    for word in query:
        query_list.append(word_dict[word])
    pos_list = []
    for word in pos:
        pos_list.append(word_dict[word])
    neg_list = []
    for word in neg:
        neg_list.append(word_dict[word])
    f.write(' '.join(["0:" + str(x) for x in query_list]) + " " + ' '.join([
        "1:" + str(x) for x in pos_list
    ]) + " " + ' '.join(["2:" + str(x) for x in neg_list]) + "\n")
f.close()

#测试集中的query和pos转化格式
f = open("test.txt", "w")
fa = open("label.txt", "w")
Y
change  
yinhaofeng 已提交
113
fb = open("testquery.txt", "w")
Y
yinhaofeng 已提交
114
for line in test_set:
Y
yinhaofeng 已提交
115 116
    query = jieba.cut(line[0].strip())
    pos = jieba.cut(line[1].strip())
Y
yinhaofeng 已提交
117 118 119 120
    label = line[2]
    query_list = []
    for word in query:
        query_list.append(word_dict[word])
Y
change  
yinhaofeng 已提交
121
    pos_list = []
Y
yinhaofeng 已提交
122 123 124 125
    for word in pos:
        pos_list.append(word_dict[word])
    f.write(' '.join(["0:" + str(x) for x in query_list]) + " " + ' '.join(
        ["1:" + str(x) for x in pos_list]) + "\n")
Y
change  
yinhaofeng 已提交
126 127
    fa.write(str(label) + "\n")
    fb.write(','.join([str(x) for x in query_list]) + "\n")
Y
yinhaofeng 已提交
128 129
f.close()
fa.close()
Y
change  
yinhaofeng 已提交
130
fb.close()