dataprovider_bow.py 3.7 KB
Newer Older
1
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Z
zhangjinchao01 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *

# id of the word not in dictionary
UNK_IDX = 0

20

Z
zhangjinchao01 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33
# initializer is called by the framework during initialization.
# It allows the user to describe the data types and setup the
# necessary data structure for later use.
# `settings` is an object. initializer need to properly fill settings.input_types.
# initializer can also store other data structures needed to be used at process().
# In this example, dictionary is stored in settings.
# `dictionay` and `kwargs` are arguments passed from trainer_config.lr.py
def initializer(settings, dictionary, **kwargs):
    # Put the word dictionary into settings
    settings.word_dict = dictionary

    # setting.input_types specifies what the data types the data provider
    # generates.
Y
Yu Yang 已提交
34
    settings.input_types = {
Z
zhangjinchao01 已提交
35 36 37
        # The first input is a sparse_binary_vector,
        # which means each dimension of the vector is either 0 or 1. It is the
        # bag-of-words (BOW) representation of the texts.
Y
Yu Yang 已提交
38
        'word': sparse_binary_vector(len(dictionary)),
Z
zhangjinchao01 已提交
39 40 41
        # The second input is an integer. It represents the category id of the
        # sample. 2 means there are two labels in the dataset.
        # (1 for positive and 0 for negative)
Y
Yu Yang 已提交
42 43
        'label': integer_value(2)
    }
44

Z
zhangjinchao01 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69

# Delaring a data provider. It has an initializer 'data_initialzer'.
# It will cache the generated data of the first pass in memory, so that
# during later pass, no on-the-fly data generation will be needed.
# `setting` is the same object used by initializer()
# `file_name` is the name of a file listed train_list or test_list file given
# to define_py_data_sources2(). See trainer_config.lr.py.
@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, file_name):
    # Open the input data file.
    with open(file_name, 'r') as f:
        # Read each line.
        for line in f:
            # Each line contains the label and text of the comment, separated by \t.
            label, comment = line.strip().split('\t')

            # Split the words into a list.
            words = comment.split()

            # convert the words into a list of ids by looking them up in word_dict.
            word_vector = [settings.word_dict.get(w, UNK_IDX) for w in words]

            # Return the features for the current comment. The first is a list
            # of ids representing a 0-1 binary sparse vector of the text,
            # the second is the integer id of the label.
Y
Yu Yang 已提交
70
            yield {'word': word_vector, 'label': int(label)}
Z
zhangjinchao01 已提交
71 72 73 74


def predict_initializer(settings, dictionary, **kwargs):
    settings.word_dict = dictionary
Y
Yu Yang 已提交
75
    settings.input_types = {'word': sparse_binary_vector(len(dictionary))}
76

Z
zhangjinchao01 已提交
77 78 79

# Declaring a data provider for prediction. The difference with process
# is that label is not generated.
D
dangqingqing 已提交
80
@provider(init_hook=predict_initializer, should_shuffle=False)
Z
zhangjinchao01 已提交
81 82 83
def process_predict(settings, file_name):
    with open(file_name, 'r') as f:
        for line in f:
84
            comment = line.strip().split()
Z
zhangjinchao01 已提交
85
            word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment]
Y
Yu Yang 已提交
86
            yield {'word': word_vector}