未验证 提交 b3580ecd 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #8000 from reyoung/feature/make_recognize_digits_normal_unittest

Make recognize digits as a normal python unittest
...@@ -4,4 +4,4 @@ cc_test(test_inference_recognize_digits_mlp ...@@ -4,4 +4,4 @@ cc_test(test_inference_recognize_digits_mlp
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model) ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model)
set_tests_properties(test_inference_recognize_digits_mlp set_tests_properties(test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits_mlp_cpu) PROPERTIES DEPENDS test_recognize_digits)
recognize_digits_*.inference.model
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_recognize_digits)
py_test(test_recognize_digits_mlp_cpu
SRCS test_recognize_digits.py
ARGS mlp)
py_test(test_recognize_digits_mlp_cuda
SRCS test_recognize_digits.py
ARGS mlp --use_cuda)
py_test(test_recognize_digits_conv_cpu
SRCS test_recognize_digits.py
ARGS conv)
py_test(test_recognize_digits_conv_cuda
SRCS test_recognize_digits.py
ARGS conv --use_cuda)
py_test(test_recognize_digits_mlp_cpu_parallel
SRCS test_recognize_digits.py
ARGS mlp --parallel)
py_test(test_recognize_digits_mlp_cuda_parallel
SRCS test_recognize_digits.py
ARGS mlp --use_cuda --parallel)
py_test(test_recognize_digits_conv_cpu_parallel
SRCS test_recognize_digits.py
ARGS conv --parallel)
py_test(test_recognize_digits_conv_cuda_parallel
SRCS test_recognize_digits.py
ARGS conv --use_cuda --parallel)
# default test # default test
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
......
...@@ -17,6 +17,7 @@ import paddle.v2.fluid as fluid ...@@ -17,6 +17,7 @@ import paddle.v2.fluid as fluid
import paddle.v2 as paddle import paddle.v2 as paddle
import sys import sys
import numpy import numpy
import unittest
def parse_arg(): def parse_arg():
...@@ -74,18 +75,18 @@ def conv_net(img, label): ...@@ -74,18 +75,18 @@ def conv_net(img, label):
return loss_net(conv_pool_2, label) return loss_net(conv_pool_2, label)
def train(args, save_dirname=None): def train(nn_type, use_cuda, parallel, save_dirname):
print("recognize digits with args: {0}".format(" ".join(sys.argv[1:]))) if use_cuda and not fluid.core.is_compiled_with_cuda():
return
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32') img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if args.nn_type == 'mlp': if nn_type == 'mlp':
net_conf = mlp net_conf = mlp
else: else:
net_conf = conv_net net_conf = conv_net
if args.parallel: if parallel:
places = fluid.layers.get_places() places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places) pd = fluid.layers.ParallelDo(places)
with pd.do(): with pd.do():
...@@ -107,7 +108,7 @@ def train(args, save_dirname=None): ...@@ -107,7 +108,7 @@ def train(args, save_dirname=None):
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss) optimizer.minimize(avg_loss)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -147,13 +148,14 @@ def train(args, save_dirname=None): ...@@ -147,13 +148,14 @@ def train(args, save_dirname=None):
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'. 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
format(pass_id, batch_id + 1, format(pass_id, batch_id + 1,
float(avg_loss_val), float(acc_val))) float(avg_loss_val), float(acc_val)))
raise AssertionError("Loss of recognize digits is too large")
def infer(args, save_dirname=None): def infer(use_cuda, save_dirname=None):
if save_dirname is None: if save_dirname is None:
return return
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc, # Use fluid.io.load_inference_model to obtain the inference program desc,
...@@ -174,11 +176,48 @@ def infer(args, save_dirname=None): ...@@ -174,11 +176,48 @@ def infer(args, save_dirname=None):
print("infer results: ", results[0]) print("infer results: ", results[0])
if __name__ == '__main__': def main(use_cuda, parallel, nn_type):
args = parse_arg() if not use_cuda and not parallel:
if not args.use_cuda and not args.parallel: save_dirname = "recognize_digits_" + nn_type + ".inference.model"
save_dirname = "recognize_digits_" + args.nn_type + ".inference.model"
else: else:
save_dirname = None save_dirname = None
train(args, save_dirname)
infer(args, save_dirname) train(
nn_type=nn_type,
use_cuda=use_cuda,
parallel=parallel,
save_dirname=save_dirname)
infer(use_cuda=use_cuda, save_dirname=save_dirname)
class TestRecognizeDigits(unittest.TestCase):
pass
def inject_test_method(use_cuda, parallel, nn_type):
def __impl__(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
main(use_cuda, parallel, nn_type)
fn = 'test_{0}_{1}_{2}'.format(nn_type, 'cuda'
if use_cuda else 'cpu', 'parallel'
if parallel else 'normal')
setattr(TestRecognizeDigits, fn, __impl__)
def inject_all_tests():
for use_cuda in (False, True):
for parallel in (False, True):
for nn_type in ('mlp', 'conv'):
inject_test_method(use_cuda, parallel, nn_type)
inject_all_tests()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册