elastic_ctr.py 3.3 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import httplib
import sys
import os

BATCH_SIZE = 3
SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf"
CTR_EMBEDDING_TABLE_SIZE = 400000001
SLOTS = []


def data_reader(data_file, samples, labels):
    if not os.path.exists(data_file):
        print "Path %s not exist" % data_file
        return -1

    with open(data_file, "r") as f:
        for line in f:
            sample = {}
            line = line.rstrip('\n')
            feature_slots = line.split(' ')
            feature_slots = feature_slots[2:]
            feature_slot_maps = [x.split(':') for x in feature_slots]

            features = [x[0] for x in feature_slot_maps]
            slots = [x[1] for x in feature_slot_maps]

            for i in range(0, len(features)):
                if slots[i] in sample:
                    sample[slots[i]] = [sample[slots[i]] + long(features[i])]
                else:
                    sample[slots[
                        i]] = [long(features[i]) % CTR_EMBEDDING_TABLE_SIZE]

            for x in SLOTS:
                if not x in sample:
                    sample[x] = [0]
            samples.append(sample)


def read_slots_conf(slots_conf_file, slots):
    if not os.path.exists(slots_conf_file):
        print "Path %s not exist" % sltos_conf_file
        return -1
    with open(slots_conf_file, "r") as f:
        for line in f:
            slots.append(line.rstrip('\n'))
    print slots
    return 0


if __name__ == "__main__":
    """ main
    """
    if len(sys.argv) != 4:
        print "Usage: python elastic_ctr.py SERVING_IP SLOT_CONF_FILE DATA_FILE"
        sys.exit(-1)

    samples = []
    labels = []

    SERVING_IP = sys.argv[1]
    SLOT_CONF_FILE = sys.argv[2]

    ret = read_slots_conf(SLOT_CONF_FILE, SLOTS)
    if ret != 0:
        sys.exit(-1)
    print SLOTS

    ret = data_reader(sys.argv[3], samples, labels)

    conn = httplib.HTTPConnection(SERVING_IP, 8010)

    for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
        batch = samples[i:i + BATCH_SIZE]
        instances = []
        for sample in batch:
            instance = []
            kv = []
            for k, v in sample.iteritems():
                kv += [{"slot_name": k, "feasigns": v}]
            print kv
            instance = [{"slots": kv}]
            instances += instance
        req = {"instances": instances}

        request_json = json.dumps(req)
        print request_json

        try:
            conn.request('POST', "/ElasticCTRPredictionService/inference",
                         request_json, {"Content-Type": "application/json"})
            response = conn.getresponse()
            print response.read()
        except httplib.HTTPException as e:
            print e.reason
        sys.exit(0)