From c9a96575d5aa89d143025d36ce105b05ed572be3 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 27 Nov 2017 16:42:08 +0800 Subject: [PATCH] py_test and test_image_classification_train support argument (#5934) * py_test support argument, test_image_classification_train support argument * use REMOVE_ITEM to rm item from list in cmake --- cmake/generic.cmake | 6 +++--- .../paddle/v2/fluid/tests/book/CMakeLists.txt | 6 ++++++ .../book/test_image_classification_train.py | 19 ++++++++++++++----- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 404717187d0..7b82d409a3b 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -459,11 +459,11 @@ function(py_test TARGET_NAME) if(WITH_TESTING) set(options STATIC static SHARED shared) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) - cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python - ${PYTHON_EXECUTABLE} ${py_test_SRCS} + ${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() endfunction() diff --git a/python/paddle/v2/fluid/tests/book/CMakeLists.txt b/python/paddle/v2/fluid/tests/book/CMakeLists.txt index 4d7664469e4..a35abe3e0c4 100644 --- a/python/paddle/v2/fluid/tests/book/CMakeLists.txt +++ b/python/paddle/v2/fluid/tests/book/CMakeLists.txt @@ -1,5 +1,11 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +list(REMOVE_ITEM TEST_OPS test_image_classification_train) +py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet) +py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg) + +# default test foreach(src ${TEST_OPS}) py_test(${src} SRCS ${src}.py) endforeach() diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index 690c5339719..cc45b10b908 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -1,7 +1,9 @@ from __future__ import print_function + import numpy as np import paddle.v2 as paddle import paddle.v2.fluid as fluid +import sys def resnet_cifar10(input, depth=32): @@ -80,11 +82,18 @@ data_shape = [3, 32, 32] images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') -# Add neural network config -# option 1. resnet -# net = resnet_cifar10(images, 32) -# option 2. vgg -net = vgg16_bn_drop(images) +net_type = "vgg" +if len(sys.argv) >= 2: + net_type = sys.argv[1] + +if net_type == "vgg": + print("train vgg net") + net = vgg16_bn_drop(images) +elif net_type == "resnet": + print("train resnet") + net = resnet_cifar10(images, 32) +else: + raise ValueError("%s network is not supported" % net_type) predict = fluid.layers.fc(input=net, size=classdim, act='softmax') cost = fluid.layers.cross_entropy(input=predict, label=label) -- GitLab