provider.py 2.2 KB
Newer Older
1
import io, os
D
dangqingqing 已提交
2 3 4 5 6
import random
import numpy as np
import six.moves.cPickle as pickle
from paddle.trainer.PyDataProvider2 import *

7

D
dangqingqing 已提交
8 9 10
def remove_unk(x, n_words):
    return [[1 if w >= n_words else w for w in sen] for sen in x]

11

D
dangqingqing 已提交
12 13 14 15 16
# ==============================================================
#  tensorflow uses fixed length, but PaddlePaddle can process
#  variable-length. Padding is used in benchmark in order to
#  compare with other platform. 
# ==============================================================
17 18 19 20 21 22
def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='post',
                  truncating='post',
                  value=0.):
D
dangqingqing 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
    lengths = [len(s) for s in sequences]

    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    x = (np.ones((nb_samples, maxlen)) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if len(s) == 0:
            continue  # empty list was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError("Truncating type '%s' not understood" % padding)

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError("Padding type '%s' not understood" % padding)
    return x


def initHook(settings, vocab_size, pad_seq, maxlen, **kwargs):
    settings.vocab_size = vocab_size
    settings.pad_seq = pad_seq
52
    settings.maxlen = maxlen
D
dangqingqing 已提交
53
    settings.input_types = [
54 55 56
        integer_value_sequence(vocab_size), integer_value(2)
    ]

D
dangqingqing 已提交
57

58 59
@provider(
    init_hook=initHook, min_pool_size=-1, cache=CacheType.CACHE_PASS_IN_MEM)
D
dangqingqing 已提交
60 61 62 63 64 65 66 67
def process(settings, file):
    f = open(file, 'rb')
    train_set = pickle.load(f)
    f.close()
    x, y = train_set

    # remove unk, namely remove the words out of dictionary
    x = remove_unk(x, settings.vocab_size)
68
    if settings.pad_seq:
D
dangqingqing 已提交
69 70 71
        x = pad_sequences(x, maxlen=settings.maxlen, value=0.)

    for i in range(len(y)):
72
        yield map(int, x[i]), int(y[i])