dataprovider.py 2.2 KB
Newer Older
Z
zhangruiqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *
import collections
import logging
import pdb

logging.basicConfig(
    format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s', )
logger = logging.getLogger('paddle')
logger.setLevel(logging.INFO)

Z
zhangruiqing01 已提交
25
N = 5  # Ngram
Z
zhangruiqing01 已提交
26
cutoff = 50  # select words with frequency > cutoff to dictionary
Z
zhangruiqing01 已提交
27

Z
zhangruiqing01 已提交
28

Z
zhangruiqing01 已提交
29 30 31 32 33 34 35 36
def build_dict(ftrain, fdict):
    sentences = []
    with open(ftrain) as fin:
        for line in fin:
            line = ['<s>'] + line.strip().split() + ['<e>']
            sentences += line
    wordfreq = collections.Counter(sentences)
    wordfreq = filter(lambda x: x[1] > cutoff, wordfreq.items())
Z
zhangruiqing01 已提交
37
    dictionary = sorted(wordfreq, key=lambda x: (-x[1], x[0]))
Z
zhangruiqing01 已提交
38 39 40 41
    words, _ = list(zip(*dictionary))
    for word in words:
        print >> fdict, word
    word_idx = dict(zip(words, xrange(len(words))))
Z
zhangruiqing01 已提交
42
    logger.info("Dictionary size=%s" % len(words))
Z
zhangruiqing01 已提交
43 44
    return word_idx

Z
zhangruiqing01 已提交
45

Z
zhangruiqing01 已提交
46 47 48 49 50 51 52 53
def initializer(settings, srcText, dictfile, **xargs):
    with open(dictfile, 'w') as fdict:
        settings.dicts = build_dict(srcText, fdict)
    input_types = []
    for i in xrange(N):
        input_types.append(integer_value(len(settings.dicts)))
    settings.input_types = input_types

Z
zhangruiqing01 已提交
54

Z
zhangruiqing01 已提交
55 56 57 58 59
@provider(init_hook=initializer)
def process(settings, filename):
    UNKID = settings.dicts['<unk>']
    with open(filename) as fin:
        for line in fin:
Z
zhangruiqing01 已提交
60
            line = ['<s>'] * (N - 1) + line.strip().split() + ['<e>']
Z
zhangruiqing01 已提交
61 62
            line = [settings.dicts.get(w, UNKID) for w in line]
            for i in range(N, len(line) + 1):
Z
zhangruiqing01 已提交
63
                yield line[i - N:i]