From 9e1799cb43c217b8a4cc0b52b19b8a2062c5e5c6 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Thu, 9 Nov 2017 17:35:13 -0800 Subject: [PATCH] "fix based on comments" --- doc/design/evaluator.md | 4 ++-- python/paddle/v2/framework/evaluator.py | 4 ++-- python/paddle/v2/framework/tests/test_fit_a_line.py | 8 +------- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/doc/design/evaluator.md b/doc/design/evaluator.md index f43bad1839..a62d75ffef 100644 --- a/doc/design/evaluator.md +++ b/doc/design/evaluator.md @@ -42,14 +42,14 @@ class Evaluator(object): """ - def reset(self, executor, program=None): + def reset(self, executor, reset_program=None): """ Reset metric states at the begin of each pass/user specified batch number. Execute the reset_program to reset the states. """ - def eval(self, executor, program=None): + def eval(self, executor, eval_program=None): """ Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. Execute the eval_program and return the result. diff --git a/python/paddle/v2/framework/evaluator.py b/python/paddle/v2/framework/evaluator.py index 664f65422c..89290abb83 100644 --- a/python/paddle/v2/framework/evaluator.py +++ b/python/paddle/v2/framework/evaluator.py @@ -39,7 +39,7 @@ class Evaluator(object): """ raise NotImplementedError() - def reset(self, executor, program=None): + def reset(self, executor, reset_program=None): """ Clear metric states at the begin of each pass/user specified batch """ @@ -63,7 +63,7 @@ class Evaluator(object): type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) executor.run(reset_program, fetch_list=self._states.values()) - def eval(self, executor, program=None): + def eval(self, executor, eval_program=None): """ Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. """ diff --git a/python/paddle/v2/framework/tests/test_fit_a_line.py b/python/paddle/v2/framework/tests/test_fit_a_line.py index 28588506a6..174ee74c3b 100644 --- a/python/paddle/v2/framework/tests/test_fit_a_line.py +++ b/python/paddle/v2/framework/tests/test_fit_a_line.py @@ -6,7 +6,6 @@ import paddle.v2.framework.optimizer as optimizer from paddle.v2.framework.framework import Program, g_main_program from paddle.v2.framework.io import save_persistables, load_persistables from paddle.v2.framework.executor import Executor -from paddle.v2.framework.evaluator import Accuracy import numpy as np @@ -32,8 +31,6 @@ y = layers.data( main_program=main_program, startup_program=startup_program) -accuracy = evaluator.Accuracy(input=y_predict, label=y) - cost = layers.square_error_cost( input=y_predict, label=y, @@ -61,7 +58,6 @@ PASS_NUM = 100 for pass_id in range(PASS_NUM): save_persistables(exe, "./fit_a_line.model/", main_program=main_program) load_persistables(exe, "./fit_a_line.model/", main_program=main_program) - accuracy.reset(exe) for data in train_reader(): x_data = np.array(map(lambda x: x[0], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32") @@ -76,10 +72,8 @@ for pass_id in range(PASS_NUM): outs = exe.run(main_program, feed={'x': tensor_x, 'y': tensor_y}, - fetch_list=[avg_cost, accuracy]) + fetch_list=[avg_cost]) out = np.array(outs[0]) - pass_acc = accuracy.eval(exe) - print pass_acc if out[0] < 10.0: exit(0) # if avg cost less than 10.0, we think our code is good. -- GitLab