提交 76677f25 编写于 作者: Q qiaolongfei

add test

上级 6f4b968f
......@@ -52,7 +52,7 @@ def grad_var_name(var_name):
return var_name + "@GRAD"
def sgd_optimizer(net, param_name, learning_rate=0.01):
def sgd_optimizer(net, param_name, learning_rate=0.005):
grad_name = grad_var_name(param_name)
optimize_op = Operator(
"sgd",
......@@ -166,9 +166,9 @@ def set_cost():
cost_grad.set(numpy.ones(cost_shape).astype("float32"), place)
def print_cost():
def mean_cost():
cost_data = numpy.array(scope.find_var("cross_entropy_3").get_tensor())
print(cost_data.sum() / len(cost_data))
return cost_data.sum() / len(cost_data)
def error_rate(predict, label):
......@@ -176,7 +176,7 @@ def error_rate(predict, label):
axis=1)
label = numpy.array(scope.find_var(label).get_tensor())
error_num = numpy.sum(predict_var != label)
print(error_num / float(len(label)))
return error_num / float(len(label))
images = data_layer(name='pixel', dims=[BATCH_SIZE, 784])
......@@ -198,16 +198,35 @@ print_inputs_outputs(forward_network)
print_inputs_outputs(backward_net)
print_inputs_outputs(optimize_net)
reader = paddle.batch(
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
def test():
test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
cost = []
error = []
for data in test_reader():
image = numpy.array(map(lambda x: x[0], data)).astype("float32")
label = numpy.array(map(lambda x: x[1], data)).astype("int32")
feed_data("pixel", image)
feed_data("label", label)
forward_network.infer_shape(scope)
forward_network.run(scope, dev_ctx)
cost.append(mean_cost())
error.append(error_rate(predict, "label"))
print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str(
sum(error) / float(len(error))))
PASS_NUM = 1000
for pass_id in range(PASS_NUM):
batch_id = 0
for data in reader():
for data in train_reader():
image = numpy.array(map(lambda x: x[0], data)).astype("float32")
label = numpy.array(map(lambda x: x[1], data)).astype("int32")
feed_data("pixel", image)
......@@ -222,7 +241,8 @@ for pass_id in range(PASS_NUM):
optimize_net.run(scope, dev_ctx)
if batch_id % 100 == 0:
print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]")
print_cost()
error_rate(predict, "label")
test()
# print(mean_cost())
# print(error_rate(predict, "label"))
batch_id = batch_id + 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册