preprocess.py 6.5 KB
Newer Older
H
hupeng03 已提交
1 2
# -*- coding: UTF-8 -*-

Z
zhangjinchao01 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
H
hupeng03 已提交
16
"""
D
dangqingqing 已提交
17
1. Tokenize the words and punctuation 
Z
zhangjinchao01 已提交
18 19 20 21
2. pos sample : rating score 5; neg sample: rating score 1-2.

Usage:
    python preprocess.py -i data_file [random seed]
H
hupeng03 已提交
22
"""
Z
zhangjinchao01 已提交
23

H
hupeng03 已提交
24 25
import sys
import os
Z
zhangjinchao01 已提交
26
import operator
H
hupeng03 已提交
27
import gzip
Z
zhangjinchao01 已提交
28 29
from subprocess import Popen, PIPE
from optparse import OptionParser
H
hupeng03 已提交
30 31 32 33 34 35 36
import json
from multiprocessing import Queue
from multiprocessing import Pool
import multiprocessing

batch_size = 5000
word_count = {}
37 38
num_tokenize = max(1,
                   multiprocessing.cpu_count() - 2)  # parse + tokenize + save
H
hupeng03 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
max_queue_size = 8
parse_queue = Queue(maxsize=max_queue_size + num_tokenize)
tokenize_queue = Queue(maxsize=max_queue_size + num_tokenize)


def create_dict(data):
    """
    Create dictionary based on data, and saved in data_dir/dict.txt.
    The first line is unk \t -1.
    data: list, input data by batch.
    """
    for seq in data:
        try:
            for w in seq.lower().split():
                if w not in word_count:
                    word_count[w] = 1
                else:
                    word_count[w] += 1
        except:
            sys.stderr.write(seq + "\tERROR\n")

Z
zhangjinchao01 已提交
60 61 62 63 64

def parse(path):
    """
    Open .gz file.
    """
H
hupeng03 已提交
65
    sys.stderr.write(path)
Z
zhangjinchao01 已提交
66 67
    g = gzip.open(path, 'r')
    for l in g:
H
hupeng03 已提交
68 69
        yield json.loads(l)
    g.close()
Z
zhangjinchao01 已提交
70 71 72 73 74 75 76 77 78


def tokenize(sentences):
    """
    Use tokenizer.perl to tokenize input sentences.
    tokenizer.perl is tool of Moses.
    sentences : a list of input sentences.
    return: a list of processed text.
    """
D
dangqingqing 已提交
79 80 81 82 83
    dir = './mosesdecoder-master/scripts/tokenizer/tokenizer.perl'
    if not os.path.exists(dir):
        sys.exit(
            "The ./mosesdecoder-master/scripts/tokenizer/tokenizer.perl does not exists."
        )
Z
zhangjinchao01 已提交
84 85 86 87 88 89 90 91
    tokenizer_cmd = [dir, '-l', 'en', '-q', '-']
    assert isinstance(sentences, list)
    text = "\n".join(sentences)
    tokenizer = Popen(tokenizer_cmd, stdin=PIPE, stdout=PIPE)
    tok_text, _ = tokenizer.communicate(text)
    toks = tok_text.split('\n')[:-1]
    return toks

H
hupeng03 已提交
92 93

def save_data(instance, data_dir, pre_fix, batch_num):
Z
zhangjinchao01 已提交
94
    """
H
hupeng03 已提交
95
    save data by batch
Z
zhangjinchao01 已提交
96
    """
H
hupeng03 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110
    label = ['1' if pre_fix == 'pos' else '0' for i in range(len(instance))]
    lines = ['%s\t%s' % (label[i], instance[i]) for i in range(len(label))]
    file_name = os.path.join(data_dir, "%s_%s.txt" % (pre_fix, batch_num))
    file(file_name, 'w').write('\n'.join(lines) + '\n')


def tokenize_batch(id):
    """
    tokenize data by batch
    """
    while True:
        num_batch, instance, pre_fix = parse_queue.get()
        if num_batch == -1:  ### parse_queue finished
            tokenize_queue.put((-1, None, None))
D
dangqingqing 已提交
111
            sys.stderr.write("Thread %s finish\n" % (id))
H
hupeng03 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            break
        tokenize_instance = tokenize(instance)
        tokenize_queue.put((num_batch, tokenize_instance, pre_fix))
        sys.stderr.write('.')


