evaluate.py 3.2 KB
Newer Older
O
overlordmax 已提交
1
import math
Z
zhang wenhui 已提交
2
import heapq  # for retrieval topK
O
overlordmax 已提交
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import multiprocessing
import numpy as np
from time import time
import paddle.fluid as fluid
import os
from gmf import GMF
from mlp import MLP
from neumf import NeuMF
from Dataset import Dataset
import logging
import paddle
import args
import utils
import time

# Global variables that are shared across processes
_model = None
_testRatings = None
_testNegatives = None
_K = None
_args = None
_model_path = None

Z
zhang wenhui 已提交
26

O
overlordmax 已提交
27 28
def run_infer(args, model_path, test_data_path):
    test_data_generator = utils.CriteoDataset()
Z
zhang wenhui 已提交
29

O
overlordmax 已提交
30
    with fluid.scope_guard(fluid.Scope()):
Z
zhang wenhui 已提交
31 32 33 34
        test_reader = fluid.io.batch(
            test_data_generator.test(test_data_path, False),
            batch_size=args.test_batch_size)

O
overlordmax 已提交
35 36 37 38
        place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())

Z
zhang wenhui 已提交
39 40
        infer_program, feed_target_names, fetch_vars = fluid.io.load_inference_model(
            model_path, exe)
O
overlordmax 已提交
41 42 43 44 45

        for data in test_reader():
            user_input = np.array([dat[0] for dat in data])
            item_input = np.array([dat[1] for dat in data])

Z
zhang wenhui 已提交
46 47 48 49 50 51 52
            pred_val = exe.run(
                infer_program,
                feed={"user_input": user_input,
                      "item_input": item_input},
                fetch_list=fetch_vars,
                return_numpy=True)

O
overlordmax 已提交
53 54
            return pred_val[0].reshape(1, -1).tolist()[0]

Z
zhang wenhui 已提交
55

O
overlordmax 已提交
56 57 58 59 60 61 62 63 64
def evaluate_model(args, testRatings, testNegatives, K, model_path):
    """
    Evaluate the performance (Hit_Ratio, NDCG) of top-K recommendation
    Return: score of each test rating.
    """
    global _model
    global _testRatings
    global _testNegatives
    global _K
Z
zhang wenhui 已提交
65
    global _model_path
O
overlordmax 已提交
66
    global _args
Z
zhang wenhui 已提交
67

O
overlordmax 已提交
68
    _args = args
Z
zhang wenhui 已提交
69
    _model_path = model_path
O
overlordmax 已提交
70 71 72
    _testRatings = testRatings
    _testNegatives = testNegatives
    _K = K
Z
zhang wenhui 已提交
73 74

    hits, ndcgs = [], []
O
overlordmax 已提交
75
    for idx in range(len(_testRatings)):
Z
zhang wenhui 已提交
76
        (hr, ndcg) = eval_one_rating(idx)
O
overlordmax 已提交
77
        hits.append(hr)
Z
zhang wenhui 已提交
78
        ndcgs.append(ndcg)
O
overlordmax 已提交
79 80
    return (hits, ndcgs)

Z
zhang wenhui 已提交
81

O
overlordmax 已提交
82 83 84 85 86 87 88 89
def eval_one_rating(idx):
    rating = _testRatings[idx]
    items = _testNegatives[idx]
    u = rating[0]
    gtItem = rating[1]
    items.append(gtItem)
    # Get prediction scores
    map_item_score = {}
Z
zhang wenhui 已提交
90 91 92
    users = np.full(len(items), u, dtype='int32')
    users = users.reshape(-1, 1)
    items_array = np.array(items).reshape(-1, 1)
O
overlordmax 已提交
93 94 95 96 97 98 99 100
    temp = np.hstack((users, items_array))
    np.savetxt("Data/test.txt", temp, fmt='%d', delimiter=',')
    predictions = run_infer(_args, _model_path, _args.test_data_path)

    for i in range(len(items)):
        item = items[i]
        map_item_score[item] = predictions[i]
    items.pop()
Z
zhang wenhui 已提交
101

O
overlordmax 已提交
102 103 104 105 106 107 108
    # Evaluate top rank list
    ranklist = heapq.nlargest(_K, map_item_score, key=map_item_score.get)
    hr = getHitRatio(ranklist, gtItem)
    ndcg = getNDCG(ranklist, gtItem)

    return (hr, ndcg)

Z
zhang wenhui 已提交
109

O
overlordmax 已提交
110 111 112 113 114 115
def getHitRatio(ranklist, gtItem):
    for item in ranklist:
        if item == gtItem:
            return 1
    return 0

Z
zhang wenhui 已提交
116

O
overlordmax 已提交
117 118 119 120
def getNDCG(ranklist, gtItem):
    for i in range(len(ranklist)):
        item = ranklist[i]
        if item == gtItem:
Z
zhang wenhui 已提交
121
            return math.log(2) / math.log(i + 2)
O
overlordmax 已提交
122
    return 0