提交 06fd660d 编写于 作者: W wangguibao

elastic_ctr

上级 d97b6795
...@@ -50,6 +50,9 @@ class ElasticCTRAPI(object): ...@@ -50,6 +50,9 @@ class ElasticCTRAPI(object):
self._instances += instance self._instances += instance
return instance return instance
def clear(self):
self._instances = []
def add_slot(self, instance, slot, feasigns): def add_slot(self, instance, slot, feasigns):
if not isinstance(instance, list): if not isinstance(instance, list):
print("add slot: parameter invalid: instance should be list") print("add slot: parameter invalid: instance should be list")
......
...@@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels): ...@@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels):
sample = {} sample = {}
line = line.rstrip('\n') line = line.rstrip('\n')
feature_slots = line.split(' ') feature_slots = line.split(' ')
labels.append(int(feature_slots[1]))
feature_slots = feature_slots[2:] feature_slots = feature_slots[2:]
feature_slot_maps = [x.split(':') for x in feature_slots] feature_slot_maps = [x.split(':') for x in feature_slots]
...@@ -89,7 +90,9 @@ if __name__ == "__main__": ...@@ -89,7 +90,9 @@ if __name__ == "__main__":
ret = data_reader(sys.argv[4], samples, labels) ret = data_reader(sys.argv[4], samples, labels)
correct = 0
for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE): for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
api.clear()
batch = samples[i:i + BATCH_SIZE] batch = samples[i:i + BATCH_SIZE]
instances = [] instances = []
for sample in batch: for sample in batch:
...@@ -102,5 +105,22 @@ if __name__ == "__main__": ...@@ -102,5 +105,22 @@ if __name__ == "__main__":
api.add_slot(instance, k, v) api.add_slot(instance, k, v)
ret = api.inference() ret = api.inference()
print(ret) ret = json.loads(ret)
sys.exit(0) predictions = ret["predictions"]
idx = 0
for x in predictions:
if x["prob0"] >= x["prob1"]:
pred = 0
else:
pred = 1
if labels[i + idx] == pred:
correct += 1
else:
print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
(i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))
idx = idx + 1
print("Acc=%f" % (float(correct) / len(samples)))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册