提交 a305cb21 编写于 作者: Y Yu Yang

Use new program when unittest

上级 9d4c93a0
......@@ -21,7 +21,7 @@ import paddle.v2 as paddle
class TestRecordIO(unittest.TestCase):
def setUp(self):
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program()):
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=32)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
......@@ -35,24 +35,26 @@ class TestRecordIO(unittest.TestCase):
'./mnist.recordio', reader, feeder)
def test_main(self):
data_file = fluid.layers.open_recordio_file(
'./mnist.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_file)
# use new program
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
'./mnist.recordio',
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
img, label = fluid.layers.read_file(data_file)
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
hidden = fluid.layers.fc(input=img, size=100, act='tanh')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
fluid.optimizer.Adam(learning_rate=1e-3).minimize(avg_loss)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
avg_loss_np = []
for i in xrange(100): # train 100 mini-batch
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
avg_loss_np = []
for i in xrange(100): # train 100 mini-batch
tmp, = exe.run(fetch_list=[avg_loss])
avg_loss_np.append(tmp)
self.assertLess(avg_loss_np[-1], avg_loss_np[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册