未验证 提交 3ea47c0b 编写于 作者: W Wang Guibao 提交者: GitHub

Merge pull request #105 from suoych/develop

fixed client
......@@ -19,7 +19,7 @@ import os
from elastic_ctr_api import ElasticCTRAPI
BATCH_SIZE = 3
BATCH_SIZE = 10
SERVING_IP = "127.0.0.1"
SLOT_CONF_FILE = "./conf/slot.conf"
CTR_EMBEDDING_TABLE_SIZE = 100000001
......@@ -33,6 +33,59 @@ def str2long(str):
return int(str)
def tied_rank(x):
"""
Computes the tied rank of elements in x.
This function computes the tied rank of elements in x.
Parameters
----------
x : list of numbers, numpy array
Returns
-------
score : list of numbers
The tied rank f each element in x
"""
sorted_x = sorted(zip(x,range(len(x))))
r = [0 for k in x]
cur_val = sorted_x[0][0]
last_rank = 0
for i in range(len(sorted_x)):
if cur_val != sorted_x[i][0]:
cur_val = sorted_x[i][0]
for j in range(last_rank, i):
r[sorted_x[j][1]] = float(last_rank+1+i)/2.0
last_rank = i
if i==len(sorted_x)-1:
for j in range(last_rank, i+1):
r[sorted_x[j][1]] = float(last_rank+i+2)/2.0
return r
def auc(actual, posterior):
"""
Computes the area under the receiver-operater characteristic (AUC)
This function computes the AUC error metric for binary classification.
Parameters
----------
actual : list of binary numbers, numpy array
The ground truth value
posterior : same type as actual
Defines a ranking on the binary numbers, from most likely to
be positive to least likely to be positive.
Returns
-------
score : double
The mean squared error between actual and posterior
"""
r = tied_rank(posterior)
num_positive = len([0 for x in actual if x==1])
num_negative = len(actual)-num_positive
sum_positive = sum([r[i] for i in range(len(r)) if actual[i]==1])
auc = ((sum_positive - num_positive*(num_positive+1)/2.0) /
(num_negative*num_positive))
return auc
def data_reader(data_file, samples, labels):
if not os.path.exists(data_file):
print("Path %s not exist" % data_file)
......@@ -66,7 +119,7 @@ def data_reader(data_file, samples, labels):
sample[x] = [0]
samples.append(sample)
if __name__ == "__main__":
""" main
"""
......@@ -89,8 +142,10 @@ if __name__ == "__main__":
sys.exit(-1)
ret = data_reader(sys.argv[4], samples, labels)
print(len(samples))
correct = 0
wrong_label_1_count = 0
result_list = []
for i in range(0, len(samples) - BATCH_SIZE, BATCH_SIZE):
api.clear()
batch = samples[i:i + BATCH_SIZE]
......@@ -105,11 +160,13 @@ if __name__ == "__main__":
api.add_slot(instance, k, v)
ret = api.inference()
continue
ret = json.loads(ret)
predictions = ret["predictions"]
idx = 0
for x in predictions:
result_list.append(x["prob1"])
if x["prob0"] >= x["prob1"]:
pred = 0
else:
......@@ -118,9 +175,14 @@ if __name__ == "__main__":
if labels[i + idx] == pred:
correct += 1
else:
print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
(i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))
#if labels[i + idx] == 1:
# wrong_label_1_count += 1
# print("error label=1 count", wrong_label_1_count)
#print("id=%d predict incorrect: pred=%d label=%d (%f %f)" %
# (i + idx, pred, labels[i + idx], x["prob0"], x["prob1"]))
pass
idx = idx + 1
print("Acc=%f" % (float(correct) / len(samples)))
#print("Acc=%f" % (float(correct) / len(samples)))
print("auc = ", auc(labels, result_list) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册