def save_batch(data_dir, num_tokenize, data_dir_dict):
    """
        save data by batch
        build dict.txt
    """
    token_count = 0
    while True:
        num_batch, instance, pre_fix = tokenize_queue.get()
        if num_batch == -1:
            token_count += 1
            if token_count == num_tokenize:  #### tokenize finished.
                break
            else:
                continue
        save_data(instance, data_dir, pre_fix, num_batch)
        create_dict(instance)  ## update dict

    sys.stderr.write("save file finish\n")
    f = open(data_dir_dict, 'w')
Z
zhangjinchao01 已提交
137
    f.write('%s\t%s\n' % ('unk', '-1'))
H
hupeng03 已提交
138 139
    for k, v in sorted(word_count.items(), key=operator.itemgetter(1), \
                       reverse=True):
Z
zhangjinchao01 已提交
140 141
        f.write('%s\t%s\n' % (k, v))
    f.close()
H
hupeng03 已提交
142
    sys.stderr.write("build dict finish\n")
Z
zhangjinchao01 已提交
143 144


H
hupeng03 已提交
145
def parse_batch(data, num_tokenize):
Z
zhangjinchao01 已提交
146
    """
H
hupeng03 已提交
147
    parse data by batch
D
dangqingqing 已提交
148
    parse -> tokenize -> save
Z
zhangjinchao01 已提交
149
    """
H
hupeng03 已提交
150 151
    raw_txt = parse(data)
    neg, pos = [], []
Z
zhangjinchao01 已提交
152
    count = 0
H
hupeng03 已提交
153
    sys.stderr.write("extract raw data\n")
Z
zhangjinchao01 已提交
154 155
    for l in raw_txt:
        rating = l["overall"]
H
hupeng03 已提交
156
        text = l["reviewText"].lower()  # # convert words to lower case
Z
zhangjinchao01 已提交
157 158 159 160
        if rating == 5.0 and text:
            pos.append(text)
        if rating < 3.0 and text:
            neg.append(text)
H
hupeng03 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
        if len(pos) == batch_size or len(neg) == batch_size:
            if len(pos) == batch_size:
                batch = pos
                pre_fix = 'pos'
            else:
                batch = neg
                pre_fix = 'neg'

            parse_queue.put((count, batch, pre_fix))
            count += 1
            if pre_fix == 'pos':
                pos = []
            else:
                neg = []

    if len(pos) > 0:
        parse_queue.put((count, pos, 'pos'))
        count += 1
    if len(neg) > 0:
        parse_queue.put((count, neg, 'neg'))
Z
zhangjinchao01 已提交
181
        count += 1
H
hupeng03 已提交
182 183 184 185
    for i in range(num_tokenize):
        parse_queue.put((-1, None, None))  #### for tokenize's input finished
    sys.stderr.write("parsing finish\n")

Z
zhangjinchao01 已提交
186 187 188 189

def option_parser():
    parser = OptionParser(usage="usage: python preprcoess.py "\
                                "-i data_path [options]")
H
hupeng03 已提交
190 191 192 193 194 195 196 197 198
    parser.add_option(
        "-i", "--data", action="store", dest="input", help="Input data path.")
    parser.add_option(
        "-s",
        "--seed",
        action="store",
        dest="seed",
        default=1024,
        help="Set random seed.")
Z
zhangjinchao01 已提交
199 200
    return parser.parse_args()

H
hupeng03 已提交
201

Z
zhangjinchao01 已提交
202 203 204 205
def main():
    reload(sys)
    sys.setdefaultencoding('utf-8')
    options, args = option_parser()
H
hupeng03 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    data = options.input
    seed = options.seed
    data_dir_dict = os.path.join(os.path.dirname(data), 'dict.txt')
    data_dir = os.path.join(os.path.dirname(data), 'tmp')
    pool = Pool(processes=num_tokenize + 2)
    pool.apply_async(parse_batch, args=(data, num_tokenize))
    for i in range(num_tokenize):
        pool.apply_async(tokenize_batch, args=(str(i), ))
    pool.apply_async(save_batch, args=(data_dir, num_tokenize, data_dir_dict))
    pool.close()
    pool.join()

    file(os.path.join(os.path.dirname(data), 'labels.list'),
         'w').write('neg\t0\npos\t1\n')

Z
zhangjinchao01 已提交
221 222 223

if __name__ == '__main__':
    main()