“dd86b40058af85d887842ca79cea370fd55a34b1”上不存在“git@gitcode.net:RobotFutures/Paddle.git”
elastic_ctr.py 3.7 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangguibao 已提交
15
from __future__ import print_function
W
wangguibao 已提交
16 17 18 19
import json
import sys
import os

W
wangguibao 已提交
20 21
from elastic_ctr_api import ElasticCTRAPI

W
wangguibao 已提交
22 23 24
BATCH_SIZE = 3
SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf"
W
wangguibao 已提交
25
CTR_EMBEDDING_TABLE_SIZE = 100000001
W
wangguibao 已提交
26 27 28
SLOTS = []


W
wangguibao 已提交
29 30 31 32 33 34 35
def str2long(str):
    if sys.version_info[0] == 2:
        return long(str)
    elif sys.version_info[0] == 3:
        return int(str)


W
wangguibao 已提交
36 37
def data_reader(data_file, samples, labels):
    if not os.path.exists(data_file):
W
wangguibao 已提交
38
        print("Path %s not exist" % data_file)
W
wangguibao 已提交
39 40 41 42 43 44 45
        return -1

    with open(data_file, "r") as f:
        for line in f:
            sample = {}
            line = line.rstrip('\n')
            feature_slots = line.split(' ')
W
wangguibao 已提交
46
            labels.append(int(feature_slots[1]))
W
wangguibao 已提交
47 48 49 50 51 52 53 54
            feature_slots = feature_slots[2:]
            feature_slot_maps = [x.split(':') for x in feature_slots]

            features = [x[0] for x in feature_slot_maps]
            slots = [x[1] for x in feature_slot_maps]

            for i in range(0, len(features)):
                if slots[i] in sample:
W
wangguibao 已提交
55 56 57 58
                    sample[slots[i]] = [
                        sample[slots[i]] + str2long(features[i]) %
                        CTR_EMBEDDING_TABLE_SIZE
                    ]
W
wangguibao 已提交
59
                else:
W
wangguibao 已提交
60 61 62
                    sample[slots[i]] = [
                        str2long(features[i]) % CTR_EMBEDDING_TABLE_SIZE
                    ]
W
wangguibao 已提交
63 64 65 66 67 68 69 70 71 72

            for x in SLOTS:
                if not x in sample:
                    sample[x] = [0]
            samples.append(sample)


if __name__ == "__main__":
    """ main
    """
W
wangguibao 已提交
73
    if len(sys.argv) != 5:
W
wangguibao 已提交
74 75 76
        print(
            "Usage: python elastic_ctr.py SERVING_IP SERVING_PORT SLOT_CONF_FILE DATA_FILE"
        )
W
wangguibao 已提交
77 78 79 80 81 82
        sys.exit(-1)

    samples = []
    labels = []

    SERVING_IP = sys.argv[1]
W
wangguibao 已提交
83 84
    SERVING_PORT = sys.argv[2]
    SLOT_CONF_FILE = sys.argv[3]
W
wangguibao 已提交
85

W
wangguibao 已提交
86 87
    api = ElasticCTRAPI(SERVING_IP, SERVING_PORT)
    ret = api.read_slots_conf(SLOT_CONF_FILE)
W
wangguibao 已提交
88 89 90
    if ret != 0:
        sys.exit(-1)

W
wangguibao 已提交
91
    ret = data_reader(sys.argv[4], samples, labels)
W
wangguibao 已提交
92

W
wangguibao 已提交
93
    correct = 0
W
wangguibao 已提交
94
    for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
W
wangguibao 已提交
95
        api.clear()
W
wangguibao 已提交
96 97 98
        batch = samples[i:i + BATCH_SIZE]
        instances = []
        for sample in batch:
W
wangguibao 已提交
99
            instance = api.add_instance()
W
wangguibao 已提交
100 101 102 103 104 105
            if sys.version_info[0] == 2:
                for k, v in sample.iteritems():
                    api.add_slot(instance, k, v)
            elif sys.version_info[0] == 3:
                for k, v in sample.items():
                    api.add_slot(instance, k, v)
W
wangguibao 已提交
106 107

        ret = api.inference()
W
wangguibao 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        ret = json.loads(ret)
        predictions = ret["predictions"]

        idx = 0
        for x in predictions:
            if x["prob0"] >= x["prob1"]:
                pred = 0
            else:
                pred = 1

            if labels[i + idx] == pred:
                correct += 1
            else:
                print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
                      (i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))

            idx = idx + 1

    print("Acc=%f" % (float(correct) / len(samples)))