未验证 提交 e3893556 编写于 作者: W Wang Guibao 提交者: GitHub

Merge pull request #103 from wangguibao/elastic_ctr

Elastic ctr
...@@ -50,6 +50,9 @@ class ElasticCTRAPI(object): ...@@ -50,6 +50,9 @@ class ElasticCTRAPI(object):
self._instances += instance self._instances += instance
return instance return instance
def clear(self):
self._instances = []
def add_slot(self, instance, slot, feasigns): def add_slot(self, instance, slot, feasigns):
if not isinstance(instance, list): if not isinstance(instance, list):
print("add slot: parameter invalid: instance should be list") print("add slot: parameter invalid: instance should be list")
......
...@@ -30,7 +30,7 @@ DEFINE_int32(batch_size, 10, "Infernce batch_size"); ...@@ -30,7 +30,7 @@ DEFINE_int32(batch_size, 10, "Infernce batch_size");
DEFINE_string(test_file, "", "test file"); DEFINE_string(test_file, "", "test file");
const int VARIABLE_NAME_LEN = 256; const int VARIABLE_NAME_LEN = 256;
const int CTR_EMBEDDING_TABLE_SIZE = 400000001; const int CTR_EMBEDDING_TABLE_SIZE = 100000001;
struct Sample { struct Sample {
std::map<std::string, std::vector<uint64_t>> slots; std::map<std::string, std::vector<uint64_t>> slots;
......
...@@ -22,7 +22,7 @@ from elastic_ctr_api import ElasticCTRAPI ...@@ -22,7 +22,7 @@ from elastic_ctr_api import ElasticCTRAPI
BATCH_SIZE = 3 BATCH_SIZE = 3
SERVING_IP = "127.0.0.1" SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf" SLOT_CONF_FILE = "./conf/slot.conf"
CTR_EMBEDDING_TABLE_SIZE = 400000001 CTR_EMBEDDING_TABLE_SIZE = 100000001
SLOTS = [] SLOTS = []
...@@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels): ...@@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels):
sample = {} sample = {}
line = line.rstrip('\n') line = line.rstrip('\n')
feature_slots = line.split(' ') feature_slots = line.split(' ')
labels.append(int(feature_slots[1]))
feature_slots = feature_slots[2:] feature_slots = feature_slots[2:]
feature_slot_maps = [x.split(':') for x in feature_slots] feature_slot_maps = [x.split(':') for x in feature_slots]
...@@ -89,7 +90,9 @@ if __name__ == "__main__": ...@@ -89,7 +90,9 @@ if __name__ == "__main__":
ret = data_reader(sys.argv[4], samples, labels) ret = data_reader(sys.argv[4], samples, labels)
correct = 0
for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE): for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
api.clear()
batch = samples[i:i + BATCH_SIZE] batch = samples[i:i + BATCH_SIZE]
instances = [] instances = []
for sample in batch: for sample in batch:
...@@ -102,5 +105,22 @@ if __name__ == "__main__": ...@@ -102,5 +105,22 @@ if __name__ == "__main__":
api.add_slot(instance, k, v) api.add_slot(instance, k, v)
ret = api.inference() ret = api.inference()
print(ret) ret = json.loads(ret)
sys.exit(0) predictions = ret["predictions"]
idx = 0
for x in predictions:
if x["prob0"] >= x["prob1"]:
pred = 0
else:
pred = 1
if labels[i + idx] == pred:
correct += 1
else:
print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
(i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))
idx = idx + 1
print("Acc=%f" % (float(correct) / len(samples)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册