retrieve.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
# -*- coding: utf-8 -*-
# ==============================================================================
# Copyright 2019 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import pickle
import argparse
import os
import nltk
import logging
import string
from tqdm import tqdm
from nltk.corpus import wordnet as wn

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_token', type=str, default='../tokenization_record/tokens/train.tokenization.uncased.data', help='token file of train set')
    parser.add_argument('--eval_token', type=str, default='../tokenization_record/tokens/dev.tokenization.uncased.data', help='token file of dev set')
    parser.add_argument('--output_dir', type=str, default='output_record/', help='output directory')
    parser.add_argument('--no_stopwords', action='store_true', help='ignore stopwords')
    parser.add_argument('--ignore_length', type=int, default=0, help='ignore words with length <= ignore_length')
    args = parser.parse_args()

    # initialize mapping from offset id to wn18 synset name
    offset_to_wn18name_dict = {} 
    fin = open('wordnet-mlj12-definitions.txt')
    for line in fin:
        info = line.strip().split('\t')
        offset_str, synset_name = info[0], info[1]
        offset_to_wn18name_dict[offset_str] = synset_name
    logger.info('Finished loading wn18 definition file.')
        

    # load pickled samples
    logger.info('Begin to load tokenization results...')
    train_samples = pickle.load(open(args.train_token, 'rb'))
    dev_samples = pickle.load(open(args.eval_token, 'rb'))
    logger.info('Finished loading tokenization results.')
    
    # build token set
    all_token_set = set()
    for sample in train_samples + dev_samples:
        for token in sample['query_tokens'] + sample['document_tokens']:
            all_token_set.add(token)
    logger.info('Finished making tokenization results into token set.')

    # load stopwords
    stopwords = set(nltk.corpus.stopwords.words('english'))
    logger.info('Finished loading stopwords list.')

    # retrive synsets
    logger.info('Begin to retrieve synsets...')
    token2synset = dict()
    stopword_cnt = 0
    punctuation_cnt = 0
    for token in tqdm(all_token_set):
        if token in set(string.punctuation):
            logger.info('{} is punctuation, skipped!'.format(token))
            punctuation_cnt += 1
            continue        
        if args.no_stopwords and token in stopwords:
            logger.info('{} is stopword, skipped!'.format(token))
            stopword_cnt += 1
            continue
        if args.ignore_length > 0 and len(token) <= args.ignore_length:
            logger.info('{} is too short, skipped!'.format(token))
            continue
        synsets = wn.synsets(token)
        wn18synset_names = []
        for synset in synsets:
            offset_str = str(synset.offset()).zfill(8)
            if offset_str in offset_to_wn18name_dict:
                wn18synset_names.append(offset_to_wn18name_dict[offset_str])
        if len(wn18synset_names) > 0:
            token2synset[token] = wn18synset_names
    logger.info('Finished retrieving sysnets.')
    logger.info('{} / {} tokens retrieved at lease 1 synset. {} stopwords and {} punctuations skipped.'.format(len(token2synset), len(all_token_set), stopword_cnt, punctuation_cnt))
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, 'retrived_synsets.data'), 'wb') as fout:
        pickle.dump(token2synset, fout)    
    logger.info('Finished dumping retrieved synsets.')

if __name__ == '__main__':
    main()