提交 40e28de4 编写于 作者: M minqiyang

Add train in mnist ut

test=release/1.4
上级 3bdbb94f
...@@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase):
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128, drop_last=True) paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
mnist.train()
dy_param_init_value = {} dy_param_init_value = {}
for epoch in range(epoch_num): for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册