提交 87e5ae3e 编写于 作者: W wangguibao

Elastic CTR

上级 80cc091a
...@@ -81,8 +81,6 @@ install(FILES ${CMAKE_CURRENT_LIST_DIR}/api/python/elasticctr/__init__.py ...@@ -81,8 +81,6 @@ install(FILES ${CMAKE_CURRENT_LIST_DIR}/api/python/elasticctr/__init__.py
DESTINATION DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/api/lib) ${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/api/lib)
install(FILES ${CMAKE_CURRENT_LIST_DIR}/demo/demo.py install(FILES ${CMAKE_CURRENT_LIST_DIR}/demo/demo.py
DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/python_client) DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/client/bin)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/demo/conf DESTINATION install(FILES ${CMAKE_CURRENT_LIST_DIR}/demo/elastic_ctr.py
${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/python_client/) DESTINATION ${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/client/bin)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/demo/data/ctr_prediction DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/elastic_ctr/python_client/data)
...@@ -67,7 +67,7 @@ int read_samples(const char* file) { ...@@ -67,7 +67,7 @@ int read_samples(const char* file) {
if (sample.slots.find(slot_name) == sample.slots.end()) { if (sample.slots.find(slot_name) == sample.slots.end()) {
std::vector<uint64_t> values; std::vector<uint64_t> values;
values.push_back(x % 400000001); values.push_back(x % CTR_EMBEDDING_TABLE_SIZE);
sample.slots[slot_name] = values; sample.slots[slot_name] = values;
} else { } else {
auto it = sample.slots.find(slot_name); auto it = sample.slots.find(slot_name);
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import httplib
import sys
import os
BATCH_SIZE = 3
SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf"
CTR_EMBEDDING_TABLE_SIZE = 400000001
SLOTS = []
def data_reader(data_file, samples, labels):
if not os.path.exists(data_file):
print "Path %s not exist" % data_file
return -1
with open(data_file, "r") as f:
for line in f:
sample = {}
line = line.rstrip('\n')
feature_slots = line.split(' ')
feature_slots = feature_slots[2:]
feature_slot_maps = [x.split(':') for x in feature_slots]
features = [x[0] for x in feature_slot_maps]
slots = [x[1] for x in feature_slot_maps]
for i in range(0, len(features)):
if slots[i] in sample:
sample[slots[i]] = [sample[slots[i]] + long(features[i])]
else:
sample[slots[
i]] = [long(features[i]) % CTR_EMBEDDING_TABLE_SIZE]
for x in SLOTS:
if not x in sample:
sample[x] = [0]
samples.append(sample)
def read_slots_conf(slots_conf_file, slots):
if not os.path.exists(slots_conf_file):
print "Path %s not exist" % sltos_conf_file
return -1
with open(slots_conf_file, "r") as f:
for line in f:
slots.append(line.rstrip('\n'))
print slots
return 0
if __name__ == "__main__":
""" main
"""
if len(sys.argv) != 4:
print "Usage: python elastic_ctr.py SERVING_IP SLOT_CONF_FILE DATA_FILE"
sys.exit(-1)
samples = []
labels = []
SERVING_IP = sys.argv[1]
SLOT_CONF_FILE = sys.argv[2]
ret = read_slots_conf(SLOT_CONF_FILE, SLOTS)
if ret != 0:
sys.exit(-1)
print SLOTS
ret = data_reader(sys.argv[3], samples, labels)
conn = httplib.HTTPConnection(SERVING_IP, 8010)
for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
batch = samples[i:i + BATCH_SIZE]
instances = []
for sample in batch:
instance = []
kv = []
for k, v in sample.iteritems():
kv += [{"slot_name": k, "feasigns": v}]
print kv
instance = [{"slots": kv}]
instances += instance
req = {"instances": instances}
request_json = json.dumps(req)
print request_json
try:
conn.request('POST', "/ElasticCTRPredictionService/inference",
request_json, {"Content-Type": "application/json"})
response = conn.getresponse()
print response.read()
except httplib.HTTPException as e:
print e.reason
sys.exit(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册