提交 7f668ee6 编写于 作者: W wangguibao

elastic_ctr

上级 e041e650
......@@ -50,6 +50,9 @@ class ElasticCTRAPI(object):
self._instances += instance
return instance
def clear(self):
self._instances = []
def add_slot(self, instance, slot, feasigns):
if not isinstance(instance, list):
print("add slot: parameter invalid: instance should be list")
......
......@@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels):
sample = {}
line = line.rstrip('\n')
feature_slots = line.split(' ')
labels.append(int(feature_slots[1]))
feature_slots = feature_slots[2:]
feature_slot_maps = [x.split(':') for x in feature_slots]
......@@ -89,7 +90,9 @@ if __name__ == "__main__":
ret = data_reader(sys.argv[4], samples, labels)
correct = 0
for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
api.clear()
batch = samples[i:i + BATCH_SIZE]
instances = []
for sample in batch:
......@@ -102,5 +105,22 @@ if __name__ == "__main__":
api.add_slot(instance, k, v)
ret = api.inference()
print(ret)
sys.exit(0)
ret = json.loads(ret)
predictions = ret["predictions"]
idx = 0
for x in predictions:
if x["prob0"] >= x["prob1"]:
pred = 0
else:
pred = 1
if labels[i + idx] == pred:
correct += 1
else:
print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
(i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))
idx = idx + 1
print("Acc=%f" % (float(correct) / len(samples)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册