提交 19de8ae1 编写于 作者: X Xinghai Sun

Fixed a error in mnist unitest.

上级 d8046da0
......@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def cross_entropy_layer(net, input, label):
cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"cross_entropy", X=input, label=label, Y=cost_name)
"cross_entropy", X=input, Label=label, Y=cost_name)
net.append_op(cross_entropy_op)
scope.new_var(cost_name)
net.infer_shape(scope)
......@@ -181,7 +181,7 @@ def error_rate(predict, label):
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE])
labels = data_layer(name="label", dims=[BATCH_SIZE, 1])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
......@@ -215,6 +215,7 @@ def test(cost_name):
for data in test_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
......@@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM):
for data in train_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册