提交 6f306f09 编写于 作者: C chengduoZH

refine unit test

上级 676dfd18
...@@ -722,9 +722,10 @@ class TestCRFModel(unittest.TestCase): ...@@ -722,9 +722,10 @@ class TestCRFModel(unittest.TestCase):
# test fetch all the variables of global_block # test fetch all the variables of global_block
import paddle.dataset.flowers as flowers import paddle.dataset.flowers as flowers
import math
def lenet(data, class_dim): def Lenet(data, class_dim):
conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None) conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None)
bn1 = fluid.layers.batch_norm(conv1, act='relu') bn1 = fluid.layers.batch_norm(conv1, act='relu')
pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2) pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2)
...@@ -774,25 +775,25 @@ class TestFetchOp(unittest.TestCase): ...@@ -774,25 +775,25 @@ class TestFetchOp(unittest.TestCase):
fetch_list = [] fetch_list = []
all_vars = main.global_block().vars all_vars = main.global_block().vars
for k, v in all_vars.iteritems(): for k, v in all_vars.iteritems():
if 'velocity' not in k: if 'tmp' not in k and k[0] is not '_' or v.persistable:
fetch_list.append(k) fetch_list.append(k)
for data in train_inputs: for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data)) ret = pe.run(fetch_list, feed=feeder.feed(data))
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
print("%s - %s" % (fetch_list[i], np.sum(ret[i]))) assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i]))
def test_update_sparse_parameter(self): def test_update_sparse_parameter(self):
tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16)
tst_reader_iter = tst_reader() tst_reader_iter = tst_reader()
seed = 100 iters = 3
iters = 4
train_inputs = [] train_inputs = []
for i in range(iters): for i in range(iters):
train_inputs.append(tst_reader_iter.next()) train_inputs.append(tst_reader_iter.next())
self.parallel_exe(train_inputs, seed) self.parallel_exe(train_inputs, seed=1)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册