preprocess.py 3.0 KB
Newer Older
M
malin10 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:20:12 2015

@author: Balázs Hidasi
"""

import numpy as np
import pandas as pd
import datetime as dt
import time

PATH_TO_ORIGINAL_DATA = './'
PATH_TO_PROCESSED_DATA = './'

data = pd.read_csv(
    PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat',
    sep=',',
    header=0,
    usecols=[0, 1, 2],
    dtype={0: np.int32,
           1: str,
           2: np.int64})
M
malin10 已提交
24 25 26
data.columns = ['session_id', 'timestamp', 'item_id']
data['Time'] = data.timestamp.apply(lambda x: time.mktime(dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timetuple())) #This is not UTC. It does not really matter.
del (data['timestamp'])
M
malin10 已提交
27

M
malin10 已提交
28 29
session_lengths = data.groupby('session_id').size()
data = data[np.in1d(data.session_id, session_lengths[session_lengths > 1]
M
malin10 已提交
30
                    .index)]
M
malin10 已提交
31 32 33 34
item_supports = data.groupby('item_id').size()
data = data[np.in1d(data.item_id, item_supports[item_supports >= 5].index)]
session_lengths = data.groupby('session_id').size()
data = data[np.in1d(data.session_id, session_lengths[session_lengths >= 2]
M
malin10 已提交
35 36 37
                    .index)]

tmax = data.Time.max()
M
malin10 已提交
38
session_max_times = data.groupby('session_id').Time.max()
M
malin10 已提交
39 40
session_train = session_max_times[session_max_times < tmax - 86400].index
session_test = session_max_times[session_max_times >= tmax - 86400].index
M
malin10 已提交
41 42 43 44 45
train = data[np.in1d(data.session_id, session_train)]
test = data[np.in1d(data.session_id, session_test)]
test = test[np.in1d(test.item_id, train.item_id)]
tslength = test.groupby('session_id').size()
test = test[np.in1d(test.session_id, tslength[tslength >= 2].index)]
M
malin10 已提交
46
print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
M
malin10 已提交
47
    len(train), train.session_id.nunique(), train.item_id.nunique()))
M
malin10 已提交
48 49 50
train.to_csv(
    PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False)
print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
M
malin10 已提交
51
    len(test), test.session_id.nunique(), test.item_id.nunique()))
M
malin10 已提交
52 53 54
test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False)

tmax = train.Time.max()
M
malin10 已提交
55
session_max_times = train.groupby('session_id').Time.max()
M
malin10 已提交
56 57
session_train = session_max_times[session_max_times < tmax - 86400].index
session_valid = session_max_times[session_max_times >= tmax - 86400].index
M
malin10 已提交
58 59 60 61 62
train_tr = train[np.in1d(train.session_id, session_train)]
valid = train[np.in1d(train.session_id, session_valid)]
valid = valid[np.in1d(valid.item_id, train_tr.item_id)]
tslength = valid.groupby('session_id').size()
valid = valid[np.in1d(valid.session_id, tslength[tslength >= 2].index)]
M
malin10 已提交
63
print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
M
malin10 已提交
64
    len(train_tr), train_tr.session_id.nunique(), train_tr.item_id.nunique()))
M
malin10 已提交
65 66 67
train_tr.to_csv(
    PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False)
print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
M
malin10 已提交
68
    len(valid), valid.session_id.nunique(), valid.item_id.nunique()))
M
malin10 已提交
69 70
valid.to_csv(
    PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False)