提交 c4ac7fab 编写于 作者: D Dong Zhihong

'add f1 test'

上级 8d9b3341
...@@ -121,18 +121,14 @@ class Accuracy(Evaluator): ...@@ -121,18 +121,14 @@ class Accuracy(Evaluator):
return executor.run(eval_program, fetch_list=[eval_out]) return executor.run(eval_program, fetch_list=[eval_out])
# This is demo for composing low level op to compute metric # Demo for composing low level op to compute the F1 metric
class F1(Evaluator): class F1(Evaluator):
def __init__(self, input, label, **kwargs): def __init__(self, input, label, **kwargs):
super(F1, self).__init__("F1", **kwargs) super(F1, self).__init__("F1", **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs) g_tp = helper.create_global_variable(
g_total = helper.create_global_variable( name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
name=unique_name("Total"), g_fp = helper.create_global_variable(
persistable=True, name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])
dtype="int64",
shape=[1]) self._states["Tp"] = g_tp
g_correct = helper.create_global_variable( self._states["Fp"] = g_fp
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
...@@ -61,6 +61,7 @@ PASS_NUM = 100 ...@@ -61,6 +61,7 @@ PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", main_program=main_program) save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
load_persistables(exe, "./fit_a_line.model/", main_program=main_program) load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32") x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32")
...@@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM): ...@@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program, outs = exe.run(main_program,
feed={'x': tensor_x, feed={'x': tensor_x,
'y': tensor_y}, 'y': tensor_y},
fetch_list=[avg_cost]) fetch_list=[avg_cost, accuracy])
out = np.array(outs[0]) out = np.array(outs[0])
pass_acc = accuracy.eval(exe)
print pass_acc
if out[0] < 10.0: if out[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good. exit(0) # if avg cost less than 10.0, we think our code is good.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册