diff --git a/elastic-ctr/client/demo/elastic_ctr.py b/elastic-ctr/client/demo/elastic_ctr.py index 77ae6879979aa0e25ae84422c78d1e9023228979..75684052b19909aae53c7146d3de909605fd06c9 100644 --- a/elastic-ctr/client/demo/elastic_ctr.py +++ b/elastic-ctr/client/demo/elastic_ctr.py @@ -14,7 +14,6 @@ from __future__ import print_function import json -import httplib import sys import os @@ -27,6 +26,13 @@ CTR_EMBEDDING_TABLE_SIZE = 400000001 SLOTS = [] +def str2long(str): + if sys.version_info[0] == 2: + return long(str) + elif sys.version_info[0] == 3: + return int(str) + + def data_reader(data_file, samples, labels): if not os.path.exists(data_file): print("Path %s not exist" % data_file) @@ -45,10 +51,14 @@ def data_reader(data_file, samples, labels): for i in range(0, len(features)): if slots[i] in sample: - sample[slots[i]] = [sample[slots[i]] + long(features[i])] + sample[slots[i]] = [ + sample[slots[i]] + str2long(features[i]) % + CTR_EMBEDDING_TABLE_SIZE + ] else: - sample[slots[ - i]] = [long(features[i]) % CTR_EMBEDDING_TABLE_SIZE] + sample[slots[i]] = [ + str2long(features[i]) % CTR_EMBEDDING_TABLE_SIZE + ] for x in SLOTS: if not x in sample: @@ -85,8 +95,12 @@ if __name__ == "__main__": for sample in batch: instance = api.add_instance() kv = [] - for k, v in sample.iteritems(): - api.add_slot(instance, k, v) + if sys.version_info[0] == 2: + for k, v in sample.iteritems(): + api.add_slot(instance, k, v) + elif sys.version_info[0] == 3: + for k, v in sample.items(): + api.add_slot(instance, k, v) ret = api.inference() print(ret)