提交 c4ac7fab 编写于 作者: D Dong Zhihong

'add f1 test'

上级 8d9b3341
......@@ -121,18 +121,14 @@ class Accuracy(Evaluator):
return executor.run(eval_program, fetch_list=[eval_out])
# This is demo for composing low level op to compute metric
# Demo for composing low level op to compute the F1 metric
class F1(Evaluator):
def __init__(self, input, label, **kwargs):
super(F1, self).__init__("F1", **kwargs)
super(Accuracy, self).__init__("accuracy", **kwargs)
g_total = helper.create_global_variable(
name=unique_name("Total"),
persistable=True,
dtype="int64",
shape=[1])
g_correct = helper.create_global_variable(
name=unique_name("Correct"),
persistable=True,
dtype="int64",
shape=[1])
g_tp = helper.create_global_variable(
name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
g_fp = helper.create_global_variable(
name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])
self._states["Tp"] = g_tp
self._states["Fp"] = g_fp
......@@ -61,6 +61,7 @@ PASS_NUM = 100
for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
accuracy.reset(exe)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
......@@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
fetch_list=[avg_cost, accuracy])
out = np.array(outs[0])
pass_acc = accuracy.eval(exe)
print pass_acc
if out[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部