From 7f668ee6bfa6af450b2d76f412277d4d9be4b0df Mon Sep 17 00:00:00 2001 From: wangguibao Date: Fri, 22 Nov 2019 21:20:58 +0800 Subject: [PATCH] elastic_ctr --- .../api/python/elasticctr/elastic_ctr_api.py | 3 +++ elastic-ctr/client/demo/elastic_ctr.py | 24 +++++++++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py b/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py index a8101f8b..b9ef7ef4 100644 --- a/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py +++ b/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py @@ -50,6 +50,9 @@ class ElasticCTRAPI(object): self._instances += instance return instance + def clear(self): + self._instances = [] + def add_slot(self, instance, slot, feasigns): if not isinstance(instance, list): print("add slot: parameter invalid: instance should be list") diff --git a/elastic-ctr/client/demo/elastic_ctr.py b/elastic-ctr/client/demo/elastic_ctr.py index 30873437..91695fd6 100644 --- a/elastic-ctr/client/demo/elastic_ctr.py +++ b/elastic-ctr/client/demo/elastic_ctr.py @@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels): sample = {} line = line.rstrip('\n') feature_slots = line.split(' ') + labels.append(int(feature_slots[1])) feature_slots = feature_slots[2:] feature_slot_maps = [x.split(':') for x in feature_slots] @@ -89,7 +90,9 @@ if __name__ == "__main__": ret = data_reader(sys.argv[4], samples, labels) + correct = 0 for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE): + api.clear() batch = samples[i:i + BATCH_SIZE] instances = [] for sample in batch: @@ -102,5 +105,22 @@ if __name__ == "__main__": api.add_slot(instance, k, v) ret = api.inference() - print(ret) - sys.exit(0) + ret = json.loads(ret) + predictions = ret["predictions"] + + idx = 0 + for x in predictions: + if x["prob0"] >= x["prob1"]: + pred = 0 + else: + pred = 1 + + if labels[i + idx] == pred: + correct += 1 + else: + print("id=%d predict incorrect: pred=%d label=%d (%f %f)" % + (i + idx, pred, labels[i + idx], x["prob0"], x["prob1"])) + + idx = idx + 1 + + print("Acc=%f" % (float(correct) / len(samples))) -- GitLab