provider.py 2.1 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
import io,os
import random
import numpy as np
import six.moves.cPickle as pickle
from paddle.trainer.PyDataProvider2 import *

def remove_unk(x, n_words):
    return [[1 if w >= n_words else w for w in sen] for sen in x]

# ==============================================================
#  tensorflow uses fixed length, but PaddlePaddle can process
#  variable-length. Padding is used in benchmark in order to
#  compare with other platform. 
# ==============================================================
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='post',
                  truncating='post', value=0.):
    lengths = [len(s) for s in sequences]

    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    x = (np.ones((nb_samples, maxlen)) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if len(s) == 0:
            continue  # empty list was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError("Truncating type '%s' not understood" % padding)

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError("Padding type '%s' not understood" % padding)
    return x


def initHook(settings, vocab_size, pad_seq, maxlen, **kwargs):
    settings.vocab_size = vocab_size
    settings.pad_seq = pad_seq
    settings.maxlen = maxlen 
    settings.input_types = [
        integer_value_sequence(vocab_size),
        integer_value(2)]

@provider(init_hook=initHook, min_pool_size=-1, cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, file):
    f = open(file, 'rb')
    train_set = pickle.load(f)
    f.close()
    x, y = train_set

    # remove unk, namely remove the words out of dictionary
    x = remove_unk(x, settings.vocab_size)
    if settings.pad_seq: 
        x = pad_sequences(x, maxlen=settings.maxlen, value=0.)

    for i in range(len(y)):
        yield map(int,x[i]), int(y[i])