# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import json import sys import os from elastic_ctr_api import ElasticCTRAPI BATCH_SIZE = 3 SERVING_IP = "127.0.0.1" SLOT_CONF_FILE = "./conf/slot.conf" CTR_EMBEDDING_TABLE_SIZE = 400000001 SLOTS = [] def str2long(str): if sys.version_info[0] == 2: return long(str) elif sys.version_info[0] == 3: return int(str) def data_reader(data_file, samples, labels): if not os.path.exists(data_file): print("Path %s not exist" % data_file) return -1 with open(data_file, "r") as f: for line in f: sample = {} line = line.rstrip('\n') feature_slots = line.split(' ') feature_slots = feature_slots[2:] feature_slot_maps = [x.split(':') for x in feature_slots] features = [x[0] for x in feature_slot_maps] slots = [x[1] for x in feature_slot_maps] for i in range(0, len(features)): if slots[i] in sample: sample[slots[i]] = [ sample[slots[i]] + str2long(features[i]) % CTR_EMBEDDING_TABLE_SIZE ] else: sample[slots[i]] = [ str2long(features[i]) % CTR_EMBEDDING_TABLE_SIZE ] for x in SLOTS: if not x in sample: sample[x] = [0] samples.append(sample) if __name__ == "__main__": """ main """ if len(sys.argv) != 5: print( "Usage: python elastic_ctr.py SERVING_IP SERVING_PORT SLOT_CONF_FILE DATA_FILE" ) sys.exit(-1) samples = [] labels = [] SERVING_IP = sys.argv[1] SERVING_PORT = sys.argv[2] SLOT_CONF_FILE = sys.argv[3] api = ElasticCTRAPI(SERVING_IP, SERVING_PORT) ret = api.read_slots_conf(SLOT_CONF_FILE) if ret != 0: sys.exit(-1) ret = data_reader(sys.argv[4], samples, labels) for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE): batch = samples[i:i + BATCH_SIZE] instances = [] for sample in batch: instance = api.add_instance() kv = [] if sys.version_info[0] == 2: for k, v in sample.iteritems(): api.add_slot(instance, k, v) elif sys.version_info[0] == 3: for k, v in sample.items(): api.add_slot(instance, k, v) ret = api.inference() print(ret) sys.exit(0)