elastic_ctr.py 3.1 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wangguibao 已提交
15
from __future__ import print_function
W
wangguibao 已提交
16 17 18 19
import json
import sys
import os

W
wangguibao 已提交
20 21
from elastic_ctr_api import ElasticCTRAPI

W
wangguibao 已提交
22 23 24 25 26 27 28
BATCH_SIZE = 3
SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf"
CTR_EMBEDDING_TABLE_SIZE = 400000001
SLOTS = []


W
wangguibao 已提交
29 30 31 32 33 34 35
def str2long(str):
    if sys.version_info[0] == 2:
        return long(str)
    elif sys.version_info[0] == 3:
        return int(str)


W
wangguibao 已提交
36 37
def data_reader(data_file, samples, labels):
    if not os.path.exists(data_file):
W
wangguibao 已提交
38
        print("Path %s not exist" % data_file)
W
wangguibao 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
        return -1

    with open(data_file, "r") as f:
        for line in f:
            sample = {}
            line = line.rstrip('\n')
            feature_slots = line.split(' ')
            feature_slots = feature_slots[2:]
            feature_slot_maps = [x.split(':') for x in feature_slots]

            features = [x[0] for x in feature_slot_maps]
            slots = [x[1] for x in feature_slot_maps]

            for i in range(0, len(features)):
                if slots[i] in sample:
W
wangguibao 已提交
54 55 56 57
                    sample[slots[i]] = [
                        sample[slots[i]] + str2long(features[i]) %
                        CTR_EMBEDDING_TABLE_SIZE
                    ]
W
wangguibao 已提交
58
                else:
W
wangguibao 已提交
59 60 61
                    sample[slots[i]] = [
                        str2long(features[i]) % CTR_EMBEDDING_TABLE_SIZE
                    ]
W
wangguibao 已提交
62 63 64 65 66 67 68 69 70 71

            for x in SLOTS:
                if not x in sample:
                    sample[x] = [0]
            samples.append(sample)


if __name__ == "__main__":
    """ main
    """
W
wangguibao 已提交
72
    if len(sys.argv) != 5:
W
wangguibao 已提交
73 74 75
        print(
            "Usage: python elastic_ctr.py SERVING_IP SERVING_PORT SLOT_CONF_FILE DATA_FILE"
        )
W
wangguibao 已提交
76 77 78 79 80 81
        sys.exit(-1)

    samples = []
    labels = []

    SERVING_IP = sys.argv[1]
W
wangguibao 已提交
82 83
    SERVING_PORT = sys.argv[2]
    SLOT_CONF_FILE = sys.argv[3]
W
wangguibao 已提交
84

W
wangguibao 已提交
85 86
    api = ElasticCTRAPI(SERVING_IP, SERVING_PORT)
    ret = api.read_slots_conf(SLOT_CONF_FILE)
W
wangguibao 已提交
87 88 89
    if ret != 0:
        sys.exit(-1)

W
wangguibao 已提交
90
    ret = data_reader(sys.argv[4], samples, labels)
W
wangguibao 已提交
91 92 93 94 95

    for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
        batch = samples[i:i + BATCH_SIZE]
        instances = []
        for sample in batch:
W
wangguibao 已提交
96
            instance = api.add_instance()
W
wangguibao 已提交
97 98 99 100 101 102
            if sys.version_info[0] == 2:
                for k, v in sample.iteritems():
                    api.add_slot(instance, k, v)
            elif sys.version_info[0] == 3:
                for k, v in sample.items():
                    api.add_slot(instance, k, v)
W
wangguibao 已提交
103 104

        ret = api.inference()
W
wangguibao 已提交
105
        print(ret)
W
wangguibao 已提交
106
        sys.exit(0)