diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index ba2a06187890e2016872f7464639df1047a0b7cf..4f8e6fd4884307eeddfd6eda88019f5bb72a6cd6 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -121,18 +121,14 @@ class Accuracy(Evaluator): return executor.run(eval_program, fetch_list=[eval_out]) -# This is demo for composing low level op to compute metric +# Demo for composing low level op to compute the F1 metric class F1(Evaluator): def __init__(self, input, label, **kwargs): super(F1, self).__init__("F1", **kwargs) - super(Accuracy, self).__init__("accuracy", **kwargs) - g_total = helper.create_global_variable( - name=unique_name("Total"), - persistable=True, - dtype="int64", - shape=[1]) - g_correct = helper.create_global_variable( - name=unique_name("Correct"), - persistable=True, - dtype="int64", - shape=[1]) + g_tp = helper.create_global_variable( + name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1]) + g_fp = helper.create_global_variable( + name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1]) + + self._states["Tp"] = g_tp + self._states["Fp"] = g_fp diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/test_fit_a_line.py index aba1f27ad64b2e510a957736fb566c7616d2b435..28588506a675bd6fbf121a2f95c0fda13a8b8934 100644 --- a/python/paddle/v2/framework/tests/test_fit_a_line.py +++ b/python/paddle/v2/framework/tests/test_fit_a_line.py @@ -61,6 +61,7 @@ PASS_NUM = 100 for pass_id in range(PASS_NUM): save_persistables(exe, "./fit_a_line.model/", main_program=main_program) load_persistables(exe, "./fit_a_line.model/", main_program=main_program) + accuracy.reset(exe) for data in train_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32") @@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM): outs = exe.run(main_program, feed={'x': tensor_x, 'y': tensor_y}, - fetch_list=[avg_cost]) + fetch_list=[avg_cost, accuracy]) out = np.array(outs[0]) + pass_acc = accuracy.eval(exe) + print pass_acc if out[0] < 10.0: exit(0) # if avg cost less than 10.0, we think our code is good.