diff --git a/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py b/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py index a8101f8b9394f9f143ebf7a77e25d2d2ed6652d0..b9ef7ef402bcb5dc92d1a6f21ec875e6ef85c8d7 100644 --- a/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py +++ b/elastic-ctr/client/api/python/elasticctr/elastic_ctr_api.py @@ -50,6 +50,9 @@ class ElasticCTRAPI(object): self._instances += instance return instance + def clear(self): + self._instances = [] + def add_slot(self, instance, slot, feasigns): if not isinstance(instance, list): print("add slot: parameter invalid: instance should be list") diff --git a/elastic-ctr/client/demo/elastic_ctr.py b/elastic-ctr/client/demo/elastic_ctr.py index 30873437c67d58c3a816ad344019bbeaa7503d54..91695fd6656345b26a76807e0d3c258ba3adf81f 100644 --- a/elastic-ctr/client/demo/elastic_ctr.py +++ b/elastic-ctr/client/demo/elastic_ctr.py @@ -43,6 +43,7 @@ def data_reader(data_file, samples, labels): sample = {} line = line.rstrip('\n') feature_slots = line.split(' ') + labels.append(int(feature_slots[1])) feature_slots = feature_slots[2:] feature_slot_maps = [x.split(':') for x in feature_slots] @@ -89,7 +90,9 @@ if __name__ == "__main__": ret = data_reader(sys.argv[4], samples, labels) + correct = 0 for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE): + api.clear() batch = samples[i:i + BATCH_SIZE] instances = [] for sample in batch: @@ -102,5 +105,22 @@ if __name__ == "__main__": api.add_slot(instance, k, v) ret = api.inference() - print(ret) - sys.exit(0) + ret = json.loads(ret) + predictions = ret["predictions"] + + idx = 0 + for x in predictions: + if x["prob0"] >= x["prob1"]: + pred = 0 + else: + pred = 1 + + if labels[i + idx] == pred: + correct += 1 + else: + print("id=%d predict incorrect: pred=%d label=%d (%f %f)" % + (i + idx, pred, labels[i + idx], x["prob0"], x["prob1"])) + + idx = idx + 1 + + print("Acc=%f" % (float(correct) / len(samples)))