提交 9e1799cb 编写于 作者: D Dong Zhihong

"fix based on comments"

上级 cfbc92e6
......@@ -42,14 +42,14 @@ class Evaluator(object):
"""
def reset(self, executor, program=None):
def reset(self, executor, reset_program=None):
"""
Reset metric states at the begin of each pass/user specified batch number.
Execute the reset_program to reset the states.
"""
def eval(self, executor, program=None):
def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
Execute the eval_program and return the result.
......
......@@ -39,7 +39,7 @@ class Evaluator(object):
"""
raise NotImplementedError()
def reset(self, executor, program=None):
def reset(self, executor, reset_program=None):
"""
Clear metric states at the begin of each pass/user specified batch
"""
......@@ -63,7 +63,7 @@ class Evaluator(object):
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
executor.run(reset_program, fetch_list=self._states.values())
def eval(self, executor, program=None):
def eval(self, executor, eval_program=None):
"""
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
"""
......
......@@ -6,7 +6,6 @@ import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_main_program
from paddle.v2.framework.io import save_persistables, load_persistables
from paddle.v2.framework.executor import Executor
from paddle.v2.framework.evaluator import Accuracy
import numpy as np
......@@ -32,8 +31,6 @@ y = layers.data(
main_program=main_program,
startup_program=startup_program)
accuracy = evaluator.Accuracy(input=y_predict, label=y)
cost = layers.square_error_cost(
input=y_predict,
label=y,
......@@ -61,7 +58,6 @@ PASS_NUM = 100
for pass_id in range(PASS_NUM):
save_persistables(exe, "./fit_a_line.model/", main_program=main_program)
load_persistables(exe, "./fit_a_line.model/", main_program=main_program)
accuracy.reset(exe)
for data in train_reader():
x_data = np.array(map(lambda x: x[0], data)).astype("float32")
y_data = np.array(map(lambda x: x[1], data)).astype("float32")
......@@ -76,10 +72,8 @@ for pass_id in range(PASS_NUM):
outs = exe.run(main_program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost, accuracy])
fetch_list=[avg_cost])
out = np.array(outs[0])
pass_acc = accuracy.eval(exe)
print pass_acc
if out[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册