# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from paddle.trainer.PyDataProvider2 import * # id of the word not in dictionary UNK_IDX = 0 # initializer is called by the framework during initialization. # It allows the user to describe the data types and setup the # necessary data structure for later use. # `settings` is an object. initializer need to properly fill settings.input_types. # initializer can also store other data structures needed to be used at process(). # In this example, dictionary is stored in settings. # `dictionay` and `kwargs` are arguments passed from trainer_config.lr.py def initializer(settings, dictionary, **kwargs): # Put the word dictionary into settings settings.word_dict = dictionary # setting.input_types specifies what the data types the data provider # generates. settings.input_types = [ # The first input is a sparse_binary_vector, # which means each dimension of the vector is either 0 or 1. It is the # bag-of-words (BOW) representation of the texts. sparse_binary_vector(len(dictionary)), # The second input is an integer. It represents the category id of the # sample. 2 means there are two labels in the dataset. # (1 for positive and 0 for negative) integer_value(2) ] # Delaring a data provider. It has an initializer 'data_initialzer'. # It will cache the generated data of the first pass in memory, so that # during later pass, no on-the-fly data generation will be needed. # `setting` is the same object used by initializer() # `file_name` is the name of a file listed train_list or test_list file given # to define_py_data_sources2(). See trainer_config.lr.py. @provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) def process(settings, file_name): # Open the input data file. with open(file_name, 'r') as f: # Read each line. for line in f: # Each line contains the label and text of the comment, separated by \t. label, comment = line.strip().split('\t') # Split the words into a list. words = comment.split() # convert the words into a list of ids by looking them up in word_dict. word_vector = [settings.word_dict.get(w, UNK_IDX) for w in words] # Return the features for the current comment. The first is a list # of ids representing a 0-1 binary sparse vector of the text, # the second is the integer id of the label. yield word_vector, int(label) def predict_initializer(settings, dictionary, **kwargs): settings.word_dict = dictionary settings.input_types = [sparse_binary_vector(len(dictionary))] # Declaring a data provider for prediction. The difference with process # is that label is not generated. @provider(init_hook=predict_initializer, should_shuffle=False) def process_predict(settings, file_name): with open(file_name, 'r') as f: for line in f: comment = line.strip().split() word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment] yield word_vector