提交 19de8ae1 编写于 作者: X Xinghai Sun

Fixed a error in mnist unitest.

上级 d8046da0
...@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None): ...@@ -128,7 +128,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def cross_entropy_layer(net, input, label): def cross_entropy_layer(net, input, label):
cost_name = "cross_entropy_%d" % uniq_id() cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator( cross_entropy_op = Operator(
"cross_entropy", X=input, label=label, Y=cost_name) "cross_entropy", X=input, Label=label, Y=cost_name)
net.append_op(cross_entropy_op) net.append_op(cross_entropy_op)
scope.new_var(cost_name) scope.new_var(cost_name)
net.infer_shape(scope) net.infer_shape(scope)
...@@ -181,7 +181,7 @@ def error_rate(predict, label): ...@@ -181,7 +181,7 @@ def error_rate(predict, label):
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784]) images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE]) labels = data_layer(name="label", dims=[BATCH_SIZE, 1])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid") fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid") fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax") predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
...@@ -215,6 +215,7 @@ def test(cost_name): ...@@ -215,6 +215,7 @@ def test(cost_name):
for data in test_reader(): for data in test_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data) feed_data(images, image_data)
feed_data(labels, label_data) feed_data(labels, label_data)
...@@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM): ...@@ -235,6 +236,7 @@ for pass_id in range(PASS_NUM):
for data in train_reader(): for data in train_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32") image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32") label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data) feed_data(images, image_data)
feed_data(labels, label_data) feed_data(labels, label_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册