未验证 提交 c9a96575 编写于 作者: Q Qiao Longfei 提交者: GitHub

py_test and test_image_classification_train support argument (#5934)

* py_test support argument, test_image_classification_train support argument

* use REMOVE_ITEM to rm item from list in cmake
上级 9de8ce17
......@@ -459,11 +459,11 @@ function(py_test TARGET_NAME)
if(WITH_TESTING)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(py_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_test(NAME ${TARGET_NAME}
COMMAND env PYTHONPATH=${PADDLE_PYTHON_BUILD_DIR}/lib-python
${PYTHON_EXECUTABLE} ${py_test_SRCS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endfunction()
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
list(REMOVE_ITEM TEST_OPS test_image_classification_train)
py_test(test_image_classification_train_resnet SRCS test_image_classification_train.py ARGS resnet)
py_test(test_image_classification_train_vgg SRCS test_image_classification_train.py ARGS vgg)
# default test
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import sys
def resnet_cifar10(input, depth=32):
......@@ -80,11 +82,18 @@ data_shape = [3, 32, 32]
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Add neural network config
# option 1. resnet
# net = resnet_cifar10(images, 32)
# option 2. vgg
net = vgg16_bn_drop(images)
net_type = "vgg"
if len(sys.argv) >= 2:
net_type = sys.argv[1]
if net_type == "vgg":
print("train vgg net")
net = vgg16_bn_drop(images)
elif net_type == "resnet":
print("train resnet")
net = resnet_cifar10(images, 32)
else:
raise ValueError("%s network is not supported" % net_type)
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册