preprocess.py 3.7 KB
Newer Older
Y
yinhaofeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yinhaofeng 已提交
14 15 16 17
#encoding=utf-8

import os
import sys
Y
yinhaofeng 已提交
18
import jieba
Y
yinhaofeng 已提交
19 20 21
import numpy as np
import random

Y
yinhaofeng 已提交
22
f = open("./raw_data.txt", "r")
Y
yinhaofeng 已提交
23 24 25 26 27 28 29
lines = f.readlines()
f.close()

#建立字典
word_dict = {}
for line in lines:
    line = line.strip().split("\t")
Y
yinhaofeng 已提交
30
    text = line[0].strip("") + " " + line[1].strip("")
Y
yinhaofeng 已提交
31
    text = jieba.cut(text)
Y
yinhaofeng 已提交
32 33
    for word in text:
        if word in word_dict:
Y
change  
yinhaofeng 已提交
34
            continue
Y
yinhaofeng 已提交
35
        else:
Y
change  
yinhaofeng 已提交
36
            word_dict[word] = len(word_dict) + 1
Y
yinhaofeng 已提交
37

Y
yinhaofeng 已提交
38
f = open("./raw_data.txt", "r")
Y
yinhaofeng 已提交
39 40 41 42 43 44 45
lines = f.readlines()
f.close()

lines = [line.strip().split("\t") for line in lines]

#建立以query为key,以负例为value的字典
neg_dict = {}
Y
change  
yinhaofeng 已提交
46
for line in lines:
Y
yinhaofeng 已提交
47 48 49 50 51 52 53 54
    if line[2] == "0":
        if line[0] in neg_dict:
            neg_dict[line[0]].append(line[1])
        else:
            neg_dict[line[0]] = [line[1]]

#建立以query为key,以正例为value的字典
pos_dict = {}
Y
change  
yinhaofeng 已提交
55
for line in lines:
Y
yinhaofeng 已提交
56 57 58 59 60 61
    if line[2] == "1":
        if line[0] in pos_dict:
            pos_dict[line[0]].append(line[1])
        else:
            pos_dict[line[0]] = [line[1]]

Y
yinhaofeng 已提交
62
print("build dict done")
Y
change  
yinhaofeng 已提交
63 64
#划分训练集和测试集
query_list = list(pos_dict.keys())
Y
yinhaofeng 已提交
65
#print(len(query_list))
Y
change  
yinhaofeng 已提交
66
random.shuffle(query_list)
Y
yinhaofeng 已提交
67 68
train_query = query_list[:11600]
test_query = query_list[11600:]
Y
change  
yinhaofeng 已提交
69 70 71 72

#获得训练集
train_set = []
for query in train_query:
Y
yinhaofeng 已提交
73 74 75 76
    for pos in pos_dict[query]:
        if query not in neg_dict:
            continue
        for neg in neg_dict[query]:
Y
change  
yinhaofeng 已提交
77 78
            train_set.append([query, pos, neg])
random.shuffle(train_set)
Y
yinhaofeng 已提交
79
print("get train_set done")
Y
yinhaofeng 已提交
80

Y
change  
yinhaofeng 已提交
81 82 83 84 85 86 87 88
#获得测试集
test_set = []
for query in test_query:
    for pos in pos_dict[query]:
        test_set.append([query, pos, 1])
    if query not in neg_dict:
        continue
    for neg in neg_dict[query]:
Y
change  
yinhaofeng 已提交
89
        test_set.append([query, neg, 0])
Y
change  
yinhaofeng 已提交
90
random.shuffle(test_set)
Y
yinhaofeng 已提交
91
print("get test_set done")
Y
yinhaofeng 已提交
92 93 94

#训练集中的query,pos,neg转化为词袋
f = open("train.txt", "w")
Y
change  
yinhaofeng 已提交
95
for line in train_set:
Y
yinhaofeng 已提交
96 97 98
    query = jieba.cut(line[0].strip())
    pos = jieba.cut(line[1].strip())
    neg = jieba.cut(line[2].strip())
Y
yinhaofeng 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    query_token = [0] * (len(word_dict) + 1)
    for word in query:
        query_token[word_dict[word]] = 1
    pos_token = [0] * (len(word_dict) + 1)
    for word in pos:
        pos_token[word_dict[word]] = 1
    neg_token = [0] * (len(word_dict) + 1)
    for word in neg:
        neg_token[word_dict[word]] = 1
    f.write(','.join([str(x) for x in query_token]) + "\t" + ','.join([
        str(x) for x in pos_token
    ]) + "\t" + ','.join([str(x) for x in neg_token]) + "\n")
f.close()

#测试集中的query和pos转化为词袋
f = open("test.txt", "w")
fa = open("label.txt", "w")
for line in test_set:
Y
yinhaofeng 已提交
117 118
    query = jieba.cut(line[0].strip())
    pos = jieba.cut(line[1].strip())
Y
yinhaofeng 已提交
119 120 121 122 123 124 125 126 127
    label = line[2]
    query_token = [0] * (len(word_dict) + 1)
    for word in query:
        query_token[word_dict[word]] = 1
    pos_token = [0] * (len(word_dict) + 1)
    for word in pos:
        pos_token[word_dict[word]] = 1
    f.write(','.join([str(x) for x in query_token]) + "\t" + ','.join(
        [str(x) for x in pos_token]) + "\n")
Y
change  
yinhaofeng 已提交
128
    fa.write(str(label) + "\n")
Y
yinhaofeng 已提交
129 130
f.close()
fa.close()