From c4ac7fab5ecfc11023fc314b0030d5662fb396ce Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Tue, 7 Nov 2017 00:24:22 -0800 Subject: [PATCH] 'add f1 test' --- python/paddle/v2/framework/evaluator.py | 20 ++++++++----------- .../v2/framework/tests/test_fit_a_line.py | 5 ++++- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index ba2a06187..4f8e6fd48 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -121,18 +121,14 @@ class Accuracy(Evaluator): return executor.run(eval_program, fetch_list=[eval_out]) -# This is demo for composing low level op to compute metric +# Demo for composing low level op to compute the F1 metric class F1(Evaluator): def __init__(self, input, label, **kwargs): super(F1, self).__init__("F1", **kwargs) - super(Accuracy, self).__init__("accuracy", **kwargs) - g_total = helper.create_global_variable( - name=unique_name("Total"), - persistable=True, - dtype="int64", - shape=[1]) - g_correct = helper.create_global_variable( - name=unique_name("Correct"), - persistable=True, - dtype="int64", - shape=[1]) + g_tp = helper.create_global_variable( + name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1]) + g_fp = helper.create_global_variable( + name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1]) + + self._states["Tp"] = g_tp + self._states["Fp"] = g_fp diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/test_fit_a_line.py index aba1f27ad..28588506a 100644 --- a/python/paddle/v2/framework/tests/test_fit_a_line.py +++ b/python/paddle/v2/framework/tests/test_fit_a_line.py @@ -61,6 +61,7 @@ PASS_NUM = 100 for pass_id in range(PASS_NUM): save_persistables(exe, "./fit_a_line.model/", main_program=main_program) load_persistables(exe, "./fit_a_line.model/", main_program=main_program) + accuracy.reset(exe) for data in train_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32") @@ -75,8 +76,10 @@ for pass_id in range(PASS_NUM): outs = exe.run(main_program, feed={'x': tensor_x, 'y': tensor_y}, - fetch_list=[avg_cost]) + fetch_list=[avg_cost, accuracy]) out = np.array(outs[0]) + pass_acc = accuracy.eval(exe) + print pass_acc if out[0] < 10.0: exit(0) # if avg cost less than 10.0, we think our code is good. -- GitLab