diff --git a/.travis.yml b/.travis.yml index 64961adcf28cbb24ead67ca6989fd2700956f2d5..a53bd1809416d6f14a1ec7f603622d3303d1ab28 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,3 @@ -group: deprecated-2017Q2 language: cpp cache: directories: diff --git a/CMakeLists.txt b/CMakeLists.txt index b779caefb9a8c9e67953b909e6a61c53a45ac13e..24a7066adc57c510030b0926c81849daa4caa6ca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,6 +93,7 @@ include(external/openblas) # download, build, install openblas include(external/swig) # download, build, install swig include(external/warpctc) # download, build, install warpctc include(external/any) # download libn::any +include(external/eigen) # download eigen3 include(generic) # simplify cmake module include(package) # set paddle packages diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake new file mode 100644 index 0000000000000000000000000000000000000000..253d436bcc04d8e0db78f6a4a2c67a050f456bba --- /dev/null +++ b/cmake/external/eigen.cmake @@ -0,0 +1,20 @@ +INCLUDE(ExternalProject) + +SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) + +INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3) + +ExternalProject_Add( + eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz" + URL_MD5 "1a47e78efe365a97de0c022d127607c3" + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) + +LIST(APPEND external_project_dependencies eigen3) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 69e8164a00d1fb57b79c63ba88c2846d30d80cd2..11c1f677ae5b308558b54bf49caf168cf6023444 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -77,6 +77,15 @@ # /cmake/external/*.cmake: # # cc_test(example_test SRCS example_test.cc DEPS example glog gflags) +# +# To build a go static library using Golang, use the go_ prefixed version: +# +# go_library(example STATIC) +# +# To build a go shared library using Golang, use the go_ prefixed version: +# +# go_library(example SHARED) +# if(NOT APPLE) find_package(Threads REQUIRED) @@ -246,42 +255,53 @@ endfunction(nv_test) set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") file(MAKE_DIRECTORY ${GOPATH}) +set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle") -# Because api.go defines a GO wrapper to ops and tensor, it depends on -# both. This implies that if any of tensor.{h,cc}, ops.{h,cu}, or -# api.go is changed, api need to be re-built. -# go_library(api -# SRCS -# api.go -# DEPS -# tensor # Because ops depend on tensor, this line is optional. -# ops) function(go_library TARGET_NAME) - set(options OPTIONAL) + set(options STATIC static SHARED shared) set(oneValueArgs "") - set(multiValueArgs SRCS DEPS) + set(multiValueArgs DEPS) cmake_parse_arguments(go_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - if (${go_library_OPTIONAL} STREQUAL "SHARED") + + if (go_library_SHARED OR go_library_shared) set(BUILD_MODE "-buildmode=c-shared") - if(APPLE) - set(LIB_NAME "lib${TARGET_NAME}.dylib") - else() - set(LIB_NAME "lib${TARGET_NAME}.so") - endif() + set(LIB_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}") else() set(BUILD_MODE "-buildmode=c-archive") - set(LIB_NAME "lib${TARGET_NAME}.a") + set(LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}") endif() - add_custom_command(OUTPUT ${TARGET_NAME}_timestamp + + # Add dummy code to support `make target_name` under Terminal Command + set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) + file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") + if (go_library_SHARED OR go_library_shared) + add_library(${TARGET_NAME} SHARED ${dummyfile}) + else() + add_library(${TARGET_NAME} STATIC ${dummyfile}) + endif() + if(go_library_DEPS) + add_dependencies(${TARGET_NAME} ${go_library_DEPS}) + endif(go_library_DEPS) + + # we need to symlink Paddle directory into GOPATH. If we + # don't do it and we have code that depends on Paddle, go + # get ./... will download a new Paddle repo from Github, + # without the changes in our current Paddle repo that we + # want to build. + file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" + # Symlink Paddle directory into GOPATH + COMMAND mkdir -p ${PADDLE_IN_GOPATH} + COMMAND rm -rf ${PADDLE_IN_GOPATH} + COMMAND ln -sf ${CMAKE_SOURCE_DIR} ${PADDLE_IN_GOPATH} + # Automatically get all dependencies specified in the source code + COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./... + # Golang build source code COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" - ${go_library_SRCS} + ${GO_SOURCE} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS}) - add_library(${TARGET_NAME} STATIC IMPORTED) - set_property(TARGET ${TARGET_NAME} PROPERTY - IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}") - add_dependencies(${TARGET_NAME} ${TARGET_NAME}_lib) endfunction(go_library) function(go_binary TARGET_NAME) @@ -311,10 +331,3 @@ function(go_test TARGET_NAME) add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS}) add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}) endfunction(go_test) - -# go_extern will download extern go project. -# go_extern(target_name extern_source) -# go_extern(go_redis github.com/hoisie/redis) -function(go_extern TARGET_NAME) - add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN}) -endfunction(go_extern) diff --git a/cmake/system.cmake b/cmake/system.cmake index 3b5cbfdd631b42ada49d0e1486824373dc69e519..adf5e2c539740076ad1808353522c7467d765e64 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -33,6 +33,7 @@ ELSE(WIN32) SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING "Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.") ENDIF() + set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") ELSE(APPLE) IF(EXISTS "/etc/issue") diff --git a/doc/design/scope.md b/doc/design/scope.md new file mode 100644 index 0000000000000000000000000000000000000000..2ff416f06e8ada48b1d4922f8869a106f35799e2 --- /dev/null +++ b/doc/design/scope.md @@ -0,0 +1,124 @@ +# Design of Scope in Paddle + +## Overview + +Scope is an important concept in programming languages, which defines a program region that a set of bindings between names and entities applies. In a specific scope, a valid name is uniquely associated with an entity, such as a variable. And in another scope, this name may refer to other entity or nothing at all. It clearly restricts the visibility and validity of names in a program. Hence **Scope** is introduced to PaddlePaddle to manage variables in context. But different from the original abstract concept, Scope now becomes an object with two important attributes: + +- Scope is an association of a name to variable. +- Variables in a parent scope can be retrieved from local scope. + +A detailed explanation of these two attributes goes as following. + + +## Scope is an association of a name to variable. + +Scope is an association of a name to variable. All variables belong to `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can run in different scopes and update different variable in the scope. + + +1. Scope only contains a map of a name to variable. + + All parameters, data, states in a Net should be variables and stored inside a scope. Each op should get inputs and outputs to do computation from a scope, such as data buffer, state(momentum) etc. + +1. Variable can only be created by Scope and a variable can only be got from Scope. User cannot create or get a variable outside a scope. This is a constraints of our framework, and will keep our framework simple and clear. + +1. Scope only contains methods that are used to Create and Get Variables. Scope do not contain Operators and have no information to run them. + `Net` is designed to drive the computation and Scope only contains a map of variables. There is no computation logic inside a `Scope`. Scope just handles the lifetime management of variables. + - `Create` is used to create a Variable by its name and add the mapping relation. + - `Get` is used to find a Variable by name. + +1. Every variable only belongs to one certain Scope. + + Variable can not belong to many scopes. If you want to use variables from parent scope, you can use `parent scope`. + +1. Scope should destruct all Variables inside it when itself is destructed. User can never store `Variable` pointer somewhere else. + + Because Variable can only be got from Scope. When destroying Scope, we also need to destroy all the Variables in it. If user store `Variable` pointer to private data member or some global variable, the pointer will be a invalid pointer when associated `Scope` is destroyed. + +```cpp +class Scope { + public: + Variable* CreateVariable(const std::string& name); + const Variable* GetVariable(const std::string& name) const; + + private: + std::unordered_map> vars_; +}; +``` + + +## Parent scope and local scope + +Just like [scope](https://en.wikipedia.org/wiki/Scope_(computer_science)) in programming languages, `Scope` in the neural network can also be a local scope. There are two attributes about local scope. + +1. We can create local variables in a local scope. When that local scope are destroyed, all local variables should also be destroyed. +2. Variables in a parent scope can be retrieved from local scopes of that parent scope, i.e., when user get a variable from a scope, it will try to search this variable in current scope. If there is no such variable in the local scope, `scope` will keep searching from its parent, until the variable is found or there is no parent. + +```cpp +class Scope { + public: + Scope(const std::shared_ptr& scope): parent_(scope) {} + + Variable* GetVariable(const std::string& name) const { + Variable* var = GetVarLocally(name); + if (var != nullptr) { + return var; + } else if (parent_ != nullptr) { + return parent_->GetVariable(name); + } else { + return nullptr; + } + } + + private: + std::shared_ptr parent_ {nullptr}; +}; +``` + +In `Scope` class, there is a private data member called `parent_`. `parent_` is a smart pointer to its parent scope. When user `Get` a variable by its `name`, the `name` will be searched inside the current scope. If the variable cannot be found locally and parent scope is not a `nullptr`, the variable will be searched inside that parent scope. `parent_` pointer's default value is `nullptr`. It means that the scope is a global scope when `parent_` is nullptr. + +A local scope is very useful when we implement Recurrent Neural Network. Each timestep of an RNN should be a `Net`. Each `Net` of timestep (`StepNet` for short) should use an independent local scope. Just like variables in a while loop is inside a local scope in programming languages. By using a single `StepNet` and changing local scope, we can implement an RNN easily. + +# Interface Design + +```cpp +class Variable { + private: + Variable() = default; + friend class Scope; +}; + +class Scope { + private: + Scope(const std::shared_ptr& parent = nullptr); + + public: + static std::shared_ptr Create(const std::shared_ptr& parent = nullptr); + + // return nullptr if not found. + Variable* GetVariable(const std::string& name) const; + + // return Error if already contains same name variable. + Error CreateVariable(const std::string& name); + + private: + std::shared_ptr parent_; + std::unordered_map> vars_; +}; +``` +## Only scope can create a variable + +To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`. + +## When scope destroyed, all variables inside this scope should be destroyed together + +The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together. + +## Sharing a parent scope + +Local scope contains a `parent_` pointer. It is a linked-list for scopes. Using a `shared_ptr` because when a local scope is using, its parents cannot be destroyed. + +Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shared pointer. We cannot construct a scope variable, because it cannot be passed to other scope as `parent` pointer. + +## Orthogonal interface + +`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily. diff --git a/doc/getstarted/concepts/use_concepts_cn.rst b/doc/getstarted/concepts/use_concepts_cn.rst index e63ca11102c8ce457afcc3c262fa5f159361c01d..f15b11bd780402a3ec1755900e8c648f5d2a7bc5 100644 --- a/doc/getstarted/concepts/use_concepts_cn.rst +++ b/doc/getstarted/concepts/use_concepts_cn.rst @@ -111,7 +111,7 @@ PaddlePaddle支持不同类型的输入数据,主要包括四种类型,和 # define training dataset reader def train_reader(): train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]]) - train_y = np.array([-2, -3, -7, -7]) + train_y = np.array([[-2], [-3], [-7], [-7]]) def reader(): for i in xrange(train_y.shape[0]): yield train_x[i], train_y[i] diff --git a/go/cmake/CMakeDetermineGoCompiler.cmake b/go/cmake/CMakeDetermineGoCompiler.cmake deleted file mode 100644 index a9bb6906c7440782bd648bb7505a548248a11bb0..0000000000000000000000000000000000000000 --- a/go/cmake/CMakeDetermineGoCompiler.cmake +++ /dev/null @@ -1,44 +0,0 @@ -if(NOT CMAKE_Go_COMPILER) - if(NOT $ENV{GO_COMPILER} STREQUAL "") - get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT) - - if(CMAKE_Go_FLAGS_ENV_INIT) - set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler") - endif() - - if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT}) - message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.") - endif() - - endif() - - set(Go_BIN_PATH - $ENV{GOPATH} - $ENV{GOROOT} - $ENV{GOROOT}/../bin - $ENV{GO_COMPILER} - /usr/bin - /usr/local/bin - ) - - if(CMAKE_Go_COMPILER_INIT) - set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler") - else() - find_program(CMAKE_Go_COMPILER - NAMES go - PATHS ${Go_BIN_PATH} - ) - EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION) - STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}") - message("-- The Golang compiler identification is ${VERSION}") - message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}") - endif() - -endif() - -mark_as_advanced(CMAKE_Go_COMPILER) - -configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in - ${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY) - -set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") diff --git a/go/cmake/CMakeGoCompiler.cmake.in b/go/cmake/CMakeGoCompiler.cmake.in deleted file mode 100644 index a71f08e064656fbaad8cfa77aea6f216515712ef..0000000000000000000000000000000000000000 --- a/go/cmake/CMakeGoCompiler.cmake.in +++ /dev/null @@ -1,8 +0,0 @@ -set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@") -set(CMAKE_Go_COMPILER_LOADED 1) - -set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go) -set(CMAKE_Go_LINKER_PREFERENCE 40) -set(CMAKE_Go_OUTPUT_EXTENSION .o) -set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1) -set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER") diff --git a/go/cmake/CMakeGoInformation.cmake b/go/cmake/CMakeGoInformation.cmake deleted file mode 100644 index ba51ac93fcd429478f324b66bd5129d94ea2a8f4..0000000000000000000000000000000000000000 --- a/go/cmake/CMakeGoInformation.cmake +++ /dev/null @@ -1,7 +0,0 @@ -if(NOT CMAKE_Go_COMPILE_OBJECT) - set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o ") -endif() - -if(NOT CMAKE_Go_LINK_EXECUTABLE) - set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o ") -endif() diff --git a/go/cmake/CMakeTestGoCompiler.cmake b/go/cmake/CMakeTestGoCompiler.cmake deleted file mode 100644 index b9891b015baced05b51e34dba562fd98a84fe14c..0000000000000000000000000000000000000000 --- a/go/cmake/CMakeTestGoCompiler.cmake +++ /dev/null @@ -1 +0,0 @@ -set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "") diff --git a/go/cmake/flags.cmake b/go/cmake/flags.cmake deleted file mode 100644 index a167c432a920e9ee93878603f3b946e8593412f6..0000000000000000000000000000000000000000 --- a/go/cmake/flags.cmake +++ /dev/null @@ -1,45 +0,0 @@ -# Setting Paddle Compile Flags -include(CheckCXXCompilerFlag) -include(CheckCCompilerFlag) -include(CheckCXXSymbolExists) -include(CheckTypeSize) - -function(CheckCompilerCXX11Flag) - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) - message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") - endif() - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" - # Apple Clang is a different compiler than upstream Clang which havs different version numbers. - # https://gist.github.com/yamaya/2924292 - if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X - if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1) - message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.") - endif() - else() - if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3) - message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.") - endif() - endif() - endif() -endfunction() - -CheckCompilerCXX11Flag() -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") - -# Common gpu architectures: Kepler, Maxwell -foreach(capability 30 35 50) - list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}") -endforeach() - -if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0") - list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52") -endif() - -# Modern gpu architectures: Pascal -if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0") - list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60") -endif() - -set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS}) diff --git a/go/cmake/golang.cmake b/go/cmake/golang.cmake deleted file mode 100644 index a5a43886f887e495500fa26b3c26fa69c63eded0..0000000000000000000000000000000000000000 --- a/go/cmake/golang.cmake +++ /dev/null @@ -1,48 +0,0 @@ -set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go") -file(MAKE_DIRECTORY ${GOPATH}) -set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle") -file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH}) - -function(GO_LIBRARY NAME BUILD_TYPE) - if(BUILD_TYPE STREQUAL "STATIC") - set(BUILD_MODE -buildmode=c-archive) - set(LIB_NAME "lib${NAME}.a") - else() - set(BUILD_MODE -buildmode=c-shared) - if(APPLE) - set(LIB_NAME "lib${NAME}.dylib") - else() - set(LIB_NAME "lib${NAME}.so") - endif() - endif() - - file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") - file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}) - - # find Paddle directory. - get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) - get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) - get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY) - - # automatically get all dependencies specified in the source code - # for given target. - add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...) - - # make a symlink that references Paddle inside $GOPATH, so go get - # will use the local changes in Paddle rather than checkout Paddle - # in github. - add_custom_target(${NAME}_copyPaddle - COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle - COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle) - add_dependencies(${NAME}_goGet ${NAME}_copyPaddle) - - add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} - -o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}" - ${CMAKE_GO_FLAGS} ${GO_SOURCE} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - - add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN}) - add_dependencies(${NAME} ${NAME}_goGet) - -endfunction(GO_LIBRARY) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 6c85b1804bb9c5f3a8bc46bb3f54cc62c56cca70..8a42d4f8af1713e246f9efaf5dc7ba878c3b271e 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -30,7 +30,13 @@ func main() { log.SetLevel(level) timeout := time.Second * time.Duration((*etcdTimeout)) - s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) + e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + idx, err := e.Register() + if err != nil { + panic(err) + } + + s, err := pserver.NewService(idx) if err != nil { panic(err) } diff --git a/go/master/c/client.go b/go/master/c/client.go index b186474dc33138aeb02a2ffe34418b379b7a2db0..9e35e986002c0ae3b7593150ece96dba29a1521b 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -13,10 +13,13 @@ typedef int paddle_master_client; import "C" import ( + "strings" "sync" + "time" "unsafe" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/coreos/etcd/clientv3" log "github.com/sirupsen/logrus" ) @@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { return h } -type addresser string - -func (a addresser) Address() string { - return string(a) +//export paddle_new_etcd_master_client +func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client { + p := C.GoString(etcdEndpoints) + cli, err := clientv3.New(clientv3.Config{ + Endpoints: strings.Split(p, ","), + DialTimeout: time.Second * time.Duration(timeout), + }) + if err != nil { + panic(err) + } + ch := make(chan string, 1) + a, err := master.GetKey(cli, master.DefaultAddrPath, timeout) + if err != nil { + panic(err) + } + ch <- a + go master.WatchKey(cli, master.DefaultAddrPath, ch) + c := master.NewClient(ch, bufSize) + return add(c) } //export paddle_new_master_client func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { a := C.GoString(addr) - c := master.NewClient(addresser(a), bufSize) + ch := make(chan string, 1) + ch <- a + c := master.NewClient(ch, bufSize) return add(c) } diff --git a/go/master/client.go b/go/master/client.go index 8451820c1963dd5a4eff0c3ab7763eb6a8e05ba4..d3bea49d0a8166420e83478076cc7bc81e48598d 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -2,18 +2,12 @@ package master import ( "os" - "time" "github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/recordio" log "github.com/sirupsen/logrus" ) -// Addresser provide the address of the master server. -type Addresser interface { - Address() string -} - // Client is the client of the master server. type Client struct { conn *connection.Conn @@ -24,11 +18,11 @@ type Client struct { // // bufSize is the record buffer size. NextRecord will read from this // buffer. -func NewClient(addr Addresser, bufSize int) *Client { +func NewClient(addrCh <-chan string, bufSize int) *Client { c := &Client{} c.conn = connection.New() c.ch = make(chan []byte, bufSize) - go c.monitorMaster(addr) + go c.monitorMaster(addrCh) go c.getRecords() return c } @@ -72,12 +66,10 @@ func (c *Client) getRecords() { } } -func (c *Client) monitorMaster(addr Addresser) { +func (c *Client) monitorMaster(addrCh <-chan string) { lastMaster := "" - monitor := func() { - // get the lastest address of the master server, + for curMaster := range addrCh { // connect to the new address once address changed. - curMaster := addr.Address() if curMaster != lastMaster { if curMaster == "" { err := c.conn.Close() @@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) { // to retry next time. curMaster = lastMaster } - } } - lastMaster = curMaster } - - monitor() - ticker := time.NewTicker(10 * time.Second) - for _ = range ticker.C { - monitor() - } } // SetDataset set dataset for the master server to dispatch. diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 251225780ae3077f90655b4e874d03b4f3794525..364dce7b58cf6366af711bde9107559a762563a4 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -26,12 +26,6 @@ func init() { log.SetLevel(log.ErrorLevel) } -type TestAddresser string - -func (a TestAddresser) Address() string { - return string(a) -} - func TestGetFinishTask(t *testing.T) { const path = "/tmp/master_client_test_0" @@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { if err != nil { panic(err) } - go func(l net.Listener) { s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) if err != nil { @@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { // Manually intialize client to avoid calling c.getRecords() c := &Client{} c.conn = connection.New() - go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) + addr := fmt.Sprintf(":%d", p) + ch := make(chan string, 1) + ch <- addr + go c.monitorMaster(ch) c.SetDataset([]string{path}) - checkOnePass := func(i int) { var tasks []Task for idx := 0; idx < totalTask; idx++ { diff --git a/go/master/client_test.go b/go/master/client_test.go index 85a86761c2e5897e3e89cbebfd32f7666c4a9f7f..c00aeebfd5d1fef6de4a8c67bf7f998a42ee863b 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { path = "/tmp/master_client_TestFull" total = 50 ) - l, err := net.Listen("tcp", ":0") if err != nil { panic(err) @@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { if err != nil { panic(err) } - go func(l net.Listener) { s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) if err != nil { @@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { } w.Close() f.Close() - - c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) + curAddr := make(chan string, 1) + curAddr <- fmt.Sprintf(":%d", p) + c := master.NewClient(curAddr, 10) c.SetDataset([]string{path}) - for pass := 0; pass < 50; pass++ { received := make(map[byte]bool) for i := 0; i < total; i++ { diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index b7293a759896f113d630d57d14b4b4ac8963f54a..e27c014792f31ca27fe1a1636d69acccc4206ea3 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -18,8 +18,8 @@ const ( DefaultAddrPath = "/master/addr" ) -// EtcdClient is the etcd client that master uses for fault tolerance -// and service registry. +// EtcdClient is the etcd client that the master uses for fault +// tolerance and service registry. type EtcdClient struct { lockPath string statePath string @@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { state := kvs[0].Value return state, nil } + +// GetKey gets the value by the specify key. +func GetKey(c *clientv3.Client, key string, timeout int) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + resp, err := c.Get(ctx, key) + cancel() + if err != nil { + return "", err + } + kvs := resp.Kvs + if len(kvs) == 0 { + return "", nil + } + v := kvs[0].Value + return string(v), nil +} + +// WatchKey watches the specify key and send to valChan if there is some event. +func WatchKey(c *clientv3.Client, key string, valChan chan<- string) { + rch := c.Watch(context.Background(), key) + for wresp := range rch { + for _, ev := range wresp.Events { + // if received event is DELETE, the value will be an empty string + log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) + valChan <- string(ev.Kv.Value) + } + } +} diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/cclient/CMakeLists.txt index fff7ae78582732c1b7af7a757c340804e91316d6..d2c339d68866bd5c91403227e97af2c97bb30eeb 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/cclient/CMakeLists.txt @@ -1,14 +1,3 @@ -cmake_minimum_required(VERSION 3.0) - -get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY) -get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY) -set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake") - -project(cxx_go C Go) - -include(golang) -include(flags) - go_library(paddle_pserver_cclient STATIC) add_subdirectory(test) diff --git a/go/pserver/cclient/test/CMakeLists.txt b/go/pserver/cclient/test/CMakeLists.txt index 1a3dd7e5e9e0ff3273fc2be67c48461797b4a6b3..916e4e99a24ea7f76f1935fc7d281cd158ac5061 100644 --- a/go/pserver/cclient/test/CMakeLists.txt +++ b/go/pserver/cclient/test/CMakeLists.txt @@ -1,22 +1,3 @@ -cmake_minimum_required(VERSION 3.0) -add_executable(main main.c) -add_dependencies(main paddle_pserver_cclient) -add_executable(test_cclient test_cclient.c) -add_dependencies(test_cclient paddle_pserver_cclient) - -if(APPLE) - set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security") -else() - set(CMAKE_EXE_LINKER_FLAGS "-pthread") -endif() - -if(PROJ_ROOT) - include_directories(${CMAKE_CURRENT_BINARY_DIR}/..) - target_link_libraries(main ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread) - target_link_libraries(test_cclient ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread) -else(PROJ_ROOT) - include_directories(${CMAKE_BINARY_DIR}) - target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread) - target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread) -endif(PROJ_ROOT) +cc_library(main SRCS main.c DEPS paddle_pserver_cclient) +cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient) diff --git a/go/pserver/client.go b/go/pserver/client.go index dda915977282d4880ddcc8c18ef6fd80ede9e01b..6938b9d5ce6f6d73c05bd6e3154777023965c319 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -1,6 +1,7 @@ package pserver import ( + "errors" "hash/fnv" "sort" "time" @@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error { // SendGrads sends gradients to parameter servers for updating // parameters. func (c *Client) SendGrads(grads []Gradient) error { + if len(grads) == 0 { + return errors.New("no gradient received") + } errCh := make(chan error, len(grads)) for _, g := range grads { go func(g Gradient) { diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index 6ecf1fa08a02ed2ce04fae0903cebd46a7b768a4..5bd16118a7f70b766016abfce55f6bb2adf8cc60 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -7,7 +7,6 @@ import ( "strconv" "strings" "testing" - "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) @@ -31,7 +30,7 @@ func init() { port[i] = p go func(l net.Listener) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go new file mode 100644 index 0000000000000000000000000000000000000000..4d88243edd4aa817ddc263ba316a3f6be9e1e67f --- /dev/null +++ b/go/pserver/etcd_client.go @@ -0,0 +1,181 @@ +package pserver + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +// EtcdClient is the etcd client that the pserver uses for fault +// tolerance, service registry and coordination. +type EtcdClient struct { + numPservers int + etcdEndpoints string + etcdClient *clientv3.Client + // etcdTimeout is also used as retry intervals. + etcdTimeout time.Duration + // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. + externalIP string + // desired number of pservers in the job. + // assume desired will not change during one training job. + desired int +} + +// NewEtcdClient creates an EtcdClient +func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { + return &EtcdClient{ + etcdTimeout: timeout, + numPservers: numPservers, + etcdEndpoints: endpoints, + } +} + +// Register registers the pserver on etcd +// +// Register returns the index of the current pserver. +func (e *EtcdClient) Register() (int, error) { + + var err error + e.externalIP, err = networkhelper.GetExternalIP() + if err != nil { + return 0, err + } + + // initialize connection to etcd. + ep := strings.Split(e.etcdEndpoints, ",") + for { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: e.etcdTimeout, + }) + if err != nil { + log.Errorf("connect to etcd error: %v", err) + time.Sleep(e.etcdTimeout) + continue + } + e.etcdClient = cli + log.Debugf("inited client to %s", e.etcdEndpoints) + break + } + // init /ps_desired using transaction, for multiple pservers may want to write + // it at the same time. + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := e.initDesiredPsercers(ctx, e.numPservers) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + // TODO: when implementing extending or reducing pservers, /ps_desired is + // changed, then we need to watch /ps_desired node for events. For now, just + // write once when init and read from it. + // wait and set s.desired init value + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := e.etcdClient.Get(ctx, PsDesired) + cancel() + if err != nil { + log.Errorf("getting %s error: %v", PsDesired, err) + time.Sleep(e.etcdTimeout) + continue + } + if len(resp.Kvs) != 0 { + e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(e.etcdTimeout) + // NOTE: wait util ps_desired value change + continue + } + break + } + } + + var pserverIdx int + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + var err error + pserverIdx, err = e.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + + return pserverIdx, nil +} + +func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + dsStr := c.Get(PsDesired) + if dsStr == "" { + c.Put(PsDesired, strconv.Itoa(numPservers)) + } + return nil + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) +} + +// registerPserverEtcd registers pserver node on etcd using transaction. +func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { + var idx int + _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + registered := false + for i := 0; i < e.desired; i++ { + psKey := "/ps/" + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + ps := c.Get(psKey) + log.Debugf("got value (%s) for key: %s", ps, psKey) + + if ps == "" { + resp, err := e.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } + // find the first id and write info + c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) + ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) + if kaerr != nil { + log.Errorf("keepalive etcd node error: %v", kaerr) + return kaerr + } + + // Eat the keep alive message so etcd + // will not expire the lease. + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) + log.Debug("register finished") + idx = i + registered = true + break + } + } + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) + + if err != nil { + return 0, err + } + + return idx, nil +} diff --git a/go/pserver/service.go b/go/pserver/service.go index f966595fdccbf23e23f94a857503ce05815164ef..f386ebea1eb8659a988de2a807303bb6687fa429 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,18 +1,9 @@ package pserver import ( - "context" "errors" "fmt" - "strconv" - "strings" "sync" - "time" - - "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -55,160 +46,25 @@ type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} + idx int mu sync.Mutex opt *optimizer paramMap map[string]Parameter - - etcdEndpoints string - etcdClient *clientv3.Client - // etcdTimeout is also used as retry intervals. - etcdTimeout time.Duration - // desired number of pservers in the job. - // assume desired will not change during one training job. - desired int - // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. - externalIP string } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { - s := &Service{opt: newOptimizer(sgd, 0.005)} +func NewService(idx int) (*Service, error) { + s := &Service{ + idx: idx, + opt: newOptimizer(sgd, 0.005), + } s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) - s.etcdEndpoints = endpoints - s.etcdTimeout = timeout - - var err error - s.externalIP, err = networkhelper.GetExternalIP() - if err != nil { - return nil, err - } - - if endpoints != "" { - // initialize connection to etcd, try - ep := strings.Split(s.etcdEndpoints, ",") - for { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: s.etcdTimeout, - }) - if err != nil { - log.Errorf("connect to etcd error: %v", err) - time.Sleep(s.etcdTimeout) - continue - } - s.etcdClient = cli - log.Debugf("inited client to %s", s.etcdEndpoints) - break - } - // init /ps_desired using transaction, for multiple pservers may want to write - // it at the same time. - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.initDesiredPsercers(ctx, numPservers) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - // TODO: when implementing extending or reducing pservers, /ps_desired is - // changed, then we need to watch /ps_desired node for events. For now, just - // write once when init and read from it. - // wait and set s.desired init value - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := s.etcdClient.Get(ctx, PsDesired) - cancel() - if err != nil { - log.Errorf("getting %s error: %v", PsDesired, err) - time.Sleep(s.etcdTimeout) - continue - } - if len(resp.Kvs) != 0 { - s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) - if err != nil { - log.Errorf("value of %s invalid %v\n", PsDesired, err) - time.Sleep(s.etcdTimeout) - // NOTE: wait util ps_desired value change - continue - } - break - } - } - // try register pserver node on etcd - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.registerPserverEtcd(ctx) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - } // if endpoints != "" - // Bypass etcd registration if no endpoints specified return s, nil } -func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - dsStr := c.Get(PsDesired) - if dsStr == "" { - c.Put(PsDesired, strconv.Itoa(numPservers)) - } - return nil - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - -// registerPserverEtcd registers pserver node on etcd using transaction. -func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - registered := false - for i := 0; i < s.desired; i++ { - psKey := "/ps/" + strconv.Itoa(i) - log.Debugf("checking %s", psKey) - ps := c.Get(psKey) - log.Debugf("got value (%s) for key: %s", ps, psKey) - - if ps == "" { - resp, err := s.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } - // find the first id and write info - c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) - log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) - if kaerr != nil { - log.Errorf("keepalive etcd node error: %v", kaerr) - return kaerr - } - - // Eat the keep alive message so etcd - // will not expire the lease. - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) - log.Debug("register finished") - registered = true - break - } - } - if registered == true { - return nil - } - return errors.New("not registerd, may due to already have enough pservers") - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index f317535592165b921491120888badd30c6795c12..d9d887cffd462eed48b972466a7d83bae35d9a1c 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestFull(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -75,7 +75,7 @@ func TestFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService("", time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 979b68e8272184d80b45cd50a1b606cb76056224..307e99bbe3a833f1fe26057ec38d0b96e04bc0fe 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -9,7 +9,7 @@ add_subdirectory(pserver) add_subdirectory(trainer) add_subdirectory(scripts) add_subdirectory(optimizer) -add_subdirectory(strings) +add_subdirectory(string) if(Boost_FOUND) add_subdirectory(memory) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index e3c3155aa902c941058ea1b15488360df6c06175..b06ecc26286de1385f6ea4eabc01396c07d7aa52 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,6 +1,5 @@ cc_library(ddim SRCS ddim.cc) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) - nv_test(dim_test SRCS dim_test.cu DEPS ddim) - cc_test(variable_test SRCS variable_test.cc) +cc_test(enforce_test SRCS enforce_test.cc) diff --git a/paddle/framework/enforce.h b/paddle/framework/enforce.h new file mode 100644 index 0000000000000000000000000000000000000000..56cb7f95647e81efef58b156002d0d378ee22820 --- /dev/null +++ b/paddle/framework/enforce.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +namespace paddle { +namespace framework { + +/** + * @brief Enforce exception. Inherits std::exception + * + * All enforce condition not met, will throw an EnforceNotMet exception. + */ +class EnforceNotMet : public std::exception { + public: + EnforceNotMet(const std::string& msg, const char* file, int fileline) { + std::ostringstream sout; + sout << msg << " at [" << file << ":" << fileline << "];"; + all_msg_ = sout.str(); + } + + const char* what() const noexcept override { return all_msg_.c_str(); } + + private: + std::string all_msg_; +}; + +// From https://stackoverflow.com/questions/30130930/ +// __buildin_expect is in C++ 11 standard. Since the condition which enforced +// should be true in most situation, it will make the compiler generate faster +// code by adding `UNLIKELY` macro. +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) + +/** + * @brief Throw a EnforceNotMet exception, automatically filled __FILE__ & + * __LINE__ + * + * This macro take __VA_ARGS__, user can pass any type if that type can + * serialize to std::ostream + */ +#define PADDLE_THROW(...) \ + do { \ + throw ::paddle::framework::EnforceNotMet( \ + ::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + } while (0) + +/** + * @brief Enforce a condition, otherwise throw an EnforceNotMet + */ +#define PADDLE_ENFORCE(condition, ...) \ + do { \ + if (UNLIKELY(!(condition))) { \ + PADDLE_THROW(__VA_ARGS__); \ + } \ + } while (0) + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/enforce_test.cc b/paddle/framework/enforce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f8da1a192f63a54324d80725c9d2f156fb11a481 --- /dev/null +++ b/paddle/framework/enforce_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +TEST(ENFORCE, OK) { + PADDLE_ENFORCE(true, "Enforce is ok %d now %f", 123, 0.345); + size_t val = 1; + const size_t limit = 10; + PADDLE_ENFORCE(val < limit, "Enforce is OK too"); +} + +TEST(ENFORCE, FAILED) { + bool in_catch = false; + try { + PADDLE_ENFORCE(false, "Enforce is not ok %d at all", 123); + } catch (paddle::framework::EnforceNotMet err) { + in_catch = true; + std::string msg = "Enforce is not ok 123 at all"; + const char* what = err.what(); + for (size_t i = 0; i < msg.length(); ++i) { + ASSERT_EQ(what[i], msg[i]); + } + } + ASSERT_TRUE(in_catch); +} \ No newline at end of file diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index b33e10e6820129a874f5355d14d8a3e990186025..72c4a7a2a1d1cf93a784f24e687727ee8481484c 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -25,21 +25,24 @@ class Variable { public: template const T& Get() const { - PADDLE_ASSERT(holder_ != nullptr); - PADDLE_ASSERT(std::type_index(typeid(T)) == - std::type_index(holder_->Type())); + PADDLE_ASSERT(IsType()); return *static_cast(holder_->Ptr()); } template T* GetMutable() { - if (holder_ == nullptr || - std::type_index(typeid(T)) != std::type_index(holder_->Type())) { + if (!IsType()) { holder_.reset(new PlaceholderImpl(new T())); } return static_cast(holder_->Ptr()); } + template + bool IsType() const { + return holder_ != nullptr && + std::type_index(typeid(T)) == std::type_index(holder_->Type()); + } + private: struct Placeholder { virtual ~Placeholder() {} diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 8ef5e9d0c116dd088b5c5c318dfb47c245b471fa..018da6c76dc27a74b074ec52c18347beba8164fc 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -601,7 +601,7 @@ void TrainerThread::backward() { void TrainerThread::backwardCallback(Parameter* para) { // CPU parameters are merged in the end - if (!para->useGpu()) return; + if (!para->useGpu() || para->isStatic()) return; int paramId = para->getID(); if (multiMachine_->getNumThreads() == 1) { diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index 125aaf947f3c9d976b117667d1d1b7700a029cc6..4b92b5d163ad107c0783beae45f8c936112fcccf 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -191,6 +191,11 @@ void Layer::addOutputArgument(int deviceId) { void Layer::copyOutputToOtherDevice() { for (size_t i = 0; i != outputOtherDevice_.size(); i++) { SetDevice device(outputOtherDevice_[i].deviceId); + // If outputOtherDevice_[i].value is a CpuMatrix, + // the copyFrom is a synchronous interface. + // If outputOtherDevice_[i].value is a GpuMatrix, since subsequent + // calculations are all on HPPL_STREAM_DEFAULT, + // copyFrom can be an asynchronous interface. outputOtherDevice_[i].value->copyFrom(*getOutputValue(), HPPL_STREAM_DEFAULT); outputOtherDevice_[i].sequenceStartPositions = diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index c910146164ebfb0737583c72c48ce6dbc5b49939..4431d613f655c1d0c8da13bb5ac9225980c650ad 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -1565,6 +1565,8 @@ void CpuMatrix::copyFrom(const Matrix& src, hl_stream_t stream) { const_cast(src.getData()), sizeof(real) * elementCnt_, stream); + // There is a need to add synchronization to ensure that the data is copied. + hl_stream_synchronize(stream); } else if (typeid(src) == typeid(CpuMatrix)) { memcpy(data_, src.getData(), sizeof(real) * elementCnt_); } else { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 748be850b4c902d1b48c1dafbb0d5ea2bf197e6e..7dfd593225065e18830b2b0c0ce854fe7a2d5178 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -239,7 +239,8 @@ public: LOG(FATAL) << "Not implemented"; } - // asynchronous copy + // For GpuMatrix this is an asynchronous copy interface + // For CpuMatrix this is an synchronous copy interface virtual void copyFrom(const Matrix& src, hl_stream_t stream) { LOG(FATAL) << "Not implemented"; } diff --git a/paddle/math/Vector.cpp b/paddle/math/Vector.cpp index c519ca500afb1dbfdff6e8d211786f4e18ccf1fd..eb87ee9bb7936d27c0c32a1a4b35ff49871c0a10 100644 --- a/paddle/math/Vector.cpp +++ b/paddle/math/Vector.cpp @@ -657,6 +657,8 @@ void CpuVectorT::copyFrom(const VectorT& src, hl_stream_t stream) { (void*)src.getData(), sizeof(T) * this->getSize(), stream); + // There is a need to add synchronization to ensure that the data is copied. + hl_stream_synchronize(stream); } else { src.copyTo(this); } diff --git a/paddle/math/Vector.h b/paddle/math/Vector.h index 9af6e30c9e13895ad95653a787ec1c1ad77a248f..80b9775fccf10c57bb48145ef56165ec7c86d8b8 100644 --- a/paddle/math/Vector.h +++ b/paddle/math/Vector.h @@ -168,11 +168,11 @@ public: virtual void copyFrom(const VectorT& src) = 0; /** - * If use_gpu, this function will push the copy-task to the specifed-stream - * and return immediately. + * If GpuVector, this function is an asynchronous interface, + * will push the copy-task to the specifed-stream and return immediately. * - * If not use GPU, this function is same as - * the copyFrom(const VectorT& src), which use stream HPPL_STREAM_DEFAULT. + * If CpuVector, this function is an synchronous interface, + * same as the copyFrom(const VectorT& src). */ virtual void copyFrom(const VectorT& src, hl_stream_t stream) = 0; diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 5a0dffe086c4e265d17c79dba435b66c0873e3c7..354f58df39365410ff9aec2576c768e58db9e0d2 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -1127,4 +1127,18 @@ TEST(Matrix, MaxOutFwdBwd) { } } +TEST(CpuMatrix, copyFrom) { + const size_t height = 1000; + const size_t width = 1000; + CpuMatrix cpu(height, width); + GpuMatrix gpu(height, width); + CpuMatrix copy(height, width); + + cpu.randomizeUniform(); + gpu.copyFrom(cpu); + copy.copyFrom(gpu, HPPL_STREAM_DEFAULT); + + TensorCheckEqual(cpu, copy); +} + #endif diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 17342356d6018c0a5dfedb5543d2df1ce33c1b50..d0bedf6ba921ad1e90f737623caf111ed290317f 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -6,4 +6,3 @@ nv_test(cuda_test SRCS cuda_test.cu) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) -cc_test(must_check_test SRCS must_check_test.cc) diff --git a/paddle/platform/must_check.h b/paddle/platform/must_check.h deleted file mode 100644 index 4fcc62afc05b14949fc43266f0d05be1f1b7891a..0000000000000000000000000000000000000000 --- a/paddle/platform/must_check.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -/** - * __must_check macro. It make the function's return value must be used, - * otherwise it will raise a compile warning. And also Paddle treat all compile - * warnings as errors. - */ -#ifdef __GNUC__ -#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 -#define __must_check __attribute__((warn_unused_result)) -#else -#define __must_check -#endif -#else -#define __must_check -#endif diff --git a/paddle/platform/must_check_test.cc b/paddle/platform/must_check_test.cc deleted file mode 100644 index 6ee3ea49acdc4384b5d5df353bfa1290856e982c..0000000000000000000000000000000000000000 --- a/paddle/platform/must_check_test.cc +++ /dev/null @@ -1,10 +0,0 @@ -#include -#include - -int __must_check SomeFunctionMustCheck() { return 0; } - -TEST(MustCheck, all) { - // This line should not be compiled, because the - // return value of SomeFunctionMustCheck marked as __must_check - // SomeFunctionMustCheck(); -} \ No newline at end of file diff --git a/paddle/py_paddle/dataprovider_converter.py b/paddle/py_paddle/dataprovider_converter.py index edc2e0292378fea0cd904d7f017762c1dade6caf..43614b9779d21795f1f274589ea93639e923ce75 100644 --- a/paddle/py_paddle/dataprovider_converter.py +++ b/paddle/py_paddle/dataprovider_converter.py @@ -109,6 +109,10 @@ class DenseScanner(IScanner): if len(self.__shape__) > 3: raise ValueError( "The dimension of input cannot be greater than 3.") + if len(self.__shape__) == 0: + raise ValueError( + "The input should be a vector, please check your input data." + ) self.__dim__ = reduce(lambda x, y: x * y, self.__shape__) if len(self.__shape__) == 1 and self.__dim__ != self.input_type.dim: raise ValueError( @@ -140,7 +144,7 @@ class DenseScanner(IScanner): if len(self.__shape__) > 1: # The last-two dimenstions are the frame height and width. # For example, the layout is CHW for 3-D feature of image. - # The H and W are the fram height and width. + # The H and W are the frame height and width. h, w = self.__shape__[-2:] argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameWidth(self.pos, w) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 2b48e4dc0f875be9a87797fa14885926999a5010..a182e5f4aef9de8c6f20681328d5ba6c0e6944ef 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -31,6 +31,7 @@ Configuring cmake in /paddle/build ... -DWITH_DOC=OFF -DWITH_GPU=${WITH_GPU:-OFF} -DWITH_AVX=${WITH_AVX:-OFF} + -DWITH_GOLANG=${WITH_GOLANG:-OFF} -DWITH_SWIG_PY=ON -DCUDNN_ROOT=/usr/ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} @@ -43,6 +44,7 @@ cmake .. \ -DWITH_DOC=OFF \ -DWITH_GPU=${WITH_GPU:-OFF} \ -DWITH_AVX=${WITH_AVX:-OFF} \ + -DWITH_GOLANG=${WITH_GOLANG:-OFF} \ -DWITH_SWIG_PY=ON \ -DCUDNN_ROOT=/usr/ \ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \ diff --git a/paddle/string/CMakeLists.txt b/paddle/string/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5becf62672d0c606c98ea1a1a4383df97088ab05 --- /dev/null +++ b/paddle/string/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library(stringpiece SRCS piece.cc) +cc_test(stringpiece_test SRCS piece_test.cc DEPS stringpiece glog gflags) + +cc_test(stringprintf_test SRCS printf_test.cc DEPS glog gflags) diff --git a/paddle/string/piece.cc b/paddle/string/piece.cc new file mode 100644 index 0000000000000000000000000000000000000000..b80afdec82d642fd3a8245b96ce1bb2bea17cbae --- /dev/null +++ b/paddle/string/piece.cc @@ -0,0 +1,138 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include "paddle/string/piece.h" + +#include + +#include +#include +#include + +namespace paddle { +namespace string { + +Piece::Piece() : data_(NULL), size_(0) {} + +Piece::Piece(const char* d, size_t n) : data_(d), size_(n) { + if (d == NULL && n != 0) + throw std::invalid_argument("Piece requires len to be 0 for NULL data"); +} + +Piece::Piece(const char* s) : data_(s) { size_ = (s == NULL) ? 0 : strlen(s); } + +Piece::Piece(const std::string& s) : data_(s.data()), size_(s.size()) {} + +char Piece::operator[](size_t n) const { + if (n >= len()) throw std::invalid_argument("index out of Piece length"); + return data_[n]; +} + +int Compare(Piece a, Piece b) { + const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); + int r = memcmp(a.data(), b.data(), min_len); + if (r == 0) { + if (a.len() < b.len()) + return -1; + else if (a.len() > b.len()) + return 1; + } + return r; +} + +bool operator==(Piece x, Piece y) { + return ((x.len() == y.len()) && + (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); +} + +bool operator!=(Piece x, Piece y) { return !(x == y); } + +bool operator<(Piece x, Piece y) { return Compare(x, y) < 0; } +bool operator>(Piece x, Piece y) { return Compare(x, y) > 0; } + +bool operator<=(Piece x, Piece y) { return Compare(x, y) <= 0; } +bool operator>=(Piece x, Piece y) { return Compare(x, y) >= 0; } + +bool HasPrefix(Piece s, Piece x) { + return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); +} + +bool HasSuffix(Piece s, Piece x) { + return ((s.len() >= x.len()) && + (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); +} + +Piece SkipPrefix(Piece s, size_t n) { + if (n > s.len()) + throw std::invalid_argument("Skip distance larger than Piece length"); + return Piece(s.data() + n, s.len() - n); +} + +Piece SkipSuffix(Piece s, size_t n) { + if (n > s.len()) + throw std::invalid_argument("Skip distance larger than Piece length"); + return Piece(s.data(), s.len() - n); +} + +Piece TrimPrefix(Piece s, Piece x) { + return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; +} + +Piece TrimSuffix(Piece s, Piece x) { + return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; +} + +bool Contains(Piece s, Piece sub) { + return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); +} + +size_t Index(Piece s, Piece sub) { + auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); + return e != s.end() ? e - s.data() : Piece::npos; +} + +size_t Find(Piece s, char c, size_t pos) { + if (pos >= s.len()) { + return Piece::npos; + } + const char* result = + reinterpret_cast(memchr(s.data() + pos, c, s.len() - pos)); + return result != nullptr ? result - s.data() : Piece::npos; +} + +size_t RFind(Piece s, char c, size_t pos) { + if (s.len() == 0) return Piece::npos; + for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); + p--) { + if (*p == c) { + return p - s.data(); + } + } + return Piece::npos; +} + +Piece SubStr(Piece s, size_t pos, size_t n) { + if (pos > s.len()) pos = s.len(); + if (n > s.len() - pos) n = s.len() - pos; + return Piece(s.data() + pos, n); +} + +std::ostream& operator<<(std::ostream& o, Piece piece) { + return o << piece.ToString(); +} + +} // namespace string +} // namespace paddle diff --git a/paddle/strings/stringpiece.h b/paddle/string/piece.h similarity index 57% rename from paddle/strings/stringpiece.h rename to paddle/string/piece.h index adff713e86f49349b8f189c1d24584bfc1bb8aa7..db7c3e69804a6a8f0510ba376432fe560ae74442 100644 --- a/paddle/strings/stringpiece.h +++ b/paddle/string/piece.h @@ -20,33 +20,34 @@ #include namespace paddle { +namespace string { -// StringPiece points into a std::string object but doesn't own the +// Piece points into a std::string object but doesn't own the // string. It is for efficient access to strings. Like Go's string -// type. Not that StringPiece doesn't mutate the underlying string, +// type. Not that Piece doesn't mutate the underlying string, // so it is thread-safe given that the underlying string doesn't -// change. Because StringPiece contains a little data members, and +// change. Because Piece contains a little data members, and // its syntax is simple as it doesn't own/manage the string, it is -// cheap to construct StringPieces and pass them around. -class StringPiece { +// cheap to construct Pieces and pass them around. +class Piece { public: static const size_t npos = static_cast(-1); // We provide non-explicit singleton constructors so users can - // pass in a "const char*" or a "string" wherever a "StringPiece" + // pass in a "const char*" or a "string" wherever a "Piece" // is expected. These contructors ensure that if data_ is NULL, // size_ is 0. - StringPiece(); - StringPiece(const char* d, size_t n); - StringPiece(const char* d); - StringPiece(const std::string& s); + Piece(); + Piece(const char* d, size_t n); + Piece(const char* d); + Piece(const std::string& s); const char* data() const { return data_; } size_t len() const { return size_; } char operator[](size_t n) const; - // StringPiece doesn't own the string, so both iterator and const + // Piece doesn't own the string, so both iterator and const // iterator are const char* indeed. typedef const char* const_iterator; typedef const char* iterator; @@ -63,43 +64,44 @@ private: // Intentionally copyable }; -int Compare(StringPiece a, StringPiece b); +int Compare(Piece a, Piece b); -bool operator==(StringPiece x, StringPiece y); -bool operator!=(StringPiece x, StringPiece y); -bool operator<(StringPiece x, StringPiece y); -bool operator>(StringPiece x, StringPiece y); -bool operator<=(StringPiece x, StringPiece y); -bool operator>=(StringPiece x, StringPiece y); +bool operator==(Piece x, Piece y); +bool operator!=(Piece x, Piece y); +bool operator<(Piece x, Piece y); +bool operator>(Piece x, Piece y); +bool operator<=(Piece x, Piece y); +bool operator>=(Piece x, Piece y); -bool HasPrefix(StringPiece s, StringPiece prefix); -bool HasSuffix(StringPiece s, StringPiece suffix); +bool HasPrefix(Piece s, Piece prefix); +bool HasSuffix(Piece s, Piece suffix); -StringPiece SkipPrefix(StringPiece s, size_t n); -StringPiece SkipSuffix(StringPiece s, size_t n); +Piece SkipPrefix(Piece s, size_t n); +Piece SkipSuffix(Piece s, size_t n); // Skip the prefix (or suffix) if it matches with the string. -StringPiece TrimPrefix(StringPiece s, StringPiece prefix); -StringPiece TrimSuffix(StringPiece s, StringPiece suffix); +Piece TrimPrefix(Piece s, Piece prefix); +Piece TrimSuffix(Piece s, Piece suffix); // Returns if s contains sub. Any s except for empty s contains an // empty sub. -bool Contains(StringPiece s, StringPiece sub); +bool Contains(Piece s, Piece sub); // Return the first occurrence of sub in s, or npos. If both s and // sub is empty, it returns npos; otherwise, if only sub is empty, it // returns 0. -size_t Index(StringPiece s, StringPiece sub); +size_t Index(Piece s, Piece sub); // Return the first occurrence of c in s[pos:end], or npos. -size_t Find(StringPiece s, char c, size_t pos); +size_t Find(Piece s, char c, size_t pos); // Search range is [0..pos] inclusive. If pos == npos, search everything. -size_t RFind(StringPiece s, char c, size_t pos); +size_t RFind(Piece s, char c, size_t pos); -StringPiece SubStr(StringPiece s, size_t pos, size_t n); +Piece SubStr(Piece s, size_t pos, size_t n); -// allow StringPiece to be logged -std::ostream& operator<<(std::ostream& o, StringPiece piece); +// allow Piece to be logged +std::ostream& operator<<(std::ostream& o, Piece piece); +} // namespace string } // namespace paddle diff --git a/paddle/strings/stringpiece_test.cc b/paddle/string/piece_test.cc similarity index 77% rename from paddle/strings/stringpiece_test.cc rename to paddle/string/piece_test.cc index 2ba66a04f641c3457efa713383484491a213668f..cf5152ff5a3cb0a2afae0c90b787abf291122fa3 100644 --- a/paddle/strings/stringpiece_test.cc +++ b/paddle/string/piece_test.cc @@ -14,7 +14,7 @@ limitations under the License. */ -#include "paddle/strings/stringpiece.h" +#include "paddle/string/piece.h" #include @@ -22,42 +22,44 @@ TEST(StringPiece, Construct) { { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(NULL, s.data()); EXPECT_EQ(0U, s.len()); } - { EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); } { - paddle::StringPiece s(NULL); + EXPECT_THROW(paddle::string::Piece s(NULL, 10000U), std::invalid_argument); + } + { + paddle::string::Piece s(NULL); EXPECT_EQ(0U, s.len()); } { std::string a; EXPECT_EQ(0U, a.size()); - paddle::StringPiece s(a); + paddle::string::Piece s(a); EXPECT_EQ(0U, s.len()); } } TEST(StringPiece, CopyAndAssign) { - paddle::StringPiece empty; + paddle::string::Piece empty; EXPECT_EQ(0U, empty.len()); - paddle::StringPiece a("hello"); - paddle::StringPiece b = a; + paddle::string::Piece a("hello"); + paddle::string::Piece b = a; EXPECT_EQ(b.len(), strlen("hello")); EXPECT_EQ(a, b); std::string storage("hello"); - paddle::StringPiece c(storage); + paddle::string::Piece c(storage); EXPECT_EQ(a, c); EXPECT_NE(a.data(), c.data()); } TEST(StringPiece, Compare) { { - paddle::StringPiece a("hello"); - paddle::StringPiece b("world"); + paddle::string::Piece a("hello"); + paddle::string::Piece b("world"); EXPECT_TRUE(a != b); EXPECT_FALSE(a == b); EXPECT_TRUE(a < b); @@ -68,7 +70,7 @@ TEST(StringPiece, Compare) { EXPECT_GT(Compare(b, a), 0); } { - paddle::StringPiece a, b; + paddle::string::Piece a, b; EXPECT_TRUE(a == b); EXPECT_FALSE(a != b); EXPECT_FALSE(a < b); @@ -82,31 +84,31 @@ TEST(StringPiece, Compare) { TEST(StringPiece, ToString) { { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(std::string(""), s.ToString()); } { - paddle::StringPiece s(NULL); + paddle::string::Piece s(NULL); EXPECT_EQ(std::string(""), s.ToString()); } { - paddle::StringPiece s("hello"); + paddle::string::Piece s("hello"); EXPECT_EQ(std::string("hello"), s.ToString()); } } TEST(StringPiece, HasPrefixSuffix) { - using paddle::HasPrefix; - using paddle::HasSuffix; + using paddle::string::HasPrefix; + using paddle::string::HasSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_FALSE(HasPrefix(s, "something")); EXPECT_TRUE(HasPrefix(s, "")); EXPECT_FALSE(HasSuffix(s, "something")); EXPECT_TRUE(HasSuffix(s, "")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_TRUE(HasPrefix(s, "")); EXPECT_TRUE(HasPrefix(s, "a")); EXPECT_TRUE(HasPrefix(s, "ap")); @@ -120,10 +122,10 @@ TEST(StringPiece, HasPrefixSuffix) { } TEST(StringPiece, SkipPrefixSuffix) { - using paddle::SkipPrefix; - using paddle::SkipSuffix; + using paddle::string::SkipPrefix; + using paddle::string::SkipSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", SkipPrefix(s, 0)); EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument); @@ -131,7 +133,7 @@ TEST(StringPiece, SkipPrefixSuffix) { EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("app", SkipPrefix(s, 0)); EXPECT_EQ("pp", SkipPrefix(s, 1)); EXPECT_EQ("p", SkipPrefix(s, 2)); @@ -147,10 +149,10 @@ TEST(StringPiece, SkipPrefixSuffix) { } TEST(StringPiece, TrimPrefixSuffix) { - using paddle::TrimPrefix; - using paddle::TrimSuffix; + using paddle::string::TrimPrefix; + using paddle::string::TrimSuffix; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", TrimPrefix(s, "")); EXPECT_EQ("", TrimPrefix(s, "something")); @@ -158,7 +160,7 @@ TEST(StringPiece, TrimPrefixSuffix) { EXPECT_EQ("", TrimSuffix(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("app", TrimPrefix(s, "")); EXPECT_EQ("pp", TrimPrefix(s, "a")); EXPECT_EQ("p", TrimPrefix(s, "ap")); @@ -174,14 +176,14 @@ TEST(StringPiece, TrimPrefixSuffix) { } TEST(StringPiece, Contains) { - using paddle::Contains; + using paddle::string::Contains; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_FALSE(Contains(s, "")); EXPECT_FALSE(Contains(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_TRUE(Contains(s, "")); EXPECT_TRUE(Contains(s, "a")); EXPECT_TRUE(Contains(s, "p")); @@ -193,15 +195,15 @@ TEST(StringPiece, Contains) { } TEST(StringPiece, Index) { - using paddle::Index; - auto npos = paddle::StringPiece::npos; + using paddle::string::Index; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, Index(s, "")); EXPECT_EQ(npos, Index(s, "something")); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(0U, Index(s, "")); EXPECT_EQ(0U, Index(s, "a")); EXPECT_EQ(1U, Index(s, "p")); @@ -213,14 +215,14 @@ TEST(StringPiece, Index) { } TEST(StringPiece, Find) { - using paddle::Find; - auto npos = paddle::StringPiece::npos; + using paddle::string::Find; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, Find(s, 'a', 0U)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(0U, Find(s, 'a', 0U)); EXPECT_EQ(1U, Find(s, 'p', 0U)); EXPECT_EQ(1U, Find(s, 'p', 1U)); @@ -230,14 +232,14 @@ TEST(StringPiece, Find) { } TEST(StringPiece, RFind) { - using paddle::RFind; - auto npos = paddle::StringPiece::npos; + using paddle::string::RFind; + auto npos = paddle::string::Piece::npos; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ(npos, RFind(s, 'a', 0U)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ(2U, RFind(s, 'p', 2U)); EXPECT_EQ(0U, RFind(s, 'a', 2U)); EXPECT_EQ(1U, RFind(s, 'p', 1U)); @@ -247,15 +249,15 @@ TEST(StringPiece, RFind) { } TEST(StringPiece, SubStr) { - using paddle::SubStr; + using paddle::string::SubStr; { - paddle::StringPiece s; + paddle::string::Piece s; EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 0, 1)); EXPECT_EQ("", SubStr(s, 1, 0)); } { - paddle::StringPiece s("app"); + paddle::string::Piece s("app"); EXPECT_EQ("", SubStr(s, 0, 0)); EXPECT_EQ("", SubStr(s, 1, 0)); EXPECT_EQ("", SubStr(s, 2, 0)); @@ -279,15 +281,15 @@ TEST(StringPiece, SubStr) { } TEST(StringPiece, StreamOutput) { - using paddle::StringPiece; + using paddle::string::Piece; std::stringstream o; - o << StringPiece(); + o << paddle::string::Piece(); EXPECT_EQ("", o.str()); - o << StringPiece("hello"); + o << paddle::string::Piece("hello"); EXPECT_EQ("hello", o.str()); - o << StringPiece(); + o << paddle::string::Piece(); EXPECT_EQ("hello", o.str()); } diff --git a/paddle/string/printf.h b/paddle/string/printf.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5ce63a8e8dfe77962ff1e7415911d381a28aac --- /dev/null +++ b/paddle/string/printf.h @@ -0,0 +1,99 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +// Compared with std::stringstream, there are primary purpose of +// string::Printf: +// +// 1. Type-safe printing, with why and how explained in +// http://www.drdobbs.com/stringprintf-a-typesafe-printf-family-fo/184401999. +// Implementation includes +// +// https://github.com/c42f/tinyformat +// boost::format +// std::stringstream +// +// std::stringstream is not convenient enough in many cases. For example: +// +// std::cout << std::setprecision(2) << std::fixed << 1.23456 << "\n"; +// +// boost::format is the most convenient one. We can have +// +// std::cout << format("%2% %1%") % 36 % 77; +// +// or +// +// format fmter("%2% %1%"); +// fmter % 36; fmter % 77; +// std::cout << fmter.c_str(); +// +// But the overloading of % might be overkilling and it would be +// more efficient if it can write to std::cout directly. +// +// tinyformat has an interface compatible with the C-printf style, +// and it can writes to a stream or returns a std::string: +// +// std::cout << tfm::printf( +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// or +// +// tfm::format(std::cout, +// "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// 2. High-performance -- most printed strings are not too long and +// doens't need dynamic memory allocation. Many StringPrintf +// implementations doesn't enforce type-safe, but are +// high-performance, including +// +// https://developers.google.com/optimization/reference/base/stringprintf/ +// https://github.com/adobe/chromium/blob/master/base/stringprintf.h +// https://github.com/google/protobuf/blob/master/src/google/protobuf/stubs/stringprintf.h +// +// According to +// https://github.com/c42f/tinyformat#compile-time-and-code-bloat, +// boost::format runs too slow and results in large executable binary +// files. So here we port tinyformat. + +#pragma once + +#include +#include +#include "paddle/string/tinyformat/tinyformat.h" // https://github.com/c42f/tinyformat + +namespace paddle { +namespace string { + +template +void Fprintf(std::ostream& out, const char* fmt, const Args&... args) { + tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...)); +} + +template +std::string Sprintf(const char* fmt, const Args&... args) { + std::ostringstream oss; + Fprintf(oss, fmt, args...); + return oss.str(); +} + +template +void Printf(const char* fmt, const Args&... args) { + Fprintf(std::cout, fmt, args...); +} + +} // namespace string +} // namespace paddle diff --git a/paddle/string/printf_test.cc b/paddle/string/printf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8f2454165d741b3937f908dcfd87501940750d5 --- /dev/null +++ b/paddle/string/printf_test.cc @@ -0,0 +1,16 @@ +#include "paddle/string/printf.h" + +#include + +#include "gtest/gtest.h" + +TEST(StringPrintf, StringPrintf) { + std::string weekday = "Wednesday"; + const char* month = "July"; + size_t day = 27; + long hour = 14; + int min = 44; + EXPECT_EQ(std::string("Wednesday, July 27, 14:44"), + paddle::string::Sprintf( + "%s, %s %d, %.2d:%.2d", weekday, month, day, hour, min)); +} diff --git a/paddle/string/tinyformat/tinyformat.h b/paddle/string/tinyformat/tinyformat.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e5e0160fb018b813c1dade727da2861a295147 --- /dev/null +++ b/paddle/string/tinyformat/tinyformat.h @@ -0,0 +1,902 @@ +// tinyformat.h +// Copyright (C) 2011, Chris Foster [chris42f (at) gmail (d0t) com] +// +// Boost Software License - Version 1.0 +// +// Permission is hereby granted, free of charge, to any person or organization +// obtaining a copy of the software and accompanying documentation covered by +// this license (the "Software") to use, reproduce, display, distribute, +// execute, and transmit the Software, and to prepare derivative works of the +// Software, and to permit third-parties to whom the Software is furnished to +// do so, all subject to the following: +// +// The copyright notices in the Software and this entire statement, including +// the above license grant, this restriction and the following disclaimer, +// must be included in all copies of the Software, in whole or in part, and +// all derivative works of the Software, unless such copies or derivative +// works are solely in the form of machine-executable object code generated by +// a source language processor. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +//------------------------------------------------------------------------------ +// Tinyformat: A minimal type safe printf replacement +// +// tinyformat.h is a type safe printf replacement library in a single C++ +// header file. Design goals include: +// +// * Type safety and extensibility for user defined types. +// * C99 printf() compatibility, to the extent possible using std::ostream +// * Simplicity and minimalism. A single header file to include and distribute +// with your projects. +// * Augment rather than replace the standard stream formatting mechanism +// * C++98 support, with optional C++11 niceties +// +// +// Main interface example usage +// ---------------------------- +// +// To print a date to std::cout: +// +// std::string weekday = "Wednesday"; +// const char* month = "July"; +// size_t day = 27; +// long hour = 14; +// int min = 44; +// +// tfm::printf("%s, %s %d, %.2d:%.2d\n", weekday, month, day, hour, min); +// +// The strange types here emphasize the type safety of the interface; it is +// possible to print a std::string using the "%s" conversion, and a +// size_t using the "%d" conversion. A similar result could be achieved +// using either of the tfm::format() functions. One prints on a user provided +// stream: +// +// tfm::format(std::cerr, "%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// +// The other returns a std::string: +// +// std::string date = tfm::format("%s, %s %d, %.2d:%.2d\n", +// weekday, month, day, hour, min); +// std::cout << date; +// +// These are the three primary interface functions. There is also a +// convenience function printfln() which appends a newline to the usual result +// of printf() for super simple logging. +// +// +// User defined format functions +// ----------------------------- +// +// Simulating variadic templates in C++98 is pretty painful since it requires +// writing out the same function for each desired number of arguments. To make +// this bearable tinyformat comes with a set of macros which are used +// internally to generate the API, but which may also be used in user code. +// +// The three macros TINYFORMAT_ARGTYPES(n), TINYFORMAT_VARARGS(n) and +// TINYFORMAT_PASSARGS(n) will generate a list of n argument types, +// type/name pairs and argument names respectively when called with an integer +// n between 1 and 16. We can use these to define a macro which generates the +// desired user defined function with n arguments. To generate all 16 user +// defined function bodies, use the macro TINYFORMAT_FOREACH_ARGNUM. For an +// example, see the implementation of printf() at the end of the source file. +// +// Sometimes it's useful to be able to pass a list of format arguments through +// to a non-template function. The FormatList class is provided as a way to do +// this by storing the argument list in a type-opaque way. Continuing the +// example from above, we construct a FormatList using makeFormatList(): +// +// FormatListRef formatList = tfm::makeFormatList(weekday, month, day, hour, +// min); +// +// The format list can now be passed into any non-template function and used +// via a call to the vformat() function: +// +// tfm::vformat(std::cout, "%s, %s %d, %.2d:%.2d\n", formatList); +// +// +// Additional API information +// -------------------------- +// +// Error handling: Define TINYFORMAT_ERROR to customize the error handling for +// format strings which are unsupported or have the wrong number of format +// specifiers (calls assert() by default). +// +// User defined types: Uses operator<< for user defined types by default. +// Overload formatValue() for more control. + +#pragma once + +#include +#include +#include +#include + +namespace paddle { +namespace string { +namespace tinyformat { + +#ifndef TINYFORMAT_ERROR +#define TINYFORMAT_ERROR(reason) assert(0 && reason) +#endif + +//------------------------------------------------------------------------------ +namespace detail { + +// Test whether type T1 is convertible to type T2 +template +struct is_convertible { +private: + // two types of different size + struct fail { + char dummy[2]; + }; + struct succeed { + char dummy; + }; + // Try to convert a T1 to a T2 by plugging into tryConvert + static fail tryConvert(...); + static succeed tryConvert(const T2 &); + static const T1 &makeT1(); + +public: + // Standard trick: the (...) version of tryConvert will be chosen from + // the overload set only if the version taking a T2 doesn't match. + // Then we compare the sizes of the return types to check which + // function matched. Very neat, in a disgusting kind of way :) + static const bool value = sizeof(tryConvert(makeT1())) == sizeof(succeed); +}; + +// Format the value by casting to type fmtT. This default implementation +// should never be called. +template ::value> +struct formatValueAsType { + static void invoke(std::ostream & /*out*/, const T & /*value*/) { assert(0); } +}; +// Specialized version for types that can actually be converted to fmtT, as +// indicated by the "convertible" template parameter. +template +struct formatValueAsType { + static void invoke(std::ostream &out, const T &value) { + out << static_cast(value); + } +}; + +// Convert an arbitrary type to integer. The version with convertible=false +// throws an error. +template ::value> +struct convertToInt { + static int invoke(const T & /*value*/) { + TINYFORMAT_ERROR( + "tinyformat: Cannot convert from argument type to " + "integer for use as variable width or precision"); + return 0; + } +}; +// Specialization for convertToInt when conversion is possible +template +struct convertToInt { + static int invoke(const T &value) { return static_cast(value); } +}; + +// Format at most ntrunc characters to the given stream. +template +inline void formatTruncated(std::ostream &out, const T &value, int ntrunc) { + std::ostringstream tmp; + tmp << value; + std::string result = tmp.str(); + out.write(result.c_str(), + (std::min)(ntrunc, static_cast(result.size()))); +} +#define TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(type) \ + inline void formatTruncated(std::ostream &out, type *value, int ntrunc) { \ + std::streamsize len = 0; \ + while (len < ntrunc && value[len] != 0) ++len; \ + out.write(value, len); \ + } +// Overload for const char* and char*. Could overload for signed & unsigned +// char too, but these are technically unneeded for printf compatibility. +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(const char) +TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR(char) +#undef TINYFORMAT_DEFINE_FORMAT_TRUNCATED_CSTR + +} // namespace detail + +//------------------------------------------------------------------------------ +// Variable formatting functions. May be overridden for user-defined types if +// desired. + +/// Format a value into a stream, delegating to operator<< by default. +/// +/// Users may override this for their own types. When this function is called, +/// the stream flags will have been modified according to the format string. +/// The format specification is provided in the range [fmtBegin, fmtEnd). For +/// truncating conversions, ntrunc is set to the desired maximum number of +/// characters, for example "%.7s" calls formatValue with ntrunc = 7. +/// +/// By default, formatValue() uses the usual stream insertion operator +/// operator<< to format the type T, with special cases for the %c and %p +/// conversions. +template +inline void formatValue(std::ostream &out, + const char * /*fmtBegin*/, + const char *fmtEnd, + int ntrunc, + const T &value) { + // The mess here is to support the %c and %p conversions: if these + // conversions are active we try to convert the type to a char or const + // void* respectively and format that instead of the value itself. For the + // %p conversion it's important to avoid dereferencing the pointer, which + // could otherwise lead to a crash when printing a dangling (const char*). + const bool canConvertToChar = detail::is_convertible::value; + const bool canConvertToVoidPtr = + detail::is_convertible::value; + if (canConvertToChar && *(fmtEnd - 1) == 'c') + detail::formatValueAsType::invoke(out, value); + else if (canConvertToVoidPtr && *(fmtEnd - 1) == 'p') + detail::formatValueAsType::invoke(out, value); + else if (ntrunc >= 0) { + // Take care not to overread C strings in truncating conversions like + // "%.4s" where at most 4 characters may be read. + detail::formatTruncated(out, value, ntrunc); + } else + out << value; +} + +// Overloaded version for char types to support printing as an integer +#define TINYFORMAT_DEFINE_FORMATVALUE_CHAR(charType) \ + inline void formatValue(std::ostream &out, \ + const char * /*fmtBegin*/, \ + const char *fmtEnd, \ + int /**/, \ + charType value) { \ + switch (*(fmtEnd - 1)) { \ + case 'u': \ + case 'd': \ + case 'i': \ + case 'o': \ + case 'X': \ + case 'x': \ + out << static_cast(value); \ + break; \ + default: \ + out << value; \ + break; \ + } \ + } +// per 3.9.1: char, signed char and unsigned char are all distinct types +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(signed char) +TINYFORMAT_DEFINE_FORMATVALUE_CHAR(unsigned char) +#undef TINYFORMAT_DEFINE_FORMATVALUE_CHAR + +//------------------------------------------------------------------------------ +// Tools for emulating variadic templates in C++98. The basic idea here is +// stolen from the boost preprocessor metaprogramming library and cut down to +// be just general enough for what we need. + +#define TINYFORMAT_ARGTYPES(n) TINYFORMAT_ARGTYPES_##n +#define TINYFORMAT_VARARGS(n) TINYFORMAT_VARARGS_##n +#define TINYFORMAT_PASSARGS(n) TINYFORMAT_PASSARGS_##n +#define TINYFORMAT_PASSARGS_TAIL(n) TINYFORMAT_PASSARGS_TAIL_##n + +// To keep it as transparent as possible, the macros below have been generated +// using python via the excellent cog.py code generation script. This avoids +// the need for a bunch of complex (but more general) preprocessor tricks as +// used in boost.preprocessor. +// +// To rerun the code generation in place, use `cog.py -r tinyformat.h` +// (see http://nedbatchelder.com/code/cog). Alternatively you can just create +// extra versions by hand. + +/*[[[cog +maxParams = 16 + +def makeCommaSepLists(lineTemplate, elemTemplate, startInd=1): + for j in range(startInd,maxParams+1): + list = ', '.join([elemTemplate % {'i':i} for i in range(startInd,j+1)]) + cog.outl(lineTemplate % {'j':j, 'list':list}) + +makeCommaSepLists('#define TINYFORMAT_ARGTYPES_%(j)d %(list)s', + 'class T%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_VARARGS_%(j)d %(list)s', + 'const T%(i)d& v%(i)d') + +cog.outl() +makeCommaSepLists('#define TINYFORMAT_PASSARGS_%(j)d %(list)s', 'v%(i)d') + +cog.outl() +cog.outl('#define TINYFORMAT_PASSARGS_TAIL_1') +makeCommaSepLists('#define TINYFORMAT_PASSARGS_TAIL_%(j)d , %(list)s', + 'v%(i)d', startInd = 2) + +cog.outl() +cog.outl('#define TINYFORMAT_FOREACH_ARGNUM(m) \\\n ' + + ' '.join(['m(%d)' % (j,) for j in range(1,maxParams+1)])) +]]]*/ +#define TINYFORMAT_ARGTYPES_1 class T1 +#define TINYFORMAT_ARGTYPES_2 class T1, class T2 +#define TINYFORMAT_ARGTYPES_3 class T1, class T2, class T3 +#define TINYFORMAT_ARGTYPES_4 class T1, class T2, class T3, class T4 +#define TINYFORMAT_ARGTYPES_5 class T1, class T2, class T3, class T4, class T5 +#define TINYFORMAT_ARGTYPES_6 \ + class T1, class T2, class T3, class T4, class T5, class T6 +#define TINYFORMAT_ARGTYPES_7 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7 +#define TINYFORMAT_ARGTYPES_8 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, class T8 +#define TINYFORMAT_ARGTYPES_9 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9 +#define TINYFORMAT_ARGTYPES_10 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10 +#define TINYFORMAT_ARGTYPES_11 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11 +#define TINYFORMAT_ARGTYPES_12 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12 +#define TINYFORMAT_ARGTYPES_13 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13 +#define TINYFORMAT_ARGTYPES_14 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14 +#define TINYFORMAT_ARGTYPES_15 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15 +#define TINYFORMAT_ARGTYPES_16 \ + class T1, class T2, class T3, class T4, class T5, class T6, class T7, \ + class T8, class T9, class T10, class T11, class T12, class T13, \ + class T14, class T15, class T16 + +#define TINYFORMAT_VARARGS_1 const T1 &v1 +#define TINYFORMAT_VARARGS_2 const T1 &v1, const T2 &v2 +#define TINYFORMAT_VARARGS_3 const T1 &v1, const T2 &v2, const T3 &v3 +#define TINYFORMAT_VARARGS_4 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4 +#define TINYFORMAT_VARARGS_5 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5 +#define TINYFORMAT_VARARGS_6 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6 +#define TINYFORMAT_VARARGS_7 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7 +#define TINYFORMAT_VARARGS_8 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8 +#define TINYFORMAT_VARARGS_9 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9 +#define TINYFORMAT_VARARGS_10 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10 +#define TINYFORMAT_VARARGS_11 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11 +#define TINYFORMAT_VARARGS_12 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12 +#define TINYFORMAT_VARARGS_13 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13 +#define TINYFORMAT_VARARGS_14 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14 +#define TINYFORMAT_VARARGS_15 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15 +#define TINYFORMAT_VARARGS_16 \ + const T1 &v1, const T2 &v2, const T3 &v3, const T4 &v4, const T5 &v5, \ + const T6 &v6, const T7 &v7, const T8 &v8, const T9 &v9, const T10 &v10, \ + const T11 &v11, const T12 &v12, const T13 &v13, const T14 &v14, \ + const T15 &v15, const T16 &v16 + +#define TINYFORMAT_PASSARGS_1 v1 +#define TINYFORMAT_PASSARGS_2 v1, v2 +#define TINYFORMAT_PASSARGS_3 v1, v2, v3 +#define TINYFORMAT_PASSARGS_4 v1, v2, v3, v4 +#define TINYFORMAT_PASSARGS_5 v1, v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_6 v1, v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_7 v1, v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_8 v1, v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_9 v1, v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_10 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_11 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_12 v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_13 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_14 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_15 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_16 \ + v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_PASSARGS_TAIL_1 +#define TINYFORMAT_PASSARGS_TAIL_2 , v2 +#define TINYFORMAT_PASSARGS_TAIL_3 , v2, v3 +#define TINYFORMAT_PASSARGS_TAIL_4 , v2, v3, v4 +#define TINYFORMAT_PASSARGS_TAIL_5 , v2, v3, v4, v5 +#define TINYFORMAT_PASSARGS_TAIL_6 , v2, v3, v4, v5, v6 +#define TINYFORMAT_PASSARGS_TAIL_7 , v2, v3, v4, v5, v6, v7 +#define TINYFORMAT_PASSARGS_TAIL_8 , v2, v3, v4, v5, v6, v7, v8 +#define TINYFORMAT_PASSARGS_TAIL_9 , v2, v3, v4, v5, v6, v7, v8, v9 +#define TINYFORMAT_PASSARGS_TAIL_10 , v2, v3, v4, v5, v6, v7, v8, v9, v10 +#define TINYFORMAT_PASSARGS_TAIL_11 , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11 +#define TINYFORMAT_PASSARGS_TAIL_12 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12 +#define TINYFORMAT_PASSARGS_TAIL_13 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13 +#define TINYFORMAT_PASSARGS_TAIL_14 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14 +#define TINYFORMAT_PASSARGS_TAIL_15 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15 +#define TINYFORMAT_PASSARGS_TAIL_16 \ + , v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16 + +#define TINYFORMAT_FOREACH_ARGNUM(m) \ + m(1) m(2) m(3) m(4) m(5) m(6) m(7) m(8) m(9) m(10) m(11) m(12) m(13) m(14) \ + m(15) m(16) +//[[[end]]] + +namespace detail { + +// Type-opaque holder for an argument to format(), with associated actions on +// the type held as explicit function pointers. This allows FormatArg's for +// each argument to be allocated as a homogenous array inside FormatList +// whereas a naive implementation based on inheritance does not. +class FormatArg { +public: + FormatArg() {} + + template + FormatArg(const T &value) + : m_value(static_cast(&value)), + m_formatImpl(&formatImpl), + m_toIntImpl(&toIntImpl) {} + + void format(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc) const { + m_formatImpl(out, fmtBegin, fmtEnd, ntrunc, m_value); + } + + int toInt() const { return m_toIntImpl(m_value); } + +private: + template + static void formatImpl(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value) { + formatValue(out, fmtBegin, fmtEnd, ntrunc, *static_cast(value)); + } + + template + static int toIntImpl(const void *value) { + return convertToInt::invoke(*static_cast(value)); + } + + const void *m_value; + void (*m_formatImpl)(std::ostream &out, + const char *fmtBegin, + const char *fmtEnd, + int ntrunc, + const void *value); + int (*m_toIntImpl)(const void *value); +}; + +// Parse and return an integer from the string c, as atoi() +// On return, c is set to one past the end of the integer. +inline int parseIntAndAdvance(const char *&c) { + int i = 0; + for (; *c >= '0' && *c <= '9'; ++c) i = 10 * i + (*c - '0'); + return i; +} + +// Print literal part of format string and return next format spec +// position. +// +// Skips over any occurrences of '%%', printing a literal '%' to the +// output. The position of the first % character of the next +// nontrivial format spec is returned, or the end of string. +inline const char *printFormatStringLiteral(std::ostream &out, + const char *fmt) { + const char *c = fmt; + for (;; ++c) { + switch (*c) { + case '\0': + out.write(fmt, c - fmt); + return c; + case '%': + out.write(fmt, c - fmt); + if (*(c + 1) != '%') return c; + // for "%%", tack trailing % onto next literal section. + fmt = ++c; + break; + default: + break; + } + } +} + +// Parse a format string and set the stream state accordingly. +// +// The format mini-language recognized here is meant to be the one from C99, +// with the form "%[flags][width][.precision][length]type". +// +// Formatting options which can't be natively represented using the ostream +// state are returned in spacePadPositive (for space padded positive numbers) +// and ntrunc (for truncating conversions). argIndex is incremented if +// necessary to pull out variable width and precision . The function returns a +// pointer to the character after the end of the current format spec. +inline const char *streamStateFromFormat(std::ostream &out, + bool &spacePadPositive, + int &ntrunc, + const char *fmtStart, + const detail::FormatArg *formatters, + int &argIndex, + int numFormatters) { + if (*fmtStart != '%') { + TINYFORMAT_ERROR( + "tinyformat: Not enough conversion specifiers in format string"); + return fmtStart; + } + // Reset stream state to defaults. + out.width(0); + out.precision(6); + out.fill(' '); + // Reset most flags; ignore irrelevant unitbuf & skipws. + out.unsetf(std::ios::adjustfield | std::ios::basefield | + std::ios::floatfield | std::ios::showbase | std::ios::boolalpha | + std::ios::showpoint | std::ios::showpos | std::ios::uppercase); + bool precisionSet = false; + bool widthSet = false; + int widthExtra = 0; + const char *c = fmtStart + 1; + // 1) Parse flags + for (;; ++c) { + switch (*c) { + case '#': + out.setf(std::ios::showpoint | std::ios::showbase); + continue; + case '0': + // overridden by left alignment ('-' flag) + if (!(out.flags() & std::ios::left)) { + // Use internal padding so that numeric values are + // formatted correctly, eg -00010 rather than 000-10 + out.fill('0'); + out.setf(std::ios::internal, std::ios::adjustfield); + } + continue; + case '-': + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + continue; + case ' ': + // overridden by show positive sign, '+' flag. + if (!(out.flags() & std::ios::showpos)) spacePadPositive = true; + continue; + case '+': + out.setf(std::ios::showpos); + spacePadPositive = false; + widthExtra = 1; + continue; + default: + break; + } + break; + } + // 2) Parse width + if (*c >= '0' && *c <= '9') { + widthSet = true; + out.width(parseIntAndAdvance(c)); + } + if (*c == '*') { + widthSet = true; + int width = 0; + if (argIndex < numFormatters) + width = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable width"); + if (width < 0) { + // negative widths correspond to '-' flag set + out.fill(' '); + out.setf(std::ios::left, std::ios::adjustfield); + width = -width; + } + out.width(width); + ++c; + } + // 3) Parse precision + if (*c == '.') { + ++c; + int precision = 0; + if (*c == '*') { + ++c; + if (argIndex < numFormatters) + precision = formatters[argIndex++].toInt(); + else + TINYFORMAT_ERROR( + "tinyformat: Not enough arguments to read variable precision"); + } else { + if (*c >= '0' && *c <= '9') + precision = parseIntAndAdvance(c); + else if (*c == '-') // negative precisions ignored, treated as zero. + parseIntAndAdvance(++c); + } + out.precision(precision); + precisionSet = true; + } + // 4) Ignore any C99 length modifier + while (*c == 'l' || *c == 'h' || *c == 'L' || *c == 'j' || *c == 'z' || + *c == 't') + ++c; + // 5) We're up to the conversion specifier character. + // Set stream flags based on conversion specifier (thanks to the + // boost::format class for forging the way here). + bool intConversion = false; + switch (*c) { + case 'u': + case 'd': + case 'i': + out.setf(std::ios::dec, std::ios::basefield); + intConversion = true; + break; + case 'o': + out.setf(std::ios::oct, std::ios::basefield); + intConversion = true; + break; + case 'X': + out.setf(std::ios::uppercase); + case 'x': + case 'p': + out.setf(std::ios::hex, std::ios::basefield); + intConversion = true; + break; + case 'E': + out.setf(std::ios::uppercase); + case 'e': + out.setf(std::ios::scientific, std::ios::floatfield); + out.setf(std::ios::dec, std::ios::basefield); + break; + case 'F': + out.setf(std::ios::uppercase); + case 'f': + out.setf(std::ios::fixed, std::ios::floatfield); + break; + case 'G': + out.setf(std::ios::uppercase); + case 'g': + out.setf(std::ios::dec, std::ios::basefield); + // As in boost::format, let stream decide float format. + out.flags(out.flags() & ~std::ios::floatfield); + break; + case 'a': + case 'A': + TINYFORMAT_ERROR( + "tinyformat: the %a and %A conversion specs " + "are not supported"); + break; + case 'c': + // Handled as special case inside formatValue() + break; + case 's': + if (precisionSet) ntrunc = static_cast(out.precision()); + // Make %s print booleans as "true" and "false" + out.setf(std::ios::boolalpha); + break; + case 'n': + // Not supported - will cause problems! + TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported"); + break; + case '\0': + TINYFORMAT_ERROR( + "tinyformat: Conversion spec incorrectly " + "terminated by end of string"); + return c; + default: + break; + } + if (intConversion && precisionSet && !widthSet) { + // "precision" for integers gives the minimum number of digits (to be + // padded with zeros on the left). This isn't really supported by the + // iostreams, but we can approximately simulate it with the width if + // the width isn't otherwise used. + out.width(out.precision() + widthExtra); + out.setf(std::ios::internal, std::ios::adjustfield); + out.fill('0'); + } + return c + 1; +} + +//------------------------------------------------------------------------------ +inline void formatImpl(std::ostream &out, + const char *fmt, + const detail::FormatArg *formatters, + int numFormatters) { + // Saved stream state + std::streamsize origWidth = out.width(); + std::streamsize origPrecision = out.precision(); + std::ios::fmtflags origFlags = out.flags(); + char origFill = out.fill(); + + for (int argIndex = 0; argIndex < numFormatters; ++argIndex) { + // Parse the format string + fmt = printFormatStringLiteral(out, fmt); + bool spacePadPositive = false; + int ntrunc = -1; + const char *fmtEnd = streamStateFromFormat(out, + spacePadPositive, + ntrunc, + fmt, + formatters, + argIndex, + numFormatters); + if (argIndex >= numFormatters) { + // Check args remain after reading any variable width/precision + TINYFORMAT_ERROR("tinyformat: Not enough format arguments"); + return; + } + const FormatArg &arg = formatters[argIndex]; + // Format the arg into the stream. + if (!spacePadPositive) + arg.format(out, fmt, fmtEnd, ntrunc); + else { + // The following is a special case with no direct correspondence + // between stream formatting and the printf() behaviour. Simulate + // it crudely by formatting into a temporary string stream and + // munging the resulting string. + std::ostringstream tmpStream; + tmpStream.copyfmt(out); + tmpStream.setf(std::ios::showpos); + arg.format(tmpStream, fmt, fmtEnd, ntrunc); + std::string result = tmpStream.str(); // allocates... yuck. + for (size_t i = 0, iend = result.size(); i < iend; ++i) + if (result[i] == '+') result[i] = ' '; + out << result; + } + fmt = fmtEnd; + } + + // Print remaining part of format string. + fmt = printFormatStringLiteral(out, fmt); + if (*fmt != '\0') + TINYFORMAT_ERROR( + "tinyformat: Too many conversion specifiers in format string"); + + // Restore stream state + out.width(origWidth); + out.precision(origPrecision); + out.flags(origFlags); + out.fill(origFill); +} + +} // namespace detail + +/// List of template arguments format(), held in a type-opaque way. +/// +/// A const reference to FormatList (typedef'd as FormatListRef) may be +/// conveniently used to pass arguments to non-template functions: All type +/// information has been stripped from the arguments, leaving just enough of a +/// common interface to perform formatting as required. +class FormatList { +public: + FormatList(detail::FormatArg *formatters, int N) + : m_formatters(formatters), m_N(N) {} + + friend void vformat(std::ostream &out, + const char *fmt, + const FormatList &list); + +private: + const detail::FormatArg *m_formatters; + int m_N; +}; + +/// Reference to type-opaque format list for passing to vformat() +typedef const FormatList &FormatListRef; + +namespace detail { + +// Format list subclass with fixed storage to avoid dynamic allocation +template +class FormatListN : public FormatList { +public: + template + FormatListN(const Args &... args) + : FormatList(&m_formatterStore[0], N), + m_formatterStore{FormatArg(args)...} { + static_assert(sizeof...(args) == N, "Number of args must be N"); + } + +private: + FormatArg m_formatterStore[N]; +}; + +// Special 0-arg version - MSVC says zero-sized C array in struct is nonstandard +template <> +class FormatListN<0> : public FormatList { +public: + FormatListN() : FormatList(0, 0) {} +}; + +} // namespace detail + +//------------------------------------------------------------------------------ +// Primary API functions + +/// Make type-agnostic format list from list of template arguments. +/// +/// The exact return type of this function is an implementation detail and +/// shouldn't be relied upon. Instead it should be stored as a FormatListRef: +/// +/// FormatListRef formatList = makeFormatList( /*...*/ ); +template +detail::FormatListN makeFormatList(const Args &... args) { + return detail::FormatListN(args...); +} + +/// Format list of arguments to the stream according to the given format string. +/// +/// The name vformat() is chosen for the semantic similarity to vprintf(): the +/// list of format arguments is held in a single function argument. +inline void vformat(std::ostream &out, const char *fmt, FormatListRef list) { + detail::formatImpl(out, fmt, list.m_formatters, list.m_N); +} + +/// Format list of arguments to the stream according to given format string. +template +void format(std::ostream &out, const char *fmt, const Args &... args) { + vformat(out, fmt, makeFormatList(args...)); +} + +/// Format list of arguments according to the given format string and return +/// the result as a string. +template +std::string format(const char *fmt, const Args &... args) { + std::ostringstream oss; + format(oss, fmt, args...); + return oss.str(); +} + +/// Format list of arguments to std::cout, according to the given format string +template +void printf(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); +} + +template +void printfln(const char *fmt, const Args &... args) { + format(std::cout, fmt, args...); + std::cout << '\n'; +} + +} // namespace tinyformat +} // namespace string +} // namespace paddle diff --git a/paddle/strings/CMakeLists.txt b/paddle/strings/CMakeLists.txt deleted file mode 100644 index 4e55eecd484c0e218ecd51bbd19b3eb4f6f92a25..0000000000000000000000000000000000000000 --- a/paddle/strings/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -cc_library(stringpiece SRCS stringpiece.cc) -cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags) diff --git a/paddle/strings/stringpiece.cc b/paddle/strings/stringpiece.cc deleted file mode 100644 index 415b3558d5dfffde26275bcb16ea3922424ca9f3..0000000000000000000000000000000000000000 --- a/paddle/strings/stringpiece.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "paddle/strings/stringpiece.h" - -#include - -#include -#include -#include - -namespace paddle { - -StringPiece::StringPiece() : data_(NULL), size_(0) {} - -StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) { - if (d == NULL && n != 0) - throw std::invalid_argument( - "StringPiece requires len to be 0 for NULL data"); -} - -StringPiece::StringPiece(const char* s) : data_(s) { - size_ = (s == NULL) ? 0 : strlen(s); -} - -StringPiece::StringPiece(const std::string& s) - : data_(s.data()), size_(s.size()) {} - -char StringPiece::operator[](size_t n) const { - if (n >= len()) - throw std::invalid_argument("index out of StringPiece length"); - return data_[n]; -} - -int Compare(StringPiece a, StringPiece b) { - const size_t min_len = (a.len() < b.len()) ? a.len() : b.len(); - int r = memcmp(a.data(), b.data(), min_len); - if (r == 0) { - if (a.len() < b.len()) - return -1; - else if (a.len() > b.len()) - return 1; - } - return r; -} - -bool operator==(StringPiece x, StringPiece y) { - return ((x.len() == y.len()) && - (x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0)); -} - -bool operator!=(StringPiece x, StringPiece y) { return !(x == y); } - -bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; } -bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; } - -bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; } -bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; } - -bool HasPrefix(StringPiece s, StringPiece x) { - return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0)); -} - -bool HasSuffix(StringPiece s, StringPiece x) { - return ((s.len() >= x.len()) && - (memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0)); -} - -StringPiece SkipPrefix(StringPiece s, size_t n) { - if (n > s.len()) - throw std::invalid_argument("Skip distance larger than StringPiece length"); - return StringPiece(s.data() + n, s.len() - n); -} - -StringPiece SkipSuffix(StringPiece s, size_t n) { - if (n > s.len()) - throw std::invalid_argument("Skip distance larger than StringPiece length"); - return StringPiece(s.data(), s.len() - n); -} - -StringPiece TrimPrefix(StringPiece s, StringPiece x) { - return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s; -} - -StringPiece TrimSuffix(StringPiece s, StringPiece x) { - return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s; -} - -bool Contains(StringPiece s, StringPiece sub) { - return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end(); -} - -size_t Index(StringPiece s, StringPiece sub) { - auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end()); - return e != s.end() ? e - s.data() : StringPiece::npos; -} - -size_t Find(StringPiece s, char c, size_t pos) { - if (pos >= s.len()) { - return StringPiece::npos; - } - const char* result = - reinterpret_cast(memchr(s.data() + pos, c, s.len() - pos)); - return result != nullptr ? result - s.data() : StringPiece::npos; -} - -size_t RFind(StringPiece s, char c, size_t pos) { - if (s.len() == 0) return StringPiece::npos; - for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data(); - p--) { - if (*p == c) { - return p - s.data(); - } - } - return StringPiece::npos; -} - -StringPiece SubStr(StringPiece s, size_t pos, size_t n) { - if (pos > s.len()) pos = s.len(); - if (n > s.len() - pos) n = s.len() - pos; - return StringPiece(s.data() + pos, n); -} - -std::ostream& operator<<(std::ostream& o, StringPiece piece) { - return o << piece.ToString(); -} - -} // namespace paddle diff --git a/paddle/utils/Error.h b/paddle/utils/Error.h index f3d535c69c53fa350612459560dd9ac7c279aa72..27ddaab3f003110a2684a871a2de17afb473d660 100644 --- a/paddle/utils/Error.h +++ b/paddle/utils/Error.h @@ -19,7 +19,21 @@ limitations under the License. */ #include #include #include -#include "paddle/platform/must_check.h" + +/** + * __must_check macro. It make the function's return value must be used, + * otherwise it will raise a compile warning. And also Paddle treat all compile + * warnings as errors. + */ +#ifdef __GNUC__ +#if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) >= 30400 +#define __must_check __attribute__((warn_unused_result)) +#else +#define __must_check +#endif +#else +#define __must_check +#endif namespace paddle { diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 1bf59ed4840ae69afc5bce49c86a08b60e9603ee..67154a8d7d366bd983b4426da87e0b33307fced4 100755 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1381,7 +1381,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(*[l.name for l in layers]) + Inputs(* [l.name for l in layers]) def outputs(layers, *args): @@ -1424,7 +1424,7 @@ def outputs(layers, *args): assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(*[l.name for l in layers]) + Outputs(* [l.name for l in layers]) return # just return outputs. if len(layers) != 1: diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 26252d5bbd77ddb70b4f03843679e4737e2f96d3..2e4beb6882789249db09705f3f4d6c5c19e492cd 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -25,8 +25,9 @@ import uci_housing import sentiment import wmt14 import mq2007 +import flowers __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007' + 'uci_housing', 'wmt14', 'mq2007', 'flowers' ] diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 07c13cf719ae0c864c23fef51f0bd7d47f265759..158cfe158c4f1c8d82d157301adcfbe0351c55df 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -13,18 +13,18 @@ # limitations under the License. """ This module will download dataset from -http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html +http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html and parse train/test set intopaddle reader creators. -This set contains images of flowers belonging to 102 different categories. +This set contains images of flowers belonging to 102 different categories. The images were acquired by searching the web and taking pictures. There are a minimum of 40 images for each category. The database was used in: Nilsback, M-E. and Zisserman, A. Automated flower classification over a large - number of classes.Proceedings of the Indian Conference on Computer Vision, -Graphics and Image Processing (2008) + number of classes.Proceedings of the Indian Conference on Computer Vision, +Graphics and Image Processing (2008) http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. """ @@ -34,9 +34,9 @@ from common import download import tarfile import scipy.io as scio from paddle.v2.image import * +from paddle.v2.reader import * import os import numpy as np -import paddle.v2 as paddle from multiprocessing import cpu_count __all__ = ['train', 'test', 'valid'] @@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' +# In official 'readme', tstid is the flag of test data +# and trnid is the flag of train data. But test data is more than train data. +# So we exchange the train data and test data. +TRAIN_FLAG = 'tstid' +TEST_FLAG = 'trnid' +VALID_FLAG = 'valid' def default_mapper(sample): @@ -53,8 +59,8 @@ def default_mapper(sample): map image bytes data to type needed by model input layer ''' img, label = sample - img = paddle.image.load_image_bytes(img) - img = paddle.image.simple_transform(img, 256, 224, True) + img = load_image_bytes(img) + img = simple_transform(img, 256, 224, True) return img.flatten().astype('float32'), label @@ -63,22 +69,23 @@ def reader_creator(data_file, setid_file, dataset_name, mapper=default_mapper, - buffered_size=1024): + buffered_size=1024, + use_xmap=True): ''' - 1. read images from tar file and + 1. read images from tar file and merge images into batch files in 102flowers.tgz_batch/ 2. get a reader to read sample from batch file - - :param data_file: downloaded data file + + :param data_file: downloaded data file :type data_file: string - :param label_file: downloaded label file + :param label_file: downloaded label file :type label_file: string :param setid_file: downloaded setid file containing information about how to split dataset :type setid_file: string :param dataset_name: data set name (tstid|trnid|valid) :type dataset_name: string - :param mapper: a function to map image bytes data to type + :param mapper: a function to map image bytes data to type needed by model input layer :type mapper: callable :param buffered_size: the size of buffer used to process images @@ -105,15 +112,17 @@ def reader_creator(data_file, for sample, label in itertools.izip(data, batch['label']): yield sample, int(label) - return paddle.reader.xmap_readers(mapper, reader, - cpu_count(), buffered_size) + if use_xmap: + return xmap_readers(mapper, reader, cpu_count(), buffered_size) + else: + return map_readers(mapper, reader) -def train(mapper=default_mapper, buffered_size=1024): +def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers training set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers training set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper, + buffered_size, use_xmap) -def test(mapper=default_mapper, buffered_size=1024): +def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers test set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers test set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper, + buffered_size, use_xmap) -def valid(mapper=default_mapper, buffered_size=1024): +def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' - Create flowers validation set reader. - It returns a reader, each sample in the reader is - image pixels in [0, 1] and label in [1, 102] + Create flowers validation set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] translated from original color image by steps: 1. resize to 256*256 2. random crop to 224*224 @@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, - buffered_size) + download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper, + buffered_size, use_xmap) def fetch(): diff --git a/python/paddle/v2/dataset/tests/flowers_test.py b/python/paddle/v2/dataset/tests/flowers_test.py index cc0626f4feae287d18dfb227cc69a4174da055da..a8ae9a07acc22eb9d3c0cc5ebb07f8f11ed21233 100644 --- a/python/paddle/v2/dataset/tests/flowers_test.py +++ b/python/paddle/v2/dataset/tests/flowers_test.py @@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): def test_train(self): instances, max_label_value = self.check_reader( paddle.v2.dataset.flowers.train()) - self.assertEqual(instances, 1020) + self.assertEqual(instances, 6149) self.assertEqual(max_label_value, 102) def test_test(self): instances, max_label_value = self.check_reader( paddle.v2.dataset.flowers.test()) - self.assertEqual(instances, 6149) + self.assertEqual(instances, 1020) self.assertEqual(max_label_value, 102) def test_valid(self): diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index ad20241b98302f136326ae491c6723a6c12ae284..bbaf8bfa979fbbf460561ebf7077b75b9c41a11a 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -51,7 +51,7 @@ class Parameters(object): def __init__(self): self.__param_conf__ = dict() self.__gradient_machines__ = [] - self.__tmp_params__ = [] + self.__tmp_params__ = dict() def __append_config__(self, param_conf): """ @@ -128,13 +128,10 @@ class Parameters(object): if len(self.__gradient_machines__) == 0: # create new parameter in python numpy. - if len(self.__tmp_params__) != 0: - ret_list = [ - mat for name, mat in self.__tmp_params__ if name == key - ] - if len(ret_list) == 1: - return ret_list[0] - return np.ndarray(shape=shape, dtype=np.float32) + if key in self.__tmp_params__: + return self.__tmp_params__[key] + else: + return np.ndarray(shape=shape, dtype=np.float32) else: for each_gradient_machine in self.__gradient_machines__: param = __get_parameter_in_gradient_machine__( @@ -187,7 +184,7 @@ class Parameters(object): (shape, value.shape)) if len(self.__gradient_machines__) == 0: - self.__tmp_params__.append((key, value)) + self.__tmp_params__[key] = value else: for each_gradient_machine in self.__gradient_machines__: __copy_parameter_to_gradient_machine__(each_gradient_machine, @@ -231,7 +228,7 @@ class Parameters(object): raise ValueError("gradient_machine should be api.GradientMachine") if len(self.__tmp_params__) != 0: - for name, val in self.__tmp_params__: + for name, val in self.__tmp_params__.iteritems(): try: __copy_parameter_to_gradient_machine__(gradient_machine, name, val) @@ -287,6 +284,18 @@ class Parameters(object): @staticmethod def from_tar(f): + """ + Create a `Parameters` object from the given file. And + the `Parameters` only contains the parameters in this + file. It is adapted the parameters are same in the + defined network and the given file. For example, it + can be used in the inference. + + :param f: the initialized model file. + :type f: tar file + :return: A Parameters object. + :rtype: Parameters. + """ params = Parameters() tar = tarfile.TarFile(fileobj=f, mode='r') for finfo in tar: @@ -302,6 +311,21 @@ class Parameters(object): params.deserialize(param_name, f) return params + def init_from_tar(self, f): + """ + Different from `from_tar`, this interface can be used to + init partial network parameters from another saved model. + + :param f: the initialized model file. + :type f: tar file + :return: Nothing. + """ + + tar_param = Parameters.from_tar(f) + for pname in tar_param.names(): + if pname in self.names(): + self.set(pname, tar_param.get(pname)) + def __get_parameter_in_gradient_machine__(gradient_machine, name): """ diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 07142056f872db5113acdd296b17c52b343c1be6..9f888b16d6b2fbf457ee4f4fe94fcb51b6f37fc9 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could be used in user program. """ -__all__ = ['np_array', 'text_file'] +__all__ = ['np_array', 'text_file', "recordio"] def np_array(x): @@ -55,3 +55,24 @@ def text_file(path): f.close() return reader + + +def recordio(path): + """ + Creates a data reader that outputs record one one by one from given recordio file + :path: path of recordio file + :returns: data reader of recordio file + """ + + import recordio as rec + + def reader(): + f = rec.reader(path) + while True: + r = f.read() + if r is None: + break + yield r + f.close() + + return reader diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index e432003129d2b8dea60138d08f13ec5e9d29a7ad..45a4288751e37b99dd1005ec78f30a98044926ff 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -166,12 +166,12 @@ def buffered(reader, size): The buffered data reader will read and save data entries into a buffer. Reading from the buffered data reader will proceed as long as the buffer is not empty. - + :param reader: the data reader to read from. :type reader: callable :param size: max buffer size. :type size: int - + :returns: the buffered data reader. """ @@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): :type mapper: callable :param reader: the data reader to read from :type reader: callable - :param process_num: process number to handle original sample + :param process_num: process number to handle original sample :type process_num: int :param buffer_size: max buffer size :type buffer_size: int @@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): :rtype: callable """ end = XmapEndSignal() - in_queue = Queue(buffer_size) - out_queue = Queue(buffer_size) - out_order = [0] # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): @@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): in_order += 1 in_queue.put(end) - # start a read worker in a thread - target = order_read_worker if order else read_worker - t = Thread(target=target, args=(reader, in_queue)) - t.daemon = True - t.start() - # define a worker to handle samples from in_queue by mapper # and put mapped samples into out_queue def handle_worker(in_queue, out_queue, mapper): @@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): in_queue.put(end) out_queue.put(end) - # start several handle_workers - target = order_handle_worker if order else handle_worker - args = (in_queue, out_queue, mapper, out_order) if order else ( - in_queue, out_queue, mapper) - workers = [] - for i in xrange(process_num): - worker = Thread(target=target, args=args) - worker.daemon = True - workers.append(worker) - for w in workers: - w.start() - def xreader(): + in_queue = Queue(buffer_size) + out_queue = Queue(buffer_size) + out_order = [0] + # start a read worker in a thread + target = order_read_worker if order else read_worker + t = Thread(target=target, args=(reader, in_queue)) + t.daemon = True + t.start() + # start several handle_workers + target = order_handle_worker if order else handle_worker + args = (in_queue, out_queue, mapper, out_order) if order else ( + in_queue, out_queue, mapper) + workers = [] + for i in xrange(process_num): + worker = Thread(target=target, args=args) + worker.daemon = True + workers.append(worker) + for w in workers: + w.start() + sample = out_queue.get() while not isinstance(sample, XmapEndSignal): yield sample diff --git a/python/paddle/v2/reader/tests/creator_test.py b/python/paddle/v2/reader/tests/creator_test.py index 359f3eeefbe8efeb343cc875c707c9251a7087fb..ba4f558874a0155d276fcb0e0d2d9258f0903f0e 100644 --- a/python/paddle/v2/reader/tests/creator_test.py +++ b/python/paddle/v2/reader/tests/creator_test.py @@ -34,5 +34,14 @@ class TestTextFile(unittest.TestCase): self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1)) +class TestRecordIO(unittest.TestCase): + def test_recordio(self): + path = os.path.join( + os.path.dirname(__file__), "test_recordio_creator.dat") + reader = paddle.v2.reader.creator.recordio(path) + for idx, r in enumerate(reader()): + self.assertSequenceEqual(r, str(idx)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/reader/tests/decorator_test.py b/python/paddle/v2/reader/tests/decorator_test.py index bb3c5d220b9ce1552d2fc429abb1863930cd4d17..5a92951b100fa51ab6df7039d9c6b54d1f9d963e 100644 --- a/python/paddle/v2/reader/tests/decorator_test.py +++ b/python/paddle/v2/reader/tests/decorator_test.py @@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase): for order in orders: for tNum in thread_nums: for size in buffered_size: - result = [] - for i in paddle.v2.reader.xmap_readers(mapper, + reader = paddle.v2.reader.xmap_readers(mapper, reader_creator_10(0), - tNum, size, order)(): - result.append(i) - if not order: - result.sort() - for idx, e in enumerate(result): - self.assertEqual(e, mapper(idx)) + tNum, size, order) + for n in xrange(3): + result = [] + for i in reader(): + result.append(i) + if not order: + result.sort() + for idx, e in enumerate(result): + self.assertEqual(e, mapper(idx)) if __name__ == '__main__': diff --git a/python/paddle/v2/reader/tests/test_recordio_creator.dat b/python/paddle/v2/reader/tests/test_recordio_creator.dat new file mode 100644 index 0000000000000000000000000000000000000000..17aa89b6796184407e83246d3f342a55a66b4a69 Binary files /dev/null and b/python/paddle/v2/reader/tests/test_recordio_creator.dat differ diff --git a/python/paddle/v2/tests/test_parameters.py b/python/paddle/v2/tests/test_parameters.py index 45372e7dd0ec7cbdd6a2eb5c0397ef7e74284cd0..7ba8a939fbd1a949d61a007b40c054e7543c0cbc 100644 --- a/python/paddle/v2/tests/test_parameters.py +++ b/python/paddle/v2/tests/test_parameters.py @@ -20,14 +20,17 @@ import cStringIO import numpy -def __rand_param_config__(name): +def __rand_param_config__(name, psize=None): conf = ParameterConfig() conf.name = name size = 1 - for i in xrange(2): - dim = random.randint(1, 1000) - conf.dims.append(dim) - size *= dim + if psize is None: + for i in xrange(2): + dim = random.randint(1, 1000) + conf.dims.append(dim) + size *= dim + else: + size = psize conf.size = size assert conf.IsInitialized() return conf @@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase): expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32) assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6)) + def test_init_from_tar(self): + def get_param(names, size): + p = parameters.Parameters() + for k, v in zip(names, size): + p.__append_config__(__rand_param_config__(k, v)) + for name in p.names(): + param = p.get(name) + param[:] = numpy.random.uniform( + -1.0, 1.0, size=p.get_shape(name)) + p.set(name, param) + return p + + def get_parames(): + name1 = ['param_0', 'param_1'] + size1 = [128, 256] + p1 = get_param(name1, size1) + file1 = cStringIO.StringIO() + p1.to_tar(file1) + file1.seek(0) + + name2 = ['param_0', 'param_1', 'param_2'] + size2 = [128, 256, 288] + p2 = get_param(name2, size2) + file2 = cStringIO.StringIO() + p2.to_tar(file2) + file2.seek(0) + return p1, file1, p2, file2 + + p1, file1, p2, file2 = get_parames() + p2.init_from_tar(file1) + for name in p1.names(): + self.assertEqual(p1.get_shape(name), p2.get_shape(name)) + v1 = p1.get(name) + v2 = p2.get(name) + self.assertTrue(numpy.isclose(v1, v2).all()) + + p1, file1, p2, file2 = get_parames() + p1.init_from_tar(file2) + for name in p1.names(): + self.assertEqual(p1.get_shape(name), p2.get_shape(name)) + v1 = p1.get(name) + v2 = p2.get(name) + self.assertTrue(numpy.isclose(v1, v2).all()) + if __name__ == '__main__': unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index 86fc0fc5c0318b03659bf84f8ad9e2a114467c74..aa6771709cad0bb4dd4ce39c81de7e6ab1ad4c73 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -15,7 +15,8 @@ setup_requires=["requests", "protobuf==3.1", "recordio", "matplotlib", - "rarfile"] + "rarfile", + "scipy>=0.19.0"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: setup_requires+=["opencv-python"]