提交 af5ac2c4 编写于 作者: G gongweibao

merge with upstream develop

...@@ -27,6 +27,7 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -27,6 +27,7 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET)
include(simd) include(simd)
...@@ -48,6 +49,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF) ...@@ -48,6 +49,7 @@ option(COVERALLS_UPLOAD "Package code coverage data to coveralls" OFF)
option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF) option(ON_TRAVIS "Exclude special unit test on Travis CI" OFF)
option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF) option(WITH_C_API "Compile PaddlePaddle with C-API(Prediction)" OFF)
option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
# CMAKE_BUILD_TYPE # CMAKE_BUILD_TYPE
if(NOT CMAKE_BUILD_TYPE) if(NOT CMAKE_BUILD_TYPE)
...@@ -110,6 +112,7 @@ include_directories("${PROJ_ROOT}") ...@@ -110,6 +112,7 @@ include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient")
include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS set(EXTERNAL_LIBS
${GFLAGS_LIBRARIES} ${GFLAGS_LIBRARIES}
...@@ -127,15 +130,21 @@ if(WITH_GPU) ...@@ -127,15 +130,21 @@ if(WITH_GPU)
endif(NOT WITH_DSO) endif(NOT WITH_DSO)
endif(WITH_GPU) endif(WITH_GPU)
if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIB} ${PTHREADPOOL_LIB} "rt")
endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
add_subdirectory(paddle)
add_subdirectory(python)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)
#TODO (add go/master/c back when fixed) add_subdirectory(go/master/c)
add_subdirectory(go/pserver/cclient) add_subdirectory(go/pserver/cclient)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle)
add_subdirectory(python)
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()
...@@ -69,3 +69,27 @@ endif(NOT WITH_GPU) ...@@ -69,3 +69,27 @@ endif(NOT WITH_GPU)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SIMD_FLAG}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SIMD_FLAG}")
if(WITH_GOLANG)
# we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go
# get ./... will download a new Paddle repo from Github,
# without the changes in our current Paddle repo that we
# want to build.
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle")
add_custom_target(go_path)
add_custom_command(TARGET go_path
# Symlink Paddle directory into GOPATH
COMMAND mkdir -p ${PADDLE_IN_GOPATH}
COMMAND rm -rf ${PADDLE_IN_GOPATH}
COMMAND ln -sf ${CMAKE_SOURCE_DIR} ${PADDLE_IN_GOPATH}
# Automatically get all dependencies specified in the source code
# We can't run `go get -d ./...` for every target, because
# multiple `go get` can not run concurrently, but make need to be
# able to run with multiple jobs.
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./go/...
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
endif(WITH_GOLANG)
...@@ -7,8 +7,17 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) ...@@ -7,8 +7,17 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
ExternalProject_Add( ExternalProject_Add(
eigen3 eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" # for latest version, please get from official website
URL_MD5 "1a47e78efe365a97de0c022d127607c3" # URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
# limitations under the License. # limitations under the License.
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
# Always invoke `FIND_PACKAGE(Protobuf)` for importing function protobuf_generate_cpp
FIND_PACKAGE(Protobuf QUIET)
SET(PROTOBUF_FOUND "OFF")
# Print and set the protobuf library information, # Print and set the protobuf library information,
# finish this cmake process and exit from this file. # finish this cmake process and exit from this file.
...@@ -39,12 +43,19 @@ macro(PROMPT_PROTOBUF_LIB) ...@@ -39,12 +43,19 @@ macro(PROMPT_PROTOBUF_LIB)
ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL) ADD_LIBRARY(protobuf_lite ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY}) SET_PROPERTY(TARGET protobuf_lite PROPERTY IMPORTED_LOCATION ${PROTOBUF_LITE_LIBRARY})
ADD_LIBRARY(protoc ${protobuf_LIBTYPE} IMPORTED GLOBAL) ADD_LIBRARY(libprotoc ${protobuf_LIBTYPE} IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY}) SET_PROPERTY(TARGET libprotoc PROPERTY IMPORTED_LOCATION ${PROTOC_LIBRARY})
ADD_EXECUTABLE(protoc IMPORTED GLOBAL)
SET_PROPERTY(TARGET protoc PROPERTY IMPORTED_LOCATION ${PROTOBUF_PROTOC_EXECUTABLE})
# FIND_Protobuf.cmake uses `Protobuf_PROTOC_EXECUTABLE`.
# make `protobuf_generate_cpp` happy.
SET(Protobuf_PROTOC_EXECUTABLE ${PROTOBUF_PROTOC_EXECUTABLE})
FOREACH(dep ${protobuf_DEPS}) FOREACH(dep ${protobuf_DEPS})
ADD_DEPENDENCIES(protobuf ${dep}) ADD_DEPENDENCIES(protobuf ${dep})
ADD_DEPENDENCIES(protobuf_lite ${dep}) ADD_DEPENDENCIES(protobuf_lite ${dep})
ADD_DEPENDENCIES(libprotoc ${dep})
ADD_DEPENDENCIES(protoc ${dep}) ADD_DEPENDENCIES(protoc ${dep})
ENDFOREACH() ENDFOREACH()
...@@ -133,18 +144,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST) ...@@ -133,18 +144,7 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
ENDFUNCTION() ENDFUNCTION()
SET(PROTOBUF_VERSION 3.1) SET(PROTOBUF_VERSION 3.1)
IF(NOT CMAKE_CROSSCOMPILING) IF(CMAKE_CROSSCOMPILING)
FIND_PACKAGE(Protobuf ${PROTOBUF_VERSION})
IF(PROTOBUF_FOUND)
SET_PROTOBUF_VERSION()
IF("${PROTOBUF_VERSION}" VERSION_LESS "3.1.0")
SET(PROTOBUF_FOUND OFF)
ELSE()
PROMPT_PROTOBUF_LIB()
ENDIF()
ENDIF(PROTOBUF_FOUND)
ELSE()
build_protobuf(protobuf_host TRUE) build_protobuf(protobuf_host TRUE)
LIST(APPEND external_project_dependencies protobuf_host) LIST(APPEND external_project_dependencies protobuf_host)
......
...@@ -32,193 +32,6 @@ IF(PYTHONINTERP_FOUND) ...@@ -32,193 +32,6 @@ IF(PYTHONINTERP_FOUND)
MESSAGE(FATAL_ERROR "Found Python Protobuf ${PY_GOOGLE.PROTOBUF_VERSION} < 3.0.0, " MESSAGE(FATAL_ERROR "Found Python Protobuf ${PY_GOOGLE.PROTOBUF_VERSION} < 3.0.0, "
"please use pip to upgrade protobuf. pip install -U protobuf") "please use pip to upgrade protobuf. pip install -U protobuf")
ENDIF() ENDIF()
ELSE(PYTHONINTERP_FOUND)
MESSAGE(FATAL_ERROR "Please install python 2.7 before building PaddlePaddle.")
##################################### PYTHON ########################################
SET(PYTHON_SOURCES_DIR ${THIRD_PARTY_PATH}/python)
SET(PYTHON_INSTALL_DIR ${THIRD_PARTY_PATH}/install/python)
SET(_python_DIR ${PYTHON_INSTALL_DIR})
IF(UNIX)
SET(PYTHON_FOUND ON)
SET(PYTHON_INCLUDE_DIR "${PYTHON_INSTALL_DIR}/include/python2.7" CACHE PATH "Python include dir" FORCE)
SET(PYTHON_LIBRARIES "${PYTHON_INSTALL_DIR}/lib/libpython2.7.a" CACHE FILEPATH "Python library" FORCE)
SET(PYTHON_EXECUTABLE ${PYTHON_INSTALL_DIR}/bin/python CACHE FILEPATH "Python executable" FORCE)
SET(PY_SITE_PACKAGES_PATH "${PYTHON_INSTALL_DIR}/lib/python2.7/site-packages" CACHE PATH "Python site-packages path" FORCE)
ELSEIF(WIN32)
SET(PYTHON_FOUND ON)
SET(PYTHON_INCLUDE_DIR "${PYTHON_INSTALL_DIR}/include" CACHE PATH "Python include dir" FORCE)
SET(PYTHON_LIBRARIES "${PYTHON_INSTALL_DIR}/libs/python27.lib" CACHE FILEPATH "Python library" FORCE)
SET(PYTHON_EXECUTABLE "${PYTHON_INSTALL_DIR}/bin/python.exe" CACHE FILEPATH "Python executable" FORCE)
SET(PY_SITE_PACKAGES_PATH "${PYTHON_INSTALL_DIR}/Lib/site-packages" CACHE PATH "Python site-packages path" FORCE)
ELSE()
MESSAGE(FATAL_ERROR "Unknown system !")
ENDIF()
IF(APPLE)
LIST(APPEND EXTERNAL_PROJECT_OPTIONAL_CMAKE_ARGS
-DCMAKE_BUILD_WITH_INSTALL_RPATH:BOOL=ON
)
ENDIF()
SET(EXTERNAL_PROJECT_OPTIONAL_CMAKE_CACHE_ARGS)
# Force Python build to "Release".
IF(CMAKE_CONFIGURATION_TYPES)
SET(SAVED_CMAKE_CFG_INTDIR ${CMAKE_CFG_INTDIR})
SET(CMAKE_CFG_INTDIR "Release")
ELSE()
LIST(APPEND EXTERNAL_PROJECT_OPTIONAL_CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=Release
)
ENDIF()
ExternalProject_Add(python
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/python-cmake-buildsystem/python-cmake-buildsystem.git"
PREFIX ${PYTHON_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DPYTHON_VERSION=2.7.12
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
CMAKE_ARGS -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
CMAKE_CACHE_ARGS
-DCMAKE_INSTALL_PREFIX:PATH=${PYTHON_INSTALL_DIR}
-DBUILD_LIBPYTHON_SHARED:BOOL=OFF
-DUSE_SYSTEM_LIBRARIES:BOOL=OFF
-DZLIB_ROOT:FILEPATH=${ZLIB_ROOT}
-DZLIB_INCLUDE_DIR:PATH=${ZLIB_INCLUDE_DIR}
-DZLIB_LIBRARY:FILEPATH=${ZLIB_LIBRARIES}
-DDOWNLOAD_SOURCES:BOOL=ON
-DINSTALL_WINDOWS_TRADITIONAL:BOOL=OFF
${EXTERNAL_PROJECT_OPTIONAL_CMAKE_CACHE_ARGS}
${EXTERNAL_PROJECT_OPTIONAL_CMAKE_ARGS}
DEPENDS zlib
)
SET(py_env
PATH=${PYTHON_INSTALL_DIR}/bin
PYTHONHOME=${PYTHON_INSTALL_DIR}
PYTHONPATH=${PYTHON_INSTALL_DIR}/lib:${PYTHON_INSTALL_DIR}/lib/python2.7:${PY_SITE_PACKAGES_PATH})
####################################################################################
##################################### SETUPTOOLS ###################################
SET(SETUPTOOLS_SOURCES_DIR ${PYTHON_SOURCES_DIR}/setuptools)
ExternalProject_Add(setuptools
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${SETUPTOOLS_SOURCES_DIR}
URL "https://pypi.python.org/packages/source/s/setuptools/setuptools-18.3.2.tar.gz"
BUILD_IN_SOURCE 1
PATCH_COMMAND ""
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
INSTALL_COMMAND ""
BUILD_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
DEPENDS python zlib
)
#####################################################################################
##################################### SIX ###########################################
SET(SIX_SOURCES_DIR ${PYTHON_SOURCES_DIR}/six)
ExternalProject_Add(six
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${SIX_SOURCES_DIR}
URL https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz
BUILD_IN_SOURCE 1
PATCH_COMMAND ""
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
INSTALL_COMMAND ""
BUILD_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
DEPENDS python setuptools
)
#####################################################################################
##################################### CYTHON ########################################
SET(CYTHON_SOURCES_DIR ${PYTHON_SOURCES_DIR}/cython)
ExternalProject_Add(cython
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${CYTHON_SOURCES_DIR}
URL https://github.com/cython/cython/archive/0.25.2.tar.gz
GIT_TAG 0.25.2
BUILD_IN_SOURCE 1
CONFIGURE_COMMAND ""
PATCH_COMMAND ""
UPDATE_COMMAND ""
INSTALL_COMMAND ""
BUILD_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
DEPENDS python
)
####################################################################################
##################################### NUMPY ########################################
SET(NUMPY_SOURCES_DIR ${PYTHON_SOURCES_DIR}/numpy)
SET(NUMPY_TAG_VERSION "v1.11.3")
SET(NUMPY_VERSION "1.11.3")
SET(EGG_NAME "")
SET(PYTHON_NUMPY_INCLUDE_DIR "")
IF(WIN32)
SET(EGG_NAME "numpy-${NUMPY_VERSION}-py2.7-${HOST_SYSTEM}.egg")
ELSE(WIN32)
IF(APPLE)
SET(EGG_NAME "numpy-${NUMPY_VERSION}-py2.7-${HOST_SYSTEM}-${MACOS_VERSION}")
ELSE(APPLE)
SET(EGG_NAME "numpy-${NUMPY_VERSION}-py2.7-linux")
SET(EGG_NAME "numpy-${NUMPY_VERSION}-py2.7-linux")
ENDIF(APPLE)
FOREACH(suffix x86_64 intel fat64 fat32 universal)
LIST(APPEND PYTHON_NUMPY_INCLUDE_DIR ${PY_SITE_PACKAGES_PATH}/${EGG_NAME}-${suffix}.egg/numpy/core/include)
ENDFOREACH()
ENDIF(WIN32)
ExternalProject_Add(numpy
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY https://github.com/numpy/numpy.git
GIT_TAG ${NUMPY_TAG_VERSION}
CONFIGURE_COMMAND ""
UPDATE_COMMAND ""
PREFIX ${NUMPY_SOURCES_DIR}
BUILD_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py build
INSTALL_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
BUILD_IN_SOURCE 1
DEPENDS python setuptools cython
)
####################################################################################
##################################### WHEEL ########################################
SET(WHEEL_SOURCES_DIR ${PYTHON_SOURCES_DIR}/wheel)
ExternalProject_Add(wheel
${EXTERNAL_PROJECT_LOG_ARGS}
URL https://pypi.python.org/packages/source/w/wheel/wheel-0.29.0.tar.gz
PREFIX ${WHEEL_SOURCES_DIR}
CONFIGURE_COMMAND ""
UPDATE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
BUILD_IN_SOURCE 1
DEPENDS python setuptools
)
####################################################################################
################################### PROTOBUF #######################################
SET(PY_PROTOBUF_SOURCES_DIR ${PYTHON_SOURCES_DIR}/protobuf)
ExternalProject_Add(python-protobuf
${EXTERNAL_PROJECT_LOG_ARGS}
URL https://pypi.python.org/packages/e0/b0/0a1b364fe8a7d177b4b7d4dca5b798500dc57a7273b93cca73931b305a6a/protobuf-3.1.0.post1.tar.gz
URL_MD5 38b5fb160c768d2f8444d0c6d637ff91
PREFIX ${PY_PROTOBUF_SOURCES_DIR}
BUILD_IN_SOURCE 1
PATCH_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py build
INSTALL_COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py install
DEPENDS python setuptools six
)
####################################################################################
LIST(APPEND external_project_dependencies python setuptools six cython wheel python-protobuf numpy)
ENDIF(PYTHONINTERP_FOUND) ENDIF(PYTHONINTERP_FOUND)
IF(WITH_PYTHON) IF(WITH_PYTHON)
......
...@@ -87,6 +87,9 @@ ...@@ -87,6 +87,9 @@
# go_library(example SHARED) # go_library(example SHARED)
# #
# including binary directory for generated headers.
include_directories(${CMAKE_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
...@@ -98,23 +101,16 @@ function(merge_static_libs TARGET_NAME) ...@@ -98,23 +101,16 @@ function(merge_static_libs TARGET_NAME)
# First get the file names of the libraries to be merged # First get the file names of the libraries to be merged
foreach(lib ${libs}) foreach(lib ${libs})
get_target_property(libtype ${lib} TYPE)
if(NOT libtype STREQUAL "STATIC_LIBRARY")
message(FATAL_ERROR "merge_static_libs can only process static libraries")
endif()
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>) set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach() endforeach()
if(APPLE) # Use OSX's libtool to merge archives if(APPLE) # Use OSX's libtool to merge archives
add_custom_target(${TARGET_NAME}_archive set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
COMMAND libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles} file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} add_library(${TARGET_NAME} STATIC ${dummyfile})
DEPENDS ${libs} add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
) COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
add_library(${TARGET_NAME} STATIC IMPORTED GLOBAL) COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
set_property(TARGET ${TARGET_NAME} PROPERTY
IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a")
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_archive)
else() # general UNIX: use "ar" to extract objects and re-add to a common lib else() # general UNIX: use "ar" to extract objects and re-add to a common lib
foreach(lib ${libs}) foreach(lib ${libs})
set(objlistfile ${lib}.objlist) # list of objects in the input library set(objlistfile ${lib}.objlist) # list of objects in the input library
...@@ -143,9 +139,9 @@ function(merge_static_libs TARGET_NAME) ...@@ -143,9 +139,9 @@ function(merge_static_libs TARGET_NAME)
set(outlibfile "$<TARGET_FILE:${TARGET_NAME}>") set(outlibfile "$<TARGET_FILE:${TARGET_NAME}>")
foreach(lib ${libs}) foreach(lib ${libs})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND ${CMAKE_AR} ru ${outlibfile} @"../${objlistfile}" COMMAND ${CMAKE_AR} ru ${outlibfile} @"../${lib}.objlist"
WORKING_DIRECTORY ${objdir}) WORKING_DIRECTORY ${lib}.objdir)
endforeach() endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
...@@ -253,10 +249,6 @@ function(nv_test TARGET_NAME) ...@@ -253,10 +249,6 @@ function(nv_test TARGET_NAME)
endif() endif()
endfunction(nv_test) endfunction(nv_test)
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle")
function(go_library TARGET_NAME) function(go_library TARGET_NAME)
set(options STATIC static SHARED shared) set(options STATIC static SHARED shared)
set(oneValueArgs "") set(oneValueArgs "")
...@@ -265,10 +257,10 @@ function(go_library TARGET_NAME) ...@@ -265,10 +257,10 @@ function(go_library TARGET_NAME)
if (go_library_SHARED OR go_library_shared) if (go_library_SHARED OR go_library_shared)
set(BUILD_MODE "-buildmode=c-shared") set(BUILD_MODE "-buildmode=c-shared")
set(LIB_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
else() else()
set(BUILD_MODE "-buildmode=c-archive") set(BUILD_MODE "-buildmode=c-archive")
set(LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}")
endif() endif()
# Add dummy code to support `make target_name` under Terminal Command # Add dummy code to support `make target_name` under Terminal Command
...@@ -283,25 +275,17 @@ function(go_library TARGET_NAME) ...@@ -283,25 +275,17 @@ function(go_library TARGET_NAME)
add_dependencies(${TARGET_NAME} ${go_library_DEPS}) add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS) endif(go_library_DEPS)
# we need to symlink Paddle directory into GOPATH. If we set(${TARGET_NAME}_LIB_PATH "${CMAKE_CURRENT_BINARY_DIR}/${${TARGET_NAME}_LIB_NAME}" CACHE STRING "output library path for target ${TARGET_NAME}")
# don't do it and we have code that depends on Paddle, go
# get ./... will download a new Paddle repo from Github,
# without the changes in our current Paddle repo that we
# want to build.
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Symlink Paddle directory into GOPATH
COMMAND mkdir -p ${PADDLE_IN_GOPATH}
COMMAND rm -rf ${PADDLE_IN_GOPATH}
COMMAND ln -sf ${CMAKE_SOURCE_DIR} ${PADDLE_IN_GOPATH}
# Automatically get all dependencies specified in the source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./...
# Golang build source code # Golang build source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" -o "${${TARGET_NAME}_LIB_PATH}"
${GO_SOURCE} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_dependencies(${TARGET_NAME} go_path)
endfunction(go_library) endfunction(go_library)
function(go_binary TARGET_NAME) function(go_binary TARGET_NAME)
...@@ -331,3 +315,13 @@ function(go_test TARGET_NAME) ...@@ -331,3 +315,13 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(proto_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(proto_srcs)
set(proto_hdrs)
protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS})
cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS protobuf)
endfunction()
...@@ -27,10 +27,6 @@ sphinx_add_target(paddle_docs ...@@ -27,10 +27,6 @@ sphinx_add_target(paddle_docs
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_EN}) ${SPHINX_HTML_DIR_EN})
add_dependencies(paddle_docs
gen_proto_py)
# configured documentation tools and intermediate build results # configured documentation tools and intermediate build results
set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build") set(BINARY_BUILD_DIR_CN "${CMAKE_CURRENT_BINARY_DIR}/cn/_build")
...@@ -51,6 +47,3 @@ sphinx_add_target(paddle_docs_cn ...@@ -51,6 +47,3 @@ sphinx_add_target(paddle_docs_cn
${SPHINX_CACHE_DIR_CN} ${SPHINX_CACHE_DIR_CN}
${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}
${SPHINX_HTML_DIR_CN}) ${SPHINX_HTML_DIR_CN})
add_dependencies(paddle_docs_cn
gen_proto_py)
...@@ -105,3 +105,48 @@ shared_library(api ...@@ -105,3 +105,48 @@ shared_library(api
### Implementation ### Implementation
As above example CMakeLists.txt executes, each function invocation adds "nodes" to a dependency graph. It also use this graph to generate CMake commands including `add_executable`, `add_dependencies`, `target_link_libraries`, and `add_test`. As above example CMakeLists.txt executes, each function invocation adds "nodes" to a dependency graph. It also use this graph to generate CMake commands including `add_executable`, `add_dependencies`, `target_link_libraries`, and `add_test`.
### Using Package Manager For Go
Building Go binaries and libraries need to satisfy their dependencies, generally
we can do `go get ./...` to download and compile all external dependencies. The
problems are:
1. `go get` will always get the latest code from the default branch of the
remote repo, so changes of dependents might break the build. This is very
different with what we already have in `cmake/external` which download a
specific version or commit id of the dependency.
1. Some locations can not access external dependencies through the internet, as mentioned
in https://github.com/PaddlePaddle/Paddle/issues/2605. Using package management
tools can package the dependencies as a "vendor" package, which can be mirrored
at many cloud file hosting, so users what to compile paddle by themselves can
download this "vendor" package from a mirror site.
#### Choose A Suitable Tool
As mentioned by @wangkuiyi, [Here](https://github.com/golang/go/wiki/PackageManagementTools)
list dozens of Go package managers. We choose the tool using following principles:
- Most "active" projects with more stars, more pull requests or commits
- Widely used project
After comparing all these projects, we shall choose between the most popular
tools: Godep and Glide.
Here's a brief comparison between Godep and Glide
: https://github.com/Masterminds/glide/wiki/Go-Package-Manager-Comparison. There are
also many complaints about using `Godep`. There's also a new "official" pakcage
management tool has been started at: https://github.com/golang/dep to resolve
such problems, but it's currently at Alpha stage. So the best choice now is
glide obviously.
#### Manage Go Packages
- Dependencies: `go/glide.yaml` will store the dependencies and their versions which
is directly imported by paddle. `go/glide.lock` will store all dependencies recursively
with their commit id. Builds will "lock" to these packages if we don't `glide up`
them
- Vendor package: `go/vendor` directory will generated when running `cmake` command. `cmake`
will download the code corresponding to `go/glide.lock`. If we put a vendor folder
under `go/`, cmake will just check the commit id to the packages under the folder,
if commit id matches, there will be no download at all.
# Design Doc: Save Model
## Overview
The model is the output of the training process. There are two
ways from which user can obtain a model:
- Save model triggered by user code: user code asks PaddlePaddle to
save a model.
- Convert model from the checkpoint: model being converted from
pservers' periodic checkpoint. In this way, the user can cancel a
job at any time, and still have a relatively fresh model (we
checkpoint around every 5 minutes).
### Trainer Saving Model vs. Pservers Saving Model
Both trainers and pservers have access to the model. So the model can
be saved from a trainer or pservers. We need to decide where the model
is saved from.
#### Dense Update vs. Sparse Update
There are two types of model update methods: dense update and sparse
update (when the model parameter is configured to be sparse).
- Dense update
Every trainer has it's own full copy of the model. Every model
update will update the entire model.
- Sparse update
The training input is sparse, and the trainer does not have the
entire model. It will only download the sub-model necessary related
to the input. When updating the model, only the sub-model related to
the training input is updated.
#### Pservers Saving Model
The benefit of letting pservers save model is they have the entire
model all the time. However, since pservers are on different nodes, it
requires a merging process to merge model shards into the same
model. Thus requires the pservers to write models to a distributed
filesystem, making the checkpoint shards visible to the merge program.
#### Trainer Saving Model
The benefit of letting one trainer to save the model is it does not
require a distributed filesystem. And it's reusing the same save model
logic when training locally - except when doing sparse update, the
trainer needs to download the entire model during the saving process.
#### Conclusion
Given trainer saving model does not require a distributed filesystem,
and is an intuitive extension to trainer saving model when training
locally, we decide to let the trainer save the model when doing
distributed training.
### Convert Model from Checkpoint
TODO
## Timeline
We first implement trainer save the model. Converting the latest
snapshot to a model will be a TODO for future.
## Trainer Save Model
### Trainer Election
One trainer will be elected as the one to save the model. When using
etcd, trainer ID is a randomly generated UUID, we will utilize etcd to
elect one trainer. When not using etcd, unique trainer IDs will be
given by the administrator, the trainer whose ID is "0" is elected to
save the model.
### Model Save Path
Each trainer will be given the directory to save the model. The
elected trainer will save the model to
`given-directory/trainerID`. Since the trainer ID is unique, this
would prevent concurrent save to the same file when multiple trainers
are elected to save the model when split-brain problem happens.
### What Happens When Model Is Saving
It takes some time to save model, we need to define what will happen
when save model is taking place.
When doing dense update, the trainer uses the local model. Pservers
does not need to pause model update.
When doing sparse update. The trainer needs to download the entire
model while saving. To get the most accurate model, the model update
needs to be paused before the download starts and resumed after the
download finishes. Otherwise, the trainer gets a model that is
"polluted": some part of the model is old, some part of the model is
new.
It's unclear that the "polluted" model will be inferior due to the
stochastic nature of deep learning, and pausing the model update will
add more complexity to the system. Since supporting sparse update is a
TODO item. We defer the evaluation of pause the model update or not
during saving model to the future.
...@@ -41,7 +41,7 @@ class Scope { ...@@ -41,7 +41,7 @@ class Scope {
const Variable* GetVariable(const std::string& name) const; const Variable* GetVariable(const std::string& name) const;
private: private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_; std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
}; };
``` ```
...@@ -59,9 +59,9 @@ class Scope { ...@@ -59,9 +59,9 @@ class Scope {
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {} Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const { Variable* GetVariable(const std::string& name) const {
Variable* var = GetVarLocally(name); auto it = vars_.find(name);
if (var != nullptr) { if (it != vars_.end()) {
return var; return it->second.get();
} else if (parent_ != nullptr) { } else if (parent_ != nullptr) {
return parent_->GetVariable(name); return parent_->GetVariable(name);
} else { } else {
...@@ -97,8 +97,8 @@ class Scope { ...@@ -97,8 +97,8 @@ class Scope {
// return nullptr if not found. // return nullptr if not found.
Variable* GetVariable(const std::string& name) const; Variable* GetVariable(const std::string& name) const;
// return Error if already contains same name variable. // return if already contains same name variable.
Error CreateVariable(const std::string& name); Variable* CreateVariable(const std::string& name);
private: private:
std::shared_ptr<Scope> parent_; std::shared_ptr<Scope> parent_;
......
...@@ -31,7 +31,7 @@ def event_handler(event): ...@@ -31,7 +31,7 @@ def event_handler(event):
# define training dataset reader # define training dataset reader
def train_reader(): def train_reader():
train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]]) train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]])
train_y = np.array([-2, -3, -7, -7]) train_y = np.array([[-2], [-3], [-7], [-7]])
def reader(): def reader():
for i in xrange(train_y.shape[0]): for i in xrange(train_y.shape[0]):
......
...@@ -30,7 +30,13 @@ func main() { ...@@ -30,7 +30,13 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
}
s, err := pserver.NewService(idx)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) go_library(paddle_master SHARED)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
#include(golang)
include(flags)
set(MASTER_LIB_NAME "paddle_master")
go_library(${MASTER_LIB_NAME} SHARED)
if(PROJ_ROOT)
add_custom_command(OUTPUT ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so
COMMAND rm ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.h
COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/lib${MASTER_LIB_NAME}.so ${PROJ_ROOT}/python/paddle/v2/master/
DEPENDS ${MASTER_LIB_NAME})
add_custom_target(paddle_master_shared ALL DEPENDS ${PROJ_ROOT}/python/paddle/v2/master/lib${MASTER_LIB_NAME}.so)
endif(PROJ_ROOT)
...@@ -13,10 +13,13 @@ typedef int paddle_master_client; ...@@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C" import "C"
import ( import (
"strings"
"sync" "sync"
"time"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h return h
} }
type addresser string //export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
func (a addresser) Address() string { p := C.GoString(etcdEndpoints)
return string(a) cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize) ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
......
...@@ -2,18 +2,12 @@ package master ...@@ -2,18 +2,12 @@ package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
...@@ -29,11 +23,11 @@ type record struct { ...@@ -29,11 +23,11 @@ type record struct {
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addr Addresser, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan record, bufSize) c.ch = make(chan record, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
} }
...@@ -78,12 +72,10 @@ func (c *Client) getRecords() { ...@@ -78,12 +72,10 @@ func (c *Client) getRecords() {
} }
} }
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := "" lastMaster := ""
monitor := func() { for curMaster := range addrCh {
// get the lastest address of the master server,
// connect to the new address once address changed. // connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster { if curMaster != lastMaster {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
...@@ -100,18 +92,10 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -100,18 +92,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time. // to retry next time.
curMaster = lastMaster curMaster = lastMaster
} }
} }
} }
lastMaster = curMaster lastMaster = curMaster
} }
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset set dataset for the master server to dispatch.
......
...@@ -26,12 +26,6 @@ func init() { ...@@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
} }
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { ...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if err != nil {
...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { ...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
......
...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { ...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
total = 50 total = 50
) )
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil { if err != nil {
...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { ...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
} }
w.Close() w.Close()
f.Close() f.Close()
curAddr := make(chan string, 1)
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
......
...@@ -18,8 +18,8 @@ const ( ...@@ -18,8 +18,8 @@ const (
DefaultAddrPath = "/master/addr" DefaultAddrPath = "/master/addr"
) )
// EtcdClient is the etcd client that master uses for fault tolerance // EtcdClient is the etcd client that the master uses for fault
// and service registry. // tolerance and service registry.
type EtcdClient struct { type EtcdClient struct {
lockPath string lockPath string
statePath string statePath string
...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value state := kvs[0].Value
return state, nil return state, nil
} }
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}
cc_library(main SRCS main.c DEPS paddle_pserver_cclient) cc_binary(main SRCS main.c DEPS paddle_pserver_cclient)
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
package pserver package pserver
import ( import (
"errors"
"hash/fnv" "hash/fnv"
"sort" "sort"
"time" "time"
...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { ...@@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
......
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
...@@ -31,7 +30,7 @@ func init() { ...@@ -31,7 +30,7 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package pserver
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
return &EtcdClient{
etcdTimeout: timeout,
numPservers: numPservers,
etcdEndpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return 0, err
}
// initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout)
continue
}
e.etcdClient = cli
log.Debugf("inited client to %s", e.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
var pserverIdx int
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
idx = i
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
return 0, err
}
return idx, nil
}
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -55,160 +46,25 @@ type Gradient Parameter ...@@ -55,160 +46,25 @@ type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { func NewService(idx int) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{
idx: idx,
opt: newOptimizer(sgd, 0.005),
}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil return s, nil
} }
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
......
...@@ -10,7 +10,7 @@ import ( ...@@ -10,7 +10,7 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) { ...@@ -75,7 +75,7 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { ...@@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService("", time.Second*5) s, err := pserver.NewService(0)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
...@@ -9,17 +9,10 @@ add_subdirectory(pserver) ...@@ -9,17 +9,10 @@ add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer) add_subdirectory(optimizer)
add_subdirectory(strings) add_subdirectory(string)
# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)
# add_subdirectory(go)
# endif()
find_package(Boost QUIET)
if(Boost_FOUND) if(Boost_FOUND)
include_directories(${Boost_INCLUDE_DIRS}) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
endif() endif()
......
...@@ -16,7 +16,7 @@ set(API_HEADER ...@@ -16,7 +16,7 @@ set(API_HEADER
Internal.h) Internal.h)
add_library(paddle_api STATIC ${API_SOURCES}) add_library(paddle_api STATIC ${API_SOURCES})
add_dependencies(paddle_api gen_proto_cpp paddle_trainer_lib) add_dependencies(paddle_api paddle_proto paddle_trainer_lib)
INCLUDE(${SWIG_USE_FILE}) INCLUDE(${SWIG_USE_FILE})
INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle) INCLUDE_DIRECTORIES(${PROJ_ROOT}/paddle)
......
...@@ -26,7 +26,7 @@ target_include_directories(paddle_capi PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) ...@@ -26,7 +26,7 @@ target_include_directories(paddle_capi PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
add_style_check_target(paddle_capi ${CAPI_SOURCES} ${CAPI_HEADER} add_style_check_target(paddle_capi ${CAPI_SOURCES} ${CAPI_HEADER}
${CAPI_PRIVATE_HEADER}) ${CAPI_PRIVATE_HEADER})
add_dependencies(paddle_capi gen_proto_cpp) add_dependencies(paddle_capi paddle_proto)
# combine all paddle static libraries together, into libpaddle_capi_whole.a # combine all paddle static libraries together, into libpaddle_capi_whole.a
......
...@@ -83,7 +83,7 @@ else() ...@@ -83,7 +83,7 @@ else()
${CUDA_CXX_SOURCES}) ${CUDA_CXX_SOURCES})
endif() endif()
add_dependencies(paddle_cuda ${external_project_dependencies}) add_dependencies(paddle_cuda paddle_proto ${external_project_dependencies})
add_style_check_target(paddle_cuda add_style_check_target(paddle_cuda
${CUDA_SOURCES} ${CUDA_SOURCES}
......
# ddim lib
cc_library(ddim SRCS ddim.cc) cc_library(ddim SRCS ddim.cc)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_test(scope_test SRCS scope_test.cc)
cc_test(enforce_test SRCS enforce_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/string/printf.h>
#include <exception>
#include <sstream>
namespace paddle {
namespace framework {
/**
* @brief Enforce exception. Inherits std::exception
*
* All enforce condition not met, will throw an EnforceNotMet exception.
*/
class EnforceNotMet : public std::exception {
public:
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
std::ostringstream sout;
sout << msg << " at [" << file << ":" << fileline << "];";
all_msg_ = sout.str();
}
const char* what() const noexcept override { return all_msg_.c_str(); }
private:
std::string all_msg_;
};
// From https://stackoverflow.com/questions/30130930/
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
// should be true in most situation, it will make the compiler generate faster
// code by adding `UNLIKELY` macro.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
/**
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
* __LINE__
*
* This macro take __VA_ARGS__, user can pass any type if that type can
* serialize to std::ostream
*/
#define PADDLE_THROW(...) \
do { \
throw ::paddle::framework::EnforceNotMet( \
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
} while (0)
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#define PADDLE_ENFORCE(condition, ...) \
do { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
} \
} while (0)
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <paddle/framework/enforce.h>
TEST(ENFORCE, OK) {
PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345);
size_t val = 1;
const size_t limit = 10;
PADDLE_ENFORCE(val < limit, "Enforce is OK too");
}
TEST(ENFORCE, FAILED) {
bool in_catch = false;
try {
PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123);
} catch (paddle::framework::EnforceNotMet err) {
in_catch = true;
std::string msg = "Enforce is not ok 123 at all";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(in_catch);
}
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
/**
* @brief Scope that manage all variables.
*
* Scope is an association of a name to Variable. All variables belong to
* Scope. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`.
* One net can run in different scopes and update different variable in the
* scope.
*/
class Scope {
public:
/**
* @brief Initialize s Scope without parent.
*/
Scope() {}
/**
* @brief Initialize a Scope with parent.
*/
explicit Scope(const std::shared_ptr<Scope>& parent) : parent_(parent) {}
/**
* @brief Create Variable
*
* Create Variable in this Scope. Return the exist one if Variable already
* been created.
*/
Variable* CreateVariable(const std::string& name) {
auto var = GetVariable(name);
if (var) {
return var;
} else {
vars_[name] = std::unique_ptr<Variable>(new Variable());
return GetVariable(name);
}
}
/**
* @brief Get Variable.
*
* Get Variable from this Scope, this function will recursive find Variable
* from it's parent scope. Return nullptr if not found.
*/
Variable* GetVariable(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) {
return it->second.get();
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
/**
* @brief If this scope has a Var named name.
*
* Find if there is a Variable in this scope and it's parent scope
*/
bool HasVariable(const std::string& name) const {
return (vars_.find(name) != vars_.end() ||
(parent_ && parent_->HasVariable(name)));
}
private:
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
std::shared_ptr<Scope> parent_{nullptr};
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/scope.h"
#include "gtest/gtest.h"
TEST(Scope, Create) {
using paddle::framework::Scope;
using paddle::framework::Variable;
auto scope = std::make_shared<Scope>();
Variable* var0 = scope->CreateVariable("");
EXPECT_NE(var0, nullptr);
/// GetVariable will return nullptr if not exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var1, nullptr);
/// CreateVariable will return one.
Variable* var2 = scope->CreateVariable("a");
EXPECT_NE(var2, nullptr);
/// Get the created variable.
Variable* var3 = scope->GetVariable("a");
EXPECT_EQ(var2, var3);
/// CreateVariable will just return the variable if it's
/// already exist.
Variable* var4 = scope->CreateVariable("a");
EXPECT_EQ(var4, var2);
}
TEST(Scope, Parent) {
using paddle::framework::Scope;
using paddle::framework::Variable;
auto parent_scope = std::make_shared<Scope>();
auto scope = std::make_shared<Scope>(parent_scope);
Variable* var0 = parent_scope->CreateVariable("a");
EXPECT_NE(var0, nullptr);
/// GetVariable will get Variable from parent scope if exist.
Variable* var1 = scope->GetVariable("a");
EXPECT_EQ(var0, var1);
}
...@@ -10,9 +10,17 @@ if(WITH_GPU) ...@@ -10,9 +10,17 @@ if(WITH_GPU)
cuda_compile(cu_objs ${cu_files}) cuda_compile(cu_objs ${cu_files})
endif() endif()
if(USE_NNPACK)
include(nnpack/nnpack.cmake)
list(APPEND cpp_files nnpack/NNPACKConvOp.cpp)
if(WITH_TESTING)
add_unittest(NNPACKConvOpTest nnpack/NNPACKConvOpTest.cpp)
endif()
endif()
add_library(paddle_function STATIC ${cpp_files} ${cu_objs}) add_library(paddle_function STATIC ${cpp_files} ${cu_objs})
add_dependencies(paddle_function ${external_project_dependencies}) add_dependencies(paddle_function ${external_project_dependencies})
add_dependencies(paddle_function gen_proto_cpp) add_dependencies(paddle_function paddle_proto)
if(WITH_TESTING) if(WITH_TESTING)
if(WITH_GPU) if(WITH_GPU)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "nnpack.h"
#include "paddle/function/ConvOp.h"
DEFINE_bool(nnpack_allocate_outside,
false,
"Allocate and free workspace memory outside the NNPACK interface.");
DEFINE_int32(nnpack_num_threads,
0,
"The number of nnpack threads"
"default: 0; 0 to disable threadpool.");
namespace paddle {
nnp_convolution_algorithm get_nnp_convolution_algorithm(
const std::string& algorithm) {
if (algorithm == "auto") {
return nnp_convolution_algorithm_auto;
} else if (algorithm == "ft8x8") {
return nnp_convolution_algorithm_ft8x8;
} else if (algorithm == "ft16x16") {
return nnp_convolution_algorithm_ft16x16;
} else if (algorithm == "wt8x8") {
return nnp_convolution_algorithm_wt8x8;
} else if (algorithm == "implicit-gemm") {
return nnp_convolution_algorithm_implicit_gemm;
} else if (algorithm == "direct") {
return nnp_convolution_algorithm_direct;
} else {
return nnp_convolution_algorithm_auto;
}
}
template <DeviceType Device>
class NNPACKConvFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
CHECK_EQ(groups_, (size_t)1);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
// algorithm_ = nnp_convolution_algorithm_auto;
transform_strategy_ = nnp_convolution_transform_strategy_compute;
nnp_status status = nnp_initialize();
CHECK_EQ(status, nnp_status_success);
workspaceBuffer_ = nullptr;
workspaceSize_ = 0;
threadpool_ = nullptr;
if (FLAGS_nnpack_num_threads) {
threadpool_ = pthreadpool_create(FLAGS_nnpack_num_threads);
VLOG(3) << "Number of threads "
<< pthreadpool_get_threads_count(threadpool_);
}
}
~NNPACKConvFunction() {
if (threadpool_) {
pthreadpool_destroy(threadpool_);
}
if (workspaceBuffer_) {
free(workspaceBuffer_);
}
}
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
check(inputs, outputs);
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
// size_t outputHeight = output[2];
// size_t outputWidth = output[3];
nnp_size inputSize = {.width = inputWidth, .height = inputHeight};
nnp_padding padding = {.top = (size_t)paddingH(),
.right = (size_t)paddingW(),
.bottom = (size_t)paddingH(),
.left = (size_t)paddingW()};
nnp_size kernelSize = {.width = filterWidth, .height = filterHeight};
nnp_size outputSubsampling = {.width = (size_t)strideW(),
.height = (size_t)strideH()};
float* inputData = inputs[0].data<float>();
float* filterData = inputs[1].data<float>();
float* outputData = outputs[0].data<float>();
void* bufferPtr = nullptr;
size_t* sizePtr = nullptr;
size_t needSize;
if (FLAGS_nnpack_allocate_outside) {
if (batchSize == 1) {
nnp_status status = nnp_convolution_inference(algorithm_,
transform_strategy_,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
outputSubsampling,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&needSize,
nnp_activation_identity,
nullptr,
nullptr,
nullptr);
CHECK_EQ(status, nnp_status_success);
} else {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
&needSize,
nnp_activation_identity,
nullptr,
nullptr,
nullptr);
CHECK_EQ(status, nnp_status_success);
}
VLOG(3) << "workspace size is " << needSize;
if (needSize > workspaceSize_) {
workspaceSize_ = needSize;
if (workspaceBuffer_) {
free(workspaceBuffer_);
} else {
posix_memalign(&workspaceBuffer_, 64, needSize);
}
}
if (needSize) {
bufferPtr = workspaceBuffer_;
sizePtr = &needSize;
}
}
if (batchSize == 1) {
nnp_status status =
nnp_convolution_inference(algorithm_,
transform_strategy_,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
outputSubsampling,
inputData,
filterData,
nullptr, /* bias */
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
} else {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
inputData,
filterData,
nullptr, /* bias */
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
}
private:
nnp_convolution_algorithm algorithm_;
nnp_convolution_transform_strategy transform_strategy_;
void* workspaceBuffer_;
size_t workspaceSize_;
pthreadpool_t threadpool_;
};
REGISTER_TYPED_FUNC(NNPACKConv, CPU, NNPACKConvFunction);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/function/Function.h"
#include "paddle/function/FunctionTest.h"
DEFINE_string(algo,
"auto",
"The algorithm (auto, ft8x8, ft16x16, wt8x8, "
"implicit-gemm, or direct) for computing convolution of NNPACK.");
namespace paddle {
#define IS_NNPACK_SUPPORT(algo, filterSize, stride) \
if (algo == "direct" && filterSize != 1) continue; \
if (algo == "direct" && batchSize != 1) continue; \
if (algo == "wt8x8" && filterSize != 3) continue; \
if (algo == "implicit-gemm" && batchSize != 1) continue; \
if (algo != "auto" && algo != "implicit-gemm" && stride > 1) continue;
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64, 128}) {
if (inputChannels < outputChannels) break;
for (size_t stride : {1, 2}) {
// if batchSize > 1 NNPACKConv only supports stride = 1
if (batchSize > 1 && stride > 1) break;
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
IS_NNPACK_SUPPORT(algo, filterSize, stride);
LOG(INFO) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{
batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2));
test.run();
}
}
}
}
}
}
}
}
};
TEST(Convolution, NNPACK) {
// NNPACK only supports stride = 1
ConvolutionTest test("GemmConv-CPU", "NNPACKConv-CPU", FLAGS_algo);
}
} // namespace paddle
# Find the NNPACK library
# NNPACK_ROOT - where to find NNPACK include and library.
#
set(NNPACK_FOUND OFF)
set(NNPACK_ROOT $ENV{NNPACK_ROOT} CACHE PATH "Folder contains NNPACK")
find_path(NNPACK_INC_DIR nnpack.h PATHS ${NNPACK_ROOT}/include)
find_library(NNPACK_LIB NAMES nnpack PATHS ${NNPACK_ROOT}/lib)
find_library(PTHREADPOOL_LIB NAMES pthreadpool PATHS ${NNPACK_ROOT}/lib)
if(NNPACK_INC_DIR AND NNPACK_LIB AND PTHREADPOOL_LIB)
set(NNPACK_FOUND ON)
INCLUDE_DIRECTORIES(${NNPACK_INC_DIR})
else()
message(FATAL_ERROR "Cannot find NNPACK in (${NNPACK_ROOT})")
endif()
...@@ -58,7 +58,7 @@ endif() ...@@ -58,7 +58,7 @@ endif()
add_style_check_target(paddle_gserver ${GSERVER_SOURCES}) add_style_check_target(paddle_gserver ${GSERVER_SOURCES})
add_style_check_target(paddle_gserver ${GSERVER_HEADER}) add_style_check_target(paddle_gserver ${GSERVER_HEADER})
add_dependencies(paddle_gserver gen_proto_cpp) add_dependencies(paddle_gserver paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -601,7 +601,7 @@ void TrainerThread::backward() { ...@@ -601,7 +601,7 @@ void TrainerThread::backward() {
void TrainerThread::backwardCallback(Parameter* para) { void TrainerThread::backwardCallback(Parameter* para) {
// CPU parameters are merged in the end // CPU parameters are merged in the end
if (!para->useGpu()) return; if (!para->useGpu() || para->isStatic()) return;
int paramId = para->getID(); int paramId = para->getID();
if (multiMachine_->getNumThreads() == 1) { if (multiMachine_->getNumThreads() == 1) {
......
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
DEFINE_bool(use_nnpack,
false,
"Whether to use nnpack for convolution calculation.");
namespace paddle { namespace paddle {
/* /*
...@@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, ...@@ -37,26 +41,38 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
for (int i = 0; i < config_.inputs_size(); i++) { for (int i = 0; i < config_.inputs_size(); i++) {
std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]}; std::vector<size_t> paddings = {(size_t)paddingY_[i], (size_t)padding_[i]};
std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]}; std::vector<size_t> strides = {(size_t)strideY_[i], (size_t)stride_[i]};
createFunction(forward_,
!isDeconv_ ? "GemmConv" : "GemmConvGradInput", if (FLAGS_use_nnpack) {
FuncConfig() CHECK_EQ(isDeconv_, false);
.set("paddings", paddings) createFunction(forward_,
.set("strides", strides) "NNPACKConv",
.set("groups", (size_t)groups_[i])); FuncConfig()
.set("paddings", paddings)
createFunction(backward_, .set("strides", strides)
!isDeconv_ ? "GemmConvGradInput" : "GemmConv", .set("groups", (size_t)groups_[i])
FuncConfig() .set("algo", std::string("auto")));
.set("paddings", paddings) } else {
.set("strides", strides) createFunction(forward_,
.set("groups", (size_t)groups_[i])); !isDeconv_ ? "GemmConv" : "GemmConvGradInput",
FuncConfig()
createFunction(backward_, .set("paddings", paddings)
"GemmConvGradFilter", .set("strides", strides)
FuncConfig() .set("groups", (size_t)groups_[i]));
.set("paddings", paddings)
.set("strides", strides) createFunction(backward_,
.set("groups", (size_t)groups_[i])); !isDeconv_ ? "GemmConvGradInput" : "GemmConv",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
createFunction(backward_,
"GemmConvGradFilter",
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)groups_[i]));
}
} }
return true; return true;
} }
......
...@@ -33,7 +33,7 @@ endif() ...@@ -33,7 +33,7 @@ endif()
add_style_check_target(paddle_math ${MATH_SOURCES}) add_style_check_target(paddle_math ${MATH_SOURCES})
add_style_check_target(paddle_math ${MATH_HEADERS}) add_style_check_target(paddle_math ${MATH_HEADERS})
add_dependencies(paddle_math gen_proto_cpp) # depends add_dependencies(paddle_math paddle_proto ${external_project_dependencies}) # depends
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
add_subdirectory(detail)
...@@ -97,6 +97,7 @@ class BuddyAllocator { ...@@ -97,6 +97,7 @@ class BuddyAllocator {
struct Block { struct Block {
size_t size; size_t size;
Block* left, right; Block* left, right;
size_t index; // allocator id
}; };
... ...
}; };
......
if(${WITH_GPU})
nv_library(system_allocator SRCS system_allocator.cc DEPS gflags)
nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
else(${WITH_GPU})
cc_library(system_allocator SRCS system_allocator.cc DEPS gflags)
cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags)
endif(${WITH_GPU})
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/detail/buddy_allocator.h"
namespace paddle {
namespace memory {
namespace detail {
BuddyAllocator::BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator)
: pool_size_(pool_size),
max_pools_(max_pools),
system_allocator_(system_allocator) {
PADDLE_ASSERT(pool_size > 0);
PADDLE_ASSERT(max_pools > 0);
PADDLE_ASSERT(system_allocator != nullptr);
}
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/detail/system_allocator.h"
#include <mutex>
#include <vector>
namespace paddle {
namespace memory {
namespace detail {
class BuddyAllocator {
public:
BuddyAllocator(size_t pool_size, size_t max_pools,
SystemAllocator* system_allocator);
~BuddyAllocator();
void* Alloc(size_t size);
void Free(void*);
size_t Used();
private:
struct Block {
size_t size_;
Block* left_; // left buddy
Block* right_; // right buddy
};
// Initially, there is only one pool. If a Alloc founds not enough
// memory from that pool, and there has not been max_num_pools_,
// create a new pool by calling system_allocator_.Alloc(pool_size_).
std::vector<void*> pools_;
size_t pool_size_; // the size of each pool;
size_t max_num_pools_; // the size of all pools;
SystemAllocator* system_allocator_;
std::mutex mutex_;
// Disable copy and assignment.
BuddyAllocator(const BuddyAllocator&) = delete;
BuddyAllocator& operator=(const BuddyAllocator&) = delete;
};
BuddyAllocator<CPUAllocator>* GetCPUBuddyAllocator() {
static BuddyAllocator<CPUAllocator>* a = nullptr;
if (a == nullptr) {
a = new BuddyAllocator<CPUAllocator>();
}
return a;
}
#ifndef PADDLE_ONLY_CPU // The following code are for CUDA.
BuddyAllocator<GPUAllocator>* GetGPUBuddyAllocator(int gpu_id) {
static BuddyAllocator<GPUAllocator>** as = NULL;
if (as == NULL) {
int gpu_num = platform::GetDeviceCount();
as = new BuddyAllocator<GPUAllocator>*[gpu_num];
for (int gpu = 0; gpu < gpu_num; gpu++) {
as[gpu] = new BuddyAllocator<GPUAllocator>();
}
}
return as[gpu_id];
}
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <stdlib.h> // for malloc and free
#include <sys/mman.h> // for mlock and munlock
#include "gflags/gflags.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/cuda.h"
// If use_pinned_memory is true, CPUAllocator calls mlock, which
// returns pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the amount
// of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, false,
"If set, allocate cpu/gpu pinned memory.");
namespace paddle {
namespace memory {
namespace detail {
void* CPUAllocator::Alloc(size_t size) {
// According to http://www.cplusplus.com/reference/cstdlib/malloc/,
// malloc might not return nullptr if size is zero, but the returned
// pointer shall not be dereferenced -- so we make it nullptr.
if (size <= 0) return nullptr;
void* p = malloc(size);
if (p != nullptr && FLAGS_use_pinned_memory) {
mlock(p, size);
}
return p;
}
void CPUAllocator::Free(void* p, size_t size) {
if (p != nullptr && FLAGS_use_pinned_memory) {
munlock(p, size);
}
free(p);
}
#ifndef PADDLE_ONLY_CPU
void* GPUAllocator::Alloc(size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
// if size is 0. We just make sure it does.
if (size <= 0) {
return nullptr;
}
void* p = 0;
cudaError_t result =
FLAGS_use_pinned_memory ? cudaMallocHost(&p, size) : cudaMalloc(&p, size);
if (result != cudaSuccess) {
cudaGetLastError(); // clear error if there is any.
}
return result == cudaSuccess ? p : nullptr;
}
void GPUAllocator::Free(void* p, size_t size) {
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// cudaFree succeeds.
cudaError_t err = FLAGS_use_pinned_memory ? cudaFreeHost(p) : cudaFree(p);
if (err != cudaErrorCudartUnloading) {
platform::throw_on_error(err, "cudaFree{Host} failed");
}
}
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stddef.h> // for size_t
namespace paddle {
namespace memory {
namespace detail {
// SystemAllocator is the parent class of CPUAllocator and
// GPUAllocator. A BuddyAllocator object uses a SystemAllocator*
// pointing to the underlying system allocator. An alternative to
// this class hierarchy is to pass a system allocator class to
// BuddyAllocator as a template parameter. This approach makes
// BuddyAllocator a class template, and it's very complicated
// algorithm would make the buddy_allocator.h messy.
class SystemAllocator {
public:
virtual ~SystemAllocator() {}
virtual void* Alloc(size_t size) = 0;
virtual void Free(void* p, size_t size) = 0;
};
class CPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#ifndef PADDLE_ONLY_CPU
class GPUAllocator : public SystemAllocator {
public:
virtual void* Alloc(size_t size);
virtual void Free(void* p, size_t size);
};
#endif // PADDLE_ONLY_CPU
} // namespace detail
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/detail/system_allocator.h"
#include <memory>
#include <vector>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
DECLARE_bool(use_pinned_memory);
void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) {
bool freed = false;
{
void* p = a.Alloc(size);
if (size > 0) {
EXPECT_NE(p, nullptr);
} else {
EXPECT_EQ(p, nullptr);
}
int* i = static_cast<int*>(p);
std::shared_ptr<int> ptr(i, [&](void* p) {
freed = true;
a.Free(p, size);
});
}
EXPECT_TRUE(freed);
}
TEST(CPUAllocator, NoLockMem) {
FLAGS_use_pinned_memory = false;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(CPUAllocator, LockMem) {
FLAGS_use_pinned_memory = true;
paddle::memory::detail::CPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#ifndef PADDLE_ONLY_CPU
TEST(GPUAllocator, NoStaging) {
FLAGS_use_pinned_memory = false;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
TEST(GPUAllocator, Staging) {
FLAGS_use_pinned_memory = true;
paddle::memory::detail::GPUAllocator a;
TestAllocator(a, 2048);
TestAllocator(a, 0);
}
#endif // PADDLE_ONLY_CPU
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/memory.h"
#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/assert.h"
#include <boost/variant.hpp>
namespace paddle {
namespace memory {
void* Alloc(platform::Place pl, size_t size) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
return detail::GetGPUBuddyAllocator(gpu_id)->Alloc(size);
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
return detail::GetCPUBuddyAllocator()->Alloc(size);
}
void Free(paddle::platform::Place pl, void* p) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
detail::GetGPUBuddyAllocator(gpu_id)->Free(p);
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
detail::GetCPUBuddyAllocator()->Free(p);
}
size_t Used(paddle::platform::Place pl) {
#ifndef PADDLE_ONLY_CPU
if (paddle::platform::is_gpu_place(pl)) {
size_t gpu_id = boost::get<platform::GPUPlace>(pl).device;
return detail::GetGPUBuddyAllocator(gpu_id)->Used();
}
#endif // PADDLE_ONLY_CPU
PADDLE_ASSERT(paddle::platform::is_cpu_place(pl));
return detail::GetCPUBuddyAllocator()->Used();
}
} // namespace memory
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -10,17 +13,15 @@ See the License for the specific language governing permissions and ...@@ -10,17 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
/**
* __must_check macro. It make the function's return value must be used, #include "paddle/platform/place.h"
* otherwise it will raise a compile warning. And also Paddle treat all compile
* warnings as errors. namespace paddle {
*/ namespace memory {
#ifdef __GNUC__
#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 void* Alloc(paddle::platform::Place, size_t);
#define __must_check __attribute__((warn_unused_result)) void Free(paddle::platform::Place, void*);
#else size_t Used(paddle::platform::Place);
#define __must_check
#endif } // namespace memory
#else } // namespace paddle
#define __must_check
#endif
...@@ -10,7 +10,7 @@ set(OPITMIZER_SRCS ...@@ -10,7 +10,7 @@ set(OPITMIZER_SRCS
) )
add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS}) add_library(paddle_optimizer STATIC ${OPITMIZER_SRCS})
add_dependencies(paddle_optimizer gen_proto_cpp) add_dependencies(paddle_optimizer paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_simple_unittest(serialization_test) add_simple_unittest(serialization_test)
......
...@@ -7,7 +7,7 @@ add_library(paddle_parameter STATIC ...@@ -7,7 +7,7 @@ add_library(paddle_parameter STATIC
${PARAMETERS_SOURCES}) ${PARAMETERS_SOURCES})
add_style_check_target(paddle_parameter ${PARAMETERS_SOURCES}) add_style_check_target(paddle_parameter ${PARAMETERS_SOURCES})
add_style_check_target(paddle_parameter ${PARAMETERS_HEADERS}) add_style_check_target(paddle_parameter ${PARAMETERS_HEADERS})
add_dependencies(paddle_parameter gen_proto_cpp) add_dependencies(paddle_parameter paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -2,4 +2,3 @@ nv_test(cuda_test SRCS cuda_test.cu) ...@@ -2,4 +2,3 @@ nv_test(cuda_test SRCS cuda_test.cu)
cc_library(place SRCS place.cc) cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags) cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
cc_test(must_check_test SRCS must_check_test.cc)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef PADDLE_ONLY_CPU
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
namespace paddle {
namespace platform {
inline void throw_on_error(cudaError_t e, const char* message) {
if (e) {
throw thrust::system_error(e, thrust::cuda_category(), message);
}
}
int GetDeviceCount(void) {
int count;
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count;
}
} // namespace platform
} // namespace paddle
#endif // PADDLE_ONLY_CPU
#include <gtest/gtest.h>
#include <paddle/platform/must_check.h>
int __must_check SomeFunctionMustCheck() { return 0; }
TEST(MustCheck, all) {
// This line should not be compiled, because the
// return value of SomeFunctionMustCheck marked as __must_check
// SomeFunctionMustCheck();
}
\ No newline at end of file
...@@ -8,8 +8,8 @@ namespace detail { ...@@ -8,8 +8,8 @@ namespace detail {
class PlacePrinter : public boost::static_visitor<> { class PlacePrinter : public boost::static_visitor<> {
public: public:
PlacePrinter(std::ostream &os) : os_(os) {} PlacePrinter(std::ostream &os) : os_(os) {}
void operator()(const CpuPlace &) { os_ << "CpuPlace"; } void operator()(const CPUPlace &) { os_ << "CPUPlace"; }
void operator()(const GpuPlace &p) { os_ << "GpuPlace(" << p.device << ")"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; }
private: private:
std::ostream &os_; std::ostream &os_;
...@@ -22,14 +22,14 @@ static Place the_default_place; ...@@ -22,14 +22,14 @@ static Place the_default_place;
void set_place(const Place &place) { the_default_place = place; } void set_place(const Place &place) { the_default_place = place; }
const Place &get_place() { return the_default_place; } const Place &get_place() { return the_default_place; }
const GpuPlace default_gpu() { return GpuPlace(0); } const GPUPlace default_gpu() { return GPUPlace(0); }
const CpuPlace default_cpu() { return CpuPlace(); } const CPUPlace default_cpu() { return CPUPlace(); }
bool is_gpu_place(const Place &p) { bool is_gpu_place(const Place &p) {
return boost::apply_visitor(IsGpuPlace(), p); return boost::apply_visitor(IsGPUPlace(), p);
} }
bool is_cpu_place(const Place &p) { bool is_cpu_place(const Place &p) {
return !boost::apply_visitor(IsGpuPlace(), p); return !boost::apply_visitor(IsGPUPlace(), p);
} }
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once #pragma once
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <iostream> #include <iostream>
namespace paddle { namespace paddle {
namespace platform { namespace platform {
struct CpuPlace { struct CPUPlace {
// WORKAROUND: for some reason, omitting this constructor // WORKAROUND: for some reason, omitting this constructor
// causes errors with boost 1.59 and OSX // causes errors with boost 1.59 and OSX
CpuPlace() {} CPUPlace() {}
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const CpuPlace &) const { return true; } inline bool operator==(const CPUPlace &) const { return true; }
inline bool operator!=(const CpuPlace &) const { return false; } inline bool operator!=(const CPUPlace &) const { return false; }
}; };
struct GpuPlace { struct GPUPlace {
GpuPlace() : GpuPlace(0) {} GPUPlace() : GPUPlace(0) {}
GpuPlace(int d) : device(d) {} GPUPlace(int d) : device(d) {}
// needed for variant equality comparison // needed for variant equality comparison
inline bool operator==(const GpuPlace &o) const { return device == o.device; } inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GpuPlace &o) const { return !(*this == o); } inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
int device; int device;
}; };
struct IsGpuPlace : public boost::static_visitor<bool> { struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CpuPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GpuPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
}; };
typedef boost::variant<GpuPlace, CpuPlace> Place; typedef boost::variant<GPUPlace, CPUPlace> Place;
void set_place(const Place &); void set_place(const Place &);
const Place &get_place(); const Place &get_place();
const GpuPlace default_gpu(); const GPUPlace default_gpu();
const CpuPlace default_cpu(); const CPUPlace default_cpu();
bool is_gpu_place(const Place &); bool is_gpu_place(const Place &);
bool is_cpu_place(const Place &); bool is_cpu_place(const Place &);
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
TEST(Place, Equality) { TEST(Place, Equality) {
paddle::platform::CpuPlace cpu; paddle::platform::CPUPlace cpu;
paddle::platform::GpuPlace g0(0), g1(1), gg0(0); paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
EXPECT_EQ(cpu, cpu); EXPECT_EQ(cpu, cpu);
EXPECT_EQ(g0, g0); EXPECT_EQ(g0, g0);
...@@ -22,19 +22,19 @@ TEST(Place, Default) { ...@@ -22,19 +22,19 @@ TEST(Place, Default) {
EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu())); EXPECT_TRUE(paddle::platform::is_gpu_place(paddle::platform::default_gpu()));
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::default_cpu()));
paddle::platform::set_place(paddle::platform::CpuPlace()); paddle::platform::set_place(paddle::platform::CPUPlace());
EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place())); EXPECT_TRUE(paddle::platform::is_cpu_place(paddle::platform::get_place()));
} }
TEST(Place, Print) { TEST(Place, Print) {
{ {
std::stringstream ss; std::stringstream ss;
ss << paddle::platform::GpuPlace(1); ss << paddle::platform::GPUPlace(1);
EXPECT_EQ("GpuPlace(1)", ss.str()); EXPECT_EQ("GPUPlace(1)", ss.str());
} }
{ {
std::stringstream ss; std::stringstream ss;
ss << paddle::platform::CpuPlace(); ss << paddle::platform::CPUPlace();
EXPECT_EQ("CpuPlace", ss.str()); EXPECT_EQ("CPUPlace", ss.str());
} }
} }
...@@ -17,7 +17,7 @@ add_library(paddle_network STATIC ...@@ -17,7 +17,7 @@ add_library(paddle_network STATIC
add_style_check_target(paddle_network ${NETWORK_SOURCES}) add_style_check_target(paddle_network ${NETWORK_SOURCES})
add_style_check_target(paddle_network ${NETWORK_HEADERS}) add_style_check_target(paddle_network ${NETWORK_HEADERS})
add_dependencies(paddle_network gen_proto_cpp) add_dependencies(paddle_network paddle_proto ${external_project_dependencies})
################### paddle_pserver ###################### ################### paddle_pserver ######################
set(PSERVER_SOURCES set(PSERVER_SOURCES
...@@ -40,7 +40,7 @@ add_library(paddle_pserver STATIC ...@@ -40,7 +40,7 @@ add_library(paddle_pserver STATIC
add_style_check_target(paddle_pserver ${PSERVER_SOURCES}) add_style_check_target(paddle_pserver ${PSERVER_SOURCES})
add_style_check_target(paddle_pserver ${PSERVER_HEADERS}) add_style_check_target(paddle_pserver ${PSERVER_HEADERS})
add_dependencies(paddle_pserver gen_proto_cpp) add_dependencies(paddle_pserver paddle_proto ${external_project_dependencies})
set(PSERVER_MAIN_SOURCES set(PSERVER_MAIN_SOURCES
ParameterServer2Main.cpp) ParameterServer2Main.cpp)
......
...@@ -144,7 +144,7 @@ class DenseScanner(IScanner): ...@@ -144,7 +144,7 @@ class DenseScanner(IScanner):
if len(self.__shape__) > 1: if len(self.__shape__) > 1:
# The last-two dimenstions are the frame height and width. # The last-two dimenstions are the frame height and width.
# For example, the layout is CHW for 3-D feature of image. # For example, the layout is CHW for 3-D feature of image.
# The H and W are the fram height and width. # The H and W are the frame height and width.
h, w = self.__shape__[-2:] h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w) argument.setSlotFrameWidth(self.pos, w)
......
cc_library(stringpiece SRCS piece.cc)
cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags)
cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <string.h> #include <string.h>
...@@ -23,29 +23,25 @@ ...@@ -23,29 +23,25 @@
#include <stdexcept> #include <stdexcept>
namespace paddle { namespace paddle {
namespace string {
StringPiece::StringPiece() : data_(NULL), size_(0) {} Piece::Piece() : data_(NULL), size_(0) {}
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) { Piece::Piece(const char* d, size_t n) : data_(d), size_(n) {
if (d == NULL && n != 0) if (d == NULL && n != 0)
throw std::invalid_argument( throw std::invalid_argument("Piece requires len to be 0 for NULL data");
"StringPiece requires len to be 0 for NULL data");
} }
StringPiece::StringPiece(const char* s) : data_(s) { Piece::Piece(const char* s) : data_(s) { size_ = (s == NULL) ? 0 : strlen(s); }
size_ = (s == NULL) ? 0 : strlen(s);
}
StringPiece::StringPiece(const std::string& s) Piece::Piece(const std::string& s) : data_(s.data()), size_(s.size()) {}
: data_(s.data()), size_(s.size()) {}
char StringPiece::operator[](size_t n) const { char Piece::operator[](size_t n) const {
if (n >= len()) if (n >= len()) throw std::invalid_argument("index out of Piece length");
throw std::invalid_argument("index out of StringPiece length");
return data_[n]; return data_[n];
} }
int Compare(StringPiece a, StringPiece b) { int Compare(Piece a, Piece b) {
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
int r = memcmp(a.data(), b.data(), min_len); int r = memcmp(a.data(), b.data(), min_len);
if (r == 0) { if (r == 0) {
...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) { ...@@ -57,85 +53,86 @@ int Compare(StringPiece a, StringPiece b) {
return r; return r;
} }
bool operator==(StringPiece x, StringPiece y) { bool operator==(Piece x, Piece y) {
return ((x.len() == y.len()) && return ((x.len() == y.len()) &&
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
} }
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } bool operator!=(Piece x, Piece y) { return !(x == y); }
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; } bool operator<(Piece x, Piece y) { return Compare(x, y) < 0; }
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; } bool operator>(Piece x, Piece y) { return Compare(x, y) > 0; }
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; } bool operator<=(Piece x, Piece y) { return Compare(x, y) <= 0; }
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; } bool operator>=(Piece x, Piece y) { return Compare(x, y) >= 0; }
bool HasPrefix(StringPiece s, StringPiece x) { bool HasPrefix(Piece s, Piece x) {
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
} }
bool HasSuffix(StringPiece s, StringPiece x) { bool HasSuffix(Piece s, Piece x) {
return ((s.len() >= x.len()) && return ((s.len() >= x.len()) &&
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
} }
StringPiece SkipPrefix(StringPiece s, size_t n) { Piece SkipPrefix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data() + n, s.len() - n); return Piece(s.data() + n, s.len() - n);
} }
StringPiece SkipSuffix(StringPiece s, size_t n) { Piece SkipSuffix(Piece s, size_t n) {
if (n > s.len()) if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length"); throw std::invalid_argument("Skip distance larger than Piece length");
return StringPiece(s.data(), s.len() - n); return Piece(s.data(), s.len() - n);
} }
StringPiece TrimPrefix(StringPiece s, StringPiece x) { Piece TrimPrefix(Piece s, Piece x) {
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
} }
StringPiece TrimSuffix(StringPiece s, StringPiece x) { Piece TrimSuffix(Piece s, Piece x) {
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
} }
bool Contains(StringPiece s, StringPiece sub) { bool Contains(Piece s, Piece sub) {
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
} }
size_t Index(StringPiece s, StringPiece sub) { size_t Index(Piece s, Piece sub) {
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
return e != s.end() ? e - s.data() : StringPiece::npos; return e != s.end() ? e - s.data() : Piece::npos;
} }
size_t Find(StringPiece s, char c, size_t pos) { size_t Find(Piece s, char c, size_t pos) {
if (pos >= s.len()) { if (pos >= s.len()) {
return StringPiece::npos; return Piece::npos;
} }
const char* result = const char* result =
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos)); reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
return result != nullptr ? result - s.data() : StringPiece::npos; return result != nullptr ? result - s.data() : Piece::npos;
} }
size_t RFind(StringPiece s, char c, size_t pos) { size_t RFind(Piece s, char c, size_t pos) {
if (s.len() == 0) return StringPiece::npos; if (s.len() == 0) return Piece::npos;
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
p--) { p--) {
if (*p == c) { if (*p == c) {
return p - s.data(); return p - s.data();
} }
} }
return StringPiece::npos; return Piece::npos;
} }
StringPiece SubStr(StringPiece s, size_t pos, size_t n) { Piece SubStr(Piece s, size_t pos, size_t n) {
if (pos > s.len()) pos = s.len(); if (pos > s.len()) pos = s.len();
if (n > s.len() - pos) n = s.len() - pos; if (n > s.len() - pos) n = s.len() - pos;
return StringPiece(s.data() + pos, n); return Piece(s.data() + pos, n);
} }
std::ostream& operator<<(std::ostream& o, StringPiece piece) { std::ostream& operator<<(std::ostream& o, Piece piece) {
return o << piece.ToString(); return o << piece.ToString();
} }
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -20,33 +20,34 @@ ...@@ -20,33 +20,34 @@
#include <string> #include <string>
namespace paddle { namespace paddle {
namespace string {
// StringPiece points into a std::string object but doesn't own the // Piece points into a std::string object but doesn't own the
// string. It is for efficient access to strings. Like Go's string // string. It is for efficient access to strings. Like Go's string
// type. Not that StringPiece doesn't mutate the underlying string, // type. Not that Piece doesn't mutate the underlying string,
// so it is thread-safe given that the underlying string doesn't // so it is thread-safe given that the underlying string doesn't
// change. Because StringPiece contains a little data members, and // change. Because Piece contains a little data members, and
// its syntax is simple as it doesn't own/manage the string, it is // its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct StringPieces and pass them around. // cheap to construct Pieces and pass them around.
class StringPiece { class Piece {
public: public:
static const size_t npos = static_cast<size_t>(-1); static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can // We provide non-explicit singleton constructors so users can
// pass in a "const char*" or a "string" wherever a "StringPiece" // pass in a "const char*" or a "string" wherever a "Piece"
// is expected. These contructors ensure that if data_ is NULL, // is expected. These contructors ensure that if data_ is NULL,
// size_ is 0. // size_ is 0.
StringPiece(); Piece();
StringPiece(const char* d, size_t n); Piece(const char* d, size_t n);
StringPiece(const char* d); Piece(const char* d);
StringPiece(const std::string& s); Piece(const std::string& s);
const char* data() const { return data_; } const char* data() const { return data_; }
size_t len() const { return size_; } size_t len() const { return size_; }
char operator[](size_t n) const; char operator[](size_t n) const;
// StringPiece doesn't own the string, so both iterator and const // Piece doesn't own the string, so both iterator and const
// iterator are const char* indeed. // iterator are const char* indeed.
typedef const char* const_iterator; typedef const char* const_iterator;
typedef const char* iterator; typedef const char* iterator;
...@@ -63,43 +64,44 @@ private: ...@@ -63,43 +64,44 @@ private:
// Intentionally copyable // Intentionally copyable
}; };
int Compare(StringPiece a, StringPiece b); int Compare(Piece a, Piece b);
bool operator==(StringPiece x, StringPiece y); bool operator==(Piece x, Piece y);
bool operator!=(StringPiece x, StringPiece y); bool operator!=(Piece x, Piece y);
bool operator<(StringPiece x, StringPiece y); bool operator<(Piece x, Piece y);
bool operator>(StringPiece x, StringPiece y); bool operator>(Piece x, Piece y);
bool operator<=(StringPiece x, StringPiece y); bool operator<=(Piece x, Piece y);
bool operator>=(StringPiece x, StringPiece y); bool operator>=(Piece x, Piece y);
bool HasPrefix(StringPiece s, StringPiece prefix); bool HasPrefix(Piece s, Piece prefix);
bool HasSuffix(StringPiece s, StringPiece suffix); bool HasSuffix(Piece s, Piece suffix);
StringPiece SkipPrefix(StringPiece s, size_t n); Piece SkipPrefix(Piece s, size_t n);
StringPiece SkipSuffix(StringPiece s, size_t n); Piece SkipSuffix(Piece s, size_t n);
// Skip the prefix (or suffix) if it matches with the string. // Skip the prefix (or suffix) if it matches with the string.
StringPiece TrimPrefix(StringPiece s, StringPiece prefix); Piece TrimPrefix(Piece s, Piece prefix);
StringPiece TrimSuffix(StringPiece s, StringPiece suffix); Piece TrimSuffix(Piece s, Piece suffix);
// Returns if s contains sub. Any s except for empty s contains an // Returns if s contains sub. Any s except for empty s contains an
// empty sub. // empty sub.
bool Contains(StringPiece s, StringPiece sub); bool Contains(Piece s, Piece sub);
// Return the first occurrence of sub in s, or npos. If both s and // Return the first occurrence of sub in s, or npos. If both s and
// sub is empty, it returns npos; otherwise, if only sub is empty, it // sub is empty, it returns npos; otherwise, if only sub is empty, it
// returns 0. // returns 0.
size_t Index(StringPiece s, StringPiece sub); size_t Index(Piece s, Piece sub);
// Return the first occurrence of c in s[pos:end], or npos. // Return the first occurrence of c in s[pos:end], or npos.
size_t Find(StringPiece s, char c, size_t pos); size_t Find(Piece s, char c, size_t pos);
// Search range is [0..pos] inclusive. If pos == npos, search everything. // Search range is [0..pos] inclusive. If pos == npos, search everything.
size_t RFind(StringPiece s, char c, size_t pos); size_t RFind(Piece s, char c, size_t pos);
StringPiece SubStr(StringPiece s, size_t pos, size_t n); Piece SubStr(Piece s, size_t pos, size_t n);
// allow StringPiece to be logged // allow Piece to be logged
std::ostream& operator<<(std::ostream& o, StringPiece piece); std::ostream& operator<<(std::ostream& o, Piece piece);
} // namespace string
} // namespace paddle } // namespace paddle
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
limitations under the License. limitations under the License.
*/ */
#include "paddle/strings/stringpiece.h" #include "paddle/string/piece.h"
#include <sstream> #include <sstream>
...@@ -22,42 +22,44 @@ ...@@ -22,42 +22,44 @@
TEST(StringPiece, Construct) { TEST(StringPiece, Construct) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(NULL, s.data()); EXPECT_EQ(NULL, s.data());
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); }
{ {
paddle::StringPiece s(NULL); EXPECT_THROW(paddle::string::Piece s(NULL, 10000U), std::invalid_argument);
}
{
paddle::string::Piece s(NULL);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
{ {
std::string a; std::string a;
EXPECT_EQ(0U, a.size()); EXPECT_EQ(0U, a.size());
paddle::StringPiece s(a); paddle::string::Piece s(a);
EXPECT_EQ(0U, s.len()); EXPECT_EQ(0U, s.len());
} }
} }
TEST(StringPiece, CopyAndAssign) { TEST(StringPiece, CopyAndAssign) {
paddle::StringPiece empty; paddle::string::Piece empty;
EXPECT_EQ(0U, empty.len()); EXPECT_EQ(0U, empty.len());
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b = a; paddle::string::Piece b = a;
EXPECT_EQ(b.len(), strlen("hello")); EXPECT_EQ(b.len(), strlen("hello"));
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
std::string storage("hello"); std::string storage("hello");
paddle::StringPiece c(storage); paddle::string::Piece c(storage);
EXPECT_EQ(a, c); EXPECT_EQ(a, c);
EXPECT_NE(a.data(), c.data()); EXPECT_NE(a.data(), c.data());
} }
TEST(StringPiece, Compare) { TEST(StringPiece, Compare) {
{ {
paddle::StringPiece a("hello"); paddle::string::Piece a("hello");
paddle::StringPiece b("world"); paddle::string::Piece b("world");
EXPECT_TRUE(a != b); EXPECT_TRUE(a != b);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a < b); EXPECT_TRUE(a < b);
...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) { ...@@ -68,7 +70,7 @@ TEST(StringPiece, Compare) {
EXPECT_GT(Compare(b, a), 0); EXPECT_GT(Compare(b, a), 0);
} }
{ {
paddle::StringPiece a, b; paddle::string::Piece a, b;
EXPECT_TRUE(a == b); EXPECT_TRUE(a == b);
EXPECT_FALSE(a != b); EXPECT_FALSE(a != b);
EXPECT_FALSE(a < b); EXPECT_FALSE(a < b);
...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) { ...@@ -82,31 +84,31 @@ TEST(StringPiece, Compare) {
TEST(StringPiece, ToString) { TEST(StringPiece, ToString) {
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s(NULL); paddle::string::Piece s(NULL);
EXPECT_EQ(std::string(""), s.ToString()); EXPECT_EQ(std::string(""), s.ToString());
} }
{ {
paddle::StringPiece s("hello"); paddle::string::Piece s("hello");
EXPECT_EQ(std::string("hello"), s.ToString()); EXPECT_EQ(std::string("hello"), s.ToString());
} }
} }
TEST(StringPiece, HasPrefixSuffix) { TEST(StringPiece, HasPrefixSuffix) {
using paddle::HasPrefix; using paddle::string::HasPrefix;
using paddle::HasSuffix; using paddle::string::HasSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(HasPrefix(s, "something")); EXPECT_FALSE(HasPrefix(s, "something"));
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_FALSE(HasSuffix(s, "something")); EXPECT_FALSE(HasSuffix(s, "something"));
EXPECT_TRUE(HasSuffix(s, "")); EXPECT_TRUE(HasSuffix(s, ""));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_TRUE(HasPrefix(s, "a")); EXPECT_TRUE(HasPrefix(s, "a"));
EXPECT_TRUE(HasPrefix(s, "ap")); EXPECT_TRUE(HasPrefix(s, "ap"));
...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) { ...@@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) {
} }
TEST(StringPiece, SkipPrefixSuffix) { TEST(StringPiece, SkipPrefixSuffix) {
using paddle::SkipPrefix; using paddle::string::SkipPrefix;
using paddle::SkipSuffix; using paddle::string::SkipSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SkipPrefix(s, 0)); EXPECT_EQ("", SkipPrefix(s, 0));
EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument); EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument);
...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) {
EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument); EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument);
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", SkipPrefix(s, 0)); EXPECT_EQ("app", SkipPrefix(s, 0));
EXPECT_EQ("pp", SkipPrefix(s, 1)); EXPECT_EQ("pp", SkipPrefix(s, 1));
EXPECT_EQ("p", SkipPrefix(s, 2)); EXPECT_EQ("p", SkipPrefix(s, 2));
...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) { ...@@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) {
} }
TEST(StringPiece, TrimPrefixSuffix) { TEST(StringPiece, TrimPrefixSuffix) {
using paddle::TrimPrefix; using paddle::string::TrimPrefix;
using paddle::TrimSuffix; using paddle::string::TrimSuffix;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", TrimPrefix(s, "")); EXPECT_EQ("", TrimPrefix(s, ""));
EXPECT_EQ("", TrimPrefix(s, "something")); EXPECT_EQ("", TrimPrefix(s, "something"));
...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) {
EXPECT_EQ("", TrimSuffix(s, "something")); EXPECT_EQ("", TrimSuffix(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("app", TrimPrefix(s, "")); EXPECT_EQ("app", TrimPrefix(s, ""));
EXPECT_EQ("pp", TrimPrefix(s, "a")); EXPECT_EQ("pp", TrimPrefix(s, "a"));
EXPECT_EQ("p", TrimPrefix(s, "ap")); EXPECT_EQ("p", TrimPrefix(s, "ap"));
...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) { ...@@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) {
} }
TEST(StringPiece, Contains) { TEST(StringPiece, Contains) {
using paddle::Contains; using paddle::string::Contains;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_FALSE(Contains(s, "")); EXPECT_FALSE(Contains(s, ""));
EXPECT_FALSE(Contains(s, "something")); EXPECT_FALSE(Contains(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_TRUE(Contains(s, "")); EXPECT_TRUE(Contains(s, ""));
EXPECT_TRUE(Contains(s, "a")); EXPECT_TRUE(Contains(s, "a"));
EXPECT_TRUE(Contains(s, "p")); EXPECT_TRUE(Contains(s, "p"));
...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) { ...@@ -193,15 +195,15 @@ TEST(StringPiece, Contains) {
} }
TEST(StringPiece, Index) { TEST(StringPiece, Index) {
using paddle::Index; using paddle::string::Index;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Index(s, "")); EXPECT_EQ(npos, Index(s, ""));
EXPECT_EQ(npos, Index(s, "something")); EXPECT_EQ(npos, Index(s, "something"));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Index(s, "")); EXPECT_EQ(0U, Index(s, ""));
EXPECT_EQ(0U, Index(s, "a")); EXPECT_EQ(0U, Index(s, "a"));
EXPECT_EQ(1U, Index(s, "p")); EXPECT_EQ(1U, Index(s, "p"));
...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) { ...@@ -213,14 +215,14 @@ TEST(StringPiece, Index) {
} }
TEST(StringPiece, Find) { TEST(StringPiece, Find) {
using paddle::Find; using paddle::string::Find;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, Find(s, 'a', 0U)); EXPECT_EQ(npos, Find(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(0U, Find(s, 'a', 0U)); EXPECT_EQ(0U, Find(s, 'a', 0U));
EXPECT_EQ(1U, Find(s, 'p', 0U)); EXPECT_EQ(1U, Find(s, 'p', 0U));
EXPECT_EQ(1U, Find(s, 'p', 1U)); EXPECT_EQ(1U, Find(s, 'p', 1U));
...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) { ...@@ -230,14 +232,14 @@ TEST(StringPiece, Find) {
} }
TEST(StringPiece, RFind) { TEST(StringPiece, RFind) {
using paddle::RFind; using paddle::string::RFind;
auto npos = paddle::StringPiece::npos; auto npos = paddle::string::Piece::npos;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ(npos, RFind(s, 'a', 0U)); EXPECT_EQ(npos, RFind(s, 'a', 0U));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ(2U, RFind(s, 'p', 2U)); EXPECT_EQ(2U, RFind(s, 'p', 2U));
EXPECT_EQ(0U, RFind(s, 'a', 2U)); EXPECT_EQ(0U, RFind(s, 'a', 2U));
EXPECT_EQ(1U, RFind(s, 'p', 1U)); EXPECT_EQ(1U, RFind(s, 'p', 1U));
...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) { ...@@ -247,15 +249,15 @@ TEST(StringPiece, RFind) {
} }
TEST(StringPiece, SubStr) { TEST(StringPiece, SubStr) {
using paddle::SubStr; using paddle::string::SubStr;
{ {
paddle::StringPiece s; paddle::string::Piece s;
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 0, 1)); EXPECT_EQ("", SubStr(s, 0, 1));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
} }
{ {
paddle::StringPiece s("app"); paddle::string::Piece s("app");
EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 1, 0));
EXPECT_EQ("", SubStr(s, 2, 0)); EXPECT_EQ("", SubStr(s, 2, 0));
...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) { ...@@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) {
} }
TEST(StringPiece, StreamOutput) { TEST(StringPiece, StreamOutput) {
using paddle::StringPiece; using paddle::string::Piece;
std::stringstream o; std::stringstream o;
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("", o.str()); EXPECT_EQ("", o.str());
o << StringPiece("hello"); o << paddle::string::Piece("hello");
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
o << StringPiece(); o << paddle::string::Piece();
EXPECT_EQ("hello", o.str()); EXPECT_EQ("hello", o.str());
} }
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Compared with std::stringstream, there are primary purpose of
// string::Printf:
//
// 1. Type-safe printing, with why and how explained in
// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999.
// Implementation includes
//
// https://github.com/c42f/tinyformat
// boost::format
// std::stringstream
//
// std::stringstream is not convenient enough in many cases. For example:
//
// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n";
//
// boost::format is the most convenient one. We can have
//
// std::cout << format("%2% %1%") % 36 % 77;
//
// or
//
// format fmter("%2% %1%");
// fmter % 36; fmter % 77;
// std::cout << fmter.c_str();
//
// But the overloading of % might be overkilling and it would be
// more efficient if it can write to std::cout directly.
//
// tinyformat has an interface compatible with the C-printf style,
// and it can writes to a stream or returns a std::string:
//
// std::cout << tfm::printf(
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// or
//
// tfm::format(std::cout,
// "%s, %s %d, %.2d:%.2d\n",
// weekday, month, day, hour, min);
//
// 2. High-performance -- most printed strings are not too long and
// doens't need dynamic memory allocation. Many StringPrintf
// implementations doesn't enforce type-safe, but are
// high-performance, including
//
// https://developers.google.com/optimization/reference/base/stringprintf/
// https://github.com/adobe/chromium/blob/master/base/stringprintf.h
// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h
//
// According to
// https://github.com/c42f/tinyformat#compile-time-and-code-bloat,
// boost::format runs too slow and results in large executable binary
// files. So here we port tinyformat.
#pragma once
#include <iostream>
#include <sstream>
#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat
namespace paddle {
namespace string {
template <typename... Args>
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
}
template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
}
template <typename... Args>
void Printf(const char* fmt, const Args&... args) {
Fprintf(std::cout, fmt, args...);
}
} // namespace string
} // namespace paddle
#include "paddle/string/printf.h"
#include <string>
#include "gtest/gtest.h"
TEST(StringPrintf, StringPrintf) {
std::string weekday = "Wednesday";
const char* month = "July";
size_t day = 27;
long hour = 14;
int min = 44;
EXPECT_EQ(std::string("Wednesday, July 27, 14:44"),
paddle::string::Sprintf(
"%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min));
}
此差异已折叠。
cc_library(stringpiece SRCS stringpiece.cc)
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
if(WITH_TESTING) if(WITH_TESTING)
add_library(paddle_test_main STATIC TestMain.cpp) add_library(paddle_test_main STATIC TestMain.cpp)
add_dependencies(paddle_test_main gen_proto_cpp) add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies})
add_library(paddle_test_util STATIC TestUtil.cpp) add_library(paddle_test_util STATIC TestUtil.cpp)
add_dependencies(paddle_test_util gen_proto_cpp) add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies})
endif() endif()
...@@ -41,7 +41,8 @@ add_style_check_target(paddle_trainer_lib ...@@ -41,7 +41,8 @@ add_style_check_target(paddle_trainer_lib
add_style_check_target(paddle_trainer_lib add_style_check_target(paddle_trainer_lib
${TRAINER_HEADERS}) ${TRAINER_HEADERS})
add_dependencies(paddle_trainer_lib add_dependencies(paddle_trainer_lib
gen_proto_cpp) paddle_proto
${external_project_dependencies})
macro(add_paddle_exe TARGET_NAME) macro(add_paddle_exe TARGET_NAME)
add_executable(${TARGET_NAME} ${ARGN}) add_executable(${TARGET_NAME} ${ARGN})
...@@ -72,6 +73,6 @@ endif() ...@@ -72,6 +73,6 @@ endif()
if(WITH_GOLANG) if(WITH_GOLANG)
add_dependencies(paddle_trainer_lib paddle_pserver_cclient) add_dependencies(paddle_trainer_lib paddle_pserver_cclient)
target_link_libraries(paddle_trainer ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a) target_link_libraries(paddle_trainer paddle_pserver_cclient)
target_link_libraries(paddle_trainer_lib ${CMAKE_BINARY_DIR}/go/pserver/cclient/libpaddle_pserver_cclient.a) target_link_libraries(paddle_trainer_lib paddle_pserver_cclient)
endif(WITH_GOLANG) endif(WITH_GOLANG)
...@@ -17,7 +17,7 @@ add_library(paddle_utils STATIC ...@@ -17,7 +17,7 @@ add_library(paddle_utils STATIC
add_style_check_target(paddle_utils ${UTIL_HEADERS}) add_style_check_target(paddle_utils ${UTIL_HEADERS})
add_style_check_target(paddle_utils ${UTIL_SOURCES} add_style_check_target(paddle_utils ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES}) ${UTIL_ARCH_SOURCES})
add_dependencies(paddle_utils gen_proto_cpp) add_dependencies(paddle_utils paddle_proto ${external_project_dependencies})
if(WITH_TESTING) if(WITH_TESTING)
add_subdirectory(tests) add_subdirectory(tests)
endif() endif()
...@@ -19,7 +19,21 @@ limitations under the License. */ ...@@ -19,7 +19,21 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/platform/must_check.h"
/**
* __must_check macro. It make the function's return value must be used,
* otherwise it will raise a compile warning. And also Paddle treat all compile
* warnings as errors.
*/
#ifdef __GNUC__
#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400
#define __must_check __attribute__((warn_unused_result))
#else
#define __must_check
#endif
#else
#define __must_check
#endif
namespace paddle { namespace paddle {
......
set(proto_filenames file(GLOB proto_filenames . *.proto)
DataConfig.proto include_directories(${CMAKE_CURRENT_BINARY_DIR})
DataFormat.proto proto_library(paddle_proto SRCS ${proto_filenames})
ModelConfig.proto
ParameterConfig.proto
ParameterService.proto
TrainerConfig.proto
OptimizerConfig.proto
ParameterServerConfig.proto)
set(PROTO_GEN) set(PROTO_GEN)
set(PROTO_GEN_PY) set(PROTO_GEN_PY)
foreach(filename ${proto_filenames}) foreach(filename ${proto_filenames})
get_filename_component(base_filename ${filename} NAME_WE) get_filename_component(ABS_FIL ${filename} ABSOLUTE)
set(CUR_PROTO_GEN get_filename_component(FIL_WE ${filename} NAME_WE)
${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.h
${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.pb.cc)
set(PROTO_GEN
${PROTO_GEN}
${CUR_PROTO_GEN})
add_custom_command(OUTPUT ${CUR_PROTO_GEN}
COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE}
--cpp_out ${CMAKE_CURRENT_BINARY_DIR}
--proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename}
DEPENDS ${filename} ${external_project_dependencies})
set(CUR_PROTO_GEN_PY set(CUR_PROTO_GEN_PY
${PROJ_ROOT}/paddle/python/paddle/proto/${base_filename}_pb2.py) ${PROJ_ROOT}/paddle/python/paddle/proto/${FIL_WE}_pb2.py)
set(PROTO_GEN_PY set(PROTO_GEN_PY
${CUR_PROTO_GEN_PY} ${CUR_PROTO_GEN_PY}
${PROTO_GEN_PY}) ${PROTO_GEN_PY})
add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY} add_custom_command(OUTPUT ${CUR_PROTO_GEN_PY}
COMMAND env ${py_env} ${PROTOBUF_PROTOC_EXECUTABLE} --python_out ${PROJ_ROOT}/python/paddle/proto COMMAND ${PROTOBUF_PROTOC_EXECUTABLE}
--proto_path ${PROJ_ROOT}/proto ${PROJ_ROOT}/proto/${filename} ARGS "--python_out=${PROJ_ROOT}/python/paddle/proto"
DEPENDS ${filename} ${external_project_dependencies}) "-I" ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${external_project_dependencies})
endforeach() endforeach()
add_custom_target(gen_proto_cpp ALL DEPENDS ${PROTO_GEN})
add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY}) add_custom_target(gen_proto_py ALL DEPENDS ${PROTO_GEN_PY})
add_library(paddle_proto STATIC ${PROTO_GEN})
target_include_directories(paddle_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
...@@ -7,10 +7,21 @@ file(GLOB UTILS_PY_FILES . ./paddle/utils/*.py) ...@@ -7,10 +7,21 @@ file(GLOB UTILS_PY_FILES . ./paddle/utils/*.py)
file(GLOB_RECURSE V2_PY_FILES ./paddle/v2/ *.py) file(GLOB_RECURSE V2_PY_FILES ./paddle/v2/ *.py)
set(PY_FILES paddle/__init__.py set(PY_FILES paddle/__init__.py
${TRAINER_PY_FILES} ${TRAINER_PY_FILES}
${HELPERS_PY_FILES} ${HELPERS_PY_FILES}
${UTILS_PY_FILES} ${UTILS_PY_FILES}
${V2_PY_FILES}) ${V2_PY_FILES})
add_custom_target(copy_paddle_master)
SET(COPY_PADDLE_MASTER "")
if(WITH_GOLANG)
SET(COPY_PADDLE_MASTER "copy_paddle_master")
add_custom_command(TARGET ${COPY_PADDLE_MASTER}
COMMAND cp ${paddle_master_LIB_PATH} ${PROJ_ROOT}/python/paddle/v2/master/
)
add_dependencies(copy_paddle_master paddle_master)
endif(WITH_GOLANG)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
${CMAKE_CURRENT_BINARY_DIR}/setup.py) ${CMAKE_CURRENT_BINARY_DIR}/setup.py)
...@@ -18,7 +29,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in ...@@ -18,7 +29,7 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/setup.py.in
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel COMMAND env ${py_env} ${PYTHON_EXECUTABLE} setup.py bdist_wheel
COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp COMMAND ${CMAKE_COMMAND} -E touch ${OUTPUT_DIR}/.timestamp
DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies}) DEPENDS gen_proto_py ${PY_FILES} ${external_project_dependencies} ${COPY_PADDLE_MASTER})
add_custom_target(paddle_python ALL DEPENDS add_custom_target(paddle_python ALL DEPENDS
${OUTPUT_DIR}/.timestamp) ${OUTPUT_DIR}/.timestamp)
......
...@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase): ...@@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase): class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs): def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__( super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs) name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert( config_assert(
len(self.inputs) == 1, len(self.inputs) == 1,
'TransLayer must have one and only one input') 'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length row_conv_conf.context_length = context_length
......
...@@ -1149,10 +1149,10 @@ def pooling_layer(input, ...@@ -1149,10 +1149,10 @@ def pooling_layer(input,
@layer_support(DROPOUT) @layer_support(DROPOUT)
def lstmemory(input, def lstmemory(input,
name=None, name=None,
size=None,
reverse=False, reverse=False,
act=None, act=None,
gate_act=None, gate_act=None,
size=None,
state_act=None, state_act=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
...@@ -1194,6 +1194,8 @@ def lstmemory(input, ...@@ -1194,6 +1194,8 @@ def lstmemory(input,
:param name: The lstmemory layer name. :param name: The lstmemory layer name.
:type name: basestring :type name: basestring
:param size: DEPRECATED. size of the lstm cell
:type size: int
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param reverse: is sequence process reversed or not. :param reverse: is sequence process reversed or not.
...@@ -1220,15 +1222,15 @@ def lstmemory(input, ...@@ -1220,15 +1222,15 @@ def lstmemory(input,
assert state_act.support_hppl assert state_act.support_hppl
assert act.support_hppl assert act.support_hppl
assert input.size is not None and input.size % 4 == 0 assert input.size is not None and input.size % 4 == 0
if size is not None: if size is not None:
if input.size / 4 == size: if input.size / 4 == size:
plog = logger.warning plog = logger.warning
else: else:
plog = logger.fatal plog = logger.fatal
plog("size of lstmemory layer: %s is automatically set to "
plog("NOTE: The lstmemory layer[%s]'s size is set by previous input " "size of input layer / 4. The parameter size passing to "
"layer. The lstm size should be equal with input layer size/4. The" "this layer is ignored." % (name))
" size which is set explicitly will be ignored." % name)
Layer( Layer(
name=name, name=name,
...@@ -1255,11 +1257,11 @@ def lstmemory(input, ...@@ -1255,11 +1257,11 @@ def lstmemory(input,
@wrap_name_default("gru") @wrap_name_default("gru")
@layer_support(DROPOUT) @layer_support(DROPOUT)
def grumemory(input, def grumemory(input,
size=None,
name=None, name=None,
reverse=False, reverse=False,
act=None, act=None,
gate_act=None, gate_act=None,
size=None,
bias_attr=None, bias_attr=None,
param_attr=None, param_attr=None,
layer_attr=None): layer_attr=None):
...@@ -1318,6 +1320,8 @@ def grumemory(input, ...@@ -1318,6 +1320,8 @@ def grumemory(input,
:type name: None|basestring :type name: None|basestring
:param input: input layer. :param input: input layer.
:type input: LayerOutput. :type input: LayerOutput.
:param size: DEPRECATED. size of the gru cell
:type size: int
:param reverse: Whether sequence process is reversed or not. :param reverse: Whether sequence process is reversed or not.
:type reverse: bool :type reverse: bool
:param act: activation type, TanhActivation by default. This activation :param act: activation type, TanhActivation by default. This activation
...@@ -1334,9 +1338,6 @@ def grumemory(input, ...@@ -1334,9 +1338,6 @@ def grumemory(input,
:type param_attr: ParameterAttribute|None|False :type param_attr: ParameterAttribute|None|False
:param layer_attr: Extra Layer attribute :param layer_attr: Extra Layer attribute
:type layer_attr: ExtraLayerAttribute|None :type layer_attr: ExtraLayerAttribute|None
:param size: Stub parameter of size, but actually not used. If set this size
will get a warning.
:type size: None
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -1348,9 +1349,9 @@ def grumemory(input, ...@@ -1348,9 +1349,9 @@ def grumemory(input,
plog = logger.warning plog = logger.warning
else: else:
plog = logger.fatal plog = logger.fatal
plog("NOTE: the gru memory layer's size is set by previous input layer," plog("size of grumemory layer: %s is automatically set to "
" and should be input size / 3. Set size explicitly will be " "size of input layer / 3. The parameter size passing to this "
"ignored.") "layer is ignored." % (name))
Layer( Layer(
name=name, name=name,
...@@ -2524,8 +2525,8 @@ def img_cmrnorm_layer(input, ...@@ -2524,8 +2525,8 @@ def img_cmrnorm_layer(input,
@wrap_bias_attr_default() @wrap_bias_attr_default()
@wrap_param_attr_default(default_factory=lambda _: ParamAttr(initial_mean=1.0, @wrap_param_attr_default(
initial_std=0.)) default_factory=lambda _: ParamAttr(initial_mean=1.0, initial_std=0.))
@wrap_act_default(act=ReluActivation()) @wrap_act_default(act=ReluActivation())
@wrap_name_default("batch_norm") @wrap_name_default("batch_norm")
@layer_support(DROPOUT) @layer_support(DROPOUT)
...@@ -3013,25 +3014,25 @@ def lstm_step_layer(input, ...@@ -3013,25 +3014,25 @@ def lstm_step_layer(input,
bias_attr=None, bias_attr=None,
layer_attr=None): layer_attr=None):
""" """
LSTM Step Layer. It used in recurrent_group. The lstm equations are shown LSTM Step Layer. This function is used only in recurrent_group.
as follow. The lstm equations are shown as follows.
.. math:: .. math::
i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i)
f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f)
c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c)
o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o)
h_t & = o_t tanh(c_t) h_t & = o_t tanh(c_t)
The input of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use The input of lstm step is :math:`Wx_t + Wh_{t-1}`, and user should use
:code:`mixed_layer` and :code:`full_matrix_projection` to calculate these :code:`mixed_layer` and :code:`full_matrix_projection` to calculate these
input vector. input vectors.
The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do The state of lstm step is :math:`c_{t-1}`. And lstm step layer will do
...@@ -3042,14 +3043,14 @@ def lstm_step_layer(input, ...@@ -3042,14 +3043,14 @@ def lstm_step_layer(input,
... ...
This layer contains two outputs. Default output is :math:`h_t`. The other This layer has two outputs. Default output is :math:`h_t`. The other
output is :math:`o_t`, which name is 'state' and can use output is :math:`o_t`, whose name is 'state' and can use
:code:`get_output_layer` to extract this output. :code:`get_output_layer` to extract this output.
:param name: Layer's name. :param name: Layer's name.
:type name: basestring :type name: basestring
:param size: Layer's size. NOTE: lstm layer's size, should be equal as :param size: Layer's size. NOTE: lstm layer's size, should be equal to
:code:`input.size/4`, and should be equal as :code:`input.size/4`, and should be equal to
:code:`state.size`. :code:`state.size`.
:type size: int :type size: int
:param input: input layer. :math:`Wx_t + Wh_{t-1}` :param input: input layer. :math:`Wx_t + Wh_{t-1}`
......
...@@ -614,6 +614,7 @@ def simple_lstm(input, ...@@ -614,6 +614,7 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit') @wrap_name_default('lstm_unit')
def lstmemory_unit(input, def lstmemory_unit(input,
memory_boot=None,
name=None, name=None,
size=None, size=None,
param_attr=None, param_attr=None,
...@@ -626,9 +627,9 @@ def lstmemory_unit(input, ...@@ -626,9 +627,9 @@ def lstmemory_unit(input,
lstm_layer_attr=None, lstm_layer_attr=None,
get_output_layer_attr=None): get_output_layer_attr=None):
""" """
Define calculations that a LSTM unit performs in a single time step. Define calculations that a LSTM unit performs during a single time step.
This function itself is not a recurrent layer, so that it can not be This function itself is not a recurrent layer, so it can not be
directly applied to sequence input. This function is always used in directly used to process sequence inputs. This function is always used in
recurrent_group (see layers.py for more details) to implement attention recurrent_group (see layers.py for more details) to implement attention
mechanism. mechanism.
...@@ -638,13 +639,13 @@ def lstmemory_unit(input, ...@@ -638,13 +639,13 @@ def lstmemory_unit(input,
.. math:: .. math::
i_t & = \\sigma(W_{xi}x_{t} + W_{hi}h_{t-1} + W_{ci}c_{t-1} + b_i) i_t & = \\sigma(W_{x_i}x_{t} + W_{h_i}h_{t-1} + W_{c_i}c_{t-1} + b_i)
f_t & = \\sigma(W_{xf}x_{t} + W_{hf}h_{t-1} + W_{cf}c_{t-1} + b_f) f_t & = \\sigma(W_{x_f}x_{t} + W_{h_f}h_{t-1} + W_{c_f}c_{t-1} + b_f)
c_t & = f_tc_{t-1} + i_t tanh (W_{xc}x_t+W_{hc}h_{t-1} + b_c) c_t & = f_tc_{t-1} + i_t tanh (W_{x_c}x_t+W_{h_c}h_{t-1} + b_c)
o_t & = \\sigma(W_{xo}x_{t} + W_{ho}h_{t-1} + W_{co}c_t + b_o) o_t & = \\sigma(W_{x_o}x_{t} + W_{h_o}h_{t-1} + W_{c_o}c_t + b_o)
h_t & = o_t tanh(c_t) h_t & = o_t tanh(c_t)
...@@ -661,6 +662,8 @@ def lstmemory_unit(input, ...@@ -661,6 +662,8 @@ def lstmemory_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: lstmemory unit name. :param name: lstmemory unit name.
:type name: basestring :type name: basestring
:param size: lstmemory unit size. :param size: lstmemory unit size.
...@@ -692,7 +695,8 @@ def lstmemory_unit(input, ...@@ -692,7 +695,8 @@ def lstmemory_unit(input,
assert input.size % 4 == 0 assert input.size % 4 == 0
size = input.size / 4 size = input.size / 4
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size)
state_mem = memory(name="%s_state" % name, size=size) state_mem = memory(
name="%s_state" % name, size=size, boot_layer=memory_boot)
with mixed_layer( with mixed_layer(
name="%s_input_recurrent" % name, name="%s_input_recurrent" % name,
...@@ -726,6 +730,7 @@ def lstmemory_unit(input, ...@@ -726,6 +730,7 @@ def lstmemory_unit(input,
def lstmemory_group(input, def lstmemory_group(input,
size=None, size=None,
name=None, name=None,
memory_boot=None,
reverse=False, reverse=False,
param_attr=None, param_attr=None,
act=None, act=None,
...@@ -737,7 +742,7 @@ def lstmemory_group(input, ...@@ -737,7 +742,7 @@ def lstmemory_group(input,
lstm_layer_attr=None, lstm_layer_attr=None,
get_output_layer_attr=None): get_output_layer_attr=None):
""" """
lstm_group is a recurrent layer group version of Long Short Term Memory. It lstm_group is a recurrent_group version of Long Short Term Memory. It
does exactly the same calculation as the lstmemory layer (see lstmemory in does exactly the same calculation as the lstmemory layer (see lstmemory in
layers.py for the maths) does. A promising benefit is that LSTM memory layers.py for the maths) does. A promising benefit is that LSTM memory
cell states, or hidden states in every time step are accessible to the cell states, or hidden states in every time step are accessible to the
...@@ -748,8 +753,8 @@ def lstmemory_group(input, ...@@ -748,8 +753,8 @@ def lstmemory_group(input,
NOTE: In PaddlePaddle's implementation, the following input-to-hidden NOTE: In PaddlePaddle's implementation, the following input-to-hidden
multiplications: multiplications:
:math:`W_{xi}x_{t}` , :math:`W_{xf}x_{t}`, :math:`W_{x_i}x_{t}` , :math:`W_{x_f}x_{t}`,
:math:`W_{xc}x_t`, :math:`W_{xo}x_{t}` are not done in lstmemory_unit to :math:`W_{x_c}x_t`, :math:`W_{x_o}x_{t}` are not done in lstmemory_unit to
speed up the calculations. Consequently, an additional mixed_layer with speed up the calculations. Consequently, an additional mixed_layer with
full_matrix_projection must be included before lstmemory_unit is called. full_matrix_projection must be included before lstmemory_unit is called.
...@@ -765,10 +770,12 @@ def lstmemory_group(input, ...@@ -765,10 +770,12 @@ def lstmemory_group(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param name: lstmemory group name.
:type name: basestring
:param size: lstmemory group size. :param size: lstmemory group size.
:type size: int :type size: int
:param name: name of the lstmemory group.
:type name: basestring
:param memory_boot: the initialization state of LSTM cell.
:type memory_boot: LayerOutput | None
:param reverse: is lstm reversed :param reverse: is lstm reversed
:type reverse: bool :type reverse: bool
:param param_attr: Parameter config, None if use default. :param param_attr: Parameter config, None if use default.
...@@ -798,6 +805,7 @@ def lstmemory_group(input, ...@@ -798,6 +805,7 @@ def lstmemory_group(input,
def __lstm_step__(ipt): def __lstm_step__(ipt):
return lstmemory_unit( return lstmemory_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
mixed_bias_attr=mixed_bias_attr, mixed_bias_attr=mixed_bias_attr,
...@@ -819,6 +827,7 @@ def lstmemory_group(input, ...@@ -819,6 +827,7 @@ def lstmemory_group(input,
@wrap_name_default('gru_unit') @wrap_name_default('gru_unit')
def gru_unit(input, def gru_unit(input,
memory_boot=None,
size=None, size=None,
name=None, name=None,
gru_bias_attr=None, gru_bias_attr=None,
...@@ -829,8 +838,8 @@ def gru_unit(input, ...@@ -829,8 +838,8 @@ def gru_unit(input,
naive=False): naive=False):
""" """
Define calculations that a gated recurrent unit performs in a single time Define calculations that a gated recurrent unit performs in a single time
step. This function itself is not a recurrent layer, so that it can not be step. This function itself is not a recurrent layer, so it can not be
directly applied to sequence input. This function is almost always used in directly used to process sequence inputs. This function is always used in
the recurrent_group (see layers.py for more details) to implement attention the recurrent_group (see layers.py for more details) to implement attention
mechanism. mechanism.
...@@ -838,6 +847,8 @@ def gru_unit(input, ...@@ -838,6 +847,8 @@ def gru_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: name of the gru group. :param name: name of the gru group.
:type name: basestring :type name: basestring
:param size: hidden size of the gru. :param size: hidden size of the gru.
...@@ -856,7 +867,7 @@ def gru_unit(input, ...@@ -856,7 +867,7 @@ def gru_unit(input,
if size is None: if size is None:
size = input.size / 3 size = input.size / 3
out_mem = memory(name=name, size=size) out_mem = memory(name=name, size=size, boot_layer=memory_boot)
if naive: if naive:
__step__ = gru_step_naive_layer __step__ = gru_step_naive_layer
...@@ -878,6 +889,7 @@ def gru_unit(input, ...@@ -878,6 +889,7 @@ def gru_unit(input,
@wrap_name_default('gru_group') @wrap_name_default('gru_group')
def gru_group(input, def gru_group(input,
memory_boot=None,
size=None, size=None,
name=None, name=None,
reverse=False, reverse=False,
...@@ -888,7 +900,7 @@ def gru_group(input, ...@@ -888,7 +900,7 @@ def gru_group(input,
gru_layer_attr=None, gru_layer_attr=None,
naive=False): naive=False):
""" """
gru_group is a recurrent layer group version of Gated Recurrent Unit. It gru_group is a recurrent_group version of Gated Recurrent Unit. It
does exactly the same calculation as the grumemory layer does. A promising does exactly the same calculation as the grumemory layer does. A promising
benefit is that gru hidden states are accessible to the user. This is benefit is that gru hidden states are accessible to the user. This is
especially useful in attention model. If you do not need to access especially useful in attention model. If you do not need to access
...@@ -908,6 +920,8 @@ def gru_group(input, ...@@ -908,6 +920,8 @@ def gru_group(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell.
:type memory_boot: LayerOutput | None
:param name: name of the gru group. :param name: name of the gru group.
:type name: basestring :type name: basestring
:param size: hidden size of the gru. :param size: hidden size of the gru.
...@@ -929,6 +943,7 @@ def gru_group(input, ...@@ -929,6 +943,7 @@ def gru_group(input,
def __gru_step__(ipt): def __gru_step__(ipt):
return gru_unit( return gru_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
gru_bias_attr=gru_bias_attr, gru_bias_attr=gru_bias_attr,
...@@ -1083,7 +1098,6 @@ def simple_gru2(input, ...@@ -1083,7 +1098,6 @@ def simple_gru2(input,
return grumemory( return grumemory(
name=name, name=name,
size=size,
input=m, input=m,
reverse=reverse, reverse=reverse,
bias_attr=gru_bias_attr, bias_attr=gru_bias_attr,
......
...@@ -7,7 +7,7 @@ layers { ...@@ -7,7 +7,7 @@ layers {
} }
layers { layers {
name: "__row_conv_layer_0__" name: "__row_conv_layer_0__"
type: "maxout" type: "row_conv"
size: 2560 size: 2560
active_type: "relu" active_type: "relu"
inputs { inputs {
......
...@@ -56,6 +56,7 @@ __all__ = [ ...@@ -56,6 +56,7 @@ __all__ = [
'plot', 'plot',
'evaluator', 'evaluator',
'image', 'image',
'master',
] ]
......
...@@ -25,8 +25,9 @@ import uci_housing ...@@ -25,8 +25,9 @@ import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import mq2007 import mq2007
import flowers
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007' 'uci_housing', 'wmt14', 'mq2007', 'flowers'
] ]
此差异已折叠。
...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): ...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train()) paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020) self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test()) paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_valid(self): def test_valid(self):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
""" """
UCI Housing dataset. UCI Housing dataset.
This module will paddle.v2.dataset.common.download dataset from This module will download dataset from
https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and https://archive.ics.uci.edu/ml/machine-learning-databases/housing/ and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """
......
...@@ -262,7 +262,12 @@ def left_right_flip(im): ...@@ -262,7 +262,12 @@ def left_right_flip(im):
return im[:, ::-1, :] return im[:, ::-1, :]
def simple_transform(im, resize_size, crop_size, is_train, is_color=True): def simple_transform(im,
resize_size,
crop_size,
is_train,
is_color=True,
mean=None):
""" """
Simply data argumentation for training. These operations include Simply data argumentation for training. These operations include
resizing, croping and flipping. resizing, croping and flipping.
...@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True): ...@@ -288,7 +293,19 @@ def simple_transform(im, resize_size, crop_size, is_train, is_color=True):
im = left_right_flip(im) im = left_right_flip(im)
else: else:
im = center_crop(im, crop_size) im = center_crop(im, crop_size)
im = to_chw(im) if len(im.shape) == 3:
im = to_chw(im)
im = im.astype('float32')
if mean is not None:
mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel
if mean.ndim == 1:
mean = mean[:, np.newaxis, np.newaxis]
else:
# elementwise mean
assert len(mean.shape) == len(im)
im -= mean
return im return im
...@@ -297,7 +314,8 @@ def load_and_transform(filename, ...@@ -297,7 +314,8 @@ def load_and_transform(filename,
resize_size, resize_size,
crop_size, crop_size,
is_train, is_train,
is_color=True): is_color=True,
mean=None):
""" """
Load image from the input file `filename` and transform image for Load image from the input file `filename` and transform image for
data argumentation. Please refer to the `simple_transform` interface data argumentation. Please refer to the `simple_transform` interface
...@@ -318,5 +336,5 @@ def load_and_transform(filename, ...@@ -318,5 +336,5 @@ def load_and_transform(filename,
:type is_train: bool :type is_train: bool
""" """
im = load_image(filename) im = load_image(filename)
im = simple_transform(im, resize_size, crop_size, is_train, is_color) im = simple_transform(im, resize_size, crop_size, is_train, is_color, mean)
return im return im
...@@ -51,7 +51,7 @@ class Parameters(object): ...@@ -51,7 +51,7 @@ class Parameters(object):
def __init__(self): def __init__(self):
self.__param_conf__ = dict() self.__param_conf__ = dict()
self.__gradient_machines__ = [] self.__gradient_machines__ = []
self.__tmp_params__ = [] self.__tmp_params__ = dict()
def __append_config__(self, param_conf): def __append_config__(self, param_conf):
""" """
...@@ -128,13 +128,10 @@ class Parameters(object): ...@@ -128,13 +128,10 @@ class Parameters(object):
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy. # create new parameter in python numpy.
if len(self.__tmp_params__) != 0: if key in self.__tmp_params__:
ret_list = [ return self.__tmp_params__[key]
mat for name, mat in self.__tmp_params__ if name == key else:
] return np.ndarray(shape=shape, dtype=np.float32)
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32)
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
param = __get_parameter_in_gradient_machine__( param = __get_parameter_in_gradient_machine__(
...@@ -187,7 +184,7 @@ class Parameters(object): ...@@ -187,7 +184,7 @@ class Parameters(object):
(shape, value.shape)) (shape, value.shape))
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value)) self.__tmp_params__[key] = value
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine, __copy_parameter_to_gradient_machine__(each_gradient_machine,
...@@ -231,7 +228,7 @@ class Parameters(object): ...@@ -231,7 +228,7 @@ class Parameters(object):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0: if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__: for name, val in self.__tmp_params__.iteritems():
try: try:
__copy_parameter_to_gradient_machine__(gradient_machine, __copy_parameter_to_gradient_machine__(gradient_machine,
name, val) name, val)
...@@ -287,6 +284,18 @@ class Parameters(object): ...@@ -287,6 +284,18 @@ class Parameters(object):
@staticmethod @staticmethod
def from_tar(f): def from_tar(f):
"""
Create a `Parameters` object from the given file. And
the `Parameters` only contains the parameters in this
file. It is adapted the parameters are same in the
defined network and the given file. For example, it
can be used in the inference.
:param f: the initialized model file.
:type f: tar file
:return: A Parameters object.
:rtype: Parameters.
"""
params = Parameters() params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r') tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar: for finfo in tar:
...@@ -302,6 +311,21 @@ class Parameters(object): ...@@ -302,6 +311,21 @@ class Parameters(object):
params.deserialize(param_name, f) params.deserialize(param_name, f)
return params return params
def init_from_tar(self, f):
"""
Different from `from_tar`, this interface can be used to
init partial network parameters from another saved model.
:param f: the initialized model file.
:type f: tar file
:return: Nothing.
"""
tar_param = Parameters.from_tar(f)
for pname in tar_param.names():
if pname in self.names():
self.set(pname, tar_param.get(pname))
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
""" """
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册