get_train_data.py 2.1 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
import scipy.sparse as sp
import numpy as np
from time import time
import args

def get_train_data(filename, write_file, num_negatives):
        '''
        Read .rating file and Return dok matrix.
        The first line of .rating file is: num_users\t num_items
        '''
        # Get number of users and items
        num_users, num_items = 0, 0
        with open(filename, "r") as f:
            line = f.readline()
            while line != None and line != "":
                arr = line.split("\t")
                u, i = int(arr[0]), int(arr[1])
                num_users = max(num_users, u)
                num_items = max(num_items, i)
                line = f.readline()
        # Construct matrix
        mat = sp.dok_matrix((num_users+1, num_items+1), dtype=np.float32)
        with open(filename, "r") as f:
            line = f.readline()
            while line != None and line != "":
                arr = line.split("\t")
                user, item, rating = int(arr[0]), int(arr[1]), float(arr[2])
                if (rating > 0):
                    mat[user, item] = 1.0
                line = f.readline()    

        file = open(write_file, 'w') 
        print("writing " + write_file)
        
O
overlordmax 已提交
35
        for (u, i) in mat.keys():
O
overlordmax 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
            # positive instance
            user_input = str(u)
            item_input = str(i)
            label = str(1)
            sample =  "{0},{1},{2}".format(user_input, item_input,label) + "\n"
            file.write(sample)
            # negative instances
            for t in range(num_negatives):
                j = np.random.randint(num_items)
                while (u, j) in mat.keys():
                    j = np.random.randint(num_items)
                user_input = str(u)
                item_input = str(j)
                label = str(0)
                sample =  "{0},{1},{2}".format(user_input, item_input,label) + "\n"
                file.write(sample)
                
if __name__ == "__main__":
    args = args.parse_args()
    get_train_data(args.path + args.dataset + ".train.rating", args.train_data_path, args.num_neg)