build_dataset.py 3.0 KB
Newer Older
Y
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import random
import pickle

random.seed(1234)

print("read and process data")

with open('./raw_data/remap.pkl', 'rb') as f:
    reviews_df = pickle.load(f)
    cate_list = pickle.load(f)
    user_count, item_count, cate_count, example_count = pickle.load(f)

train_set = []
test_set = []
for reviewerID, hist in reviews_df.groupby('reviewerID'):
    pos_list = hist['asin'].tolist()

    def gen_neg():
        neg = pos_list[0]
        while neg in pos_list:
            neg = random.randint(0, item_count - 1)
        return neg

    neg_list = [gen_neg() for i in range(len(pos_list))]

    for i in range(1, len(pos_list)):
        hist = pos_list[:i]
        if i != len(pos_list) - 1:
            train_set.append((reviewerID, hist, pos_list[i], 1))
            train_set.append((reviewerID, hist, neg_list[i], 0))
        else:
            label = (pos_list[i], neg_list[i])
            test_set.append((reviewerID, hist, label))

random.shuffle(train_set)
random.shuffle(test_set)

assert len(test_set) == user_count


def print_to_file(data, fout):
    for i in range(len(data)):
        fout.write(str(data[i]))
        if i != len(data) - 1:
            fout.write(' ')
        else:
            fout.write(';')


print("make train data")
with open("paddle_train.txt", "w") as fout:
    for line in train_set:
        history = line[1]
        target = line[2]
        label = line[3]
        cate = [cate_list[x] for x in history]
        print_to_file(history, fout)
        print_to_file(cate, fout)
        fout.write(str(target) + ";")
        fout.write(str(cate_list[target]) + ";")
        fout.write(str(label) + "\n")

print("make test data")
with open("paddle_test.txt", "w") as fout:
    for line in test_set:
        history = line[1]
        target = line[2]
        cate = [cate_list[x] for x in history]

        print_to_file(history, fout)
        print_to_file(cate, fout)
        fout.write(str(target[0]) + ";")
        fout.write(str(cate_list[target[0]]) + ";")
        fout.write("1\n")

        print_to_file(history, fout)
        print_to_file(cate, fout)
        fout.write(str(target[1]) + ";")
        fout.write(str(cate_list[target[1]]) + ";")
        fout.write("0\n")

print("make config data")
with open('config.txt', 'w') as f:
    f.write(str(user_count) + "\n")
    f.write(str(item_count) + "\n")
    f.write(str(cate_count) + "\n")