提交 9e1799cb 编写于 作者: D Dong Zhihong

"fix based on comments"

上级 cfbc92e6
...@@ -42,14 +42,14 @@ class Evaluator(object): ...@@ -42,14 +42,14 @@ class Evaluator(object):
""" """
def reset(self, executor, program=None): def reset(self, executor, reset_program=None):
""" """
Reset metric states at the begin of each pass/user specified batch number. Reset metric states at the begin of each pass/user specified batch number.
Execute the reset_program to reset the states. Execute the reset_program to reset the states.
""" """
def eval(self, executor, program=None): def eval(self, executor, eval_program=None):
""" """
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Execute the eval_program and return the result. Execute the eval_program and return the result.
......
...@@ -39,7 +39,7 @@ class Evaluator(object): ...@@ -39,7 +39,7 @@ class Evaluator(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def reset(self, executor, program=None): def reset(self, executor, reset_program=None):
""" """
Clear metric states at the begin of each pass/user specified batch Clear metric states at the begin of each pass/user specified batch
""" """
...@@ -63,7 +63,7 @@ class Evaluator(object): ...@@ -63,7 +63,7 @@ class Evaluator(object):
type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
executor.run(reset_program, fetch_list=self._states.values()) executor.run(reset_program, fetch_list=self._states.values())
def eval(self, executor, program=None): def eval(self, executor, eval_program=None):
""" """
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
""" """
......
...@@ -6,7 +6,6 @@ import paddle.v2.framework.optimizer as optimizer ...@@ -6,7 +6,6 @@ import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program from paddle.v2.framework.framework import Program, g_main_program
from paddle.v2.framework.io import save_persistables, load_persistables from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor from paddle.v2.framework.executor import Executor
from paddle.v2.framework.evaluator import Accuracy
import numpy as np import numpy as np
...@@ -32,8 +31,6 @@ y = layers.data( ...@@ -32,8 +31,6 @@ y = layers.data(
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
accuracy = evaluator.Accuracy(input=y_predict, label=y)
cost = layers.square_error_cost( cost = layers.square_error_cost(
input=y_predict, input=y_predict,
label=y, label=y,
...@@ -61,7 +58,6 @@ PASS_NUM = 100 ...@@ -61,7 +58,6 @@ PASS_NUM = 100
for pass_id in range(PASS_NUM): for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", main_program=main_program) save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
load_persistables(exe, "./fit_a_line.model/", main_program=main_program) load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
accuracy.reset(exe)
for data in train_reader(): for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32") x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("float32")
...@@ -76,10 +72,8 @@ for pass_id in range(PASS_NUM): ...@@ -76,10 +72,8 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program, outs = exe.run(main_program,
feed={'x': tensor_x, feed={'x': tensor_x,
'y': tensor_y}, 'y': tensor_y},
fetch_list=[avg_cost, accuracy]) fetch_list=[avg_cost])
out = np.array(outs[0]) out = np.array(outs[0])
pass_acc = accuracy.eval(exe)
print pass_acc
if out[0] < 10.0: if out[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good. exit(0) # if avg cost less than 10.0, we think our code is good.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册